热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

theano——shared,function(outputs,updates,givens)

theano下的functions(fromtheanoimportfunction)的两个重要性质:多个outputs;多

  • theano 下的 functions (from theano import function)的两个重要性质:
    • 多个 outputs;
    • 多个 updates;

1. shared


  • shared 与 theano.tensor.matrix()/vector() 的区别在于,tensor.matrix/vector 符号变量不需要初始化,而 shared 变量则需要初始化赋值;

>>> import theano
>>> import theano.tensor as T >>> state = theano.shared(0)
>>> type(state)
theano.tensor.sharedvar.ScalarSharedVariable
>>> state.get_value(borrow=True)
array(0)# shared 型变量,通过 get_value 取出其值;>>> x = T.iscalar('x')
>>> type(x)
theano.tensor.var.TensorVariable# tensor value(thenao.tensor模块下的变量?)是没有get_value成员函数的>>> type(x+state)
theano.tensor.var.TensorVariable

2. updates

>>> counter = function([x], state, updates=[(state, state+x)])
>>> counter(5)
>>> state.get_value(borrow=True)
5
>>> counter(6)
>>> state.get_value(borrow=True)
11

不使用可选参数updates:

>>> state += x# 此时state的类型已自动转换为tensor value,不再具有get_value的成员
>>> counter = theano.function([x], state)
>>> counter(5)
array(5)
>>> state.get_value()
AttributeError: 'TensorVariable' object has no attribute 'get_value'

theano.function 的 updates, 二元 tuple 构成的 list,

[(a, a+1), (b, b+1)]

其含义其实为,a = a+1, b = b+1;

所以常见与,梯度下降算法中的权值的更新,如下:

updates = [(W, W-eta*delta_W),(b, b-eta*delta_b)]

3. givens

test_model = theano.function(inputs=[idx],outputs=layer3.error(y),givens={x: test_set_x[idx*batch_size:(idx+1)*batch_size],y: test_set_y[idx*batch_size:(idx+1)*batch_size]}
)
valid_model = theano.function(inputs=[idx],outputs=layer3.error(y),givens={x: valid_set_x[idx*batch_size:(idx+1)*batch_size],y: valid_set_y[idx*batch_size:(idx+1)*batch_size]}
)

推荐阅读
author-avatar
东隅海纳堂_684
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有