提交 2ce6473f 编写于 作者: W wozna

Fix test_transpose_int8_mkldnn

test=develop
上级 61403f87
......@@ -35,8 +35,8 @@ MkldnnQuantizerConfig::MkldnnQuantizerConfig() {
rules_["prior_box"]["Boxes"] = ScaleAlgo::NONE;
rules_["prior_box"]["Variances"] = ScaleAlgo::NONE;
rules_["transpose"]["X"] = ScaleAlgo::KL;
rules_["transpose"]["Out"] = ScaleAlgo::KL;
rules_["transpose2"]["X"] = ScaleAlgo::KL;
rules_["transpose2"]["Out"] = ScaleAlgo::KL;
}
ScaleAlgo MkldnnQuantizerConfig::scale_algo(
......
......@@ -86,8 +86,9 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
std::shared_ptr<primitive::at> src_memory_p =
std::shared_ptr<primitive::at>(new primitive::at(*src_memory));
auto dst_md = platform::MKLDNNMemDesc({dst_tz}, memory::data_type::f32,
memory::format::nchw);
auto dst_md = platform::MKLDNNMemDesc(
{dst_tz}, memory::data_type::f32,
platform::MKLDNNFormatForSize(dst_tz.size(), memory::format::nchw));
auto dst_pd = mkldnn::memory::primitive_desc(dst_md, engine);
dst_memory = std::make_shared<mkldnn::memory>(
dst_pd, to_void_cast<float>(output_data));
......
......@@ -16,6 +16,7 @@ from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid.core as core
from paddle.fluid.tests.unittests.op_test import OpTest
from mkldnn_op_test import format_reorder
......@@ -26,10 +27,11 @@ class TestTransposeOp(OpTest):
self.initTestCase()
self.initInputData()
self.use_mkldnn = True
self._cpu_only = True
self.axis = (0, 2, 3, 1)
self.inputs = {
'X': format_reorder(self.input_data, self.shape)
'X': format_reorder(self.input_data, self.shape).astype(np.int8)
} #transform data format to 'NHWC' for INT8 transpose specially.
self.attrs = {
......@@ -38,7 +40,7 @@ class TestTransposeOp(OpTest):
}
self.outputs = {
'XShape': np.random.random(self.shape).astype('int8'),
'XShape': np.random.random(self.shape).astype(np.int8),
'Out': self.inputs['X'].transpose(self.axis)
}
......@@ -46,14 +48,15 @@ class TestTransposeOp(OpTest):
self.op_type = "transpose2"
def test_check_output(self):
self.check_output(no_check_set=['XShape'])
self.check_output_with_place(
core.CPUPlace(), 1e-5, no_check_set=['XShape'])
def initTestCase(self):
self.shape = (2, 3, 4, 5)
def initInputData(self):
self.input_data = (
np.random.randint(0, 100, self.shape) - 50).astype('int8')
np.random.randint(0, 100, self.shape) - 50).astype(np.int8)
class TestINT8Case(TestTransposeOp):
......@@ -62,7 +65,7 @@ class TestINT8Case(TestTransposeOp):
def initInputData(self):
self.input_data = (
np.random.randint(0, 100, self.shape) - 50).astype('int8')
np.random.randint(0, 100, self.shape) - 50).astype(np.int8)
class TestUINT8Case(TestTransposeOp):
......@@ -71,7 +74,7 @@ class TestUINT8Case(TestTransposeOp):
def initDataType(self):
self.input_data = (np.random.randint(0, 100,
self.shape)).astype('uint8')
self.shape)).astype(np.uint8)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册