tf 重排 切分 逆排列
目录
tf 重排 切分
gather按索引提取:
invert_permutation 逆排列
tf 重排 切分
import tensorflow as tf# 模拟原始数据:6个样本,每个样本3维特征
x = tf.constant([[10, 11, 12], # 0[20, 21, 22], # 1[30, 31, 32], # 2[40, 41, 42], # 3[50, 51, 52], # 4[60, 61, 62] # 5
])# permutation:重排索引(如打乱顺序)
permutation = [3, 0, 1, 5, 4, 2]# sizes:切分数量,表示每一段的大小
sizes = [2, 1, 3] # 一共还是6个样本# 执行
result = tf.split(tf.gather(x, permutation, axis=0), sizes)# 打印结果
for i, part in enumerate(result):print(f"Part {i}:\n{part.numpy()}\n")
gather
按索引提取:
tf.gather(params, indices, axis=0)
的作用是:
-
从张量
params
中,按照indices
给出的索引,在指定的维度(默认为第0维)上提取数据。
invert_permutation 逆排列
tf.math.invert_permutation 是 TensorFlow 中的一个函数,用来求一个排列的“逆排列”。