tf.where

2022-09-04 20:58:20 浏览数 (1)

代码语言:javascript复制
tf.where(
    condition,
    x=None,
    y=None,
    name=None
)

根据条件返回元素(x或y)。 如果x和y都为空,那么这个操作返回条件的真元素的坐标。坐标在二维张量中返回,其中第一个维度(行)表示真实元素的数量,第二个维度(列)表示真实元素的坐标。记住,输出张量的形状可以根据输入中有多少个真值而变化。索引按行主顺序输出。如果两者都是非零,则x和y必须具有相同的形状。如果x和y是标量,条件张量必须是标量。如果x和y是更高秩的向量,那么条件必须是大小与x的第一个维度匹配的向量,或者必须具有与x相同的形状。条件张量充当一个掩码,它根据每个元素的值选择输出中对应的元素/行是来自x(如果为真)还是来自y(如果为假)。如果条件是一个向量,x和y是高秩矩阵,那么它选择从x和y复制哪一行(外维),如果条件与x和y形状相同,那么它选择从x和y复制哪一个元素。

参数:

  • condition: bool类型的张量
  • x: 一个张量,它的形状可能和条件相同。如果条件为秩1,x的秩可能更高,但是它的第一个维度必须与条件的大小匹配
  • y: 与x形状和类型相同的张量
  • name: 操作的名称(可选)

返回值:

  • 一个与x, y相同类型和形状的张量,如果它们是非零的话。一个带形状(num_true, dim_size(condition))的张量。

异常:

  • ValueError: When exactly one of x or y is non-None.

原链接: https://tensorflow.google.cn/versions/r1.9/api_docs/python/tf/where?hl=en

0 人点赞