未验证 提交 5f678698 编写于 作者: W Weilong Wu 提交者: GitHub

support transpose test on prim and cinn (#51284)

* support transpose test on prim and cinn

* fix transpose logic and polish eager_op_test
上级 f8d0a358
......@@ -88,7 +88,11 @@ void transpose_grad(const Tensor& grad_out,
std::vector<int> reverse_perm(perm);
// make origin ranks
for (int i = 0; i < static_cast<int>(perm.size()); ++i) {
reverse_perm[perm[i]] = i;
if (perm[i] >= 0) {
reverse_perm[perm[i]] = i;
} else {
reverse_perm[perm[i] + perm.size()] = i;
}
}
auto grad_x_tmp = transpose<T>(grad_out, reverse_perm);
set_output<T>(grad_x_tmp, grad_x);
......
......@@ -1217,7 +1217,8 @@ set(TEST_CINN_OPS
test_elementwise_sub_op
test_elementwise_div_op
test_elementwise_mul_op
test_gather_nd_op)
test_gather_nd_op
test_transpose_op)
foreach(TEST_CINN_OPS ${TEST_CINN_OPS})
if(WITH_CINN)
......
......@@ -1392,6 +1392,7 @@ class OpTest(unittest.TestCase):
inplace_atol=None,
):
core._set_prim_all_enabled(False)
core.set_prim_eager_enabled(False)
def find_imperative_actual(target_name, dygraph_outs, place):
for name in dygraph_outs:
......@@ -1982,6 +1983,7 @@ class OpTest(unittest.TestCase):
numeric_place=None,
):
core._set_prim_all_enabled(False)
core.set_prim_eager_enabled(False)
if check_prim:
prim_grad_checker = PrimGradChecker(
self,
......
......@@ -16,10 +16,10 @@ import unittest
import numpy as np
from paddle.fluid.tests.unittests.test_transpose_op import TestTransposeOp
from paddle.fluid.tests.unittests.op_test import OpTest
class TestTransposeMKLDNN(TestTransposeOp):
class TestTransposeMKLDNN(OpTest):
def setUp(self):
self.init_op_type()
self.initTestCase()
......
......@@ -32,6 +32,7 @@ class TestTransposeOp(OpTest):
self.init_op_type()
self.initTestCase()
self.python_api = paddle.transpose
self.prim_op_type = "prim"
self.inputs = {'X': np.random.random(self.shape).astype("float64")}
self.attrs = {
'axis': list(self.axis),
......@@ -50,7 +51,7 @@ class TestTransposeOp(OpTest):
self.check_output(no_check_set=['XShape'])
def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_prim=True)
def initTestCase(self):
self.shape = (3, 40)
......@@ -118,12 +119,44 @@ class TestCase9(TestTransposeOp):
class TestCase10(TestTransposeOp):
def setUp(self):
self.init_op_type()
self.initTestCase()
self.python_api = paddle.transpose
self.prim_op_type = "prim"
self.enable_cinn = False
self.inputs = {'X': np.random.random(self.shape).astype("float64")}
self.attrs = {
'axis': list(self.axis),
'use_mkldnn': self.use_mkldnn,
}
self.outputs = {
'XShape': np.random.random(self.shape).astype("float64"),
'Out': self.inputs['X'].transpose(self.axis),
}
def initTestCase(self):
self.shape = (10, 8, 2)
self.axis = (-1, 1, -3)
class TestCase_ZeroDim(TestTransposeOp):
def setUp(self):
self.init_op_type()
self.initTestCase()
self.python_api = paddle.transpose
self.prim_op_type = "prim"
self.enable_cinn = False
self.inputs = {'X': np.random.random(self.shape).astype("float64")}
self.attrs = {
'axis': list(self.axis),
'use_mkldnn': self.use_mkldnn,
}
self.outputs = {
'XShape': np.random.random(self.shape).astype("float64"),
'Out': self.inputs['X'].transpose(self.axis),
}
def initTestCase(self):
self.shape = ()
self.axis = ()
......@@ -134,6 +167,7 @@ class TestAutoTuneTransposeOp(OpTest):
self.init_op_type()
self.initTestCase()
self.python_api = paddle.transpose
self.prim_op_type = "prim"
self.inputs = {'X': np.random.random(self.shape).astype("float64")}
self.attrs = {
'axis': list(self.axis),
......@@ -160,7 +194,7 @@ class TestAutoTuneTransposeOp(OpTest):
fluid.core.disable_autotune()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_prim=True)
class TestAutoTuneTransposeBF16Op(OpTest):
......@@ -169,6 +203,8 @@ class TestAutoTuneTransposeBF16Op(OpTest):
self.initTestCase()
self.dtype = np.uint16
self.python_api = paddle.transpose
self.prim_op_type = "prim"
self.enable_cinn = False
x = np.random.random(self.shape).astype("float32")
self.inputs = {'X': convert_float_to_uint16(x)}
self.attrs = {
......@@ -198,7 +234,7 @@ class TestAutoTuneTransposeBF16Op(OpTest):
fluid.core.disable_autotune()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_prim=True)
class TestTransposeBF16Op(OpTest):
......@@ -206,6 +242,8 @@ class TestTransposeBF16Op(OpTest):
self.init_op_type()
self.initTestCase()
self.dtype = np.uint16
self.prim_op_type = "prim"
self.enable_cinn = False
self.python_api = paddle.transpose
x = np.random.random(self.shape).astype("float32")
......
......@@ -99,6 +99,7 @@ def transpose(x, perm, name=None):
'float64',
'int32',
'int64',
'uint16',
'complex64',
'complex128',
],
......
......@@ -482,6 +482,7 @@ def transpose(x, perm, name=None):
'float64',
'int32',
'int64',
'uint16',
'complex64',
'complex128',
],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册