未验证 提交 7a73692b 编写于 作者: C Chen Weihang 提交者: GitHub

normalized custom operator impl (#32666)

上级 10c493a8
...@@ -246,7 +246,7 @@ class CustomOperator : public OperatorWithKernel { ...@@ -246,7 +246,7 @@ class CustomOperator : public OperatorWithKernel {
* it can only be determined at runtime. * it can only be determined at runtime.
*/ */
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(proto::VarType::RAW, ctx.GetPlace()); return framework::OpKernelType(proto::VarType::RAW, ctx.GetPlace());
} }
...@@ -257,7 +257,7 @@ class CustomOperator : public OperatorWithKernel { ...@@ -257,7 +257,7 @@ class CustomOperator : public OperatorWithKernel {
*/ */
framework::OpKernelType GetKernelTypeForVar( framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor, const std::string& var_name, const Tensor& tensor,
const OpKernelType& expected_kernel_type) { const OpKernelType& expected_kernel_type) const override {
return OpKernelType(expected_kernel_type.data_type_, return OpKernelType(expected_kernel_type.data_type_,
expected_kernel_type.place_, tensor.layout()); expected_kernel_type.place_, tensor.layout());
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册