提交 12ee5014 编写于 作者: D dongzhihong

"fix operator grad config"

上级 43ba24e0
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using framework::Tensor;
class RowwiseAddOp : public framework::OperatorWithKernel { class RowwiseAddOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -50,14 +52,23 @@ for i in xrange(X.shape[0]): ...@@ -50,14 +52,23 @@ for i in xrange(X.shape[0]):
} }
}; };
class RowwiseAddGradOp : public framework::OperatorWithKernel { class RowwiseAddGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 4UL, // PADDLE_ENFORCE(ctx.InputSize() == 4UL,
"RowwiseAddGrad inputs is I, O, OG, size must be 4"); // "RowwiseAddGrad inputs is I, O, OG, size must be 4");
PADDLE_ENFORCE(ctx.OutputSize() == 2, // PADDLE_ENFORCE(ctx.OutputSize() == 2,
"RowwiseAddGrad output is IG, size must be 2"); // "RowwiseAddGrad output is IG, size must be 2");
ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims()); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "X should not be null");
ctx.Output<Tensor>(1)->Resize(ctx.Input<Tensor>(1)->dims()); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("b"), "b should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto dims0 = ctx.Input<Tensor>("X")->dims();
auto dims1 = ctx.Input<Tensor>("b")->dims();
ctx.Output<Tensor>(framework::GradVarName("X"))->Resize(dims0);
ctx.Output<Tensor>(framework::GradVarName("b"))->Resize(dims1);
} }
}; };
......
...@@ -51,19 +51,20 @@ template <typename Place, typename T> ...@@ -51,19 +51,20 @@ template <typename Place, typename T>
class RowwiseAddGradKernel : public framework::OpKernel { class RowwiseAddGradKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* XGrad = context.Output<Tensor>(0); auto* dX = context.Output<Tensor>(framework::GradVarName("X"));
auto* bGrad = context.Output<Tensor>(1); auto* db = context.Output<Tensor>(framework::GradVarName("b"));
XGrad->mutable_data<T>(context.GetPlace()); auto* dOut = context.Output<Tensor>(framework::GradVarName("Out"));
bGrad->mutable_data<T>(context.GetPlace()); dX->mutable_data<T>(context.GetPlace());
db->mutable_data<T>(context.GetPlace());
// I, O, OG => [X, b], [Out], [OutGrad] auto OutGrad = EigenMatrix<T>::From(*dOut);
auto OutGrad = EigenMatrix<T>::From(*context.Input<Tensor>(3)); auto place = context.GetEigenDevice<Place>();
EigenMatrix<T>::From(*XGrad).device(context.GetEigenDevice<Place>()) = EigenMatrix<T>::From(*dX).device(place) = OutGrad;
OutGrad;
// https://eigen.tuxfamily.org/dox/unsupported/TensorBase_8h_source.html // https://eigen.tuxfamily.org/dox/unsupported/TensorBase_8h_source.html
EigenVector<T>::Flatten(*bGrad).device(context.GetEigenDevice<Place>()) = // colwise add
OutGrad.cumsum(1); // colwise add Eigen::array<int, 1> dims{{1}}; /* dimension to reduce */
EigenVector<T>::Flatten(*db).device(place) = OutGrad.sum(dims);
} }
}; };
} // namespace operators } // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册