热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

多元高斯分布张量流概率的混合

正如标题中所述,我正在尝试使用张量流概率包来创建多元正态分布的混合。

正如标题中所述,我正在尝试使用张量流概率包来创建多元正态分布的混合。

在我的原始项目中,我从神经网络的输出中获取类别,位置和方差的权重。但是,在创建图形时,出现以下错误:


  

components [0]批处理形状必须与cat形状和其他组件批处理形状兼容

我使用占位符重新创建了相同的问题:

import tensorflow as tf
import tensorflow_probability as tfp # dist= tfp.distributions
tf.compat.v1.disable_eager_execution()
sess = tf.compat.v1.InteractiveSession()
l1 = tf.compat.v1.placeholder(dtype=tf.float32,shape=[None,2],name='observations_1')
l2 = tf.compat.v1.placeholder(dtype=tf.float32,name='observations_2')
log_std = tf.compat.v1.get_variable('log_std',[1,dtype=tf.float32,initializer=tf.constant_initializer(1.0),trainable=True)
mix = tf.compat.v1.placeholder(dtype=tf.float32,1],name='weights')
cat = tfp.distributions.Categorical(probs=[mix,1.-mix])
compOnents= [
tfp.distributions.MultivariateNormalDiag(loc=l1,scale_diag=tf.exp(log_std)),tfp.distributions.MultivariateNormalDiag(loc=l2,]
bimix_gauss = tfp.distributions.Mixture(
cat=cat,compOnents=components)

所以,我的问题是,我做错了什么?我调查了错误,似乎是tensorshape_util.is_compatible_with引发了错误,但我不明白为什么。

谢谢!


似乎您为tfp.distributions.Categorical提供了错误的输入。参数probs的形状应为[batch_size,cat_size],而您提供的参数应为[cat_size,batch_size,1]。因此,也许尝试使用probs参数化tf.concat([mix,1-mix],1)

您的log_std可能也有问题,其形状与l1l2不同。如果MultivariateNormalDiag不能正确广播,请尝试将其形状指定为(None,2)或对其进行平铺,以使其第一维与您的位置参数相对应。

,

当组件是相同类型时,MixtureSameFamily应该性能更高。

仅传递一个分类实例(带有.batch_shape [b1,b2,...,bn])和单个MVNDiag实例(带有.batch_shape [b1,b2,...,bn,numcats])

对于只有两节课,我想知道伯努利是否会上课?


推荐阅读
author-avatar
众神痴梦_325
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有