最近在学习tensorflow2.0的时候看到一些特别好用的高级函数,这里来记录一下它们的用法
1.tf.gather()
tf.gather(params,indices,validate_indices=None,name=None,axis=0)简单的理解一下,首先传入一个需要处理的张量,然后传入对他的选择操作,也就是一个索引张量。
下面举个例子:
考虑班级成绩册的例子,共有 4 个班级,每个班级 35 个学生,8 门科目,保存成绩册的张量 shape 为[4,35,8]。
#创建成绩册 record=tf.random.uniform([4,35,8],maxval=100) record.numpy 复制代码
如果现在需要收集第 1,2 两个班级的成绩册,我们可以通过切片操作
record1_2=record[0:2] record1_2.numpy
也可以使用tf.gather()得到一样的结果
#从第一个维度(班级)选择前两个班级 record1_2=tf.gather(record,[0,1],axis=0) record1_2.numpy 复制代码
但是换个要求,需要抽查所有班级的第 1,4,9,12,13,27 号同学的成绩,这时候用切片就不好得到结果了,用gather还是很容易的
#从第二个维度(学生)抽取 score=tf.gather(record,[0,3,8,11,12,26],axis=1) score.numpy 复制代码
2.tf.gather_nd()
通过 tf.gather_nd(),可以通过指定每次采样的坐标来实现采样多个点的目的 例子:得到班级 1,学生 1 的科目 2;班级 2,学生 2 的科目 3;班级 3,学生 3 的科目 4 的成绩
score=tf.gather_nd(record,[[0,0,1],[1,1,2],[2,2,3]]) score.numpy
3.tf.scatter_nd()
通过 tf.scatter_nd(indices, updates, shape)可以高效地刷新张量的部分数据,但是只能在全 0 张量的白板上面刷新,因此可能需要结合其他操作来实现现有张量的数据刷新功能。
#需要刷新的位置 indices = tf.constant([[4], [3], [1], [7]]) # 构造需要写入的数据 updates = tf.constant([4.4, 3.3, 1.1, 7.7]) # 在长度为 8 的全 0 向量上根据 indices 写入 updates tf.scatter_nd(indices, updates, [8])
4.tf.meshgrid()
通过 tf.meshgrid 可以方便地生成二维网格采样点坐标,或者可以理解成为了满足矩阵相乘,把x按行重复y的列次,y按列重复x的行次(广播机制)
例子:实现
import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D plt.rcParams['axes.unicode_minus']=False x = tf.linspace(-8.,8,100) # 设置 x 坐标的间隔 y = tf.linspace(-8.,8,100) # 设置 y 坐标的间隔 x,y = tf.meshgrid(x,y) # 生成网格点,并拆分后返回 print(x.shape,y.shape) # 打印拆分后的所有点的 x,y 坐标张量 shape z = tf.sqrt(x**2+y**2) z = tf.sin(z)/z # sinc 函数实现 fig = plt.figure() ax = Axes3D(fig) # 根据网格点绘制 sinc 函数 3D 曲面 ax.contour3D(x.numpy(), y.numpy(), z.numpy(), 50) plt.show()
或者来个简单的例子更能体现它的变换
x=tf.constant([1,2,3]) y=tf.constant([3,4,5]) x,y = tf.meshgrid(x,y) print(x.numpy,y.numpy) 复制代码
这样meshgrid的作用就一目了然了