未验证 提交 944ea436 编写于 作者: J jakpiase 提交者: GitHub

fix for conv2D training error (#38938)

上级 05c98ec7
...@@ -613,7 +613,7 @@ class ConvMKLDNNHandlerT ...@@ -613,7 +613,7 @@ class ConvMKLDNNHandlerT
auto weights_mem_p = this->AcquireMemory("@weights_mem_p_target"); auto weights_mem_p = this->AcquireMemory("@weights_mem_p_target");
if (is_test && weights_mem_p) { if (is_test && weights_mem_p) {
return weights_mem_p; return weights_mem_p;
} else { } else if (is_test) {
const K* filter_data = filter->data<K>(); const K* filter_data = filter->data<K>();
auto weights_tz = framework::vectorize(filter->dims()); auto weights_tz = framework::vectorize(filter->dims());
platform::GetGroupConvWeightsTz(weights_tz, groups); platform::GetGroupConvWeightsTz(weights_tz, groups);
...@@ -626,6 +626,19 @@ class ConvMKLDNNHandlerT ...@@ -626,6 +626,19 @@ class ConvMKLDNNHandlerT
user_src_md, this->fwd_pd_->weights_desc(), user_src_md, this->fwd_pd_->weights_desc(),
platform::to_void_cast<K>(filter_data), "@weights_mem_p", is_test, {}, platform::to_void_cast<K>(filter_data), "@weights_mem_p", is_test, {},
scale_data, mask); scale_data, mask);
} else {
const T* filter_data = filter->data<T>();
auto weights_tz = framework::vectorize(filter->dims());
platform::GetGroupConvWeightsTz(weights_tz, groups);
auto user_src_md = platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<T>(),
GetWeightsFormat(filter->format(), groups, is_conv3d));
return this->AcquireMemoryWithReorder(
user_src_md, this->fwd_pd_->weights_desc(),
platform::to_void_cast<T>(filter_data), "@weights_mem_p", is_test, {},
scale_data, mask);
} }
} }
...@@ -1027,7 +1040,8 @@ REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d_grad, MKLDNN, ...@@ -1027,7 +1040,8 @@ REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d_grad, MKLDNN,
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE( REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
conv2d_grad, MKLDNN, ::paddle::platform::CPUPlace, BF16, conv2d_grad, MKLDNN, ::paddle::platform::CPUPlace, BF16,
ops::kConvMKLDNNFP32, ops::kConvMKLDNNFP32,
ops::ConvMKLDNNGradOpKernel<paddle::platform::bfloat16, float>); ops::ConvMKLDNNGradOpKernel<paddle::platform::bfloat16,
paddle::platform::bfloat16>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(depthwise_conv2d, MKLDNN, REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(depthwise_conv2d, MKLDNN,
::paddle::platform::CPUPlace, FP32, ::paddle::platform::CPUPlace, FP32,
......
...@@ -377,6 +377,14 @@ class MKLDNNHandlerT { ...@@ -377,6 +377,14 @@ class MKLDNNHandlerT {
if (bwd_pd_ == nullptr) { if (bwd_pd_ == nullptr) {
return false; return false;
} else { } else {
if (std::is_same<TBackward_params, mkldnn_dummy_primitive>::value ==
false) {
const std::string key_bw_w_pd = key_ + "@bwd_w_pd";
bwd_w_pd_ =
std::static_pointer_cast<typename TBackward_params::primitive_desc>(
dev_ctx_.GetBlob(key_bw_w_pd));
}
// When BWD is cached then still we need to Get FWD PD // When BWD is cached then still we need to Get FWD PD
const std::string key_fpd = key_ + "@fwd_pd"; const std::string key_fpd = key_ + "@fwd_pd";
fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>( fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>(
......
...@@ -50,6 +50,7 @@ class TestConv2DBF16Op(TestConv2DOp): ...@@ -50,6 +50,7 @@ class TestConv2DBF16Op(TestConv2DOp):
self.init_fuse_residual() self.init_fuse_residual()
self.init_data_type() self.init_data_type()
self.init_force_fp32_output() self.init_force_fp32_output()
self.init_infer_or_train()
self.conv2d_param = { self.conv2d_param = {
'stride': self.stride, 'stride': self.stride,
...@@ -83,6 +84,9 @@ class TestConv2DBF16Op(TestConv2DOp): ...@@ -83,6 +84,9 @@ class TestConv2DBF16Op(TestConv2DOp):
if self.input_type is not np.float32: if self.input_type is not np.float32:
self.input = convert_float_to_uint16(self.input) self.input = convert_float_to_uint16(self.input)
if self.weight_type is not np.float32:
self.filter = convert_float_to_uint16(self.filter)
self.inputs = { self.inputs = {
'Input': self.input, 'Input': self.input,
'Filter': OpTest.np_dtype_to_fluid_dtype( 'Filter': OpTest.np_dtype_to_fluid_dtype(
...@@ -105,6 +109,8 @@ class TestConv2DBF16Op(TestConv2DOp): ...@@ -105,6 +109,8 @@ class TestConv2DBF16Op(TestConv2DOp):
'fuse_residual_connection': self.fuse_residual 'fuse_residual_connection': self.fuse_residual
} }
self.init_additional_attrs()
def test_check_output(self): def test_check_output(self):
self.check_output_with_place(core.CPUPlace()) self.check_output_with_place(core.CPUPlace())
...@@ -141,6 +147,12 @@ class TestConv2DBF16Op(TestConv2DOp): ...@@ -141,6 +147,12 @@ class TestConv2DBF16Op(TestConv2DOp):
def init_fuse_residual(self): def init_fuse_residual(self):
self.fuse_residual = True self.fuse_residual = True
def init_infer_or_train(self):
self.weight_type = np.float32
def init_additional_attrs(self):
self.attrs['is_test'] = True
@OpTestTool.skip_if_not_cpu_bf16() @OpTestTool.skip_if_not_cpu_bf16()
class TestConv2DWithGradBF16Op(TestConv2DBF16Op): class TestConv2DWithGradBF16Op(TestConv2DBF16Op):
...@@ -150,6 +162,12 @@ class TestConv2DWithGradBF16Op(TestConv2DBF16Op): ...@@ -150,6 +162,12 @@ class TestConv2DWithGradBF16Op(TestConv2DBF16Op):
def init_fuse_residual(self): def init_fuse_residual(self):
self.fuse_residual = None self.fuse_residual = None
def init_additional_attrs(self):
self.attrs['is_test'] = False
def init_infer_or_train(self):
self.weight_type = np.uint16
def test_check_grad(self): def test_check_grad(self):
dout = self.conv_output_float dout = self.conv_output_float
x = self.inputs_fp32['Input'] x = self.inputs_fp32['Input']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册