diff --git a/paddle/operators/clip_by_norm_op.cc b/paddle/operators/clip_by_norm_op.cc index ebb7bdda558e8b951755e9e0417ddcf332c75364..d9fc532e39500fa397be80396b075e866bad9362 100644 --- a/paddle/operators/clip_by_norm_op.cc +++ b/paddle/operators/clip_by_norm_op.cc @@ -27,7 +27,7 @@ class ClipByNormOp : public framework::OperatorWithKernel { "Input(X) of ClipByNormOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of ClipByNormOp should not be null."); - auto max_norm = Attr("max_norm"); + auto max_norm = ctx->Attrs().Get("max_norm"); PADDLE_ENFORCE_GT(max_norm, 0, "max_norm should be greater than 0."); auto x_dims = ctx->GetInputDim("X"); ctx->SetOutputDim("Out", x_dims); @@ -35,7 +35,6 @@ class ClipByNormOp : public framework::OperatorWithKernel { } }; -template class ClipByNormOpMaker : public framework::OpProtoAndCheckerMaker { public: ClipByNormOpMaker(framework::OpProto* proto, @@ -46,7 +45,7 @@ class ClipByNormOpMaker : public framework::OpProtoAndCheckerMaker { "The number of dimensions must be between [1, 9]."); AddOutput("Out", "(Tensor) The output of clip_by_norm op with shape as input(X)"); - AddAttr("max_norm", "(float) The maximum norm value."); + AddAttr("max_norm", "(float) The maximum norm value."); AddComment(R"DOC( ClipByNorm operator limits the L2 norm of the input 'X' within 'max_norm'. If the L2 norm of 'X' is less than or equal to 'max_norm', 'Out' will be @@ -66,6 +65,6 @@ where norm('X') represents the L2 norm of 'X'. namespace ops = paddle::operators; REGISTER_OP_WITHOUT_GRADIENT(clip_by_norm, ops::ClipByNormOp, - ops::ClipByNormOpMaker); + ops::ClipByNormOpMaker); REGISTER_OP_CPU_KERNEL( clip_by_norm, ops::ClipByNormKernel);