未验证 提交 24b2cc8d 编写于 作者: C Charles-hit 提交者: GitHub

support cast, dropout, gather, mean prim ops zero dim (#54966)

上级 12a296cb
...@@ -150,11 +150,20 @@ void gather_grad(const Tensor& x, ...@@ -150,11 +150,20 @@ void gather_grad(const Tensor& x,
} }
// transpose out_grad and zero grad to target rank. // transpose out_grad and zero grad to target rank.
auto tmp_zero_x_grad = transpose<T>(zero_tensor, tmp_perm); auto tmp_zero_x_grad = zero_tensor;
auto tmp_out_grad = transpose<T>(out_grad, tmp_perm); auto tmp_out_grad = out_grad;
if (zero_tensor.dims().size() > 0) {
tmp_zero_x_grad = transpose<T>(zero_tensor, tmp_perm);
}
if (out_grad.dims().size() > 0) {
tmp_out_grad = transpose<T>(out_grad, tmp_perm);
}
// scatter grad to grad_x // scatter grad to grad_x
auto tmp_grad_x = scatter<T>(tmp_zero_x_grad, index, tmp_out_grad, false); auto tmp_grad_x = scatter<T>(tmp_zero_x_grad, index, tmp_out_grad, false);
auto tmp_grad_x_tranposed = transpose<T>(tmp_grad_x, reverse_perm); auto tmp_grad_x_tranposed = tmp_grad_x;
if (tmp_grad_x.dims().size() > 0) {
tmp_grad_x_tranposed = transpose<T>(tmp_grad_x, reverse_perm);
}
set_output<T>(tmp_grad_x_tranposed, grad_x); set_output<T>(tmp_grad_x_tranposed, grad_x);
} }
......
...@@ -255,9 +255,11 @@ def mean_composite(x, axis, keepdim): ...@@ -255,9 +255,11 @@ def mean_composite(x, axis, keepdim):
axes = axis or list(range(0, len(x.shape))) axes = axis or list(range(0, len(x.shape)))
axes = [axes] if isinstance(axes, int) else axes axes = [axes] if isinstance(axes, int) else axes
sum_x = sum(x, axis=axes, keepdim=keepdim) sum_x = sum(x, axis=axes, keepdim=keepdim)
value_to_fill = functools.reduce( ele_nums_list = [x.shape[axis] for axis in axes]
operator.mul, [x.shape[axis] for axis in axes] if ele_nums_list == []:
) value_to_fill = 1
else:
value_to_fill = functools.reduce(operator.mul, ele_nums_list)
norm = fill_constant( norm = fill_constant(
shape=[], shape=[],
value=value_to_fill, value=value_to_fill,
......
...@@ -34,7 +34,8 @@ def cast_wrapper(x, out_dtype=None): ...@@ -34,7 +34,8 @@ def cast_wrapper(x, out_dtype=None):
class TestCastOpFp32ToFp64(OpTest): class TestCastOpFp32ToFp64(OpTest):
def setUp(self): def setUp(self):
ipt = np.random.random(size=[10, 10]) self.init_shapes()
ipt = np.random.random(size=self.input_shape)
self.inputs = {'X': ipt.astype('float32')} self.inputs = {'X': ipt.astype('float32')}
self.outputs = {'Out': ipt.astype('float64')} self.outputs = {'Out': ipt.astype('float64')}
self.attrs = { self.attrs = {
...@@ -46,6 +47,9 @@ class TestCastOpFp32ToFp64(OpTest): ...@@ -46,6 +47,9 @@ class TestCastOpFp32ToFp64(OpTest):
self.python_api = cast_wrapper self.python_api = cast_wrapper
self.public_python_api = cast_wrapper self.public_python_api = cast_wrapper
def init_shapes(self):
self.input_shape = [10, 10]
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -53,6 +57,11 @@ class TestCastOpFp32ToFp64(OpTest): ...@@ -53,6 +57,11 @@ class TestCastOpFp32ToFp64(OpTest):
self.check_grad(['X'], ['Out'], check_prim=True) self.check_grad(['X'], ['Out'], check_prim=True)
class TestCastOpFp32ToFp64_ZeroDim(TestCastOpFp32ToFp64):
def init_shapes(self):
self.input_shape = ()
class TestCastOpFp16ToFp32(OpTest): class TestCastOpFp16ToFp32(OpTest):
def setUp(self): def setUp(self):
ipt = np.random.random(size=[10, 10]) ipt = np.random.random(size=[10, 10])
......
...@@ -87,6 +87,24 @@ class TestDropoutOp(OpTest): ...@@ -87,6 +87,24 @@ class TestDropoutOp(OpTest):
self.check_grad(['X'], 'Out', check_prim=False) self.check_grad(['X'], 'Out', check_prim=False)
class TestDropoutOp_ZeroDim(TestDropoutOp):
def setUp(self):
self.op_type = "dropout"
self.prim_op_type = "comp"
self.python_api = dropout_wapper
self.public_python_api = prim_dropout_wrapper
self.inputs = {'X': np.random.random(()).astype("float32")}
self.attrs = {'dropout_prob': 0.0, 'fix_seed': True, 'is_test': False}
self.outputs = {
'Out': self.inputs['X'],
'Mask': np.ones(()).astype('uint8'),
}
# Because prim op compare res with dygraph
# when p = 0 dropout api return x,in dygraph mode x_grad = out_grad,
# but in static mode x_grad = []
self.enable_check_static_comp = False
class TestDropoutOpInput1d(OpTest): class TestDropoutOpInput1d(OpTest):
def setUp(self): def setUp(self):
self.op_type = "dropout" self.op_type = "dropout"
...@@ -126,6 +144,20 @@ class TestDropoutOp2(TestDropoutOp): ...@@ -126,6 +144,20 @@ class TestDropoutOp2(TestDropoutOp):
} }
class TestDropoutOp2_ZeroDim(TestDropoutOp2):
def setUp(self):
self.op_type = "dropout"
self.python_api = dropout_wapper
self.public_python_api = prim_dropout_wrapper
self.prim_op_type = "comp"
self.inputs = {'X': np.random.random(()).astype("float32")}
self.attrs = {'dropout_prob': 1.0, 'fix_seed': True, 'is_test': False}
self.outputs = {
'Out': np.zeros(()).astype('float32'),
'Mask': np.zeros(()).astype('uint8'),
}
class TestDropoutOp3(TestDropoutOp): class TestDropoutOp3(TestDropoutOp):
def setUp(self): def setUp(self):
self.op_type = "dropout" self.op_type = "dropout"
......
...@@ -70,6 +70,20 @@ class TestGatherOp(OpTest): ...@@ -70,6 +70,20 @@ class TestGatherOp(OpTest):
pass pass
class TestGatherOp_ZeroDim(TestGatherOp):
def config(self):
"""
For multi-dimension input
"""
self.x_shape = 100
self.config_dtype()
self.index = 2
self.index_type = "int32"
def if_enable_cinn(self):
self.enable_cinn = False
class TestGatherOpFP16(TestGatherOp): class TestGatherOpFP16(TestGatherOp):
def config_dtype(self): def config_dtype(self):
self.x_type = "float16" self.x_type = "float16"
......
...@@ -152,17 +152,20 @@ class TestReduceMeanOp(OpTest): ...@@ -152,17 +152,20 @@ class TestReduceMeanOp(OpTest):
self.public_python_api = reduce_mean_wrapper self.public_python_api = reduce_mean_wrapper
self.prim_op_type = "comp" self.prim_op_type = "comp"
self.dtype = 'float64' self.dtype = 'float64'
self.shape = [2, 3, 4, 5] self.init_shapes()
self.axis = [0] self.axis = [0]
if self.shape == []:
self.axis = []
self.keepdim = False self.keepdim = False
self.set_attrs() self.set_attrs()
self.if_enable_cinn() self.if_enable_cinn()
np.random.seed(10) np.random.seed(10)
x_np = np.random.uniform(-1, 1, self.shape).astype(self.dtype) x_np = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
if not hasattr(self, "reduce_all"): if not hasattr(self, "reduce_all") and not x_np.shape == ():
self.reduce_all = (not self.axis) or len(self.axis) == len(x_np) self.reduce_all = (not self.axis) or len(self.axis) == len(x_np)
if x_np.shape == ():
self.reduce_all = True
out_np = ref_reduce_mean(x_np, self.axis, self.keepdim, self.reduce_all) out_np = ref_reduce_mean(x_np, self.axis, self.keepdim, self.reduce_all)
self.inputs = {'X': x_np} self.inputs = {'X': x_np}
self.outputs = {'Out': out_np} self.outputs = {'Out': out_np}
...@@ -172,6 +175,9 @@ class TestReduceMeanOp(OpTest): ...@@ -172,6 +175,9 @@ class TestReduceMeanOp(OpTest):
'reduce_all': self.reduce_all, 'reduce_all': self.reduce_all,
} }
def init_shapes(self):
self.shape = [2, 3, 4, 5]
def set_attrs(self): def set_attrs(self):
pass pass
...@@ -195,6 +201,12 @@ class TestReduceMeanOp(OpTest): ...@@ -195,6 +201,12 @@ class TestReduceMeanOp(OpTest):
) )
class TestReduceMeanOp_ZeroDim(TestReduceMeanOp):
def init_shapes(self):
self.shape = []
self.enable_cinn = False
@unittest.skipIf( @unittest.skipIf(
not core.is_compiled_with_cuda() not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)), or not core.is_bfloat16_supported(core.CUDAPlace(0)),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册