From e12d1a1ce04275e4d91788a0482f3c0ebcfab609 Mon Sep 17 00:00:00 2001 From: sweetsky0901 Date: Mon, 25 Dec 2017 16:48:54 +0800 Subject: [PATCH] for esp data type --- paddle/operators/norm_op.cc | 20 ++------------------ paddle/operators/norm_op.cu | 4 ++-- 2 files changed, 4 insertions(+), 20 deletions(-) diff --git a/paddle/operators/norm_op.cc b/paddle/operators/norm_op.cc index 1d9b55d887..b198b76cd4 100644 --- a/paddle/operators/norm_op.cc +++ b/paddle/operators/norm_op.cc @@ -55,14 +55,6 @@ class NormOpMaker : public framework::OpProtoAndCheckerMaker { }; class NormOp : public framework::OperatorWithKernel { - protected: - framework::OpKernelType GetKernelType( - const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.device_context()); - } - public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { @@ -80,14 +72,6 @@ class NormOp : public framework::OperatorWithKernel { }; class NormOpGrad : public framework::OperatorWithKernel { - protected: - framework::OpKernelType GetKernelType( - const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.device_context()); - } - public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { @@ -105,7 +89,7 @@ REGISTER_OP(norm, ops::NormOp, ops::NormOpMaker, norm_grad, ops::NormOpGrad); REGISTER_OP_CPU_KERNEL( norm, ops::NormKernel, - ops::NormKernel); + ops::NormKernel); REGISTER_OP_CPU_KERNEL( norm_grad, ops::NormGradKernel, - ops::NormGradKernel); + ops::NormGradKernel); diff --git a/paddle/operators/norm_op.cu b/paddle/operators/norm_op.cu index 7d84aaa732..2941c89b93 100644 --- a/paddle/operators/norm_op.cu +++ b/paddle/operators/norm_op.cu @@ -18,7 +18,7 @@ limitations under the License. */ namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( norm, ops::NormKernel, - ops::NormKernel); + ops::NormKernel); REGISTER_OP_CUDA_KERNEL( norm_grad, ops::NormGradKernel, - ops::NormGradKernel); + ops::NormGradKernel); -- GitLab