未验证 提交 272ed912 编写于 作者: C Charles-hit 提交者: GitHub

support some prim ops zero dim part3 (#54919)

上级 98debaa8
......@@ -937,7 +937,12 @@ void topk_grad(const Tensor& x,
const bool& sorted,
Tensor* x_grad) {
if (x_grad) {
auto zero_tensor = full<T>(phi::vectorize(x.dims()), 0.0, x.dtype());
// put_along_axis doesn't support zero dim
if (x.dims().size() == 0) {
by_pass<T>(out_grad, x_grad);
return;
}
auto zero_tensor = full<T>(phi::vectorize(x.dims()), 0, x.dtype());
auto x_grad_tmp = put_along_axis<T>(zero_tensor, indices, out_grad, axis);
set_output<T>(x_grad_tmp, x_grad);
}
......
......@@ -150,6 +150,16 @@ class TestSumOp1(OpTest):
self.out = self.x.cumsum(axis=2)
class TestSumOp1_ZeroDim(TestSumOp1):
def set_attrs_input_output(self):
self.attrs = {'axis': 0}
self.x = np.random.random(()).astype(self.dtype_)
self.out = self.x
def if_enable_cinn(self):
self.enable_cinn = False
class TestSumOp2(TestSumOp1):
def set_attrs_input_output(self):
self.attrs = {'axis': -1, 'reverse': True}
......
......@@ -1031,6 +1031,17 @@ class Test1DReduce(OpTest):
self.check_grad(['X'], 'Out', check_prim=True)
class TestReduceSum_ZeroDim(Test1DReduce):
def setUp(self):
self.op_type = "reduce_sum"
self.python_api = paddle.sum
self.public_python_api = paddle.sum
self.prim_op_type = "prim"
self.inputs = {'X': np.random.random(()).astype("float64")}
self.outputs = {'Out': self.inputs['X'].sum(axis=0)}
self.if_enable_cinn()
class Test2DReduce0(Test1DReduce):
def setUp(self):
self.op_type = "reduce_sum"
......
......@@ -412,6 +412,48 @@ class TestSliceOp_starts_OneTensor_ends_ListTensor(OpTest):
self.check_grad(['Input'], 'Out', max_relative_error=0.006)
class TestSliceOp_ZeroDim(OpTest):
def setUp(self):
self.op_type = "slice"
self.python_api = slice_wrapper
self.config()
starts_tensor = []
ends_tensor = []
for index, ele in enumerate(self.starts):
starts_tensor.append(
("x" + str(index), np.array(1).astype('int32'))
)
for index, ele in enumerate(self.ends):
ends_tensor.append(("y" + str(index), np.array(3).astype('int32')))
self.inputs = {
'Input': self.input,
"StartsTensorList": starts_tensor,
'EndsTensorList': ends_tensor,
}
self.outputs = {'Out': self.out}
self.attrs = {
'axes': self.axes,
'infer_flags': self.infer_flags,
}
def config(self):
self.input = np.random.random([20, 3, 3]).astype("float64")
self.starts = [1, 1]
self.ends = [3, 3]
self.axes = [1, 2]
self.infer_flags = [-1, -1]
self.out = self.input[0:20, 1:3, 1:3]
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['Input'], 'Out')
# Test CUDA float16
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
......
......@@ -73,6 +73,30 @@ class TestTopkOp(OpTest):
self.check_grad(['X'], 'Out', check_prim=True)
class TestTopkOp_ZeroDim(TestTopkOp):
def init_args(self):
self.k = 1
self.axis = 0
self.largest = True
def setUp(self):
self.op_type = "top_k_v2"
self.prim_op_type = "prim"
self.python_api = paddle.topk
self.public_python_api = paddle.topk
self.dtype = np.float64
self.input_data = np.random.random(())
self.init_args()
self.if_enable_cinn()
self.inputs = {'X': self.input_data}
self.attrs = {'k': self.k, 'largest': self.largest}
output, indices = self.input_data, np.array(0).astype('int64')
self.outputs = {'Out': output, 'Indices': indices}
def if_enable_cinn(self):
pass
class TestTopkOp1(TestTopkOp):
def init_args(self):
self.k = 3
......
......@@ -42,6 +42,7 @@ class TestTransposeOp(OpTest):
'XShape': np.random.random(self.shape).astype("float64"),
'Out': self.inputs['X'].transpose(self.axis),
}
self.if_enable_cinn()
def init_op_type(self):
self.op_type = "transpose2"
......@@ -53,11 +54,23 @@ class TestTransposeOp(OpTest):
def test_check_grad(self):
self.check_grad(['X'], 'Out', check_prim=True)
def if_enable_cinn(self):
pass
def initTestCase(self):
self.shape = (3, 40)
self.axis = (1, 0)
class TestTransposeOp_ZeroDim(TestTransposeOp):
def initTestCase(self):
self.shape = ()
self.axis = ()
def if_enable_cinn(self):
self.enable_cinn = False
class TestCase0(TestTransposeOp):
def initTestCase(self):
self.shape = (100,)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册