提交 12ee5014 编写于 作者: D dongzhihong

"fix operator grad config"

上级 43ba24e0
......@@ -17,6 +17,8 @@
namespace paddle {
namespace operators {
using framework::Tensor;
class RowwiseAddOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -50,14 +52,23 @@ for i in xrange(X.shape[0]):
}
};
class RowwiseAddGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 4UL,
"RowwiseAddGrad inputs is I, O, OG, size must be 4");
PADDLE_ENFORCE(ctx.OutputSize() == 2,
"RowwiseAddGrad output is IG, size must be 2");
ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims());
ctx.Output<Tensor>(1)->Resize(ctx.Input<Tensor>(1)->dims());
// PADDLE_ENFORCE(ctx.InputSize() == 4UL,
// "RowwiseAddGrad inputs is I, O, OG, size must be 4");
// PADDLE_ENFORCE(ctx.OutputSize() == 2,
// "RowwiseAddGrad output is IG, size must be 2");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "X should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("b"), "b should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto dims0 = ctx.Input<Tensor>("X")->dims();
auto dims1 = ctx.Input<Tensor>("b")->dims();
ctx.Output<Tensor>(framework::GradVarName("X"))->Resize(dims0);
ctx.Output<Tensor>(framework::GradVarName("b"))->Resize(dims1);
}
};
......
......@@ -51,19 +51,20 @@ template <typename Place, typename T>
class RowwiseAddGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* XGrad = context.Output<Tensor>(0);
auto* bGrad = context.Output<Tensor>(1);
XGrad->mutable_data<T>(context.GetPlace());
bGrad->mutable_data<T>(context.GetPlace());
auto* dX = context.Output<Tensor>(framework::GradVarName("X"));
auto* db = context.Output<Tensor>(framework::GradVarName("b"));
auto* dOut = context.Output<Tensor>(framework::GradVarName("Out"));
dX->mutable_data<T>(context.GetPlace());
db->mutable_data<T>(context.GetPlace());
// I, O, OG => [X, b], [Out], [OutGrad]
auto OutGrad = EigenMatrix<T>::From(*context.Input<Tensor>(3));
EigenMatrix<T>::From(*XGrad).device(context.GetEigenDevice<Place>()) =
OutGrad;
auto OutGrad = EigenMatrix<T>::From(*dOut);
auto place = context.GetEigenDevice<Place>();
EigenMatrix<T>::From(*dX).device(place) = OutGrad;
// https://eigen.tuxfamily.org/dox/unsupported/TensorBase_8h_source.html
EigenVector<T>::Flatten(*bGrad).device(context.GetEigenDevice<Place>()) =
OutGrad.cumsum(1); // colwise add
// colwise add
Eigen::array<int, 1> dims{{1}}; /* dimension to reduce */
EigenVector<T>::Flatten(*db).device(place) = OutGrad.sum(dims);
}
};
} // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册