未验证 提交 ca68b13f 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN] Fix to elementwise_add grad (#24639)

上级 824572c1
...@@ -104,7 +104,7 @@ class ElementwiseOp : public framework::OperatorWithKernel { ...@@ -104,7 +104,7 @@ class ElementwiseOp : public framework::OperatorWithKernel {
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
int rankdiff = ctx.Input<Tensor>("X")->dims().size() - int rankdiff = ctx.Input<Tensor>("X")->dims().size() -
ctx.Input<Tensor>("Y")->dims().size(); ctx.Input<Tensor>("Y")->dims().size();
return (axis == -1) || (axis == rankdiff); return (rankdiff == 0) || (axis == -1) || (axis == rankdiff);
}; };
if (platform::CanMKLDNNBeUsed(ctx) && if (platform::CanMKLDNNBeUsed(ctx) &&
...@@ -243,9 +243,7 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { ...@@ -243,9 +243,7 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
// If broadcasting is needed, use native implementation // If broadcasting is needed, use native implementation
auto CanMKLDNNElementwiseAddGradBeUsed = [&]() { auto CanMKLDNNElementwiseAddGradBeUsed = [&]() {
auto dx = ctx.Output<Tensor>(framework::GradVarName("X")); return (ctx.Input<Tensor>("X")->dims() == ctx.Input<Tensor>("Y")->dims());
auto dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
return (dx != nullptr && dy != nullptr && dx->dims() == dy->dims());
}; };
if (platform::CanMKLDNNBeUsed(ctx) && if (platform::CanMKLDNNBeUsed(ctx) &&
......
...@@ -85,6 +85,7 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel<T> { ...@@ -85,6 +85,7 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel<T> {
in->set_format(out->format()); in->set_format(out->format());
}; };
// TODO(jczaja): Double check if vcopy works for blocked data
auto blas = math::GetBlas<paddle::platform::CPUDeviceContext, T>(ctx); auto blas = math::GetBlas<paddle::platform::CPUDeviceContext, T>(ctx);
if (dx) { if (dx) {
blas.VCOPY(dout->numel(), dout->data<T>(), blas.VCOPY(dout->numel(), dout->data<T>(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册