未验证 提交 0b96793e 编写于 作者: X xiongkun 提交者: GitHub

[Dygraph TestsFix] Test some tests in new dygraph final_state mode. (#41363)

* fix less than

* fix some tests

* fix additional 3 unittest case
上级 e0d12b8d
...@@ -132,6 +132,7 @@ PD_REGISTER_KERNEL(full_like, ...@@ -132,6 +132,7 @@ PD_REGISTER_KERNEL(full_like,
phi::FullLikeKernel, phi::FullLikeKernel,
float, float,
double, double,
uint8_t,
int16_t, int16_t,
int, int,
int64_t, int64_t,
......
...@@ -452,11 +452,12 @@ def spawn(func, args=(), nprocs=-1, join=True, daemon=False, **options): ...@@ -452,11 +452,12 @@ def spawn(func, args=(), nprocs=-1, join=True, daemon=False, **options):
def train(print_result=False): def train(print_result=False):
# 1. initialize parallel environment # 1. initialize parallel environment
dist.init_parallel_env() group = dist.init_parallel_env()
process_group = group.process_group if group else None
# 2. create data parallel layer & optimizer # 2. create data parallel layer & optimizer
layer = LinearNet() layer = LinearNet()
dp_layer = paddle.DataParallel(layer) dp_layer = paddle.DataParallel(layer, process_group=process_group)
loss_fn = nn.MSELoss() loss_fn = nn.MSELoss()
adam = opt.Adam( adam = opt.Adam(
......
...@@ -182,8 +182,8 @@ def equal(x, y, name=None): ...@@ -182,8 +182,8 @@ def equal(x, y, name=None):
y = full(shape=[1], dtype=x.dtype, fill_value=y) y = full(shape=[1], dtype=x.dtype, fill_value=y)
if in_dygraph_mode(): if in_dygraph_mode():
axis = -1 default_axis = -1
return _C_ops.final_state_equal(x, y, axis) return _C_ops.final_state_equal(x, y, default_axis)
else: else:
if _in_legacy_dygraph(): if _in_legacy_dygraph():
return _C_ops.equal(x, y) return _C_ops.equal(x, y)
...@@ -232,8 +232,8 @@ def greater_equal(x, y, name=None): ...@@ -232,8 +232,8 @@ def greater_equal(x, y, name=None):
print(result1) # result1 = [True False True] print(result1) # result1 = [True False True]
""" """
if in_dygraph_mode(): if in_dygraph_mode():
axis = -1 default_axis = -1
return _C_ops.final_state_greater_equal(x, y, axis) return _C_ops.final_state_greater_equal(x, y, default_axis)
else: else:
if _in_legacy_dygraph(): if _in_legacy_dygraph():
return _C_ops.greater_equal(x, y) return _C_ops.greater_equal(x, y)
...@@ -383,8 +383,8 @@ def less_than(x, y, name=None): ...@@ -383,8 +383,8 @@ def less_than(x, y, name=None):
print(result1) # result1 = [False True False] print(result1) # result1 = [False True False]
""" """
if in_dygraph_mode(): if in_dygraph_mode():
axis = -1 default_axis = -1
return _C_ops.final_state_less_than(x, y, axis) return _C_ops.final_state_less_than(x, y, default_axis)
else: else:
if _in_legacy_dygraph(): if _in_legacy_dygraph():
return _C_ops.less_than(x, y) return _C_ops.less_than(x, y)
......
...@@ -2668,6 +2668,7 @@ def cumsum(x, axis=None, dtype=None, name=None): ...@@ -2668,6 +2668,7 @@ def cumsum(x, axis=None, dtype=None, name=None):
x = cast(x, dtype) x = cast(x, dtype)
if in_dygraph_mode(): if in_dygraph_mode():
if axis is None: axis = -1
return _C_ops.final_state_cumsum(x, axis, flatten, False, False) return _C_ops.final_state_cumsum(x, axis, flatten, False, False)
if _in_legacy_dygraph(): if _in_legacy_dygraph():
if axis is None: if axis is None:
......
...@@ -419,14 +419,14 @@ ...@@ -419,14 +419,14 @@
func : cumprod func : cumprod
backward : cumprod_grad backward : cumprod_grad
# cumsum
- api : cumsum - api : cumsum
args : (Tensor x, int axis, bool flatten, bool exclusive, bool reverse) args : (Tensor x, int axis, bool flatten, bool exclusive, bool reverse)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : CumsumInferMeta func : CumsumInferMeta
kernel : kernel :
func : cumsum func : cumsum
backward : cumsum_grad
- api : depthwise_conv2d_transpose - api : depthwise_conv2d_transpose
args : (Tensor x, Tensor filter, int[] strides, int[] paddings, int[] output_padding, int[] output_size, str padding_algorithm, int groups, int[] dilations, str data_format) args : (Tensor x, Tensor filter, int[] strides, int[] paddings, int[] output_padding, int[] output_size, str padding_algorithm, int groups, int[] dilations, str data_format)
......
...@@ -286,6 +286,15 @@ ...@@ -286,6 +286,15 @@
kernel : kernel :
func : cumprod_grad func : cumprod_grad
- backward_api : cumsum_grad
forward : cumsum(Tensor x, int axis, bool flatten, bool exclusive, bool reverse) -> Tensor(out)
infer_meta :
func : UnchangedInferMeta
param : [x]
args : (Tensor out_grad, int axis, bool flatten, bool exclusive, bool reverse)
output : Tensor(x_grad)
invoke : cumsum(out_grad, axis, flatten, exclusive, !reverse)
- backward_api : depthwise_conv2d_transpose_grad - backward_api : depthwise_conv2d_transpose_grad
forward : depthwise_conv2d_transpose(Tensor x, Tensor filter, int[] strides, int[] paddings, int[] output_padding, int[] output_size, str padding_algorithm, int groups, int[] dilations, str data_format) -> Tensor(out) forward : depthwise_conv2d_transpose(Tensor x, Tensor filter, int[] strides, int[] paddings, int[] output_padding, int[] output_size, str padding_algorithm, int groups, int[] dilations, str data_format) -> Tensor(out)
args : (Tensor x, Tensor filter, Tensor out_grad, int[] strides, int[] paddings, int[] output_padding, int[] output_size, str padding_algorithm, int groups, int[] dilations, str data_format) args : (Tensor x, Tensor filter, Tensor out_grad, int[] strides, int[] paddings, int[] output_padding, int[] output_size, str padding_algorithm, int groups, int[] dilations, str data_format)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册