tf.one_hot() 函数是将input转化为one-hot类型数据输出
如果我们有一个5类分类问题,我们有数据 (Xi,Yi)(X_i,Y_i)(Xi,Yi),其中类别YiY_iYi有5种取值(因为是5分类问题),所以如果YjY_jYj为第1类那么其独热编码为: [1,0,0,0,0][1,0,0,0,0][1,0,0,0,0],如果是第2类那么独热编码为:[0,1,0,0,0][0,1,0,0,0][0,1,0,0,0],也就是说只对存在有该类别数的位置上进行标记为1,其他皆为0。
其定义如下:
one_hot
(indices,depth,on_value=None,off_value=None,axis=None,dtype=None,name=None
)
参数说明:
indices - 输入的多个数值,通常是一维矩阵形式。
depth - 输出张量的尺寸,indices中元素默认不超过(depth-1),如果超过,输出为[0,0,···,0]
on_value - 定义在 indices[j] = i 时填充输出的值的标量.(默认:1)
off_value - 定义在 indices[j] != i 时填充输出的值的标量.(默认:0)
axis - 要填充的轴(默认:-1,一个新的最内层轴).
dtype - 输出张量的数据类型.
import tensorflow as tfvar0 = tf.one_hot(indices=[1, 2, 3], depth=3, axis=0)
var1 = tf.one_hot(indices=[1, 2, 3], depth=4, axis=0)
var2 = tf.one_hot(indices=[1, 2, 3], depth=4, axis=1)
var3 = tf.one_hot(indices=[1, 2, 3], depth=4, axis=-1)print("var0(axis=0 depth=3):\n",var0)
print("var1(axis=0 depth=4P):\n",var1)
print("var2(axis=1):\n",var2)
print("var3(axis=-1):\n",var3)