diff --git a/paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc index 4c136a2fc2ce8c811355b322da6ff5539ff062c9..6296654b8bdd9d071fd9484e359a3a5943b1a655 100644 --- a/paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc @@ -84,12 +84,10 @@ class ReQuantOpKernel : public framework::OpKernel { auto src_dt = framework::ToMKLDNNDataType(input->type()); auto dst_dt = with_shift ? framework::MKLDNNDataType::u8 : src_dt; - auto src_md = - platform::MKLDNNMemDesc({src_tz}, src_dt, MKLDNNMemoryFormat::nhwc); + auto src_md = platform::MKLDNNMemDesc({src_tz}, src_dt, input->format()); src_memory = std::make_shared(src_md, engine, to_void_cast(input_data)); - auto dst_md = - platform::MKLDNNMemDesc({dst_tz}, dst_dt, MKLDNNMemoryFormat::nhwc); + auto dst_md = platform::MKLDNNMemDesc({dst_tz}, dst_dt, input->format()); dnnl::primitive_attr attri; int mask = 0; diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_requantize_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_requantize_mkldnn_op.py index ba2fdbab30cdc4981ae25bcdd9ebaa068ba3a616..88aebac42e84b749a360cf84d2300ae4ae50084d 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_requantize_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_requantize_mkldnn_op.py @@ -23,14 +23,18 @@ from mkldnn_op_test import format_reorder class TestReQuantizeOp(OpTest): + def set_input_size(self): + self.input_size = [1, 1, 10, 10] + self.format_reorder = format_reorder + def setUp(self): self.op_type = 'requantize' self.scale_in = 127.0 self.shift_in = 0.0 self.scale_out = 100.0 self.shift_out = 0.0 - self.input_size = [1, 1, 10, 10] self.input_data_type = 'int8' + self.set_input_size() self.set_scales() self.set_shifts() self.set_input_data_type() @@ -76,7 +80,7 @@ class TestReQuantizeOp(OpTest): np.rint(self.input.astype('float32') * scale_ratio + new_shift), type_min, type_max).astype(dst_type) - self.output = format_reorder(output_tmp, self.input_size) + self.output = self.format_reorder(output_tmp, self.input_size) self.outputs = {'Output': self.output} def test_check_output(self): @@ -266,6 +270,18 @@ class TestReQuantizeOp_U8_DifferentScales_2_DifferentShift_2( self.shift_out = 128.0 +# ---------------test non-four dimentional formats-------------------------- + + +class TestReQuantizeOp_2DimFormat(TestReQuantizeOp): + def format_reorder_2Dim(self, out, size): + return out + + def set_input_size(self): + self.input_size = [10, 20] + self.format_reorder = self.format_reorder_2Dim + + # ---------------test reused requantize op, no shift------------------------ @@ -274,6 +290,7 @@ class TestReQuantizeOpReused(TestReQuantizeOp): # self.input_size = [1, 1, 10, 10] self.input_size = [1, 1, 2, 2] self.input_data_type = 'int8' + self.format_reorder = format_reorder self.set_scales() self.set_shifts() self.set_input_data_type()