未验证 提交 1dfd857b 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Fix format in requantize mkldnn op (#34137)

上级 9bc59673
......@@ -84,12 +84,10 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
auto src_dt = framework::ToMKLDNNDataType(input->type());
auto dst_dt = with_shift ? framework::MKLDNNDataType::u8 : src_dt;
auto src_md =
platform::MKLDNNMemDesc({src_tz}, src_dt, MKLDNNMemoryFormat::nhwc);
auto src_md = platform::MKLDNNMemDesc({src_tz}, src_dt, input->format());
src_memory = std::make_shared<dnnl::memory>(src_md, engine,
to_void_cast<T>(input_data));
auto dst_md =
platform::MKLDNNMemDesc({dst_tz}, dst_dt, MKLDNNMemoryFormat::nhwc);
auto dst_md = platform::MKLDNNMemDesc({dst_tz}, dst_dt, input->format());
dnnl::primitive_attr attri;
int mask = 0;
......
......@@ -23,14 +23,18 @@ from mkldnn_op_test import format_reorder
class TestReQuantizeOp(OpTest):
def set_input_size(self):
self.input_size = [1, 1, 10, 10]
self.format_reorder = format_reorder
def setUp(self):
self.op_type = 'requantize'
self.scale_in = 127.0
self.shift_in = 0.0
self.scale_out = 100.0
self.shift_out = 0.0
self.input_size = [1, 1, 10, 10]
self.input_data_type = 'int8'
self.set_input_size()
self.set_scales()
self.set_shifts()
self.set_input_data_type()
......@@ -76,7 +80,7 @@ class TestReQuantizeOp(OpTest):
np.rint(self.input.astype('float32') * scale_ratio + new_shift),
type_min, type_max).astype(dst_type)
self.output = format_reorder(output_tmp, self.input_size)
self.output = self.format_reorder(output_tmp, self.input_size)
self.outputs = {'Output': self.output}
def test_check_output(self):
......@@ -266,6 +270,18 @@ class TestReQuantizeOp_U8_DifferentScales_2_DifferentShift_2(
self.shift_out = 128.0
# ---------------test non-four dimentional formats--------------------------
class TestReQuantizeOp_2DimFormat(TestReQuantizeOp):
def format_reorder_2Dim(self, out, size):
return out
def set_input_size(self):
self.input_size = [10, 20]
self.format_reorder = self.format_reorder_2Dim
# ---------------test reused requantize op, no shift------------------------
......@@ -274,6 +290,7 @@ class TestReQuantizeOpReused(TestReQuantizeOp):
# self.input_size = [1, 1, 10, 10]
self.input_size = [1, 1, 2, 2]
self.input_data_type = 'int8'
self.format_reorder = format_reorder
self.set_scales()
self.set_shifts()
self.set_input_data_type()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册