提交 2ce6473f 编写于 作者: W wozna

Fix test_transpose_int8_mkldnn

test=develop
上级 61403f87
...@@ -35,8 +35,8 @@ MkldnnQuantizerConfig::MkldnnQuantizerConfig() { ...@@ -35,8 +35,8 @@ MkldnnQuantizerConfig::MkldnnQuantizerConfig() {
rules_["prior_box"]["Boxes"] = ScaleAlgo::NONE; rules_["prior_box"]["Boxes"] = ScaleAlgo::NONE;
rules_["prior_box"]["Variances"] = ScaleAlgo::NONE; rules_["prior_box"]["Variances"] = ScaleAlgo::NONE;
rules_["transpose"]["X"] = ScaleAlgo::KL; rules_["transpose2"]["X"] = ScaleAlgo::KL;
rules_["transpose"]["Out"] = ScaleAlgo::KL; rules_["transpose2"]["Out"] = ScaleAlgo::KL;
} }
ScaleAlgo MkldnnQuantizerConfig::scale_algo( ScaleAlgo MkldnnQuantizerConfig::scale_algo(
......
...@@ -86,8 +86,9 @@ class DeQuantOpKernel : public framework::OpKernel<T> { ...@@ -86,8 +86,9 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
std::shared_ptr<primitive::at> src_memory_p = std::shared_ptr<primitive::at> src_memory_p =
std::shared_ptr<primitive::at>(new primitive::at(*src_memory)); std::shared_ptr<primitive::at>(new primitive::at(*src_memory));
auto dst_md = platform::MKLDNNMemDesc({dst_tz}, memory::data_type::f32, auto dst_md = platform::MKLDNNMemDesc(
memory::format::nchw); {dst_tz}, memory::data_type::f32,
platform::MKLDNNFormatForSize(dst_tz.size(), memory::format::nchw));
auto dst_pd = mkldnn::memory::primitive_desc(dst_md, engine); auto dst_pd = mkldnn::memory::primitive_desc(dst_md, engine);
dst_memory = std::make_shared<mkldnn::memory>( dst_memory = std::make_shared<mkldnn::memory>(
dst_pd, to_void_cast<float>(output_data)); dst_pd, to_void_cast<float>(output_data));
......
...@@ -16,6 +16,7 @@ from __future__ import print_function ...@@ -16,6 +16,7 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import paddle.fluid.core as core
from paddle.fluid.tests.unittests.op_test import OpTest from paddle.fluid.tests.unittests.op_test import OpTest
from mkldnn_op_test import format_reorder from mkldnn_op_test import format_reorder
...@@ -26,10 +27,11 @@ class TestTransposeOp(OpTest): ...@@ -26,10 +27,11 @@ class TestTransposeOp(OpTest):
self.initTestCase() self.initTestCase()
self.initInputData() self.initInputData()
self.use_mkldnn = True self.use_mkldnn = True
self._cpu_only = True
self.axis = (0, 2, 3, 1) self.axis = (0, 2, 3, 1)
self.inputs = { self.inputs = {
'X': format_reorder(self.input_data, self.shape) 'X': format_reorder(self.input_data, self.shape).astype(np.int8)
} #transform data format to 'NHWC' for INT8 transpose specially. } #transform data format to 'NHWC' for INT8 transpose specially.
self.attrs = { self.attrs = {
...@@ -38,7 +40,7 @@ class TestTransposeOp(OpTest): ...@@ -38,7 +40,7 @@ class TestTransposeOp(OpTest):
} }
self.outputs = { self.outputs = {
'XShape': np.random.random(self.shape).astype('int8'), 'XShape': np.random.random(self.shape).astype(np.int8),
'Out': self.inputs['X'].transpose(self.axis) 'Out': self.inputs['X'].transpose(self.axis)
} }
...@@ -46,14 +48,15 @@ class TestTransposeOp(OpTest): ...@@ -46,14 +48,15 @@ class TestTransposeOp(OpTest):
self.op_type = "transpose2" self.op_type = "transpose2"
def test_check_output(self): def test_check_output(self):
self.check_output(no_check_set=['XShape']) self.check_output_with_place(
core.CPUPlace(), 1e-5, no_check_set=['XShape'])
def initTestCase(self): def initTestCase(self):
self.shape = (2, 3, 4, 5) self.shape = (2, 3, 4, 5)
def initInputData(self): def initInputData(self):
self.input_data = ( self.input_data = (
np.random.randint(0, 100, self.shape) - 50).astype('int8') np.random.randint(0, 100, self.shape) - 50).astype(np.int8)
class TestINT8Case(TestTransposeOp): class TestINT8Case(TestTransposeOp):
...@@ -62,7 +65,7 @@ class TestINT8Case(TestTransposeOp): ...@@ -62,7 +65,7 @@ class TestINT8Case(TestTransposeOp):
def initInputData(self): def initInputData(self):
self.input_data = ( self.input_data = (
np.random.randint(0, 100, self.shape) - 50).astype('int8') np.random.randint(0, 100, self.shape) - 50).astype(np.int8)
class TestUINT8Case(TestTransposeOp): class TestUINT8Case(TestTransposeOp):
...@@ -71,7 +74,7 @@ class TestUINT8Case(TestTransposeOp): ...@@ -71,7 +74,7 @@ class TestUINT8Case(TestTransposeOp):
def initDataType(self): def initDataType(self):
self.input_data = (np.random.randint(0, 100, self.input_data = (np.random.randint(0, 100,
self.shape)).astype('uint8') self.shape)).astype(np.uint8)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册