提交 34d68f24 编写于 作者: W wwhu

fix doc and code style

上级 65451b5c
...@@ -39,15 +39,14 @@ template <typename AttrType> ...@@ -39,15 +39,14 @@ template <typename AttrType>
class ClipByNormOpMaker : public framework::OpProtoAndCheckerMaker { class ClipByNormOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ClipByNormOpMaker(framework::OpProto* proto, ClipByNormOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker) framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor)The input of clip_by_norm op." "(Tensor) The input of clip_by_norm op."
"The number of dimensions must be between [1, 9]."); "The number of dimensions must be between [1, 9].");
AddOutput("Out", AddOutput("Out",
"(Tensor)The output of clip_by_norm op with shape as input(X)"); "(Tensor) The output of clip_by_norm op with shape as input(X)");
AddAttr<AttrType>( AddAttr<AttrType>("max_norm", "(float)The maximum norm value.");
"max_norm", "(float)The maximum norm value.");
AddComment(R"DOC( AddComment(R"DOC(
ClipByNorm operator limits the L2 norm of the input 'X' within 'max_norm'. ClipByNorm operator limits the L2 norm of the input 'X' within 'max_norm'.
If the L2 norm of 'X' is less than or equal to 'max_norm', 'Out' will be If the L2 norm of 'X' is less than or equal to 'max_norm', 'Out' will be
...@@ -62,29 +61,11 @@ where norm('X') represents the L2 norm of 'X'. ...@@ -62,29 +61,11 @@ where norm('X') represents the L2 norm of 'X'.
} }
}; };
class ClipByNormOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto x_dims = ctx->GetInputDim("X");
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
}
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(clip_by_norm, REGISTER_OP_WITHOUT_GRADIENT(clip_by_norm, ops::ClipByNormOp,
ops::ClipByNormOp,
ops::ClipByNormOpMaker<float>); ops::ClipByNormOpMaker<float>);
REGISTER_OP_CPU_KERNEL(clip_by_norm, REGISTER_OP_CPU_KERNEL(
ops::ClipByNormKernel clip_by_norm, ops::ClipByNormKernel<paddle::platform::CPUPlace, float>);
<paddle::platform::CPUPlace, float>);
...@@ -15,6 +15,5 @@ ...@@ -15,6 +15,5 @@
#include "paddle/operators/clip_by_norm_op.h" #include "paddle/operators/clip_by_norm_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(clip_by_norm, REGISTER_OP_GPU_KERNEL(
ops::ClipByNormKernel clip_by_norm, ops::ClipByNormKernel<paddle::platform::GPUPlace, float>);
<paddle::platform::GPUPlace, float>);
...@@ -25,9 +25,6 @@ using Tensor = framework::Tensor; ...@@ -25,9 +25,6 @@ using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor, template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>; using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenScalar = framework::EigenScalar<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T>
class ClipByNormKernel : public framework::OpKernel<T> { class ClipByNormKernel : public framework::OpKernel<T> {
......
...@@ -18,21 +18,19 @@ class TestClipByNormOp(OpTest): ...@@ -18,21 +18,19 @@ class TestClipByNormOp(OpTest):
output = self.max_norm * input / norm output = self.max_norm * input / norm
else: else:
output = input output = input
self.outputs = { self.outputs = {'Out': output}
'Out': output
}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def initTestCase(self): def initTestCase(self):
self.shape = (100,) self.shape = (100, )
self.max_norm = 1.0 self.max_norm = 1.0
class TestCase1(TestClipByNormOp): class TestCase1(TestClipByNormOp):
def initTestCase(self): def initTestCase(self):
self.shape = (100,) self.shape = (100, )
self.max_norm = 1e20 self.max_norm = 1e20
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册