"""
记:
logits: 为网络最后一层的输出,无激活函数
p = softmax(logits): 概率
log_p = log_softmax(logits) = log(p): log概率
结论:
"""
import tensorflow as tf
import numpy as nptf.random.set_random_seed(1111)
logits = tf.constant([[5., 3, 2]])
p = tf.nn.softmax(logits, axis=1)
log_p = tf.nn.log_softmax(logits, axis=1)
sample = tf.distributions.Categorical(probs=p).sample(sample_shape=(10000, )) sess = tf.Session()
s = sess.run(sample)cnt = np.zeros((3, ))
for i in s:cnt[i] += 1print(cnt)
print(sess.run(p))