未验证 提交 5f678698 编写于 作者: W Weilong Wu 提交者: GitHub

support transpose test on prim and cinn (#51284)

* support transpose test on prim and cinn

* fix transpose logic and polish eager_op_test
上级 f8d0a358
...@@ -88,7 +88,11 @@ void transpose_grad(const Tensor& grad_out, ...@@ -88,7 +88,11 @@ void transpose_grad(const Tensor& grad_out,
std::vector<int> reverse_perm(perm); std::vector<int> reverse_perm(perm);
// make origin ranks // make origin ranks
for (int i = 0; i < static_cast<int>(perm.size()); ++i) { for (int i = 0; i < static_cast<int>(perm.size()); ++i) {
reverse_perm[perm[i]] = i; if (perm[i] >= 0) {
reverse_perm[perm[i]] = i;
} else {
reverse_perm[perm[i] + perm.size()] = i;
}
} }
auto grad_x_tmp = transpose<T>(grad_out, reverse_perm); auto grad_x_tmp = transpose<T>(grad_out, reverse_perm);
set_output<T>(grad_x_tmp, grad_x); set_output<T>(grad_x_tmp, grad_x);
......
...@@ -1217,7 +1217,8 @@ set(TEST_CINN_OPS ...@@ -1217,7 +1217,8 @@ set(TEST_CINN_OPS
test_elementwise_sub_op test_elementwise_sub_op
test_elementwise_div_op test_elementwise_div_op
test_elementwise_mul_op test_elementwise_mul_op
test_gather_nd_op) test_gather_nd_op
test_transpose_op)
foreach(TEST_CINN_OPS ${TEST_CINN_OPS}) foreach(TEST_CINN_OPS ${TEST_CINN_OPS})
if(WITH_CINN) if(WITH_CINN)
......
...@@ -1392,6 +1392,7 @@ class OpTest(unittest.TestCase): ...@@ -1392,6 +1392,7 @@ class OpTest(unittest.TestCase):
inplace_atol=None, inplace_atol=None,
): ):
core._set_prim_all_enabled(False) core._set_prim_all_enabled(False)
core.set_prim_eager_enabled(False)
def find_imperative_actual(target_name, dygraph_outs, place): def find_imperative_actual(target_name, dygraph_outs, place):
for name in dygraph_outs: for name in dygraph_outs:
...@@ -1982,6 +1983,7 @@ class OpTest(unittest.TestCase): ...@@ -1982,6 +1983,7 @@ class OpTest(unittest.TestCase):
numeric_place=None, numeric_place=None,
): ):
core._set_prim_all_enabled(False) core._set_prim_all_enabled(False)
core.set_prim_eager_enabled(False)
if check_prim: if check_prim:
prim_grad_checker = PrimGradChecker( prim_grad_checker = PrimGradChecker(
self, self,
......
...@@ -16,10 +16,10 @@ import unittest ...@@ -16,10 +16,10 @@ import unittest
import numpy as np import numpy as np
from paddle.fluid.tests.unittests.test_transpose_op import TestTransposeOp from paddle.fluid.tests.unittests.op_test import OpTest
class TestTransposeMKLDNN(TestTransposeOp): class TestTransposeMKLDNN(OpTest):
def setUp(self): def setUp(self):
self.init_op_type() self.init_op_type()
self.initTestCase() self.initTestCase()
......
...@@ -32,6 +32,7 @@ class TestTransposeOp(OpTest): ...@@ -32,6 +32,7 @@ class TestTransposeOp(OpTest):
self.init_op_type() self.init_op_type()
self.initTestCase() self.initTestCase()
self.python_api = paddle.transpose self.python_api = paddle.transpose
self.prim_op_type = "prim"
self.inputs = {'X': np.random.random(self.shape).astype("float64")} self.inputs = {'X': np.random.random(self.shape).astype("float64")}
self.attrs = { self.attrs = {
'axis': list(self.axis), 'axis': list(self.axis),
...@@ -50,7 +51,7 @@ class TestTransposeOp(OpTest): ...@@ -50,7 +51,7 @@ class TestTransposeOp(OpTest):
self.check_output(no_check_set=['XShape']) self.check_output(no_check_set=['XShape'])
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out', check_prim=True)
def initTestCase(self): def initTestCase(self):
self.shape = (3, 40) self.shape = (3, 40)
...@@ -118,12 +119,44 @@ class TestCase9(TestTransposeOp): ...@@ -118,12 +119,44 @@ class TestCase9(TestTransposeOp):
class TestCase10(TestTransposeOp): class TestCase10(TestTransposeOp):
def setUp(self):
self.init_op_type()
self.initTestCase()
self.python_api = paddle.transpose
self.prim_op_type = "prim"
self.enable_cinn = False
self.inputs = {'X': np.random.random(self.shape).astype("float64")}
self.attrs = {
'axis': list(self.axis),
'use_mkldnn': self.use_mkldnn,
}
self.outputs = {
'XShape': np.random.random(self.shape).astype("float64"),
'Out': self.inputs['X'].transpose(self.axis),
}
def initTestCase(self): def initTestCase(self):
self.shape = (10, 8, 2) self.shape = (10, 8, 2)
self.axis = (-1, 1, -3) self.axis = (-1, 1, -3)
class TestCase_ZeroDim(TestTransposeOp): class TestCase_ZeroDim(TestTransposeOp):
def setUp(self):
self.init_op_type()
self.initTestCase()
self.python_api = paddle.transpose
self.prim_op_type = "prim"
self.enable_cinn = False
self.inputs = {'X': np.random.random(self.shape).astype("float64")}
self.attrs = {
'axis': list(self.axis),
'use_mkldnn': self.use_mkldnn,
}
self.outputs = {
'XShape': np.random.random(self.shape).astype("float64"),
'Out': self.inputs['X'].transpose(self.axis),
}
def initTestCase(self): def initTestCase(self):
self.shape = () self.shape = ()
self.axis = () self.axis = ()
...@@ -134,6 +167,7 @@ class TestAutoTuneTransposeOp(OpTest): ...@@ -134,6 +167,7 @@ class TestAutoTuneTransposeOp(OpTest):
self.init_op_type() self.init_op_type()
self.initTestCase() self.initTestCase()
self.python_api = paddle.transpose self.python_api = paddle.transpose
self.prim_op_type = "prim"
self.inputs = {'X': np.random.random(self.shape).astype("float64")} self.inputs = {'X': np.random.random(self.shape).astype("float64")}
self.attrs = { self.attrs = {
'axis': list(self.axis), 'axis': list(self.axis),
...@@ -160,7 +194,7 @@ class TestAutoTuneTransposeOp(OpTest): ...@@ -160,7 +194,7 @@ class TestAutoTuneTransposeOp(OpTest):
fluid.core.disable_autotune() fluid.core.disable_autotune()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out', check_prim=True)
class TestAutoTuneTransposeBF16Op(OpTest): class TestAutoTuneTransposeBF16Op(OpTest):
...@@ -169,6 +203,8 @@ class TestAutoTuneTransposeBF16Op(OpTest): ...@@ -169,6 +203,8 @@ class TestAutoTuneTransposeBF16Op(OpTest):
self.initTestCase() self.initTestCase()
self.dtype = np.uint16 self.dtype = np.uint16
self.python_api = paddle.transpose self.python_api = paddle.transpose
self.prim_op_type = "prim"
self.enable_cinn = False
x = np.random.random(self.shape).astype("float32") x = np.random.random(self.shape).astype("float32")
self.inputs = {'X': convert_float_to_uint16(x)} self.inputs = {'X': convert_float_to_uint16(x)}
self.attrs = { self.attrs = {
...@@ -198,7 +234,7 @@ class TestAutoTuneTransposeBF16Op(OpTest): ...@@ -198,7 +234,7 @@ class TestAutoTuneTransposeBF16Op(OpTest):
fluid.core.disable_autotune() fluid.core.disable_autotune()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out', check_prim=True)
class TestTransposeBF16Op(OpTest): class TestTransposeBF16Op(OpTest):
...@@ -206,6 +242,8 @@ class TestTransposeBF16Op(OpTest): ...@@ -206,6 +242,8 @@ class TestTransposeBF16Op(OpTest):
self.init_op_type() self.init_op_type()
self.initTestCase() self.initTestCase()
self.dtype = np.uint16 self.dtype = np.uint16
self.prim_op_type = "prim"
self.enable_cinn = False
self.python_api = paddle.transpose self.python_api = paddle.transpose
x = np.random.random(self.shape).astype("float32") x = np.random.random(self.shape).astype("float32")
......
...@@ -99,6 +99,7 @@ def transpose(x, perm, name=None): ...@@ -99,6 +99,7 @@ def transpose(x, perm, name=None):
'float64', 'float64',
'int32', 'int32',
'int64', 'int64',
'uint16',
'complex64', 'complex64',
'complex128', 'complex128',
], ],
......
...@@ -482,6 +482,7 @@ def transpose(x, perm, name=None): ...@@ -482,6 +482,7 @@ def transpose(x, perm, name=None):
'float64', 'float64',
'int32', 'int32',
'int64', 'int64',
'uint16',
'complex64', 'complex64',
'complex128', 'complex128',
], ],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册