diff --git a/paddle/operators/norm_op.cc b/paddle/operators/norm_op.cc index 990a1504ea34e7422518fb708a1ee920f30daadc..1d9b55d887ed1bd980dc208dc0154c370925890f 100644 --- a/paddle/operators/norm_op.cc +++ b/paddle/operators/norm_op.cc @@ -16,6 +16,7 @@ limitations under the License. */ namespace paddle { namespace operators { +template class NormOpMaker : public framework::OpProtoAndCheckerMaker { public: NormOpMaker(OpProto* proto, OpAttrChecker* op_checker) @@ -28,9 +29,9 @@ class NormOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Scale", "(Tensor) The input tensor of norm operator. " "The format of input tensor is C * 1."); - AddAttr("epsilon", - "(float, default 1e-10) Constant " - "for numerical stability.") + AddAttr("epsilon", + "(float, default 1e-10) Constant " + "for numerical stability.") .SetDefault(1.0e-10f); AddOutput("Out", "(Tensor) The output tensor of norm operator." @@ -100,7 +101,8 @@ class NormOpGrad : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(norm, ops::NormOp, ops::NormOpMaker, norm_grad, ops::NormOpGrad); +REGISTER_OP(norm, ops::NormOp, ops::NormOpMaker, norm_grad, + ops::NormOpGrad); REGISTER_OP_CPU_KERNEL( norm, ops::NormKernel, ops::NormKernel); diff --git a/paddle/operators/norm_op.h b/paddle/operators/norm_op.h index d3dcf4834162f7a57c9a33b51736e8f4d4446134..b22df373af1e96bf4ba25fa7e3663b4fa963ab71 100644 --- a/paddle/operators/norm_op.h +++ b/paddle/operators/norm_op.h @@ -26,14 +26,14 @@ template using EigenMatrix = framework::EigenMatrix; -template +template class NormKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { const framework::Tensor* in_x = context.Input("X"); const framework::Tensor* scale = context.Input("Scale"); auto* out = context.Output("Out"); - T epsilon = context.Attr("epsilon"); + auto epsilon = static_cast(context.Attr("epsilon")); out->mutable_data(context.GetPlace()); int batch_size = in_x->dims()[0]; int channels = in_x->dims()[1]; @@ -82,7 +82,7 @@ class NormKernel : public framework::OpKernel { } } }; -template +template class NormGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { @@ -90,7 +90,7 @@ class NormGradKernel : public framework::OpKernel { const framework::Tensor* scale = context.Input("Scale"); const framework::Tensor* out_grad = context.Input(framework::GradVarName("Out")); - T epsilon = context.Attr("epsilon"); + auto epsilon = static_cast(context.Attr("epsilon")); framework::Tensor* in_x_grad = context.Output(framework::GradVarName("X")); in_x_grad->mutable_data(context.GetPlace());