From e4bf1a8ed97ff1e2808f4f19f18d4f19981fd687 Mon Sep 17 00:00:00 2001 From: zhwesky2010 <1183042833@qq.com> Date: Tue, 9 May 2023 00:23:02 +0800 Subject: [PATCH] [Zero-Dim] Support p_norm/reduce_sum_p output 0D (#53421) --- .../operators/prim_ops/reduce_sum_p_op.cc | 3 -- paddle/phi/infermeta/unary.cc | 36 ++++++++++--------- .../operators/dist_reduce_sum_p.py | 2 +- .../tests/unittests/test_pairwise_distance.py | 4 --- test/auto_parallel/test_dist_pnorm.py | 12 ++++++- test/auto_parallel/test_prim_dist_op.py | 3 +- test/autograd/test_primops.py | 2 +- 7 files changed, 33 insertions(+), 29 deletions(-) diff --git a/paddle/fluid/operators/prim_ops/reduce_sum_p_op.cc b/paddle/fluid/operators/prim_ops/reduce_sum_p_op.cc index b31b4934706..89754a7dbfc 100644 --- a/paddle/fluid/operators/prim_ops/reduce_sum_p_op.cc +++ b/paddle/fluid/operators/prim_ops/reduce_sum_p_op.cc @@ -79,9 +79,6 @@ class ReduceSumPrimOpShapeInference : public framework::InferShapeBase { x_shape.erase(remove(x_shape.begin(), x_shape.end(), kDelFlag), x_shape.end()); } - if (!keepdim && x_shape.size() == 0) { - x_shape.push_back(1); - } PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_shape); } diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 1d1d8333982..42b99dee7cb 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -2750,31 +2750,33 @@ void PNormInferMeta(const MetaTensor& x, x_rank, x_dim)); - std::vector reduce_dims; + std::vector out_dim_vector; if (asvector) { - reduce_dims.emplace_back(1); if (keepdim) { - for (int i = 1; i < x_dim.size(); ++i) { - reduce_dims.emplace_back(1); + for (int i = 0; i < x_rank; ++i) { + out_dim_vector.emplace_back(1); } - x_dim = phi::make_ddim(reduce_dims); + } else { + out_dim_vector = {}; } } else { - if (axis < 0) axis = x_dim.size() + axis; - for (int i = 0; i < x_dim.size(); ++i) { - if (i != axis) reduce_dims.emplace_back(x_dim[i]); - } - if (reduce_dims.size() == 0) { - reduce_dims.emplace_back(1); + if (axis < 0) axis = axis + x_rank; + if (keepdim) { + for (int i = 0; i < x_dim.size(); ++i) { + if (i != axis) { + out_dim_vector.emplace_back(x_dim[i]); + } else { + out_dim_vector.emplace_back(1); + } + } + } else { + for (int i = 0; i < x_dim.size(); ++i) { + if (i != axis) out_dim_vector.emplace_back(x_dim[i]); + } } - x_dim[axis] = 1; } - if (keepdim) { - out->set_dims(x_dim); - } else { - out->set_dims(phi::make_ddim(reduce_dims)); - } + out->set_dims(phi::make_ddim(out_dim_vector)); out->set_dtype(x.dtype()); } diff --git a/python/paddle/distributed/auto_parallel/operators/dist_reduce_sum_p.py b/python/paddle/distributed/auto_parallel/operators/dist_reduce_sum_p.py index 50a4d3466b0..919cad3d83e 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_reduce_sum_p.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_reduce_sum_p.py @@ -58,7 +58,7 @@ class DistributedReduceSumPrimtiveImpl0(DistributedOperatorImpl): output_name = outputs[0] output_var = dist_op.serial_op.block._var_recursive(output_name) - if output_var.shape != (1,): + if output_var.shape != (): return False return True diff --git a/python/paddle/fluid/tests/unittests/test_pairwise_distance.py b/python/paddle/fluid/tests/unittests/test_pairwise_distance.py index b8e2c51a183..e89a713282d 100644 --- a/python/paddle/fluid/tests/unittests/test_pairwise_distance.py +++ b/python/paddle/fluid/tests/unittests/test_pairwise_distance.py @@ -22,10 +22,6 @@ from paddle import fluid def np_pairwise_distance(x, y, p=2.0, epsilon=1e-6, keepdim=False): distance = np.linalg.norm(x - y + epsilon, ord=p, axis=-1, keepdims=keepdim) - # Paddle currently has not supported for 0-d Tensors, so even if keep_dim is False, - # and neither x nor y is batched, a Tensor of shape (1, ) is returned - if distance.ndim == 0: - distance = np.expand_dims(distance, axis=0) return distance diff --git a/test/auto_parallel/test_dist_pnorm.py b/test/auto_parallel/test_dist_pnorm.py index deb7c8f7156..5ff30d27b6d 100644 --- a/test/auto_parallel/test_dist_pnorm.py +++ b/test/auto_parallel/test_dist_pnorm.py @@ -121,9 +121,19 @@ class TestDistPNormDP(TestDistPNorm): op_dist_attr = self.dist_context.get_op_dist_attr_for_program(op) if op.type == "p_norm": assert op_dist_attr.impl_type == "p_norm" - if op.type in ["p_norm", "p_norm_grad"]: for input_attr in op_dist_attr.inputs_dist_attrs.values(): assert set(input_attr.dims_mapping) == {-1} + for output_attr in op_dist_attr.outputs_dist_attrs.values(): + if len(output_attr.dims_mapping) == 0: + assert output_attr.dims_mapping == [] + else: + assert set(output_attr.dims_mapping) == {-1} + if op.type == "p_norm_grad": + for input_attr in op_dist_attr.inputs_dist_attrs.values(): + if len(input_attr.dims_mapping) == 0: + assert input_attr.dims_mapping == [] + else: + assert set(input_attr.dims_mapping) == {-1} for output_attr in op_dist_attr.outputs_dist_attrs.values(): assert set(output_attr.dims_mapping) == {-1} if op.type == 'c_allgather': diff --git a/test/auto_parallel/test_prim_dist_op.py b/test/auto_parallel/test_prim_dist_op.py index 0fc153a3a33..5a4a1b5a512 100644 --- a/test/auto_parallel/test_prim_dist_op.py +++ b/test/auto_parallel/test_prim_dist_op.py @@ -55,7 +55,7 @@ class TestPrimDistOp(unittest.TestCase): self.tmp1 = paddle.static.data(name='tmp1', shape=[20], dtype='float') self.tmp2 = paddle.static.data(name='tmp2', shape=[20], dtype='float') self.batch_reduced = paddle.static.data( - name='batch_reduced', shape=[1], dtype='float' + name='batch_reduced', shape=[], dtype='float' ) self.attrs = {} @@ -108,7 +108,6 @@ class TestPrimDistOp(unittest.TestCase): self.main_program, self.startup_program, [(self.w, self.w_grad)] ) ops = dist_main_prog.global_block().ops - self.assertTrue(ops[1].type == "c_allreduce_sum") self.assertTrue(ops[3].type == "c_allreduce_sum") diff --git a/test/autograd/test_primops.py b/test/autograd/test_primops.py index 65f818fa054..f8fa82aea37 100644 --- a/test/autograd/test_primops.py +++ b/test/autograd/test_primops.py @@ -110,7 +110,7 @@ paddle.enable_static() primops.reduce_sum, randn(2, 3), {'axis': (0, 1)}, - (1,), + (), 'float64', ), ( -- GitLab