提交 ed681d52 编写于 作者: A Abhinav Arora

Fix conv_mkldnn_op.cc which is causing CI failure

上级 6f831423
...@@ -72,10 +72,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -72,10 +72,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto dst_md = platform::MKLDNNMemDesc( auto dst_md = platform::MKLDNNMemDesc(
dst_tz, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw); dst_tz, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw);
auto src_memory = mkldnn::memory({src_md, mkldnn_engine}, auto src_memory =
reinterpret_cast<void*>(input_data)); mkldnn::memory({src_md, mkldnn_engine},
auto weights_memory = mkldnn::memory({weights_md, mkldnn_engine}, reinterpret_cast<void*>(const_cast<T*>(input_data)));
reinterpret_cast<void*>(filter_data)); auto weights_memory =
mkldnn::memory({weights_md, mkldnn_engine},
reinterpret_cast<void*>(const_cast<T*>(filter_data)));
auto dst_memory = mkldnn::memory({dst_md, mkldnn_engine}, output_data); auto dst_memory = mkldnn::memory({dst_md, mkldnn_engine}, output_data);
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd = std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd =
...@@ -180,9 +182,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -180,9 +182,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
dst_tz, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw); dst_tz, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw);
// create memory // create memory
auto diff_dst_memory = auto diff_dst_memory = mkldnn::memory(
mkldnn::memory({diff_weights_md, mkldnn_engine}, {diff_weights_md, mkldnn_engine},
reinterpret_cast<void*>(output_grad_data)); reinterpret_cast<void*>(const_cast<T*>(output_grad_data)));
// Retrieve conv_pd from device context // Retrieve conv_pd from device context
auto conv_pd = auto conv_pd =
std::static_pointer_cast<mkldnn::convolution_forward::primitive_desc>( std::static_pointer_cast<mkldnn::convolution_forward::primitive_desc>(
...@@ -202,8 +204,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -202,8 +204,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto diff_weights_memory = auto diff_weights_memory =
mkldnn::memory({diff_weights_md, mkldnn_engine}, mkldnn::memory({diff_weights_md, mkldnn_engine},
reinterpret_cast<void*>(filter_grad_data)); reinterpret_cast<void*>(filter_grad_data));
auto src_memory = mkldnn::memory({src_md, mkldnn_engine}, auto src_memory =
reinterpret_cast<void*>(input_data)); mkldnn::memory({src_md, mkldnn_engine},
reinterpret_cast<void*>(const_cast<T*>(input_data)));
// create backward conv primitive for weights // create backward conv primitive for weights
auto conv_bwd_weights_prim = mkldnn::convolution_backward_weights( auto conv_bwd_weights_prim = mkldnn::convolution_backward_weights(
...@@ -222,11 +225,12 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -222,11 +225,12 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
strides, paddings, *conv_pd, mkldnn_engine); strides, paddings, *conv_pd, mkldnn_engine);
// create memory // create memory
auto diff_src_memory = auto diff_src_memory = mkldnn::memory(
mkldnn::memory({diff_src_md, mkldnn_engine}, {diff_src_md, mkldnn_engine},
reinterpret_cast<void*>(input_grad_data)); reinterpret_cast<void*>(const_cast<T*>(input_grad_data)));
auto weights_memory = mkldnn::memory( auto weights_memory =
{weights_md, mkldnn_engine}, reinterpret_cast<void*>(filter_data)); mkldnn::memory({weights_md, mkldnn_engine},
reinterpret_cast<void*>(const_cast<T*>(filter_data)));
// create backward conv primitive for data // create backward conv primitive for data
auto conv_bwd_data_prim = mkldnn::convolution_backward_data( auto conv_bwd_data_prim = mkldnn::convolution_backward_data(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册