提交 264b6447 编写于 作者: D dongzhihong

"add rowwise add backward op"

上级 7e60706b
......@@ -46,6 +46,17 @@ for i in xrange(X.shape[0]):
)DOC");
}
};
class RowWiseAddGradOp : public OperatorWithKernel {
protected:
void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 4UL,
"RowWiseAddGrad inputs is I, O, OG, size must be 4");
PADDLE_ENFORCE(ctx.OutputSize() == 2,
"RowWiseAddGrad output is IG, size must be 2");
ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims());
ctx.Output<Tensor>(1)->Resize(ctx.Input<Tensor>(1)->dims());
}
};
} // namespace operators
} // namespace paddle
......@@ -53,3 +64,7 @@ for i in xrange(X.shape[0]):
REGISTER_OP(rowwise_add, ops::RowWiseAddOp, ops::RowWiseAddOpMaker);
REGISTER_OP_CPU_KERNEL(rowwise_add,
ops::RowWiseAddKernel<ops::CPUPlace, float>);
REGISTER_GRADIENT_OP(rowwise_add, rowwise_add_grad, ops::RowWiseAddGradOp);
REGISTER_OP_CPU_KERNEL(rowwise_add_grad,
ops::RowWiseAddGradKernel<ops::CPUPlace, float>);
......@@ -38,5 +38,24 @@ public:
}
};
template <typename Place, typename T>
class RowWiseAddGradKernel : public OpKernel {
public:
void Compute(const ExecutionContext& context) const override {
auto XGrad = context.Output<Tensor>(0);
auto bGrad = context.Output<Tensor>(1);
XGrad->mutable_data<T>(context.GetPlace());
bGrad->mutable_data<T>(context.GetPlace());
// I, O, OG => [X, b], [Out], [OutGrad]
auto OutGrad = EigenMatrix<T>::From(*context.Input<Tensor>(3));
EigenMatrix<T>::From(*XGrad).device(*(context.GetEigenDevice<Place>())) =
OutGrad;
// const int dimension = bGrad.dimension(0);
// https://eigen.tuxfamily.org/dox/unsupported/TensorBase_8h_source.html
EigenVector<T>::Flatten(*bGrad).device(*(context.GetEigenDevice<Place>())) =
OutGrad.cumsum(1); // colwise add
}
};
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册