提交 11487de9 编写于 作者: A Abhinav Arora

Resolve conflict

上级 2e2726f1
...@@ -40,20 +40,24 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm, ...@@ -40,20 +40,24 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm,
const T *dst_data = dst->template mutable_data<T>(ctx.GetPlace()); const T *dst_data = dst->template mutable_data<T>(ctx.GetPlace());
// get memory dim // get memory dim
PADDLE_ENFORCE(src->dims().size() == 4, PADDLE_ENFORCE(src->dims().size() == 2 || src->dims().size() == 4,
"Input dim must be with 4, i.e. NCHW"); "Input dim must be with 2 or 4");
std::vector<int> src_tz = framework::vectorize2int(src->dims()); std::vector<int> src_tz = framework::vectorize2int(src->dims());
// create memory description // create memory description
// TODO(kbinias-intel): support more formats auto data_md = src_tz.size() == 2
auto data_md = platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32, ? platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
mkldnn::memory::format::nchw); mkldnn::memory::format::nc)
: platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
mkldnn::memory::format::nchw);
// create memory primitives // create memory primitives
auto src_memory = auto src_memory =
mkldnn::memory({data_md, mkldnn_engine}, static_cast<void *>(src_data)); mkldnn::memory({data_md, mkldnn_engine},
static_cast<void *>(const_cast<float *>(src_data)));
auto dst_memory = auto dst_memory =
mkldnn::memory({data_md, mkldnn_engine}, static_cast<void *>(dst_data)); mkldnn::memory({data_md, mkldnn_engine},
static_cast<void *>(const_cast<float *>(dst_data)));
auto forward_desc = mkldnn::eltwise_forward::desc( auto forward_desc = mkldnn::eltwise_forward::desc(
mkldnn::prop_kind::forward_training, algorithm, data_md, alpha, beta); mkldnn::prop_kind::forward_training, algorithm, data_md, alpha, beta);
...@@ -93,16 +97,21 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm, ...@@ -93,16 +97,21 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
std::vector<int> src_tz = framework::vectorize2int(x->dims()); std::vector<int> src_tz = framework::vectorize2int(x->dims());
// create memory description // create memory description
auto data_md = platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32, auto data_md = src_tz.size() == 2
mkldnn::memory::format::nchw); ? platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
mkldnn::memory::format::nc)
: platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
mkldnn::memory::format::nchw);
// create memory primitives // create memory primitives
auto src_memory = auto src_memory = mkldnn::memory(
mkldnn::memory({data_md, mkldnn_engine}, static_cast<void *>(src)); {data_md, mkldnn_engine}, static_cast<void *>(const_cast<float *>(src)));
auto diff_src_memory = auto diff_src_memory =
mkldnn::memory({data_md, mkldnn_engine}, static_cast<void *>(diff_src)); mkldnn::memory({data_md, mkldnn_engine},
static_cast<void *>(const_cast<float *>(diff_src)));
auto diff_dst_memory = auto diff_dst_memory =
mkldnn::memory({data_md, mkldnn_engine}, static_cast<void *>(diff_dst)); mkldnn::memory({data_md, mkldnn_engine},
static_cast<void *>(const_cast<float *>(diff_dst)));
auto backward_desc = auto backward_desc =
mkldnn::eltwise_backward::desc(algorithm, data_md, data_md, alpha, beta); mkldnn::eltwise_backward::desc(algorithm, data_md, data_md, alpha, beta);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册