From e811c865677c80a5fce13d4bac7178bf0fa20d7b Mon Sep 17 00:00:00 2001 From: sweetsky0901 Date: Fri, 22 Dec 2017 18:36:40 +0800 Subject: [PATCH] for epsilon dataType --- paddle/operators/norm_op.cc | 10 ++++++---- paddle/operators/norm_op.h | 8 ++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/paddle/operators/norm_op.cc b/paddle/operators/norm_op.cc index 990a1504e..1d9b55d88 100644 --- a/paddle/operators/norm_op.cc +++ b/paddle/operators/norm_op.cc @@ -16,6 +16,7 @@ limitations under the License. */ namespace paddle { namespace operators { +template class NormOpMaker : public framework::OpProtoAndCheckerMaker { public: NormOpMaker(OpProto* proto, OpAttrChecker* op_checker) @@ -28,9 +29,9 @@ class NormOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Scale", "(Tensor) The input tensor of norm operator. " "The format of input tensor is C * 1."); - AddAttr("epsilon", - "(float, default 1e-10) Constant " - "for numerical stability.") + AddAttr("epsilon", + "(float, default 1e-10) Constant " + "for numerical stability.") .SetDefault(1.0e-10f); AddOutput("Out", "(Tensor) The output tensor of norm operator." @@ -100,7 +101,8 @@ class NormOpGrad : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(norm, ops::NormOp, ops::NormOpMaker, norm_grad, ops::NormOpGrad); +REGISTER_OP(norm, ops::NormOp, ops::NormOpMaker, norm_grad, + ops::NormOpGrad); REGISTER_OP_CPU_KERNEL( norm, ops::NormKernel, ops::NormKernel); diff --git a/paddle/operators/norm_op.h b/paddle/operators/norm_op.h index d3dcf4834..b22df373a 100644 --- a/paddle/operators/norm_op.h +++ b/paddle/operators/norm_op.h @@ -26,14 +26,14 @@ template using EigenMatrix = framework::EigenMatrix; -template +template class NormKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { const framework::Tensor* in_x = context.Input("X"); const framework::Tensor* scale = context.Input("Scale"); auto* out = context.Output("Out"); - T epsilon = context.Attr("epsilon"); + auto epsilon = static_cast(context.Attr("epsilon")); out->mutable_data(context.GetPlace()); int batch_size = in_x->dims()[0]; int channels = in_x->dims()[1]; @@ -82,7 +82,7 @@ class NormKernel : public framework::OpKernel { } } }; -template +template class NormGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { @@ -90,7 +90,7 @@ class NormGradKernel : public framework::OpKernel { const framework::Tensor* scale = context.Input("Scale"); const framework::Tensor* out_grad = context.Input(framework::GradVarName("Out")); - T epsilon = context.Attr("epsilon"); + auto epsilon = static_cast(context.Attr("epsilon")); framework::Tensor* in_x_grad = context.Output(framework::GradVarName("X")); in_x_grad->mutable_data(context.GetPlace()); -- GitLab