提交 ed681d52 编写于 作者: A Abhinav Arora

Fix conv_mkldnn_op.cc which is causing CI failure

上级 6f831423
......@@ -72,10 +72,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto dst_md = platform::MKLDNNMemDesc(
dst_tz, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw);
auto src_memory = mkldnn::memory({src_md, mkldnn_engine},
reinterpret_cast<void*>(input_data));
auto weights_memory = mkldnn::memory({weights_md, mkldnn_engine},
reinterpret_cast<void*>(filter_data));
auto src_memory =
mkldnn::memory({src_md, mkldnn_engine},
reinterpret_cast<void*>(const_cast<T*>(input_data)));
auto weights_memory =
mkldnn::memory({weights_md, mkldnn_engine},
reinterpret_cast<void*>(const_cast<T*>(filter_data)));
auto dst_memory = mkldnn::memory({dst_md, mkldnn_engine}, output_data);
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd =
......@@ -180,9 +182,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
dst_tz, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw);
// create memory
auto diff_dst_memory =
mkldnn::memory({diff_weights_md, mkldnn_engine},
reinterpret_cast<void*>(output_grad_data));
auto diff_dst_memory = mkldnn::memory(
{diff_weights_md, mkldnn_engine},
reinterpret_cast<void*>(const_cast<T*>(output_grad_data)));
// Retrieve conv_pd from device context
auto conv_pd =
std::static_pointer_cast<mkldnn::convolution_forward::primitive_desc>(
......@@ -202,8 +204,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto diff_weights_memory =
mkldnn::memory({diff_weights_md, mkldnn_engine},
reinterpret_cast<void*>(filter_grad_data));
auto src_memory = mkldnn::memory({src_md, mkldnn_engine},
reinterpret_cast<void*>(input_data));
auto src_memory =
mkldnn::memory({src_md, mkldnn_engine},
reinterpret_cast<void*>(const_cast<T*>(input_data)));
// create backward conv primitive for weights
auto conv_bwd_weights_prim = mkldnn::convolution_backward_weights(
......@@ -222,11 +225,12 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
strides, paddings, *conv_pd, mkldnn_engine);
// create memory
auto diff_src_memory =
mkldnn::memory({diff_src_md, mkldnn_engine},
reinterpret_cast<void*>(input_grad_data));
auto weights_memory = mkldnn::memory(
{weights_md, mkldnn_engine}, reinterpret_cast<void*>(filter_data));
auto diff_src_memory = mkldnn::memory(
{diff_src_md, mkldnn_engine},
reinterpret_cast<void*>(const_cast<T*>(input_grad_data)));
auto weights_memory =
mkldnn::memory({weights_md, mkldnn_engine},
reinterpret_cast<void*>(const_cast<T*>(filter_data)));
// create backward conv primitive for data
auto conv_bwd_data_prim = mkldnn::convolution_backward_data(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册