未验证 提交 ca68b13f 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN] Fix to elementwise_add grad (#24639)

上级 824572c1
......@@ -104,7 +104,7 @@ class ElementwiseOp : public framework::OperatorWithKernel {
int axis = ctx.Attr<int>("axis");
int rankdiff = ctx.Input<Tensor>("X")->dims().size() -
ctx.Input<Tensor>("Y")->dims().size();
return (axis == -1) || (axis == rankdiff);
return (rankdiff == 0) || (axis == -1) || (axis == rankdiff);
};
if (platform::CanMKLDNNBeUsed(ctx) &&
......@@ -243,9 +243,7 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
#ifdef PADDLE_WITH_MKLDNN
// If broadcasting is needed, use native implementation
auto CanMKLDNNElementwiseAddGradBeUsed = [&]() {
auto dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
return (dx != nullptr && dy != nullptr && dx->dims() == dy->dims());
return (ctx.Input<Tensor>("X")->dims() == ctx.Input<Tensor>("Y")->dims());
};
if (platform::CanMKLDNNBeUsed(ctx) &&
......
......@@ -85,6 +85,7 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel<T> {
in->set_format(out->format());
};
// TODO(jczaja): Double check if vcopy works for blocked data
auto blas = math::GetBlas<paddle::platform::CPUDeviceContext, T>(ctx);
if (dx) {
blas.VCOPY(dout->numel(), dout->data<T>(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册