From 11487de9eb242c8216fb70e29617a067cc3df3ed Mon Sep 17 00:00:00 2001 From: Abhinav Arora Date: Sun, 8 Apr 2018 23:18:22 -0700 Subject: [PATCH] Resolve conflict --- .../fluid/operators/activation_mkldnn_op.cc | 35 ++++++++++++------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/operators/activation_mkldnn_op.cc b/paddle/fluid/operators/activation_mkldnn_op.cc index 5b6390ab73..ab7c612271 100644 --- a/paddle/fluid/operators/activation_mkldnn_op.cc +++ b/paddle/fluid/operators/activation_mkldnn_op.cc @@ -40,20 +40,24 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm, const T *dst_data = dst->template mutable_data(ctx.GetPlace()); // get memory dim - PADDLE_ENFORCE(src->dims().size() == 4, - "Input dim must be with 4, i.e. NCHW"); + PADDLE_ENFORCE(src->dims().size() == 2 || src->dims().size() == 4, + "Input dim must be with 2 or 4"); std::vector src_tz = framework::vectorize2int(src->dims()); // create memory description - // TODO(kbinias-intel): support more formats - auto data_md = platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32, - mkldnn::memory::format::nchw); + auto data_md = src_tz.size() == 2 + ? platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32, + mkldnn::memory::format::nc) + : platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32, + mkldnn::memory::format::nchw); // create memory primitives auto src_memory = - mkldnn::memory({data_md, mkldnn_engine}, static_cast(src_data)); + mkldnn::memory({data_md, mkldnn_engine}, + static_cast(const_cast(src_data))); auto dst_memory = - mkldnn::memory({data_md, mkldnn_engine}, static_cast(dst_data)); + mkldnn::memory({data_md, mkldnn_engine}, + static_cast(const_cast(dst_data))); auto forward_desc = mkldnn::eltwise_forward::desc( mkldnn::prop_kind::forward_training, algorithm, data_md, alpha, beta); @@ -93,16 +97,21 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm, std::vector src_tz = framework::vectorize2int(x->dims()); // create memory description - auto data_md = platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32, - mkldnn::memory::format::nchw); + auto data_md = src_tz.size() == 2 + ? platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32, + mkldnn::memory::format::nc) + : platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32, + mkldnn::memory::format::nchw); // create memory primitives - auto src_memory = - mkldnn::memory({data_md, mkldnn_engine}, static_cast(src)); + auto src_memory = mkldnn::memory( + {data_md, mkldnn_engine}, static_cast(const_cast(src))); auto diff_src_memory = - mkldnn::memory({data_md, mkldnn_engine}, static_cast(diff_src)); + mkldnn::memory({data_md, mkldnn_engine}, + static_cast(const_cast(diff_src))); auto diff_dst_memory = - mkldnn::memory({data_md, mkldnn_engine}, static_cast(diff_dst)); + mkldnn::memory({data_md, mkldnn_engine}, + static_cast(const_cast(diff_dst))); auto backward_desc = mkldnn::eltwise_backward::desc(algorithm, data_md, data_md, alpha, beta); -- GitLab