代码语言:javascript复制
tf.gather(
params,
indices,
validate_indices=None,
name=None,
axis=0
)
根据索引从params坐标轴中收集切片。标必须是任何维度(通常是0-D或1-D)的整数张量。产生一个带有形状参数的输出张量,其中: params.shape[:axis] indices.shape params.shape[axis 1:]。
# Scalar indices (output is rank(params) - 1).
output[a_0, ..., a_n, b_0, ..., b_n] =
params[a_0, ..., a_n, indices, b_0, ..., b_n]
# Vector indices (output is rank(params)).
output[a_0, ..., a_n, i, b_0, ..., b_n] =
params[a_0, ..., a_n, indices[i], b_0, ..., b_n]
# Higher rank indices (output is rank(params) rank(indices) - 1).
output[a_0, ..., a_n, i, ..., j, b_0, ... b_n] =
params[a_0, ..., a_n, indices[i, ..., j], b_0, ..., b_n]
注意,在CPU上,如果发现一个out of bound索引,将返回一个错误。在GPU上,如果发现一个out of bound索引,则在相应的输出值中存储一个0。
参数:
params
: 一个张量。用来收集值的张量。必须至少是秩轴 1。indices
: 一个张量。必须是下列类型之一:int32、int64。指数张量。必须在range [0, params.shape[axis]]中。axis
: 张量。必须是下列类型之一:int32、int64。以参数为单位的轴,用来收集指标。默认为第一个维度。支持负索引。- name: 操作的名称(可选)。
返回值:
- 一个张量。具有与params相同的类型。
原链接: https://tensorflow.google.cn/versions/r1.9/api_docs/python/tf/gather?hl=en