作者:众神痴梦_325 | 来源:互联网 | 2023-09-02 11:55
正如标题中所述,我正在尝试使用张量流概率包来创建多元正态分布的混合。
在我的原始项目中,我从神经网络的输出中获取类别,位置和方差的权重。但是,在创建图形时,出现以下错误:
components [0]批处理形状必须与cat形状和其他组件批处理形状兼容
我使用占位符重新创建了相同的问题:
import tensorflow as tf
import tensorflow_probability as tfp # dist= tfp.distributions
tf.compat.v1.disable_eager_execution()
sess = tf.compat.v1.InteractiveSession()
l1 = tf.compat.v1.placeholder(dtype=tf.float32,shape=[None,2],name='observations_1')
l2 = tf.compat.v1.placeholder(dtype=tf.float32,name='observations_2')
log_std = tf.compat.v1.get_variable('log_std',[1,dtype=tf.float32,initializer=tf.constant_initializer(1.0),trainable=True)
mix = tf.compat.v1.placeholder(dtype=tf.float32,1],name='weights')
cat = tfp.distributions.Categorical(probs=[mix,1.-mix])
compOnents= [
tfp.distributions.MultivariateNormalDiag(loc=l1,scale_diag=tf.exp(log_std)),tfp.distributions.MultivariateNormalDiag(loc=l2,]
bimix_gauss = tfp.distributions.Mixture(
cat=cat,compOnents=components)
所以,我的问题是,我做错了什么?我调查了错误,似乎是tensorshape_util.is_compatible_with
引发了错误,但我不明白为什么。
谢谢!
似乎您为tfp.distributions.Categorical
提供了错误的输入。参数probs
的形状应为[batch_size,cat_size]
,而您提供的参数应为[cat_size,batch_size,1]
。因此,也许尝试使用probs
参数化tf.concat([mix,1-mix],1)
。
您的log_std
可能也有问题,其形状与l1
和l2
不同。如果MultivariateNormalDiag
不能正确广播,请尝试将其形状指定为(None,2)
或对其进行平铺,以使其第一维与您的位置参数相对应。
,
当组件是相同类型时,MixtureSameFamily应该性能更高。
仅传递一个分类实例(带有.batch_shape [b1,b2,...,bn])和单个MVNDiag实例(带有.batch_shape [b1,b2,...,bn,numcats])
对于只有两节课,我想知道伯努利是否会上课?