热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

使用Python的numpy实现随机梯度下降

StochasticgradientdescentimplementationwithPythonsnumpy


Stochastic gradient descent implementation with Python's numpy


我必须使用 python numpy 库来实现随机梯度下降。为此,我给出了以下函数定义:




1
2
3
4
5
6

def compute_stoch_gradient(y, tx, w):
   """Compute a stochastic gradient for batch data."""
def stochastic_gradient_descent(
        y, tx, initial_w, batch_size, max_epochs, gamma):
   """Stochastic gradient descent algorithm."""

我还获得了以下帮助功能:




1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24

def batch_iter(y, tx, batch_size, num_batches=1, shuffle=True):
   """
    Generate a minibatch iterator for a dataset.
    Takes as input two iterables (here the output desired values 'y' and the input data 'tx')
    Outputs an iterator which gives mini-batches of `batch_size` matching elements from `y` and `tx`.
    Data can be randomly shuffled to avoid ordering in the original data messing with the randomness of the minibatches.
    Example of use :
    for minibatch_y, minibatch_tx in batch_iter(y, tx, 32):
       
   """

    data_size = len(y)
    if shuffle:
        shuffle_indices = np.random.permutation(np.arange(data_size))
        shuffled_y = y[shuffle_indices]
        shuffled_tx = tx[shuffle_indices]
    else:
        shuffled_y = y
        shuffled_tx = tx
    for batch_num in range(num_batches):
        start_index = batch_num * batch_size
        end_index = min((batch_num + 1) * batch_size, data_size)
        if start_index != end_index:
            yield shuffled_y[start_index:end_index], shuffled_tx[start_index:end_index]

我实现了以下两个功能:




1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19

def compute_stoch_gradient(y, tx, w):
   """Compute a stochastic gradient for batch data."""
    e = y - tx.dot(w)
    return (-1/y.shape[0])*tx.transpose().dot(e)
def stochastic_gradient_descent(y, tx, initial_w, batch_size, max_epochs, gamma):
   """Stochastic gradient descent algorithm."""
    ws = [initial_w]
    losses = []
    w = initial_w
    for n_iter in range(max_epochs):
        for minibatch_y,minibatch_x in batch_iter(y,tx,batch_size):
            w = ws[n_iter] - gamma * compute_stoch_gradient(minibatch_y,minibatch_x,ws[n_iter])
            ws.append(np.copy(w))
            loss = y - tx.dot(w)
            losses.append(loss)
    return losses, ws

我不确定迭代应该在 range(max_epochs) 还是更大的范围内完成。我这样说是因为我读到一个纪元是"每次我们遍历整个数据集"。所以我认为一个时代包含多个迭代......



相关讨论



  • 对于第二个问题:阅读关于 sgd 的批处理、小批量和 epoch。


  • 您在内部循环中调用 batch_iter ,每次调用时都会实例化一个新的生成器对象。相反,您想在循环之外实例化一个生成器,然后对其进行迭代,例如for minibatch_y, minibatch_x in batch_iter(...)





在典型的实现中,批量大小为 B 的小批量梯度下降应该从数据集中随机选择 B 个数据点,并根据该子集上计算的梯度更新权重。这个过程本身将持续很多次,直到收敛或某个阈值最大迭代。 B=1 的 Mini-batch 是 SGD,有时会很吵。

除了上述评论之外,您可能还想尝试一下批量大小和学习率(步长),因为它们对随机和小批量梯度下降的收敛速度有显着影响。

下图显示了在对亚马逊产品评论数据集进行情感分析时,这两个参数对 SGDlogistic regression 的收敛速度的影响,该作业出现在机器学习 - 分类的课程中华盛顿大学:

enter

enter

有关这方面的更多详细信息,您可以参考 https://sandipanweb.wordpress.com/2017/03/31/online-learning-sentiment-analysis-with-logistic-regression-via-stochastic-gradient-ascent /?frame-nOnce=987e584e16





推荐阅读
  • 本文介绍了绕过WAF的XSS检测机制的方法,包括确定payload结构、测试和混淆。同时提出了一种构建XSS payload的方法,该payload与安全机制使用的正则表达式不匹配。通过清理用户输入、转义输出、使用文档对象模型(DOM)接收器和源、实施适当的跨域资源共享(CORS)策略和其他安全策略,可以有效阻止XSS漏洞。但是,WAF或自定义过滤器仍然被广泛使用来增加安全性。本文的方法可以绕过这种安全机制,构建与正则表达式不匹配的XSS payload。 ... [详细]
  • 开源Keras Faster RCNN模型介绍及代码结构解析
    本文介绍了开源Keras Faster RCNN模型的环境需求和代码结构,包括FasterRCNN源码解析、RPN与classifier定义、data_generators.py文件的功能以及损失计算。同时提供了该模型的开源地址和安装所需的库。 ... [详细]
  • Python爬虫中使用正则表达式的方法和注意事项
    本文介绍了在Python爬虫中使用正则表达式的方法和注意事项。首先解释了爬虫的四个主要步骤,并强调了正则表达式在数据处理中的重要性。然后详细介绍了正则表达式的概念和用法,包括检索、替换和过滤文本的功能。同时提到了re模块是Python内置的用于处理正则表达式的模块,并给出了使用正则表达式时需要注意的特殊字符转义和原始字符串的用法。通过本文的学习,读者可以掌握在Python爬虫中使用正则表达式的技巧和方法。 ... [详细]
  • 本文介绍了RPC框架Thrift的安装环境变量配置与第一个实例,讲解了RPC的概念以及如何解决跨语言、c++客户端、web服务端、远程调用等需求。Thrift开发方便上手快,性能和稳定性也不错,适合初学者学习和使用。 ... [详细]
  • 背景应用安全领域,各类攻击长久以来都危害着互联网上的应用,在web应用安全风险中,各类注入、跨站等攻击仍然占据着较前的位置。WAF(Web应用防火墙)正是为防御和阻断这类攻击而存在 ... [详细]
  • 使用圣杯布局模式实现网站首页的内容布局
    本文介绍了使用圣杯布局模式实现网站首页的内容布局的方法,包括HTML部分代码和实例。同时还提供了公司新闻、最新产品、关于我们、联系我们等页面的布局示例。商品展示区包括了车里子和农家生态土鸡蛋等产品的价格信息。 ... [详细]
  • Python教学练习二Python1-12练习二一、判断季节用户输入月份,判断这个月是哪个季节?3,4,5月----春 ... [详细]
  • 在本教程中,我们将看到如何使用FLASK制作第一个用于机器学习模型的RESTAPI。我们将从创建机器学习模型开始。然后,我们将看到使用Flask创建AP ... [详细]
  • 生成式对抗网络模型综述摘要生成式对抗网络模型(GAN)是基于深度学习的一种强大的生成模型,可以应用于计算机视觉、自然语言处理、半监督学习等重要领域。生成式对抗网络 ... [详细]
  • 本文介绍了九度OnlineJudge中的1002题目“Grading”的解决方法。该题目要求设计一个公平的评分过程,将每个考题分配给3个独立的专家,如果他们的评分不一致,则需要请一位裁判做出最终决定。文章详细描述了评分规则,并给出了解决该问题的程序。 ... [详细]
  • 本文介绍了Perl的测试框架Test::Base,它是一个数据驱动的测试框架,可以自动进行单元测试,省去手工编写测试程序的麻烦。与Test::More完全兼容,使用方法简单。以plural函数为例,展示了Test::Base的使用方法。 ... [详细]
  • 不同优化算法的比较分析及实验验证
    本文介绍了神经网络优化中常用的优化方法,包括学习率调整和梯度估计修正,并通过实验验证了不同优化算法的效果。实验结果表明,Adam算法在综合考虑学习率调整和梯度估计修正方面表现较好。该研究对于优化神经网络的训练过程具有指导意义。 ... [详细]
  • CF:3D City Model(小思维)问题解析和代码实现
    本文通过解析CF:3D City Model问题,介绍了问题的背景和要求,并给出了相应的代码实现。该问题涉及到在一个矩形的网格上建造城市的情景,每个网格单元可以作为建筑的基础,建筑由多个立方体叠加而成。文章详细讲解了问题的解决思路,并给出了相应的代码实现供读者参考。 ... [详细]
  • 本文介绍了在Python张量流中使用make_merged_spec()方法合并设备规格对象的方法和语法,以及参数和返回值的说明,并提供了一个示例代码。 ... [详细]
  • cs231n Lecture 3 线性分类笔记(一)
    内容列表线性分类器简介线性评分函数阐明线性分类器损失函数多类SVMSoftmax分类器SVM和Softmax的比较基于Web的可交互线性分类器原型小结注:中文翻译 ... [详细]
author-avatar
精神还没分裂2011
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有