未验证 提交 d9780a22 编写于 作者: A arlesniak 提交者: GitHub

Fix for wrong conditions between forward and backward in elementwise_add_grad op (#38176)

上级 a4afb97a
...@@ -319,16 +319,7 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { ...@@ -319,16 +319,7 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
ctx, framework::GradVarName("Out")); ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
// If broadcasting is needed, use native implementation if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
auto CanMKLDNNElementwiseGradBeUsed = [&]() {
auto dx_dims = ctx.Input<Tensor>("X")->dims();
auto dy_dims = ctx.Input<Tensor>("Y")->dims();
// No broadcast or broadcasting of data on inner dims is supported
return (dx_dims[dx_dims.size() - 1] == dy_dims[dy_dims.size() - 1]);
};
if (this->CanMKLDNNBeUsed(ctx, input_data_type) &&
CanMKLDNNElementwiseGradBeUsed()) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(), return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN, framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN); framework::LibraryType::kMKLDNN);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册