WhereKernel的一些疑问
Created by: miemie2013
大佬们好,我在看WhereKernel源码时,遇到3个问题:
template <typename DeviceContext, typename T>
class WhereKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* condition = context.Input<framework::Tensor>("Condition");
auto* X = context.Input<framework::Tensor>("X");
auto* Y = context.Input<framework::Tensor>("Y");
auto* out = context.Output<framework::Tensor>("Out");
const bool* cond_data = condition->data<bool>();
const T* x_data = X->data<T>();
const T* y_data = Y->data<T>();
T* out_data = out->mutable_data<T>(context.GetPlace());
auto x_numel = X->numel();
for (int i = 0; i < x_numel; i++) {
// 条件为True,取x_data[i]。否则取y_data[i]
out_data[i] = cond_data[i] ? x_data[i] : y_data[i];
}
}
};
根据官方文档, paddle.fluid.layers.where(condition) 该OP计算输入元素中为True的元素在输入中的坐标(index)。 参数: condition (Variable)– 输入秩至少为1的多维Tensor,数据类型是bool类型。 返回: 输出condition元素为True的坐标(index),将所有的坐标(index)组成一个2-D的Tensor。
第一个问题,paddle.fluid.layers.where()参数里没有X和Y,那X和Y表示什么呢? 第二个问题,返回的是坐标,但我感觉怎么好像是out_data[i]被写入了x_data[i]或者y_data[i],难道它们是坐标? 第三个问题,返回的是一个2-D的Tensor,形状是(M, N),其中M是符合条件的元素的个数,N是坐标的维数。但我感觉out_data的长度好像是固定了一样,只是不符合条件的位置被写入了 y_data[i],符合条件的位置被写入了x_data[i],就变成坐标这一步是在哪完成了呢?谢谢!