未验证 提交 1dfd857b 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Fix format in requantize mkldnn op (#34137)

上级 9bc59673
...@@ -84,12 +84,10 @@ class ReQuantOpKernel : public framework::OpKernel<T> { ...@@ -84,12 +84,10 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
auto src_dt = framework::ToMKLDNNDataType(input->type()); auto src_dt = framework::ToMKLDNNDataType(input->type());
auto dst_dt = with_shift ? framework::MKLDNNDataType::u8 : src_dt; auto dst_dt = with_shift ? framework::MKLDNNDataType::u8 : src_dt;
auto src_md = auto src_md = platform::MKLDNNMemDesc({src_tz}, src_dt, input->format());
platform::MKLDNNMemDesc({src_tz}, src_dt, MKLDNNMemoryFormat::nhwc);
src_memory = std::make_shared<dnnl::memory>(src_md, engine, src_memory = std::make_shared<dnnl::memory>(src_md, engine,
to_void_cast<T>(input_data)); to_void_cast<T>(input_data));
auto dst_md = auto dst_md = platform::MKLDNNMemDesc({dst_tz}, dst_dt, input->format());
platform::MKLDNNMemDesc({dst_tz}, dst_dt, MKLDNNMemoryFormat::nhwc);
dnnl::primitive_attr attri; dnnl::primitive_attr attri;
int mask = 0; int mask = 0;
......
...@@ -23,14 +23,18 @@ from mkldnn_op_test import format_reorder ...@@ -23,14 +23,18 @@ from mkldnn_op_test import format_reorder
class TestReQuantizeOp(OpTest): class TestReQuantizeOp(OpTest):
def set_input_size(self):
self.input_size = [1, 1, 10, 10]
self.format_reorder = format_reorder
def setUp(self): def setUp(self):
self.op_type = 'requantize' self.op_type = 'requantize'
self.scale_in = 127.0 self.scale_in = 127.0
self.shift_in = 0.0 self.shift_in = 0.0
self.scale_out = 100.0 self.scale_out = 100.0
self.shift_out = 0.0 self.shift_out = 0.0
self.input_size = [1, 1, 10, 10]
self.input_data_type = 'int8' self.input_data_type = 'int8'
self.set_input_size()
self.set_scales() self.set_scales()
self.set_shifts() self.set_shifts()
self.set_input_data_type() self.set_input_data_type()
...@@ -76,7 +80,7 @@ class TestReQuantizeOp(OpTest): ...@@ -76,7 +80,7 @@ class TestReQuantizeOp(OpTest):
np.rint(self.input.astype('float32') * scale_ratio + new_shift), np.rint(self.input.astype('float32') * scale_ratio + new_shift),
type_min, type_max).astype(dst_type) type_min, type_max).astype(dst_type)
self.output = format_reorder(output_tmp, self.input_size) self.output = self.format_reorder(output_tmp, self.input_size)
self.outputs = {'Output': self.output} self.outputs = {'Output': self.output}
def test_check_output(self): def test_check_output(self):
...@@ -266,6 +270,18 @@ class TestReQuantizeOp_U8_DifferentScales_2_DifferentShift_2( ...@@ -266,6 +270,18 @@ class TestReQuantizeOp_U8_DifferentScales_2_DifferentShift_2(
self.shift_out = 128.0 self.shift_out = 128.0
# ---------------test non-four dimentional formats--------------------------
class TestReQuantizeOp_2DimFormat(TestReQuantizeOp):
def format_reorder_2Dim(self, out, size):
return out
def set_input_size(self):
self.input_size = [10, 20]
self.format_reorder = self.format_reorder_2Dim
# ---------------test reused requantize op, no shift------------------------ # ---------------test reused requantize op, no shift------------------------
...@@ -274,6 +290,7 @@ class TestReQuantizeOpReused(TestReQuantizeOp): ...@@ -274,6 +290,7 @@ class TestReQuantizeOpReused(TestReQuantizeOp):
# self.input_size = [1, 1, 10, 10] # self.input_size = [1, 1, 10, 10]
self.input_size = [1, 1, 2, 2] self.input_size = [1, 1, 2, 2]
self.input_data_type = 'int8' self.input_data_type = 'int8'
self.format_reorder = format_reorder
self.set_scales() self.set_scales()
self.set_shifts() self.set_shifts()
self.set_input_data_type() self.set_input_data_type()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册