tf.split的作用就是将张量按指定维度和分割个数拆分为子张量列表。
源码解释
tf.split(value, num_or_size_splits, axis=0, num=None, name='split'
)参数:value: 张量
num_or_size_splits:int或者一维张量或者list,将根据axis所指向的维度进行切分数据。可以理解为切这个张量的分割量(注意不是比例)。当是一个整数时,表示按该值均分,比如2,则一分为二,当是列表时,则按列表的元素值来分。比如假如原维数值为10,那么[2,8]则表示切分为2和8.这要看例子就一目了然
axis: 切分的维度,即根据哪个维度来把这个张量给大卸八块。
num: 这个不用管没什么用
name: 该操作的名称
import tensorflow as tfa = tf.random.normal(shape=(10,10,1))
t1,t2,t3 = tf.split(a,num_or_size_splits=[2,3,5],axis=0)
print(t1.shape)
print(t2.shape)
print(t3.shape)
m,n = tf.split(a,num_or_size_splits=2,axis=1)
print(m.shape)
print(n.shape)
x,y,z = tf.split(a,num_or_size_splits=[2,3,5],axis=1)
print(x.shape)
print(y.shape)
print(z.shape)
注意:
num_or_size_splits切分值,无论是均分还是按列表个数分,都必须完完全全被分割完整,不能有小数或者多余维度剩下。
比如:第二维度均分三份,报错!
原因第二维是10的长度,不能被3整除
r1,r2,r3 = tf.split(a,num_or_size_splits=3,axis=1)
再比如:
第二维按[2,2,2]的个数来分配,报错!
原因第二维是10的长度,按[2,2,2]分配后还剩4维没有分配,同样按[2,8,2]来分不用想了肯定报错,本就10个怎么还能分出多2个。
u1,u2,u3 = tf.split(a,num_or_size_splits=[2,2,2],axis=1)
总结:
tf.split(value, num_or_size_splits, axis=0, num=None, name=‘split’)
value: 为原张量
num_or_size_splits:为分割量,如果为单个整数,则表示按该整数均分,如果是list则按元素多少分割张量。分割后的个部分数量必须能完完全全把原维度的数量分割干净,不能有小数或多余更不能少。
axis:指定张量维度,分割需要依据此维度来进行分割
num:不用管,没多大用
name: 不用管,操作名称