diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index 0526ae52b390305695d6537cb2c161391fc85ad0..44289015bc7c4ba98b75a5fb1444afce98b585dc 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -613,7 +613,7 @@ class ConvMKLDNNHandlerT auto weights_mem_p = this->AcquireMemory("@weights_mem_p_target"); if (is_test && weights_mem_p) { return weights_mem_p; - } else { + } else if (is_test) { const K* filter_data = filter->data(); auto weights_tz = framework::vectorize(filter->dims()); platform::GetGroupConvWeightsTz(weights_tz, groups); @@ -626,6 +626,19 @@ class ConvMKLDNNHandlerT user_src_md, this->fwd_pd_->weights_desc(), platform::to_void_cast(filter_data), "@weights_mem_p", is_test, {}, scale_data, mask); + } else { + const T* filter_data = filter->data(); + auto weights_tz = framework::vectorize(filter->dims()); + platform::GetGroupConvWeightsTz(weights_tz, groups); + + auto user_src_md = platform::MKLDNNMemDesc( + weights_tz, platform::MKLDNNGetDataType(), + GetWeightsFormat(filter->format(), groups, is_conv3d)); + + return this->AcquireMemoryWithReorder( + user_src_md, this->fwd_pd_->weights_desc(), + platform::to_void_cast(filter_data), "@weights_mem_p", is_test, {}, + scale_data, mask); } } @@ -1027,7 +1040,8 @@ REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d_grad, MKLDNN, REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE( conv2d_grad, MKLDNN, ::paddle::platform::CPUPlace, BF16, ops::kConvMKLDNNFP32, - ops::ConvMKLDNNGradOpKernel); + ops::ConvMKLDNNGradOpKernel); REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(depthwise_conv2d, MKLDNN, ::paddle::platform::CPUPlace, FP32, diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index ef216e48416f9da9453bc6456c1bc051615d4435..9aadd36c2e8aca50a9c54cf086e3563255f51f3b 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -377,6 +377,14 @@ class MKLDNNHandlerT { if (bwd_pd_ == nullptr) { return false; } else { + if (std::is_same::value == + false) { + const std::string key_bw_w_pd = key_ + "@bwd_w_pd"; + bwd_w_pd_ = + std::static_pointer_cast( + dev_ctx_.GetBlob(key_bw_w_pd)); + } + // When BWD is cached then still we need to Get FWD PD const std::string key_fpd = key_ + "@fwd_pd"; fwd_pd_ = std::static_pointer_cast( diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_bf16_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_bf16_mkldnn_op.py index 4c753da0512f88ace740452123349a49e31f213d..702d26b073b6b5371a7e281a74ffed2df1450a70 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_bf16_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_bf16_mkldnn_op.py @@ -50,6 +50,7 @@ class TestConv2DBF16Op(TestConv2DOp): self.init_fuse_residual() self.init_data_type() self.init_force_fp32_output() + self.init_infer_or_train() self.conv2d_param = { 'stride': self.stride, @@ -83,6 +84,9 @@ class TestConv2DBF16Op(TestConv2DOp): if self.input_type is not np.float32: self.input = convert_float_to_uint16(self.input) + if self.weight_type is not np.float32: + self.filter = convert_float_to_uint16(self.filter) + self.inputs = { 'Input': self.input, 'Filter': OpTest.np_dtype_to_fluid_dtype( @@ -105,6 +109,8 @@ class TestConv2DBF16Op(TestConv2DOp): 'fuse_residual_connection': self.fuse_residual } + self.init_additional_attrs() + def test_check_output(self): self.check_output_with_place(core.CPUPlace()) @@ -141,6 +147,12 @@ class TestConv2DBF16Op(TestConv2DOp): def init_fuse_residual(self): self.fuse_residual = True + def init_infer_or_train(self): + self.weight_type = np.float32 + + def init_additional_attrs(self): + self.attrs['is_test'] = True + @OpTestTool.skip_if_not_cpu_bf16() class TestConv2DWithGradBF16Op(TestConv2DBF16Op): @@ -150,6 +162,12 @@ class TestConv2DWithGradBF16Op(TestConv2DBF16Op): def init_fuse_residual(self): self.fuse_residual = None + def init_additional_attrs(self): + self.attrs['is_test'] = False + + def init_infer_or_train(self): + self.weight_type = np.uint16 + def test_check_grad(self): dout = self.conv_output_float x = self.inputs_fp32['Input']