tf.gather

2022-09-04 21:03:12 浏览数 (1)

代码语言:javascript复制
tf.gather(
    params,
    indices,
    validate_indices=None,
    name=None,
    axis=0
)

根据索引从params坐标轴中收集切片。标必须是任何维度(通常是0-D或1-D)的整数张量。产生一个带有形状参数的输出张量,其中: params.shape[:axis] indices.shape params.shape[axis 1:]。

代码语言:javascript复制
# Scalar indices (output is rank(params) - 1).
output[a_0, ..., a_n, b_0, ..., b_n] =
   params[a_0, ..., a_n, indices, b_0, ..., b_n]

# Vector indices (output is rank(params)).
output[a_0, ..., a_n, i, b_0, ..., b_n] =
   params[a_0, ..., a_n, indices[i], b_0, ..., b_n]

# Higher rank indices (output is rank(params)   rank(indices) - 1).
output[a_0, ..., a_n, i, ..., j, b_0, ... b_n] =
   params[a_0, ..., a_n, indices[i, ..., j], b_0, ..., b_n]

注意,在CPU上,如果发现一个out of bound索引,将返回一个错误。在GPU上,如果发现一个out of bound索引,则在相应的输出值中存储一个0。

参数:

  • params: 一个张量。用来收集值的张量。必须至少是秩轴 1。
  • indices: 一个张量。必须是下列类型之一:int32、int64。指数张量。必须在range [0, params.shape[axis]]中。
  • axis: 张量。必须是下列类型之一:int32、int64。以参数为单位的轴,用来收集指标。默认为第一个维度。支持负索引。
  • name: 操作的名称(可选)。

返回值:

  • 一个张量。具有与params相同的类型。

原链接: https://tensorflow.google.cn/versions/r1.9/api_docs/python/tf/gather?hl=en

0 人点赞