未验证 提交 c52fe48f 编写于 作者: W wangchaochaohu 提交者: GitHub

fix the GetKernelTypeForVar of input for fluid.gather (#28534)

上级 621b31c5
...@@ -69,7 +69,11 @@ class GatherOp : public framework::OperatorWithKernel { ...@@ -69,7 +69,11 @@ class GatherOp : public framework::OperatorWithKernel {
framework::OpKernelType GetKernelTypeForVar( framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const framework::Tensor& tensor, const std::string& var_name, const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override { const framework::OpKernelType& expected_kernel_type) const override {
return expected_kernel_type; if (var_name == "Axis") {
return expected_kernel_type;
}
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册