提交 8ff3590e 编写于 作者: D dongzhihong

fix op name

上级 264b6447
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class RowWiseAddOp : public OperatorWithKernel { class RowwiseAddOp : public OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override { void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 2UL, PADDLE_ENFORCE(ctx.InputSize() == 2UL,
...@@ -32,9 +32,9 @@ protected: ...@@ -32,9 +32,9 @@ protected:
} }
}; };
class RowWiseAddOpMaker : public OpProtoAndCheckerMaker { class RowwiseAddOpMaker : public OpProtoAndCheckerMaker {
public: public:
RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker) RowwiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The left input of row-wise add op, must be matrix"); AddInput("X", "The left input of row-wise add op, must be matrix");
AddInput("b", "The right input of row-wise add op, must be vector"); AddInput("b", "The right input of row-wise add op, must be vector");
...@@ -46,13 +46,13 @@ for i in xrange(X.shape[0]): ...@@ -46,13 +46,13 @@ for i in xrange(X.shape[0]):
)DOC"); )DOC");
} }
}; };
class RowWiseAddGradOp : public OperatorWithKernel { class RowwiseAddGradOp : public OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override { void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 4UL, PADDLE_ENFORCE(ctx.InputSize() == 4UL,
"RowWiseAddGrad inputs is I, O, OG, size must be 4"); "RowwiseAddGrad inputs is I, O, OG, size must be 4");
PADDLE_ENFORCE(ctx.OutputSize() == 2, PADDLE_ENFORCE(ctx.OutputSize() == 2,
"RowWiseAddGrad output is IG, size must be 2"); "RowwiseAddGrad output is IG, size must be 2");
ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims()); ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims());
ctx.Output<Tensor>(1)->Resize(ctx.Input<Tensor>(1)->dims()); ctx.Output<Tensor>(1)->Resize(ctx.Input<Tensor>(1)->dims());
} }
...@@ -61,10 +61,10 @@ protected: ...@@ -61,10 +61,10 @@ protected:
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP(rowwise_add, ops::RowWiseAddOp, ops::RowWiseAddOpMaker); REGISTER_OP(rowwise_add, ops::RowwiseAddOp, ops::RowwiseAddOpMaker);
REGISTER_OP_CPU_KERNEL(rowwise_add, REGISTER_OP_CPU_KERNEL(rowwise_add,
ops::RowWiseAddKernel<ops::CPUPlace, float>); ops::RowwiseAddKernel<ops::CPUPlace, float>);
REGISTER_GRADIENT_OP(rowwise_add, rowwise_add_grad, ops::RowWiseAddGradOp); REGISTER_GRADIENT_OP(rowwise_add, rowwise_add_grad, ops::RowwiseAddGradOp);
REGISTER_OP_CPU_KERNEL(rowwise_add_grad, REGISTER_OP_CPU_KERNEL(rowwise_add_grad,
ops::RowWiseAddGradKernel<ops::CPUPlace, float>); ops::RowwiseAddGradKernel<ops::CPUPlace, float>);
#include "paddle/operators/rowwise_add_op.h" #include "paddle/operators/rowwise_add_op.h"
REGISTER_OP_GPU_KERNEL(rowwise_add, REGISTER_OP_GPU_KERNEL(rowwise_add,
ops::RowWiseAddKernel<ops::GPUPlace, float>); ops::RowwiseAddKernel<ops::GPUPlace, float>);
...@@ -19,7 +19,7 @@ namespace paddle { ...@@ -19,7 +19,7 @@ namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class RowWiseAddKernel : public OpKernel { class RowwiseAddKernel : public OpKernel {
public: public:
void Compute(const ExecutionContext& context) const override { void Compute(const ExecutionContext& context) const override {
auto out = context.Output<Tensor>(0); auto out = context.Output<Tensor>(0);
...@@ -39,7 +39,7 @@ public: ...@@ -39,7 +39,7 @@ public:
}; };
template <typename Place, typename T> template <typename Place, typename T>
class RowWiseAddGradKernel : public OpKernel { class RowwiseAddGradKernel : public OpKernel {
public: public:
void Compute(const ExecutionContext& context) const override { void Compute(const ExecutionContext& context) const override {
auto XGrad = context.Output<Tensor>(0); auto XGrad = context.Output<Tensor>(0);
...@@ -51,7 +51,7 @@ public: ...@@ -51,7 +51,7 @@ public:
auto OutGrad = EigenMatrix<T>::From(*context.Input<Tensor>(3)); auto OutGrad = EigenMatrix<T>::From(*context.Input<Tensor>(3));
EigenMatrix<T>::From(*XGrad).device(*(context.GetEigenDevice<Place>())) = EigenMatrix<T>::From(*XGrad).device(*(context.GetEigenDevice<Place>())) =
OutGrad; OutGrad;
// const int dimension = bGrad.dimension(0);
// https://eigen.tuxfamily.org/dox/unsupported/TensorBase_8h_source.html // https://eigen.tuxfamily.org/dox/unsupported/TensorBase_8h_source.html
EigenVector<T>::Flatten(*bGrad).device(*(context.GetEigenDevice<Place>())) = EigenVector<T>::Flatten(*bGrad).device(*(context.GetEigenDevice<Place>())) =
OutGrad.cumsum(1); // colwise add OutGrad.cumsum(1); // colwise add
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册