tf.split()函数的用法

时间:2023-02-06 16:58:42


from PIL import  Image
import numpy as np
import tensorflow as tf

'''
split 对维度进行分割
tf.split(
data, 数据图片 ( 300*600*3)
num_or_size_splits , 分割的数组 传个数
axis, 代表维度,当前的维度为 0 1 2
)

tf.split(data,3,2) 得到数据的维度为 [(300,600,1),(300,600,1),(300,600,1)]
tf.split(data,[100.200,300],1) 得到数据的维度为 [(300,100,3),(300,100,3),(300,200,3)]

'''


img = Image.open('./test_data/tabby_cat.png')
img = np.array(img)

a = tf.split(img,3,2)


with tf.Session() as sess:
sess.run(tf.global_variables_initializer())

a1 = sess.run([a ])

# for i in range(len(a1[0])):
# print(a1[i].shape )

for i in range(len(a1[0])):
# print(type(a1[0][i]))

print(a1[0][i].shape)













# print(img.shape)