提交 e57bc4d7 编写于 作者: S sneaxiy

Merge branch 'refine_elementwise_add' of https://github.com/sneaxiy/Paddle...

Merge branch 'refine_elementwise_add' of https://github.com/sneaxiy/Paddle into refine_elementwise_add
......@@ -141,8 +141,6 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
......@@ -150,13 +148,13 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> {
if (dx != nullptr) dx->ShareDataWith(*dout);
if (dy == nullptr) return;
if (x->dims() == y->dims()) {
const framework::DDim& x_dim = dout->dims();
framework::DDim y_dim = dy->dims();
if (x_dim == y_dim) {
dy->ShareDataWith(*dout);
} else {
dy->mutable_data<T>(ctx.GetPlace());
// Perform reduction to dout to calculate dy
const framework::DDim& x_dim = x->dims();
framework::DDim y_dim = y->dims();
int axis = ctx.Attr<int>("axis");
axis = (axis == -1 ? x_dim.size() - y_dim.size() : axis);
y_dim = trim_trailing_singular_dims(y_dim);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册