热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

tensorflow2.1中独热编码函数tf.one_hot()的用法

tf.one_hot()函数是将input转化为one-hot类型数据输出如果我们有一个5类分类问题,我们有数据(Xi,Yi)(X_i,Y_i)(Xi​,Yi​)&

tf.one_hot() 函数是将input转化为one-hot类型数据输出

如果我们有一个5类分类问题,我们有数据 (Xi,Yi)(X_i,Y_i)(Xi,Yi),其中类别YiY_iYi有5种取值(因为是5分类问题),所以如果YjY_jYj为第1类那么其独热编码为: [1,0,0,0,0][1,0,0,0,0][1,0,0,0,0],如果是第2类那么独热编码为:[0,1,0,0,0][0,1,0,0,0][0,1,0,0,0],也就是说只对存在有该类别数的位置上进行标记为1,其他皆为0。

其定义如下:

one_hot
(indices,#输入,这里是一维的depth,# one hot dimension.on_value=None,#output 默认1off_value=None,#output 默认0axis=None,dtype=None,name=None
)

参数说明:

indices - 输入的多个数值,通常是一维矩阵形式。

depth - 输出张量的尺寸,indices中元素默认不超过(depth-1),如果超过,输出为[0,0,···,0]

on_value - 定义在 indices[j] = i 时填充输出的值的标量.(默认:1)

off_value - 定义在 indices[j] != i 时填充输出的值的标量.(默认:0)

axis - 要填充的轴(默认:-1,一个新的最内层轴).

dtype - 输出张量的数据类型.

import tensorflow as tfvar0 = tf.one_hot(indices=[1, 2, 3], depth=3, axis=0)
var1 = tf.one_hot(indices=[1, 2, 3], depth=4, axis=0)
var2 = tf.one_hot(indices=[1, 2, 3], depth=4, axis=1)
# axis=1 按行排
var3 = tf.one_hot(indices=[1, 2, 3], depth=4, axis=-1)print("var0(axis=0 depth=3):\n",var0)
print("var1(axis=0 depth=4P):\n",var1)
print("var2(axis=1):\n",var2)
print("var3(axis=-1):\n",var3)
# var0(axis=0 depth=3):
# tf.Tensor(
# [[0. 0. 0.]
# [1. 0. 0.]
# [0. 1. 0.]], shape=(3, 3), dtype=float32)
# var1(axis=0 depth=4P):
# tf.Tensor(
# [[0. 0. 0.]
# [1. 0. 0.]
# [0. 1. 0.]
# [0. 0. 1.]], shape=(4, 3), dtype=float32)
# var2(axis=1):
# tf.Tensor(
# [[0. 1. 0. 0.]
# [0. 0. 1. 0.]
# [0. 0. 0. 1.]], shape=(3, 4), dtype=float32)
# var3(axis=-1):
# tf.Tensor(
# [[0. 1. 0. 0.]
# [0. 0. 1. 0.]
# [0. 0. 0. 1.]], shape=(3, 4), dtype=float32)


推荐阅读
author-avatar
cc_vb8
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有