提交 9b501736 编写于 作者: Y Yihua Xu 提交者: Tao Luo

Fix the format issue when 'X' is not nchw. (#17833)

test=develop
上级 b888a4c5
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h" #include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
namespace paddle { namespace paddle {
...@@ -53,12 +54,45 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> { ...@@ -53,12 +54,45 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
// Execute default elementwise_add operator when // Execute default elementwise_add operator when
// broadcast operations need to performed. // broadcast operations need to performed.
if (x_dims != y_dims_untrimed) { if (x_dims != y_dims_untrimed) {
Tensor _x;
mkldnn::memory::format format;
std::vector<int> src_x_tz = framework::vectorize2int(x_dims);
if ((src_x_tz.size() == 3 &&
x->format() != (format = memory::format::ncw)) ||
(src_x_tz.size() == 4 &&
x->format() != (format = memory::format::nchw)) ||
(src_x_tz.size() == 5 &&
x->format() != (format = memory::format::ncdhw))) {
_x.Resize(x_dims);
auto user_x_memory_pd = memory::primitive_desc(
{{src_x_tz}, memory::data_type::f32, x->format()}, mkldnn_engine);
auto x_memory_pd = memory::primitive_desc(
{{src_x_tz}, memory::data_type::f32, format}, mkldnn_engine);
auto size = x_memory_pd.get_size();
_x.mutable_data<T>(ctx.GetPlace(), paddle::memory::Allocator::kDefault,
size);
auto user_x_memory =
memory(user_x_memory_pd, paddle::platform::to_void_cast<T>(x_data));
auto x_memory = memory(x_memory_pd,
paddle::platform::to_void_cast<T>(_x.data<T>()));
auto x_reorder = reorder(user_x_memory, x_memory);
std::vector<primitive> pipeline;
pipeline.push_back(x_reorder);
stream(stream::kind::eager).submit(pipeline).wait();
} else {
format = x->format();
_x.ShareDataWith(*x);
}
auto sum_func = [](T a, T b) -> T { return a + b; }; auto sum_func = [](T a, T b) -> T { return a + b; };
TransformFunctor<decltype(sum_func), T, TransformFunctor<decltype(sum_func), T,
paddle::platform::CPUDeviceContext, T> paddle::platform::CPUDeviceContext, T>
functor( functor(
x, y, z, &_x, y, z,
ctx.template device_context<paddle::platform::CPUDeviceContext>(), ctx.template device_context<paddle::platform::CPUDeviceContext>(),
sum_func); sum_func);
...@@ -78,7 +112,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> { ...@@ -78,7 +112,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
functor.RunMidWise(n, pre, post); functor.RunMidWise(n, pre, post);
} }
z->set_layout(DataLayout::kMKLDNN); z->set_layout(DataLayout::kMKLDNN);
z->set_format(x->format()); z->set_format(format);
} else { } else {
PADDLE_ENFORCE(x->layout() == DataLayout::kMKLDNN && PADDLE_ENFORCE(x->layout() == DataLayout::kMKLDNN &&
x->format() != memory::format::format_undef, x->format() != memory::format::format_undef,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册