代码语言:javascript复制
tf.where(
condition,
x=None,
y=None,
name=None
)
根据条件返回元素(x或y)。 如果x和y都为空,那么这个操作返回条件的真元素的坐标。坐标在二维张量中返回,其中第一个维度(行)表示真实元素的数量,第二个维度(列)表示真实元素的坐标。记住,输出张量的形状可以根据输入中有多少个真值而变化。索引按行主顺序输出。如果两者都是非零,则x和y必须具有相同的形状。如果x和y是标量,条件张量必须是标量。如果x和y是更高秩的向量,那么条件必须是大小与x的第一个维度匹配的向量,或者必须具有与x相同的形状。条件张量充当一个掩码,它根据每个元素的值选择输出中对应的元素/行是来自x(如果为真)还是来自y(如果为假)。如果条件是一个向量,x和y是高秩矩阵,那么它选择从x和y复制哪一行(外维),如果条件与x和y形状相同,那么它选择从x和y复制哪一个元素。
参数:
- condition: bool类型的张量
- x: 一个张量,它的形状可能和条件相同。如果条件为秩1,x的秩可能更高,但是它的第一个维度必须与条件的大小匹配
- y: 与x形状和类型相同的张量
- name: 操作的名称(可选)
返回值:
- 一个与x, y相同类型和形状的张量,如果它们是非零的话。一个带形状(num_true, dim_size(condition))的张量。
异常:
ValueError
: When exactly one ofx
ory
is non-None.
原链接: https://tensorflow.google.cn/versions/r1.9/api_docs/python/tf/where?hl=en