From d4451cb049b767972106465fc2861075a09248f8 Mon Sep 17 00:00:00 2001 From: zhwesky2010 <1183042833@qq.com> Date: Thu, 1 Jun 2023 19:34:54 +0800 Subject: [PATCH] [Zero-Dim] OpTest support shape check and fix previous case problem (#54117) --- paddle/fluid/operators/l1_norm_op.cc | 2 +- paddle/fluid/operators/mean_iou_op.cc | 4 +- paddle/fluid/operators/mean_iou_op.cu | 4 +- paddle/fluid/operators/mean_iou_op.h | 10 +- .../operators/reduce_ops/reduce_mean_op.cc | 7 +- paddle/phi/api/yaml/legacy_ops.yaml | 2 +- paddle/phi/infermeta/binary.cc | 2 +- paddle/phi/infermeta/unary.cc | 223 ++++-------------- paddle/phi/infermeta/unary.h | 13 - .../phi/kernels/impl/multi_dot_kernel_impl.h | 5 + .../elementwise_multiply_kernel.cc | 11 +- .../collective_global_gather_dygraph.py | 2 +- .../collective_global_scatter_dygraph.py | 2 +- .../fleet/hybrid_parallel_mp_layers.py | 4 +- ...parallel_dygraph_no_sync_gradient_check.py | 8 +- .../fleet/parallel_margin_cross_entropy.py | 4 +- test/legacy_test/eager_op_test.py | 129 +++++----- ...allel_dygraph_dataparallel_with_pylayer.py | 8 +- .../parallel_dygraph_gradient_check.py | 8 +- ...el_dygraph_gradient_check_in_eager_mode.py | 8 +- test/legacy_test/test_accuracy_op.py | 12 +- test/legacy_test/test_adam_op.py | 4 +- test/legacy_test/test_allclose_op.py | 16 +- test/legacy_test/test_arg_min_max_v2_op.py | 2 +- test/legacy_test/test_auc_op.py | 7 +- test/legacy_test/test_auc_single_pred_op.py | 4 +- test/legacy_test/test_cross_entropy_op.py | 16 +- test/legacy_test/test_ctc_align.py | 2 +- test/legacy_test/test_data_norm_op.py | 20 +- test/legacy_test/test_edit_distance_op.py | 6 +- test/legacy_test/test_elementwise_div_op.py | 2 +- test/legacy_test/test_elementwise_sub_op.py | 2 +- test/legacy_test/test_fake_quantize_op.py | 27 +-- test/legacy_test/test_fused_adam_op.py | 8 +- test/legacy_test/test_gather_nd_op.py | 4 +- .../test_generate_proposal_labels_op.py | 2 +- test/legacy_test/test_is_empty_op.py | 4 +- test/legacy_test/test_isclose_op.py | 16 +- test/legacy_test/test_isfinite_op.py | 12 +- test/legacy_test/test_linspace.py | 4 +- test/legacy_test/test_logspace.py | 2 +- test/legacy_test/test_mean_iou.py | 42 ++-- test/legacy_test/test_multiclass_nms_op.py | 12 +- test/legacy_test/test_nll_loss.py | 12 +- test/legacy_test/test_numel_op.py | 4 +- .../test_positive_negative_pair_op.py | 6 +- test/legacy_test/test_reduce_op.py | 54 +++-- test/legacy_test/test_scatter_op.py | 6 +- test/legacy_test/test_seed_op.py | 4 +- test/legacy_test/test_segment_ops.py | 8 + test/legacy_test/test_slice_op.py | 4 +- test/legacy_test/test_squared_l2_norm_op.py | 2 +- test/legacy_test/test_unbind_op.py | 22 +- ...allel_dygraph_dataparallel_with_pylayer.py | 8 +- test/xpu/parallel_dygraph_gradient_check.py | 8 +- ...el_dygraph_gradient_check_in_eager_mode.py | 8 +- 56 files changed, 366 insertions(+), 462 deletions(-) diff --git a/paddle/fluid/operators/l1_norm_op.cc b/paddle/fluid/operators/l1_norm_op.cc index 08a0b894c9f..92f190c0025 100644 --- a/paddle/fluid/operators/l1_norm_op.cc +++ b/paddle/fluid/operators/l1_norm_op.cc @@ -27,7 +27,7 @@ class L1NormOp : public framework::OperatorWithKernel { OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "L1NormOp"); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "L1NormOp"); - ctx->SetOutputDim("Out", {1}); + ctx->SetOutputDim("Out", phi::make_ddim({})); } }; diff --git a/paddle/fluid/operators/mean_iou_op.cc b/paddle/fluid/operators/mean_iou_op.cc index 27fd86fab08..d87c49187c2 100644 --- a/paddle/fluid/operators/mean_iou_op.cc +++ b/paddle/fluid/operators/mean_iou_op.cc @@ -34,7 +34,7 @@ class MeanIoUOp : public framework::OperatorWithKernel { int64_t num_classes = static_cast(ctx->Attrs().Get("num_classes")); - ctx->SetOutputDim("OutMeanIou", {1}); + ctx->SetOutputDim("OutMeanIou", phi::make_ddim({})); ctx->SetOutputDim("OutWrong", {num_classes}); ctx->SetOutputDim("OutCorrect", {num_classes}); } @@ -78,7 +78,7 @@ class MeanIoUOpMaker : public framework::OpProtoAndCheckerMaker { .AsDispensable(); AddOutput("OutMeanIou", "(vector), A Tensor representing the" - " mean intersection-over-union with shape [1]."); + " mean intersection-over-union with shape []."); AddOutput("OutWrong", "(Tensor), A Tensor with shape [num_classes]. "); AddOutput("OutCorrect", "(Tensor), A Tensor with shape [num_classes]. "); AddAttr("num_classes", "(int), The possible number of labels."); diff --git a/paddle/fluid/operators/mean_iou_op.cu b/paddle/fluid/operators/mean_iou_op.cu index 1dbc9f6fdc8..46abb4b7291 100644 --- a/paddle/fluid/operators/mean_iou_op.cu +++ b/paddle/fluid/operators/mean_iou_op.cu @@ -111,7 +111,7 @@ class MeanIoUCUDAOpKernel : public framework::OpKernel { out_mean_iou->mutable_data(ctx.GetPlace()); // Get Eigen tensor - auto out_mean_iou_t = EigenTensor::From(*out_mean_iou); + auto out_mean_iou_t = EigenScalar::From(*out_mean_iou); auto out_wrong_t = EigenTensor::From(*out_wrong); auto out_correct_t = EigenTensor::From(*out_correct); @@ -131,7 +131,7 @@ class MeanIoUCUDAOpKernel : public framework::OpKernel { auto in_mean_ious = ctx.MultiInput("InMeanIou"); for (int i = 0; i < in_mean_ious.size(); ++i) { out_mean_iou_t.device(place) += - EigenTensor::From(*in_mean_ious[i]); + EigenScalar::From(*in_mean_ious[i]); } auto in_wrongs = ctx.MultiInput("InWrongs"); for (int i = 0; i < in_wrongs.size(); ++i) { diff --git a/paddle/fluid/operators/mean_iou_op.h b/paddle/fluid/operators/mean_iou_op.h index 436cf84a548..8569d567c8f 100644 --- a/paddle/fluid/operators/mean_iou_op.h +++ b/paddle/fluid/operators/mean_iou_op.h @@ -27,6 +27,11 @@ template using EigenTensor = framework::EigenTensor; +template +using EigenScalar = framework::EigenScalar; + template class MeanIoUKernel : public framework::OpKernel { public: @@ -50,7 +55,7 @@ class MeanIoUKernel : public framework::OpKernel { int* out_correct_data = out_correct->mutable_data(ctx.GetPlace()); // get eigen tensor - auto out_mean_iou_t = EigenTensor::From(*out_mean_iou); + auto out_mean_iou_t = EigenScalar::From(*out_mean_iou); auto out_wrong_t = EigenTensor::From(*out_wrong); auto out_correct_t = EigenTensor::From(*out_correct); @@ -79,8 +84,9 @@ class MeanIoUKernel : public framework::OpKernel { auto in_mean_ious = ctx.MultiInput("InMeanIou"); for (size_t i = 0; i < in_mean_ious.size(); ++i) { out_mean_iou_t.device(place) += - EigenTensor::From(*in_mean_ious[i]); + EigenScalar::From(*in_mean_ious[i]); } + auto in_wrongs = ctx.MultiInput("InWrongs"); for (size_t i = 0; i < in_wrongs.size(); ++i) { out_wrong_t.device(place) += EigenTensor::From(*in_wrongs[i]); diff --git a/paddle/fluid/operators/reduce_ops/reduce_mean_op.cc b/paddle/fluid/operators/reduce_ops/reduce_mean_op.cc index 2c95cf8bd1f..0048ec1e724 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_mean_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_mean_op.cc @@ -98,9 +98,10 @@ class __reduce_meanMaker__ : public ops::ReduceBaseOpMaker { virtual std::string GetOpType() const { return "Reduce reduce_mean"; } }; -DECLARE_INFER_SHAPE_FUNCTOR(reduce_mean, - ReduceMeanInferShapeFunctor, - PD_INFER_META(phi::OriginReduceInferMetaBase)); +DECLARE_INFER_SHAPE_FUNCTOR( + reduce_mean, + ReduceMeanInferShapeFunctor, + PD_INFER_META(phi::ReduceIntArrayAxisInferMetaBase)); REGISTER_OPERATOR(reduce_mean, ops::ReduceBaseOp, diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 9ab4103b123..50158c513fd 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -633,7 +633,7 @@ args : (Tensor x, IntArray axis={}, bool keepdim=false) output : Tensor(out) infer_meta : - func : OriginReduceInferMeta + func : ReduceIntArrayAxisInferMeta kernel : func : mean backward : mean_grad diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 87e10cb6252..09f199e4be5 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -130,7 +130,7 @@ void KLDivInferMeta(const MetaTensor& x, if ("none" == reduction) { out->set_dims(dim_x); } else { - out->set_dims({1}); + out->set_dims(phi::make_ddim({})); } out->set_dtype(x.dtype()); } diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 92cf654aee8..e43e945f375 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1659,7 +1659,7 @@ void IdentityLossInferMeta(const MetaTensor& x, out->set_dtype(x.dtype()); out->set_dims(x.dims()); } else { - out->set_dims(phi::make_ddim({1})); + out->set_dims(phi::make_ddim({})); out->set_dtype(x.dtype()); } } @@ -3069,6 +3069,28 @@ DDim ReduceInferDim(const MetaTensor& x, return out_dim; } +void ReduceInferMetaBase(const MetaTensor& x, + const std::vector& axis, + bool keep_dim, + bool reduce_all, + MetaTensor* out) { + DDim out_dim = ReduceInferDim(x, axis, keep_dim, reduce_all); + out->set_dims(out_dim); + out->set_dtype(x.dtype()); + out->set_layout(x.layout()); +} + +void ReduceInferMeta(const MetaTensor& x, + const std::vector& axis, + bool keep_dim, + MetaTensor* out) { + bool reduce_all = false; + if (axis.size() == 0) { + reduce_all = true; + } + ReduceInferMetaBase(x, axis, keep_dim, reduce_all, out); +} + DDim ReduceInferDimForIntArrayAxis(const MetaTensor& x, const IntArray& axis, bool keep_dim, @@ -3096,23 +3118,18 @@ DDim ReduceInferDimForIntArrayAxis(const MetaTensor& x, return phi::make_ddim(vec_dim); } -void ReduceInferMeta(const MetaTensor& x, - const std::vector& axis, - bool keep_dim, - MetaTensor* out) { - bool reduce_all = false; - if (axis.size() == 0) { - reduce_all = true; +void ReduceIntArrayAxisInferMetaBase(const MetaTensor& x, + const IntArray& axis, + bool keep_dim, + bool reduce_all, + MetaTensor* out, + MetaConfig config) { + DDim out_dim; + if (config.is_runtime || !axis.FromTensor()) { + out_dim = ReduceInferDim(x, axis.GetData(), keep_dim, reduce_all); + } else { + out_dim = ReduceInferDimForIntArrayAxis(x, axis, keep_dim, reduce_all); } - ReduceInferMetaBase(x, axis, keep_dim, reduce_all, out); -} - -void ReduceInferMetaBase(const MetaTensor& x, - const std::vector& axis, - bool keep_dim, - bool reduce_all, - MetaTensor* out) { - DDim out_dim = ReduceInferDim(x, axis, keep_dim, reduce_all); out->set_dims(out_dim); out->set_dtype(x.dtype()); out->set_layout(x.layout()); @@ -3130,23 +3147,6 @@ void ReduceIntArrayAxisInferMeta(const MetaTensor& x, ReduceIntArrayAxisInferMetaBase(x, axis, keep_dim, reduce_all, out, config); } -void ReduceIntArrayAxisInferMetaBase(const MetaTensor& x, - const IntArray& axis, - bool keep_dim, - bool reduce_all, - MetaTensor* out, - MetaConfig config) { - DDim out_dim; - if (config.is_runtime || !axis.FromTensor()) { - out_dim = ReduceInferDim(x, axis.GetData(), keep_dim, reduce_all); - } else { - out_dim = ReduceInferDimForIntArrayAxis(x, axis, keep_dim, reduce_all); - } - out->set_dims(out_dim); - out->set_dtype(x.dtype()); - out->set_layout(x.layout()); -} - void ReduceScatterInferMeta(const MetaTensor& x, int nranks, MetaTensor* out) { auto dim = x.dims(); if (dim[0] > 0 || dim[0] < -1) { @@ -3945,119 +3945,6 @@ void StridedSliceInferMeta(const MetaTensor& x, x, axes, starts, ends, strides, infer_flags, decrease_axis, out, config); } -// TODO(zhouwei): OriginReduceInferDim doesn't support 0D, remove in future -DDim OriginReduceInferDim(const MetaTensor& x, - const std::vector& axis, - bool keep_dim, - bool reduce_all) { - auto x_rank = x.dims().size(); - - std::vector formated_axis = axis; - for (size_t i = 0; i < axis.size(); ++i) { - if (x_rank == 0) { - PADDLE_ENFORCE_EQ( - axis[i] == 0 || axis[i] == -1, - true, - phi::errors::InvalidArgument( - "When input 0D Tensor, the axis can only be -1, 0, None or []")); - } else { - PADDLE_ENFORCE_LT(axis[i], - x_rank, - errors::InvalidArgument( - "The reduce dim index %d should be in the " - "range [ -dimension(X), dimension(X) ) " - "which dimesion = %d. But received dim index = %d.", - i, - x_rank, - axis[i])); - PADDLE_ENFORCE_GE(axis[i], - -x_rank, - errors::InvalidArgument( - "The reduce dim index %d should be in the " - "range [ -dimension(X), dimension(X) ) " - "which dimesion = %d. But received dim index = %d.", - i, - x_rank, - axis[i])); - } - - if (axis[i] < 0) { - formated_axis[i] = axis[i] + x_rank; - } - } - - bool full_dim = true; - std::set dims_set(formated_axis.begin(), formated_axis.end()); - for (int64_t i = 0; i < x_rank; ++i) { - if (dims_set.find(i) == dims_set.end()) { - full_dim = false; - break; - } - } - reduce_all = reduce_all || full_dim; - - std::vector out_dim_vector; - for (int64_t i = 0; i < x_rank; ++i) { - if (reduce_all || dims_set.find(i) != dims_set.end()) { - if (keep_dim) { - out_dim_vector.push_back(1); - } else { - continue; - } - } else { - out_dim_vector.push_back(x.dims().at(i)); - } - } - - DDim out_dim = phi::make_ddim(out_dim_vector); - return out_dim; -} - -// TODO(zhouwei): OriginReduceInferDim doesn't support 0D, remove in future -DDim OriginReduceInferDimForIntArrayAxis(const MetaTensor& x, - const IntArray& axis, - bool keep_dim, - bool reduce_all) { - std::vector vec_axis = axis.GetData(); - std::vector vec_dim; - if (reduce_all) { - if (keep_dim) { - vec_dim = std::vector(x.dims().size(), 1); - } else { - vec_dim = {}; - } - } else { - if (keep_dim) { - vec_dim = std::vector(x.dims().size(), -1); - } else { - auto x_rank = static_cast(x.dims().size()); - if (vec_axis.size() > x_rank) { - vec_dim = {-1}; - } else { - vec_dim = std::vector(x.dims().size() - vec_axis.size(), -1); - } - } - } - return phi::make_ddim(vec_dim); -} - -/* Why not use SumRawInferMeta directly? - Because we need make InferMetaFunction's args follow the design of - ops.yaml -*/ -void SumInferMeta(const MetaTensor& x, - const IntArray& axis, - DataType dtype, - bool keep_dim, - MetaTensor* out, - MetaConfig config) { - bool reduce_all = false; - if (axis.size() == 0) { - reduce_all = true; - } - SumRawInferMeta(x, axis, keep_dim, reduce_all, dtype, out, config); -} - void SumRawInferMeta(const MetaTensor& x, const IntArray& axis, bool keep_dim, @@ -4067,10 +3954,9 @@ void SumRawInferMeta(const MetaTensor& x, MetaConfig config) { DDim out_dim; if (config.is_runtime || !axis.FromTensor()) { - out_dim = OriginReduceInferDim(x, axis.GetData(), keep_dim, reduce_all); + out_dim = ReduceInferDim(x, axis.GetData(), keep_dim, reduce_all); } else { - out_dim = - OriginReduceInferDimForIntArrayAxis(x, axis, keep_dim, reduce_all); + out_dim = ReduceInferDimForIntArrayAxis(x, axis, keep_dim, reduce_all); } DataType out_dtype; @@ -4089,36 +3975,21 @@ void SumRawInferMeta(const MetaTensor& x, out->set_layout(x.layout()); } -// TODO(zhouwei): OriginReduce doesn't support 0D, remove in future -void OriginReduceInferMeta(const MetaTensor& x, - const IntArray& axis, - bool keep_dim, - MetaTensor* out, - MetaConfig config) { +/* Why not use SumRawInferMeta directly? + Because we need make InferMetaFunction's args follow the design of + ops.yaml +*/ +void SumInferMeta(const MetaTensor& x, + const IntArray& axis, + DataType dtype, + bool keep_dim, + MetaTensor* out, + MetaConfig config) { bool reduce_all = false; if (axis.size() == 0) { reduce_all = true; } - OriginReduceInferMetaBase(x, axis, keep_dim, reduce_all, out, config); -} - -// TODO(zhouwei): OriginReduce doesn't support 0D, remove in future -void OriginReduceInferMetaBase(const MetaTensor& x, - const IntArray& axis, - bool keep_dim, - bool reduce_all, - MetaTensor* out, - MetaConfig config) { - DDim out_dim; - if (config.is_runtime || !axis.FromTensor()) { - out_dim = OriginReduceInferDim(x, axis.GetData(), keep_dim, reduce_all); - } else { - out_dim = - OriginReduceInferDimForIntArrayAxis(x, axis, keep_dim, reduce_all); - } - out->set_dims(out_dim); - out->set_dtype(x.dtype()); - out->set_layout(x.layout()); + SumRawInferMeta(x, axis, keep_dim, reduce_all, dtype, out, config); } void SvdInferMeta(const MetaTensor& x, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 7dc922ac9a4..297e6d5648d 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -572,19 +572,6 @@ void SumRawInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config = MetaConfig()); -void OriginReduceInferMeta(const MetaTensor& x, - const IntArray& axis, - bool keep_dim, - MetaTensor* out, - MetaConfig config = MetaConfig()); - -void OriginReduceInferMetaBase(const MetaTensor& x, - const IntArray& axis, - bool keep_dim, - bool reduce_all, - MetaTensor* out, - MetaConfig config = MetaConfig()); - void SvdInferMeta(const MetaTensor& x, bool full_matrices, MetaTensor* u, diff --git a/paddle/phi/kernels/impl/multi_dot_kernel_impl.h b/paddle/phi/kernels/impl/multi_dot_kernel_impl.h index dfb3e04a8c3..e63ee311907 100644 --- a/paddle/phi/kernels/impl/multi_dot_kernel_impl.h +++ b/paddle/phi/kernels/impl/multi_dot_kernel_impl.h @@ -446,9 +446,14 @@ void MultiDotGradKernel(const Context& ctx, } else { MultiDotGradMatChainOrder( ctx, dout, ins, dout_dim, ins_dims, &dx); + // if x's shape is: [3] [3, 4] [4] + // dx's shape will be: [1, 3] [3, 4] [4, 1] if (ins[n - 1]->dims().size() == 1) { dx[n - 1]->Resize({dx[n - 1]->dims()[0]}); } + if (ins[0]->dims().size() == 1) { + dx[0]->Resize({dx[0]->dims()[1]}); + } } } diff --git a/paddle/phi/kernels/selected_rows/elementwise_multiply_kernel.cc b/paddle/phi/kernels/selected_rows/elementwise_multiply_kernel.cc index 9fe8eef7ec8..dccbba6947a 100644 --- a/paddle/phi/kernels/selected_rows/elementwise_multiply_kernel.cc +++ b/paddle/phi/kernels/selected_rows/elementwise_multiply_kernel.cc @@ -31,12 +31,11 @@ void MultiplyRawKernel(const Context& dev_ctx, const DenseTensor& y, int axis, SelectedRows* out) { - PADDLE_ENFORCE_EQ(y.dims().size() == 1 && y.dims()[0] == 1, - true, - phi::errors::InvalidArgument( - "For MultiplyKernel, if X is Sparse, Y must be " - "scalar. But reveived the size of Y = %s.", - y.dims().size())); + PADDLE_ENFORCE_EQ( + phi::product(y.dims()), + 1, + phi::errors::InvalidArgument("For MultiplyKernel, if X is Sparse, Y must " + "contain only one element.")); out->set_rows(x.rows()); out->set_height(x.height()); auto z = out->mutable_value(); diff --git a/test/collective/collective_global_gather_dygraph.py b/test/collective/collective_global_gather_dygraph.py index f3cc9ab412a..c816132f9ef 100644 --- a/test/collective/collective_global_gather_dygraph.py +++ b/test/collective/collective_global_gather_dygraph.py @@ -60,7 +60,7 @@ class TestCollectiveGlobalGatherAPI(TestCollectiveAPIRunnerBase): c = output * output c.stop_gradient = False c.backward() - return [output.numpy(), local_input_buf.grad.numpy()] + return [output.numpy(False), local_input_buf.grad.numpy(False)] if __name__ == "__main__": diff --git a/test/collective/collective_global_scatter_dygraph.py b/test/collective/collective_global_scatter_dygraph.py index 3c6fa897f19..2e5001371fd 100644 --- a/test/collective/collective_global_scatter_dygraph.py +++ b/test/collective/collective_global_scatter_dygraph.py @@ -58,7 +58,7 @@ class TestCollectiveGlobalScatterAPI(TestCollectiveAPIRunnerBase): output.stop_gradient = False c = output * output c.backward() - return [output.numpy(), local_input_buf.grad.numpy()] + return [output.numpy(False), local_input_buf.grad.numpy(False)] if __name__ == "__main__": diff --git a/test/collective/fleet/hybrid_parallel_mp_layers.py b/test/collective/fleet/hybrid_parallel_mp_layers.py index 0f39cae720f..b8e57a9a11b 100644 --- a/test/collective/fleet/hybrid_parallel_mp_layers.py +++ b/test/collective/fleet/hybrid_parallel_mp_layers.py @@ -343,7 +343,9 @@ class TestDistTraning(unittest.TestCase): integral_grad = paddle.concat(integral_grad, axis=-1) np.testing.assert_allclose( - integral_data.grad.numpy(), integral_grad.numpy(), rtol=1e-6 + integral_data.grad.numpy(False), + integral_grad.numpy(False), + rtol=1e-6, ) diff --git a/test/collective/fleet/parallel_dygraph_no_sync_gradient_check.py b/test/collective/fleet/parallel_dygraph_no_sync_gradient_check.py index 21814832b66..a3f4d5f0b16 100644 --- a/test/collective/fleet/parallel_dygraph_no_sync_gradient_check.py +++ b/test/collective/fleet/parallel_dygraph_no_sync_gradient_check.py @@ -115,8 +115,8 @@ class TestDistTraning(unittest.TestCase): out_b.sum().backward() def check_acc(self, grad, acc_grad): - grad = grad.numpy() if grad is not None else None - acc_grad = acc_grad.numpy() if acc_grad is not None else None + grad = grad.numpy(False) if grad is not None else None + acc_grad = acc_grad.numpy(False) if acc_grad is not None else None return np.testing.assert_allclose(grad, acc_grad, rtol=1e-6) def print_trainer_0(self, *args): @@ -134,7 +134,9 @@ class TestDistTraning(unittest.TestCase): grad = param._grad_ivar() other_grad = self.broadcast_param(grad.clone(), root=1) if self.trainer_id == 0: - np.testing.assert_allclose(other_grad.numpy(), grad.numpy()) + np.testing.assert_allclose( + other_grad.numpy(False), grad.numpy(False) + ) if __name__ == '__main__': diff --git a/test/collective/fleet/parallel_margin_cross_entropy.py b/test/collective/fleet/parallel_margin_cross_entropy.py index 355a20860e8..6b8ba6fc4c5 100644 --- a/test/collective/fleet/parallel_margin_cross_entropy.py +++ b/test/collective/fleet/parallel_margin_cross_entropy.py @@ -208,8 +208,8 @@ class TestParallelMarginSoftmaxCrossEntropyOp(unittest.TestCase): ) np.testing.assert_allclose( - integral_data.grad.numpy(), - integral_grad.numpy(), + integral_data.grad.numpy(False), + integral_grad.numpy(False), rtol=1e-5, atol=1e-7, ) diff --git a/test/legacy_test/eager_op_test.py b/test/legacy_test/eager_op_test.py index f5e84f1cc9c..cdd54602135 100644 --- a/test/legacy_test/eager_op_test.py +++ b/test/legacy_test/eager_op_test.py @@ -964,6 +964,11 @@ class OpTest(unittest.TestCase): for name in api_outs: np_api = np.array(api_outs[name]) np_dyg = np.array(dygraph_outs[name]) + assert ( + np_api.shape == np_dyg.shape + ), "Operator ({}) : Output ({}) shape mismatch, expect shape is {}, but actual shape is {}".format( + self.op_type, name, np_dyg.shape, np_api.shape + ) np.testing.assert_allclose( np_api, np_dyg, @@ -1230,6 +1235,11 @@ class OpTest(unittest.TestCase): # to check inplace result instead of numpy.array_equal. expect_out = np.array(expect_outs[i]) actual_out = np.array(actual_outs[i]) + assert ( + actual_out.shape == expect_out.shape + ), "Operator ({}) : Output ({}) shape mismatch, expect shape is {}, but actual shape is {}".format( + self.op_type, name, expect_out.shape, actual_out.shape + ) if inplace_atol is not None: np.testing.assert_allclose( expect_out, @@ -1720,41 +1730,28 @@ class OpTest(unittest.TestCase): raise NotImplementedError("base class, not implement!") def _compare_numpy(self, name, actual_np, expect_np): - if actual_np.shape == expect_np.shape: - np.testing.assert_allclose( - actual_np, - expect_np, - atol=self.atol if hasattr(self, 'atol') else atol, - rtol=self.rtol if hasattr(self, 'rtol') else rtol, - equal_nan=equal_nan, - err_msg=( - "Operator (" - + self.op_type - + ") Output (" - + name - + ") has diff at " - + str(place) - + " in " - + self.checker_name - ), - ) - return - self.op_test.assertTrue( - np.allclose( - actual_np, - expect_np, - atol=self.atol if hasattr(self, 'atol') else atol, - rtol=self.rtol if hasattr(self, 'rtol') else rtol, - equal_nan=equal_nan, + expect_np = np.array(expect_np) + assert ( + actual_np.shape == expect_np.shape + ), "Operator ({}) : Output ({}) shape mismatch, expect shape is {}, but actual shape is {}".format( + self.op_type, name, expect_np.shape, actual_np.shape + ) + np.testing.assert_allclose( + actual_np, + expect_np, + atol=self.atol if hasattr(self, 'atol') else atol, + rtol=self.rtol if hasattr(self, 'rtol') else rtol, + equal_nan=equal_nan, + err_msg=( + "Operator (" + + self.op_type + + ") Output (" + + name + + ") has diff at " + + str(place) + + " in " + + self.checker_name ), - "Operator (" - + self.op_type - + ") Output (" - + name - + ") has diff at " - + str(place) - + " in " - + self.checker_name, ) def _compare_list(self, name, actual, expect): @@ -1775,10 +1772,6 @@ class OpTest(unittest.TestCase): ) # modify there for fp32 check - # NOTE(zhiqiu): np.allclose([], [1.]) returns True - # see details: https://stackoverflow.com/questions/38331703/why-does-numpys-broadcasting-sometimes-allow-comparing-arrays-of-different-leng - if expect_np.size == 0: - self.op_test.assertTrue(actual_np.size == 0) self._compare_numpy(name, actual_np, expect_np) if isinstance(expect, tuple): self._compare_list(name, actual, expect) @@ -1905,41 +1898,19 @@ class OpTest(unittest.TestCase): self.op_test.disable_cal_ref_output() def _compare_numpy(self, name, actual_np, expect_np): - if ( - functools.reduce(lambda x, y: x * y, actual_np.shape, 1) - == 0 - and functools.reduce(lambda x, y: x * y, expect_np.shape, 1) - == 0 - ): - pass - else: - if actual_np.shape == expect_np.shape: - np.testing.assert_allclose( - actual_np, - expect_np, - atol=atol, - rtol=self.rtol if hasattr(self, 'rtol') else rtol, - equal_nan=equal_nan, - err_msg=( - "Operator (" - + self.op_type - + ") Output (" - + name - + ") has diff at " - + str(place) - + " in " - + self.checker_name - ), - ) - return - self.op_test.assertTrue( - np.allclose( - actual_np, - expect_np, - atol=atol, - rtol=self.rtol if hasattr(self, 'rtol') else rtol, - equal_nan=equal_nan, - ), + expect_np = np.array(expect_np) + assert ( + actual_np.shape == expect_np.shape + ), "Operator ({}) : Output ({}) shape mismatch, expect shape is {}, but actual shape is {}".format( + self.op_type, name, expect_np.shape, actual_np.shape + ) + np.testing.assert_allclose( + actual_np, + expect_np, + atol=atol, + rtol=self.rtol if hasattr(self, 'rtol') else rtol, + equal_nan=equal_nan, + err_msg=( "Operator (" + self.op_type + ") Output (" @@ -1947,8 +1918,9 @@ class OpTest(unittest.TestCase): + ") has diff at " + str(place) + " in " - + self.checker_name, - ) + + self.checker_name + ), + ) def convert_uint16_to_float_ifneed(self, actual_np, expect_np): if actual_np.dtype == np.uint16: @@ -2240,6 +2212,11 @@ class OpTest(unittest.TestCase): atol=1e-5, ): for a, b, name in zip(numeric_grads, analytic_grads, names): + assert tuple(a.shape) == tuple( + b.shape + ), "Operator ({}) : Output ({}) gradient shape mismatch, expect shape is {}, but actual shape is {}".format( + self.op_type, name, a.shape, b.shape + ) # Used by bfloat16 for now to solve precision problem if self.is_bfloat16_op(): if a.size == 0: @@ -2721,7 +2698,7 @@ class OpTest(unittest.TestCase): inputs=paddle.utils.flatten(inputs), grad_outputs=grad_outputs, ) - return [grad.numpy() for grad in grad_inputs] + return [grad.numpy(False) for grad in grad_inputs] @staticmethod def _numpy_to_lod_tensor(np_value, lod, place): diff --git a/test/legacy_test/parallel_dygraph_dataparallel_with_pylayer.py b/test/legacy_test/parallel_dygraph_dataparallel_with_pylayer.py index ba6aff81fc0..febf4b5f87d 100644 --- a/test/legacy_test/parallel_dygraph_dataparallel_with_pylayer.py +++ b/test/legacy_test/parallel_dygraph_dataparallel_with_pylayer.py @@ -100,8 +100,8 @@ class TestDistTraning(unittest.TestCase): model_b.clear_gradients() def check_acc(self, grad, acc_grad): - grad = grad.numpy() if grad is not None else None - acc_grad = acc_grad.numpy() if acc_grad is not None else None + grad = grad.numpy(False) if grad is not None else None + acc_grad = acc_grad.numpy(False) if acc_grad is not None else None return np.testing.assert_allclose(grad, acc_grad, rtol=1e-6) def broadcast_param(self, param, root): @@ -115,7 +115,9 @@ class TestDistTraning(unittest.TestCase): grad = param._grad_ivar() other_grad = self.broadcast_param(grad.clone(), root=1) if self.trainer_id == 0: - np.testing.assert_allclose(other_grad.numpy(), grad.numpy()) + np.testing.assert_allclose( + other_grad.numpy(False), grad.numpy(False) + ) if __name__ == '__main__': diff --git a/test/legacy_test/parallel_dygraph_gradient_check.py b/test/legacy_test/parallel_dygraph_gradient_check.py index 0ce17bc0848..a6c47b65d8f 100644 --- a/test/legacy_test/parallel_dygraph_gradient_check.py +++ b/test/legacy_test/parallel_dygraph_gradient_check.py @@ -113,8 +113,8 @@ class TestDistTraning(unittest.TestCase): def check_acc(self, grad, grad_sum, acc_grad): if grad is not None: - grad_sum = grad_sum + grad.numpy() - acc_grad = acc_grad.numpy() if acc_grad is not None else None + grad_sum = grad_sum + grad.numpy(False) + acc_grad = acc_grad.numpy(False) if acc_grad is not None else None np.testing.assert_allclose(grad_sum, acc_grad, rtol=1e-6) return grad_sum @@ -133,7 +133,9 @@ class TestDistTraning(unittest.TestCase): grad = param._grad_ivar() other_grad = self.broadcast_param(grad.clone(), root=1) if self.trainer_id == 0: - np.testing.assert_allclose(other_grad.numpy(), grad.numpy()) + np.testing.assert_allclose( + other_grad.numpy(False), grad.numpy(False) + ) if __name__ == '__main__': diff --git a/test/legacy_test/parallel_dygraph_gradient_check_in_eager_mode.py b/test/legacy_test/parallel_dygraph_gradient_check_in_eager_mode.py index 96d1de3bf7d..df66ff7616a 100644 --- a/test/legacy_test/parallel_dygraph_gradient_check_in_eager_mode.py +++ b/test/legacy_test/parallel_dygraph_gradient_check_in_eager_mode.py @@ -121,8 +121,8 @@ class TestDistTraning(unittest.TestCase): def check_acc(self, grad, grad_sum, acc_grad): if grad is not None: - grad_sum = grad_sum + grad.numpy() - acc_grad = acc_grad.numpy() if acc_grad is not None else None + grad_sum = grad_sum + grad.numpy(False) + acc_grad = acc_grad.numpy(False) if acc_grad is not None else None np.testing.assert_allclose(grad_sum, acc_grad, rtol=1e-6) return grad_sum @@ -141,7 +141,9 @@ class TestDistTraning(unittest.TestCase): grad = param.grad other_grad = self.broadcast_param(grad, root=1) if self.trainer_id == 0: - np.testing.assert_allclose(other_grad.numpy(), grad.numpy()) + np.testing.assert_allclose( + other_grad.numpy(False), grad.numpy(False) + ) if __name__ == '__main__': diff --git a/test/legacy_test/test_accuracy_op.py b/test/legacy_test/test_accuracy_op.py index 5c6efe69382..ae60df680e2 100755 --- a/test/legacy_test/test_accuracy_op.py +++ b/test/legacy_test/test_accuracy_op.py @@ -44,9 +44,9 @@ class TestAccuracyOp(OpTest): num_correct += 1 break self.outputs = { - 'Accuracy': np.array([num_correct / float(n)]).astype(self.dtype), - 'Correct': np.array([num_correct]).astype("int32"), - 'Total': np.array([n]).astype("int32"), + 'Accuracy': np.array(num_correct / float(n)).astype(self.dtype), + 'Correct': np.array(num_correct).astype("int32"), + 'Total': np.array(n).astype("int32"), } def init_dtype(self): @@ -91,10 +91,10 @@ class TestAccuracyOpBf16(OpTest): break self.outputs = { 'Accuracy': convert_float_to_uint16( - np.array([num_correct / float(n)]).astype(np.float32) + np.array(num_correct / float(n)).astype(np.float32) ), - 'Correct': np.array([num_correct]).astype("int32"), - 'Total': np.array([n]).astype("int32"), + 'Correct': np.array(num_correct).astype("int32"), + 'Total': np.array(n).astype("int32"), } def init_dtype(self): diff --git a/test/legacy_test/test_adam_op.py b/test/legacy_test/test_adam_op.py index 3196e1b2830..d483b4ebb5a 100644 --- a/test/legacy_test/test_adam_op.py +++ b/test/legacy_test/test_adam_op.py @@ -639,8 +639,8 @@ class TestAdamOpWithSkipUpdate(OpTest): 'Moment1Out': moment1, 'Moment2Out': moment2, 'ParamOut': param, - 'Beta1PowOut': self.inputs['Beta1Pow'], - 'Beta2PowOut': self.inputs['Beta2Pow'], + 'Beta1PowOut': np.array([]), + 'Beta2PowOut': np.array([]), } def test_check_output(self): diff --git a/test/legacy_test/test_allclose_op.py b/test/legacy_test/test_allclose_op.py index 48e2d89931c..d79d3cd69af 100644 --- a/test/legacy_test/test_allclose_op.py +++ b/test/legacy_test/test_allclose_op.py @@ -42,15 +42,13 @@ class TestAllcloseOp(OpTest): self.attrs = {'equal_nan': self.equal_nan} self.outputs = { 'Out': np.array( - [ - np.allclose( - self.inputs['Input'], - self.inputs['Other'], - rtol=self.rtol, - atol=self.atol, - equal_nan=self.equal_nan, - ) - ] + np.allclose( + self.inputs['Input'], + self.inputs['Other'], + rtol=self.rtol, + atol=self.atol, + equal_nan=self.equal_nan, + ) ) } diff --git a/test/legacy_test/test_arg_min_max_v2_op.py b/test/legacy_test/test_arg_min_max_v2_op.py index ab68f5fc3a5..bca85ebd488 100644 --- a/test/legacy_test/test_arg_min_max_v2_op.py +++ b/test/legacy_test/test_arg_min_max_v2_op.py @@ -117,7 +117,7 @@ def create_kernel_case(op_type, numpy_op_type): self.dtype = "float64" self.x = 1000 * np.random.random(self.dims).astype(self.dtype) self.inputs = {'X': self.x} - self.attrs = {"axis": self.axis, "flatten": True, "keepdims": True} + self.attrs = {"axis": self.axis, "flatten": True, "keepdims": False} self.numpy_op = eval("np.%s" % (numpy_op_type)) self.outputs = { 'Out': np.array(self.numpy_op(self.x.flatten(), axis=self.axis)) diff --git a/test/legacy_test/test_auc_op.py b/test/legacy_test/test_auc_op.py index 19c7120f858..81261ab5b0e 100644 --- a/test/legacy_test/test_auc_op.py +++ b/test/legacy_test/test_auc_op.py @@ -99,8 +99,8 @@ class TestGlobalAucOp(OpTest): neg = python_auc._stat_neg self.outputs = { 'AUC': np.array(python_auc.accumulate()), - 'StatPosOut': np.array(pos), - 'StatNegOut': np.array(neg), + 'StatPosOut': np.array([pos]), + 'StatNegOut': np.array([neg]), } def test_check_output(self): @@ -132,8 +132,9 @@ class TestAucAPI(unittest.TestCase): feed={"input": x, "label": y, "ins_tag_weight": z}, fetch_list=[result[0]], ) - auc_np = np.array([0.66666667]).astype("float32") + auc_np = np.array(0.66666667).astype("float32") np.testing.assert_allclose(output, auc_np, rtol=1e-05) + assert auc_np.shape == auc_np.shape class TestAucOpError(unittest.TestCase): diff --git a/test/legacy_test/test_auc_single_pred_op.py b/test/legacy_test/test_auc_single_pred_op.py index d6dd4e9b0b8..3445b5a0a11 100644 --- a/test/legacy_test/test_auc_single_pred_op.py +++ b/test/legacy_test/test_auc_single_pred_op.py @@ -104,8 +104,8 @@ class TestAucGlobalSinglePredOp(OpTest): neg = python_auc._stat_neg self.outputs = { 'AUC': np.array(python_auc.accumulate()), - 'StatPosOut': np.array(pos), - 'StatNegOut': np.array(neg), + 'StatPosOut': np.array([pos]), + 'StatNegOut': np.array([neg]), } def test_check_output(self): diff --git a/test/legacy_test/test_cross_entropy_op.py b/test/legacy_test/test_cross_entropy_op.py index f7bc9f62d0c..0cf4e0a6c2f 100644 --- a/test/legacy_test/test_cross_entropy_op.py +++ b/test/legacy_test/test_cross_entropy_op.py @@ -58,7 +58,7 @@ class TestCrossEntropyOp(OpTest): ) def get_cross_entropy(self): - self.cross_entropy = np.asmatrix( + self.cross_entropy = np.array( [ [-np.log(self.x[i][self.label[i][0]])] for i in range(self.x.shape[0]) @@ -91,7 +91,7 @@ class TestCrossEntropyOpRemoveLastDim(TestCrossEntropyOp): ) def get_cross_entropy(self): - self.cross_entropy = np.asmatrix( + self.cross_entropy = np.array( [-np.log(self.x[i][self.label[i]]) for i in range(self.x.shape[0])], dtype="float64", ) @@ -140,7 +140,7 @@ class TestCrossEntropyOp3(TestCrossEntropyOp): self.label[np.arange(self.batch_size), self.label_index] = 1 def get_cross_entropy(self): - self.cross_entropy = np.asmatrix( + self.cross_entropy = np.array( [ [-np.log(self.x[i][self.label_index[i]])] for i in range(self.x.shape[0]) @@ -181,7 +181,7 @@ class TestCrossEntropyOp4(TestCrossEntropyOp): self.label = self.label_2d.reshape(self.shape + [1]) def get_cross_entropy(self): - cross_entropy_2d = np.asmatrix( + cross_entropy_2d = np.array( [ [-np.log(self.X_2d[i][self.label_2d[i][0]])] for i in range(self.X_2d.shape[0]) @@ -211,7 +211,7 @@ class TestCrossEntropyOp4RemoveLastDim(TestCrossEntropyOp4): self.label = self.label_2d.reshape(self.shape) def get_cross_entropy(self): - cross_entropy_2d = np.asmatrix( + cross_entropy_2d = np.array( [ [-np.log(self.X_2d[i][self.label_2d[i][0]])] for i in range(self.X_2d.shape[0]) @@ -285,7 +285,7 @@ class TestCrossEntropyOp6(TestCrossEntropyOp): ) def get_cross_entropy(self): - cross_entropy_2d = np.asmatrix( + cross_entropy_2d = np.array( [ [-np.log(self.X_2d[i][self.label_index_2d[i]])] for i in range(self.X_2d.shape[0]) @@ -321,7 +321,7 @@ class TestCrossEntropyOp7(TestCrossEntropyOp): ) def get_cross_entropy(self): - self.cross_entropy = np.asmatrix( + self.cross_entropy = np.array( [ [-np.log(self.x[i][self.label[i][0]])] if self.label[i][0] != self.ignore_index @@ -351,7 +351,7 @@ class TestCrossEntropyOp7RemoveLastDim(TestCrossEntropyOp7): ) def get_cross_entropy(self): - self.cross_entropy = np.asmatrix( + self.cross_entropy = np.array( [ [-np.log(self.x[i][self.label[i]])] if self.label[i] != self.ignore_index diff --git a/test/legacy_test/test_ctc_align.py b/test/legacy_test/test_ctc_align.py index a1eb774cc31..b95ac8a3c6e 100644 --- a/test/legacy_test/test_ctc_align.py +++ b/test/legacy_test/test_ctc_align.py @@ -37,7 +37,7 @@ def CTCAlign(input, lod, blank, merge_repeated, padding=0, input_length=None): cur_offset += lod0[i] result = np.array(result).reshape([len(result), 1]).astype("int32") if len(result) == 0: - result = np.array([-1]) + result = np.array([[-1]]) return result else: result = [[] for i in range(len(input))] diff --git a/test/legacy_test/test_data_norm_op.py b/test/legacy_test/test_data_norm_op.py index 9bb518bfd4f..0839a778d30 100644 --- a/test/legacy_test/test_data_norm_op.py +++ b/test/legacy_test/test_data_norm_op.py @@ -252,8 +252,8 @@ class TestDataNormOp(OpTest): y = np.array(x_val) - mean = np.zeros(x_shape).astype(tp) - scale = np.ones(x_shape).astype(tp) + mean = np.zeros(x_shape[1]).astype(tp) + scale = np.ones(x_shape[1]).astype(tp) self.inputs = { "X": x_val, @@ -309,8 +309,8 @@ class TestDataNormOpWithEnableScaleAndShift(OpTest): y = np.array(x_val) - mean = np.zeros(x_shape).astype(tp) - scale = np.ones(x_shape).astype(tp) + mean = np.zeros(x_shape[1]).astype(tp) + scale = np.ones(x_shape[1]).astype(tp) self.inputs = { "X": x_val, @@ -373,8 +373,8 @@ class TestDataNormOpWithoutEnableScaleAndShift(OpTest): y = np.array(x_val) - mean = np.zeros(x_shape).astype(tp) - scale = np.ones(x_shape).astype(tp) + mean = np.zeros(x_shape[1]).astype(tp) + scale = np.ones(x_shape[1]).astype(tp) self.inputs = { "X": x_val, @@ -432,8 +432,8 @@ class TestDataNormOpWithEnableScaleAndShift_1(OpTest): y = np.array(x_val) - mean = np.zeros(x_shape).astype(tp) - scale = np.ones(x_shape).astype(tp) + mean = np.zeros(x_shape[1]).astype(tp) + scale = np.ones(x_shape[1]).astype(tp) self.inputs = { "X": x_val, @@ -493,8 +493,8 @@ class TestDataNormOpWithSlotDim(OpTest): y = np.array(x_val) - mean = np.zeros(x_shape).astype(tp) - scale = np.ones(x_shape).astype(tp) + mean = np.zeros(x_shape[1]).astype(tp) + scale = np.ones(x_shape[1]).astype(tp) self.inputs = { "X": x_val, diff --git a/test/legacy_test/test_edit_distance_op.py b/test/legacy_test/test_edit_distance_op.py index 324d83a4e5e..1bcedf8b99d 100644 --- a/test/legacy_test/test_edit_distance_op.py +++ b/test/legacy_test/test_edit_distance_op.py @@ -83,7 +83,7 @@ class TestEditDistanceOp(OpTest): num_strs = len(self.x1_lod) distance = np.zeros((num_strs, 1)).astype("float32") - sequence_num = np.array(2).astype("int64") + sequence_num = np.array([2]).astype("int64") x1_offset = 0 x2_offset = 0 @@ -128,7 +128,7 @@ class TestEditDistanceOpNormalizedCase0(OpTest): num_strs = len(self.x1_lod) distance = np.zeros((num_strs, 1)).astype("float32") - sequence_num = np.array(num_strs).astype("int64") + sequence_num = np.array([num_strs]).astype("int64") x1_offset = 0 x2_offset = 0 @@ -184,7 +184,7 @@ class TestEditDistanceOpNormalizedTensor(OpTest): num_strs = len(self.x1_lod) distance = np.zeros((num_strs, 1)).astype("float32") - sequence_num = np.array(num_strs).astype("int64") + sequence_num = np.array([num_strs]).astype("int64") for i in range(0, num_strs): distance[i] = Levenshtein( diff --git a/test/legacy_test/test_elementwise_div_op.py b/test/legacy_test/test_elementwise_div_op.py index 671ed2b4131..ea8fbac7f09 100644 --- a/test/legacy_test/test_elementwise_div_op.py +++ b/test/legacy_test/test_elementwise_div_op.py @@ -503,7 +503,7 @@ class TestDivideOp(unittest.TestCase): x = paddle.to_tensor(np_x) y = paddle.to_tensor(np_y) z = paddle.divide(x, y) - np_z = z.numpy() + np_z = z.numpy(False) z_expected = np.array([2.0, 0.6, 2.0]) self.assertEqual((np_z == z_expected).all(), True) diff --git a/test/legacy_test/test_elementwise_sub_op.py b/test/legacy_test/test_elementwise_sub_op.py index 59871057f00..880955402e2 100644 --- a/test/legacy_test/test_elementwise_sub_op.py +++ b/test/legacy_test/test_elementwise_sub_op.py @@ -889,7 +889,7 @@ class TestSubtractApi(unittest.TestCase): x = fluid.dygraph.to_variable(np_x) y = fluid.dygraph.to_variable(np_y) z = self._executed_api(x, y) - np_z = z.numpy() + np_z = z.numpy(False) z_expected = np.array([1.0, -2.0, 2.0]) self.assertEqual((np_z == z_expected).all(), True) diff --git a/test/legacy_test/test_fake_quantize_op.py b/test/legacy_test/test_fake_quantize_op.py index d1f8f37d360..9f89763731a 100644 --- a/test/legacy_test/test_fake_quantize_op.py +++ b/test/legacy_test/test_fake_quantize_op.py @@ -48,7 +48,7 @@ class TestFakeQuantizeAbsMaxOp(OpTest): ): input_data = distribution(input_shape).astype(dtype) compute_type = get_compute_type(dtype) - scale = np.max(np.abs(input_data)) + scale = np.max(np.abs(input_data)).flatten() bnt = (1 << (self.attrs['bit_length'] - 1)) - 1 inv_scale = 1.0 / (scale + 1e-6) if scale < 1e-30 else 1.0 / scale if round_type == 'TiesToEven': @@ -195,8 +195,8 @@ class TestFakeQuantizeRangeAbsMaxOp(OpTest): } self.outputs = { 'Out': output_data, - 'OutScale': out_scale[0], - 'OutScales': out_scale, + 'OutScale': np.array([], dtype) if is_test else out_scale, + 'OutScales': np.array([], dtype) if is_test else out_scale, } self.dtype = dtype self.attrs['is_test'] = is_test @@ -231,10 +231,10 @@ class TestMovingAverageAbsMaxScaleOp(OpTest): input_data = distribution(input_shape).astype(dtype) in_accum = np.ones(1).astype(dtype) in_state = np.ones(1).astype(dtype) - out_accum = self.attrs['moving_rate'] * in_accum[0] + np.max( + out_accum = self.attrs['moving_rate'] * in_accum + np.max( np.abs(input_data) ) - out_state = self.attrs['moving_rate'] * in_state[0] + 1.0 + out_state = self.attrs['moving_rate'] * in_state + 1.0 out_scale = out_accum / out_state self.inputs = { 'X': input_data, @@ -276,13 +276,10 @@ class TestFakeQuantizeMovingAverageAbsMaxOp(OpTest): in_accum = np.ones(1).astype(dtype) in_state = np.ones(1).astype(dtype) in_scale = np.array([0.001]).astype(dtype) - out_accum = np.zeros(1).astype(dtype) - out_state = np.zeros(1).astype(dtype) - out_scale = np.zeros(1).astype(dtype) - out_accum[0] = self.attrs['moving_rate'] * in_accum[0] + np.max( + out_accum = self.attrs['moving_rate'] * in_accum + np.max( np.abs(input_data) ) - out_state[0] = self.attrs['moving_rate'] * in_state[0] + 1.0 + out_state = self.attrs['moving_rate'] * in_state + 1.0 out_scale = out_accum / out_state if round_type == 'TiesToEven': round_out = np.round( @@ -354,7 +351,7 @@ class TestFakeQuantizeDequantizeAbsMaxOp(OpTest): self, dtype, input_shape, distribution, round_type='TiesAwayFromZero' ): input_data = distribution(input_shape).astype(dtype) - scale = np.max(np.abs(input_data)).astype(dtype) + scale = np.max(np.abs(input_data)).flatten().astype(dtype) bnt = (1 << (self.attrs['bit_length'] - 1)) - 1 if round_type == 'TiesToEven': round_out = np.round(input_data / scale * bnt) @@ -592,12 +589,8 @@ class TestquantizeOpTrain(TestquantizeOp): zero_point = np.zeros(scale.shape, dtype="int32") in_accum = np.ones(1).astype(self.data_type) in_state = np.ones(1).astype(self.data_type) - out_accum = np.zeros(1).astype(self.data_type) - out_state = np.zeros(1).astype(self.data_type) - out_accum[0] = self.attrs['moving_rate'] * in_accum[0] + np.max( - np.abs(x) - ) - out_state[0] = self.attrs['moving_rate'] * in_state[0] + 1.0 + out_accum = self.attrs['moving_rate'] * in_accum + np.max(np.abs(x)) + out_state = self.attrs['moving_rate'] * in_state + 1.0 out_scale = out_accum / out_state round_out = np.round(x / out_scale * self.max_range) diff --git a/test/legacy_test/test_fused_adam_op.py b/test/legacy_test/test_fused_adam_op.py index 55cf5ce3467..951943b1956 100644 --- a/test/legacy_test/test_fused_adam_op.py +++ b/test/legacy_test/test_fused_adam_op.py @@ -64,12 +64,8 @@ def fused_adam_step(inputs, attributes, num): ) for i in range(num): - beta1_pows_out.append( - np.array([beta1_pows[i][1]]).astype("float32") * beta1 - ) - beta2_pows_out.append( - np.array([beta2_pows[i][1]]).astype("float32") * beta2 - ) + beta1_pows_out.append(beta1_pows[i][1] * beta1) + beta2_pows_out.append(beta2_pows[i][1] * beta2) return ( params_out, diff --git a/test/legacy_test/test_gather_nd_op.py b/test/legacy_test/test_gather_nd_op.py index c564d2ae8c3..bc839b0c2db 100644 --- a/test/legacy_test/test_gather_nd_op.py +++ b/test/legacy_test/test_gather_nd_op.py @@ -93,7 +93,7 @@ class TestGatherNdOpWithIndex1(OpTest): target_dtype = "float32" xnp = np.random.random((5, 20)).astype(target_dtype) index = np.array([1]).astype("int32") - output = xnp[index] + output = xnp[index[-1]] if self.dtype == np.uint16: xnp = convert_float_to_uint16(xnp) output = convert_float_to_uint16(output) @@ -150,7 +150,7 @@ class TestGatherNdOpWithLowIndex(OpTest): target_dtype = "float32" xnp = np.random.uniform(0, 100, (10, 10)).astype(target_dtype) index = np.array([[1], [2]]).astype("int64") - output = xnp[tuple(index.T)] # [[14, 25, 1], [76, 22, 3]] + output = xnp[tuple(index.T)] # shape is [2, 10] if self.dtype == np.uint16: xnp = convert_float_to_uint16(xnp) diff --git a/test/legacy_test/test_generate_proposal_labels_op.py b/test/legacy_test/test_generate_proposal_labels_op.py index 67a499dd10e..f8cc3185fd4 100644 --- a/test/legacy_test/test_generate_proposal_labels_op.py +++ b/test/legacy_test/test_generate_proposal_labels_op.py @@ -420,7 +420,7 @@ class TestGenerateProposalLabelsOp(OpTest): self.bbox_targets = np.vstack(self.bbox_targets) self.bbox_inside_weights = np.vstack(self.bbox_inside_weights) self.bbox_outside_weights = np.vstack(self.bbox_outside_weights) - self.max_overlap_with_gt = np.vstack(self.max_overlap_with_gt) + self.max_overlap_with_gt = np.concatenate(self.max_overlap_with_gt) class TestCascade(TestGenerateProposalLabelsOp): diff --git a/test/legacy_test/test_is_empty_op.py b/test/legacy_test/test_is_empty_op.py index f771c33cb67..b1e212be747 100644 --- a/test/legacy_test/test_is_empty_op.py +++ b/test/legacy_test/test_is_empty_op.py @@ -25,7 +25,7 @@ class TestEmpty(OpTest): self.op_type = "is_empty" self.python_api = paddle.is_empty self.inputs = {'X': np.array([1, 2, 3])} - self.outputs = {'Out': np.array([False])} + self.outputs = {'Out': np.array(False)} def test_check_output(self): self.check_output() @@ -36,7 +36,7 @@ class TestNotEmpty(TestEmpty): self.op_type = "is_empty" self.python_api = paddle.is_empty self.inputs = {'X': np.array([])} - self.outputs = {'Out': np.array([True])} + self.outputs = {'Out': np.array(True)} class TestIsEmptyOpError(unittest.TestCase): diff --git a/test/legacy_test/test_isclose_op.py b/test/legacy_test/test_isclose_op.py index ad30997619f..c09d7fd7751 100644 --- a/test/legacy_test/test_isclose_op.py +++ b/test/legacy_test/test_isclose_op.py @@ -42,16 +42,12 @@ class TestIscloseOp(OpTest): } self.attrs = {'equal_nan': self.equal_nan} self.outputs = { - 'Out': np.array( - [ - np.isclose( - self.inputs['Input'], - self.inputs['Other'], - rtol=self.rtol, - atol=self.atol, - equal_nan=self.equal_nan, - ) - ] + 'Out': np.isclose( + self.inputs['Input'], + self.inputs['Other'], + rtol=self.rtol, + atol=self.atol, + equal_nan=self.equal_nan, ) } diff --git a/test/legacy_test/test_isfinite_op.py b/test/legacy_test/test_isfinite_op.py index efda5d502c6..d5a409489d8 100755 --- a/test/legacy_test/test_isfinite_op.py +++ b/test/legacy_test/test_isfinite_op.py @@ -31,7 +31,7 @@ class TestInf(OpTest): x[-1] = np.inf self.inputs = {'X': x} - self.outputs = {'Out': np.array(True).astype(self.dtype)} + self.outputs = {'Out': np.array([True]).astype(self.dtype)} def init_dtype(self): pass @@ -62,7 +62,7 @@ class TestInfBF16(OpTest): x[0] = np.inf x[-1] = np.inf - out = np.array(True) + out = np.array([True]) self.inputs = {'X': convert_float_to_uint16(x)} self.outputs = {'Out': out} @@ -81,7 +81,7 @@ class TestNAN(OpTest): x[-1] = np.nan self.inputs = {'X': x} - self.outputs = {'Out': np.array(True).astype(self.dtype)} + self.outputs = {'Out': np.array([True]).astype(self.dtype)} def init_dtype(self): pass @@ -112,7 +112,7 @@ class TestNANBF16(OpTest): x[0] = np.nan x[-1] = np.nan - out = np.array(True) + out = np.array([True]) self.inputs = {'X': convert_float_to_uint16(x)} self.outputs = {'Out': out} @@ -132,7 +132,7 @@ class TestIsfinite(OpTest): out = np.isinf(x) | np.isnan(x) self.inputs = {'X': x} - self.outputs = {'Out': np.array(False).astype(self.dtype)} + self.outputs = {'Out': np.array([False]).astype(self.dtype)} def init_dtype(self): pass @@ -163,7 +163,7 @@ class TestIsfiniteBF16(OpTest): x[0] = np.inf x[-1] = np.nan - out = np.array(False) + out = np.array([False]) self.inputs = {'X': convert_float_to_uint16(x)} self.outputs = {'Out': out} diff --git a/test/legacy_test/test_linspace.py b/test/legacy_test/test_linspace.py index 476e1725689..f36a5e7c8cb 100644 --- a/test/legacy_test/test_linspace.py +++ b/test/legacy_test/test_linspace.py @@ -66,7 +66,7 @@ class TestLinspaceOpNumOneCase(TestLinspaceOpCommonCase): 'Stop': np.array([0]).astype(self.dtype), 'Num': np.array([1]).astype('int32'), } - self.outputs = {'Out': np.array(10, dtype=self.dtype)} + self.outputs = {'Out': np.array([10], dtype=self.dtype)} def test_check_output(self): self.check_output() @@ -136,7 +136,7 @@ class TestLinspaceOpNumOneCaseBF16(TestLinspaceOpCommonCaseBF16): 'Num': np.array([1]).astype('int32'), } self.outputs = { - 'Out': convert_float_to_uint16(np.array(10, dtype="float32")) + 'Out': convert_float_to_uint16(np.array([10], dtype="float32")) } diff --git a/test/legacy_test/test_logspace.py b/test/legacy_test/test_logspace.py index 2982a12968c..e68dba46fef 100644 --- a/test/legacy_test/test_logspace.py +++ b/test/legacy_test/test_logspace.py @@ -113,7 +113,7 @@ class TestLogspaceOpNumOneCase(TestLogspaceOpCommonCase): 'Base': np.array([2]).astype(dtype), } self.attrs = {'dtype': int(paddle.float32)} - self.outputs = {'Out': np.power(2, np.array(10)).astype(dtype)} + self.outputs = {'Out': np.power(2, np.array([10])).astype(dtype)} class TestLogspaceOpMinusBaseCase(TestLogspaceOpCommonCase): diff --git a/test/legacy_test/test_mean_iou.py b/test/legacy_test/test_mean_iou.py index fe94ed9714e..551bfeda624 100644 --- a/test/legacy_test/test_mean_iou.py +++ b/test/legacy_test/test_mean_iou.py @@ -47,7 +47,7 @@ def compute_mean_iou( mean_iou = (out_correct / denominator).sum() / valid_count for _, in_mean_iou in in_mean_ious: - mean_iou += in_mean_iou + mean_iou += float(in_mean_iou) return mean_iou, out_wrong, out_correct @@ -84,21 +84,12 @@ class TestMeanIOUOp(OpTest): ) ) - in_mean_ious = [] - for i in range(self.in_mean_iou_num): - in_mean_ious.append( - ( - "in_mean_iou_%d" % i, - np.random.uniform(0, 1, [1]).astype("float32"), - ) - ) - self.inputs = { 'Predictions': predictions, 'Labels': labels, 'InWrongs': in_wrongs, 'InCorrects': in_corrects, - 'InMeanIou': in_mean_ious, + 'InMeanIou': self.in_mean_ious, } self.attrs = {'num_classes': int(self.num_classes)} mean_iou, out_wrong, out_correct = compute_mean_iou( @@ -107,7 +98,7 @@ class TestMeanIOUOp(OpTest): self.num_classes, in_wrongs, in_corrects, - in_mean_ious, + self.in_mean_ious, ) self.outputs = { 'OutMeanIou': mean_iou, @@ -120,7 +111,7 @@ class TestMeanIOUOp(OpTest): self.image_size = [128, 128] self.in_wrong_num = 0 self.in_correct_num = 0 - self.in_mean_iou_num = 0 + self.in_mean_ious = [] def test_check_output(self): self.check_output() @@ -132,7 +123,14 @@ class TestCase1(TestMeanIOUOp): self.image_size = [100, 128] self.in_wrong_num = 2 self.in_correct_num = 2 - self.in_mean_iou_num = 2 + self.in_mean_ious = [] + for i in range(2): + self.in_mean_ious.append( + ( + "in_mean_iou_%d" % i, + np.random.uniform(0, 1, []).astype("float32"), + ) + ) # NOTE(dev): Skip check_dygraph becuase Python API doesn't expose # in_wrong_num/in_correct_num/in_mean_iou_num argument @@ -140,5 +138,21 @@ class TestCase1(TestMeanIOUOp): self.check_output(check_dygraph=False) +class TestCase2(TestCase1): + def config(self): + self.num_classes = 5 + self.image_size = [100, 128] + self.in_wrong_num = 2 + self.in_correct_num = 2 + self.in_mean_ious = [] + for i in range(2): + self.in_mean_ious.append( + ( + "in_mean_iou_%d" % i, + np.random.uniform(0, 1, [1]).astype("float32"), + ) + ) + + if __name__ == '__main__': unittest.main() diff --git a/test/legacy_test/test_multiclass_nms_op.py b/test/legacy_test/test_multiclass_nms_op.py index 29ce4be023f..01ae6d835cf 100644 --- a/test/legacy_test/test_multiclass_nms_op.py +++ b/test/legacy_test/test_multiclass_nms_op.py @@ -613,7 +613,9 @@ class TestMulticlassNMS2Op(TestMulticlassNMSOp): else np.array([], dtype=np.float32).reshape([0, BOX_SIZE + 2]) ) index_outs = ( - det_outs[:, -1:].astype('int') if len(det_outs) else det_outs + det_outs[:, -1:].astype('int') + if len(det_outs) + else np.array([], dtype='int').reshape([0, 1]) ) self.op_type = 'multiclass_nms2' self.inputs = {'BBoxes': boxes, 'Scores': scores} @@ -685,7 +687,9 @@ class TestMulticlassNMS2LoDInput(TestMulticlassNMSLoDInput): else np.array([], dtype=np.float32).reshape([0, BOX_SIZE + 2]) ) index_outs = ( - det_outs[:, -1:].astype('int') if len(det_outs) else det_outs + det_outs[:, -1:].astype('int') + if len(det_outs) + else np.array([], dtype='int').reshape([0, 1]) ) self.op_type = 'multiclass_nms2' self.inputs = { @@ -760,7 +764,9 @@ class TestMulticlassNMS3Op(TestMulticlassNMS2Op): else np.array([], dtype=np.float32).reshape([0, BOX_SIZE + 2]) ) index_outs = ( - det_outs[:, -1:].astype('int') if len(det_outs) else det_outs + det_outs[:, -1:].astype('int') + if len(det_outs) + else np.array([], dtype='int').reshape([0, 1]) ) self.op_type = 'multiclass_nms3' self.inputs = {'BBoxes': boxes, 'Scores': scores} diff --git a/test/legacy_test/test_nll_loss.py b/test/legacy_test/test_nll_loss.py index 3c4c204f036..e0bb3882806 100644 --- a/test/legacy_test/test_nll_loss.py +++ b/test/legacy_test/test_nll_loss.py @@ -38,9 +38,9 @@ def nll_loss_1d( total_weight += cur_weight out[i] = -logs[i][cur_target] * cur_weight if reduction == 'sum': - return np.sum(out), np.array([total_weight]).astype('float64') + return np.sum(out), np.array(total_weight).astype('float64') elif reduction == 'mean': - return out.sum() / total_weight, np.array([total_weight]).astype( + return out.sum() / total_weight, np.array(total_weight).astype( 'float64' ) elif reduction == 'none': @@ -67,9 +67,9 @@ def nll_loss_2d( total_weight += cur_weight out[i][h][w] = -logs[i][cur_target][h][w] * cur_weight if reduction == 'sum': - return np.sum(out), np.array([total_weight]).astype('float64') + return np.sum(out), np.array(total_weight).astype('float64') elif reduction == 'mean': - return out.sum() / total_weight, np.array([total_weight]).astype( + return out.sum() / total_weight, np.array(total_weight).astype( 'float64' ) elif reduction == 'none': @@ -993,7 +993,7 @@ class TestNLLLossOp1DNoReduce(OpTest): 0, self.input_shape[1], self.label_shape ).astype("int64") output_np = nll_loss_1d(input_np, label_np, reduction='none') - total_weight_np = np.array([0]).astype('float64') + total_weight_np = np.array(0).astype('float64') self.inputs = {'X': input_np, 'Label': label_np} if self.with_weight: np.random.seed(200) @@ -1094,7 +1094,7 @@ class TestNLLLossOp2DNoReduce(OpTest): 0, self.input_shape[1], self.label_shape ).astype("int64") output_np = nll_loss_2d(input_np, label_np, reduction='none') - total_weight_np = np.array([0]).astype('float64') + total_weight_np = np.array(0).astype('float64') self.inputs = {'X': input_np, 'Label': label_np} if self.with_weight: np.random.seed(200) diff --git a/test/legacy_test/test_numel_op.py b/test/legacy_test/test_numel_op.py index c841cde6cbb..5c8c477877c 100644 --- a/test/legacy_test/test_numel_op.py +++ b/test/legacy_test/test_numel_op.py @@ -31,7 +31,7 @@ class TestNumelOp(OpTest): self.inputs = { 'Input': x, } - self.outputs = {'Out': np.array([np.size(x)])} + self.outputs = {'Out': np.array(np.size(x))} def test_check_output(self): self.check_output() @@ -84,7 +84,7 @@ class TestNumelOpBF16(OpTest): self.init() x = np.random.random(self.shape).astype(np.float32) self.inputs = {'Input': convert_float_to_uint16(x)} - self.outputs = {'Out': np.array([np.size(x)])} + self.outputs = {'Out': np.array(np.size(x))} def test_check_output(self): place = paddle.CUDAPlace(0) diff --git a/test/legacy_test/test_positive_negative_pair_op.py b/test/legacy_test/test_positive_negative_pair_op.py index 692acb5fa58..e01aa7fdbfb 100644 --- a/test/legacy_test/test_positive_negative_pair_op.py +++ b/test/legacy_test/test_positive_negative_pair_op.py @@ -47,9 +47,9 @@ def py_pnpair_op(score, label, query, column=-1, weight=None): neg += w return ( - np.array(pos).astype('float32'), - np.array(neg).astype('float32'), - np.array(neu).astype('float32'), + np.array([pos]).astype('float32'), + np.array([neg]).astype('float32'), + np.array([neu]).astype('float32'), ) diff --git a/test/legacy_test/test_reduce_op.py b/test/legacy_test/test_reduce_op.py index f567f427532..cb9c9f02c6a 100644 --- a/test/legacy_test/test_reduce_op.py +++ b/test/legacy_test/test_reduce_op.py @@ -73,7 +73,7 @@ class TestComplexSumOP(TestSumOp): class TestSumOp_ZeroDim(TestSumOp): def init_attrs(self): - self.attrs = {'dim': [], 'reduce_all': True} + self.attrs = {'dim': []} def init_input(self): self.x = np.random.random([]).astype(self.dtype) @@ -736,10 +736,14 @@ class TestProd8DBFP16OP(TestProd8DOp): self.check_grad_with_place(paddle.CUDAPlace(0), ['X'], 'Out') +def reduce_all_wrapper(x, axis=None, keepdim=False, reduce_all=True, name=None): + return paddle.all(x, axis, keepdim, name) + + class TestAllOp(OpTest): def setUp(self): self.op_type = "reduce_all" - self.python_api = paddle.all + self.python_api = reduce_all_wrapper self.inputs = {'X': np.random.randint(0, 2, (5, 6, 10)).astype("bool")} self.outputs = {'Out': self.inputs['X'].all()} self.attrs = {'reduce_all': True} @@ -754,7 +758,7 @@ class TestAllOp_ZeroDim(OpTest): self.op_type = "reduce_all" self.inputs = {'X': np.random.randint(0, 2, []).astype("bool")} self.outputs = {'Out': self.inputs['X'].all()} - self.attrs = {'dim': [], 'reduce_all': True} + self.attrs = {'dim': []} def test_check_output(self): self.check_output() @@ -769,7 +773,7 @@ class TestAll8DOp(OpTest): "bool" ) } - self.attrs = {'reduce_all': True, 'dim': (2, 3, 4)} + self.attrs = {'dim': (2, 3, 4)} self.outputs = {'Out': self.inputs['X'].all(axis=self.attrs['dim'])} def test_check_output(self): @@ -851,10 +855,14 @@ class TestAllOpError(unittest.TestCase): self.assertRaises(TypeError, paddle.all, input2) +def reduce_any_wrapper(x, axis=None, keepdim=False, reduce_all=True, name=None): + return paddle.any(x, axis, keepdim, name) + + class TestAnyOp(OpTest): def setUp(self): self.op_type = "reduce_any" - self.python_api = paddle.any + self.python_api = reduce_any_wrapper self.inputs = {'X': np.random.randint(0, 2, (5, 6, 10)).astype("bool")} self.outputs = {'Out': self.inputs['X'].any()} self.attrs = {'reduce_all': True} @@ -869,7 +877,7 @@ class TestAnyOp_ZeroDim(OpTest): self.op_type = "reduce_any" self.inputs = {'X': np.random.randint(0, 2, []).astype("bool")} self.outputs = {'Out': self.inputs['X'].any()} - self.attrs = {'dim': [], 'reduce_all': True} + self.attrs = {'dim': []} def test_check_output(self): self.check_output() @@ -884,7 +892,7 @@ class TestAny8DOp(OpTest): "bool" ) } - self.attrs = {'reduce_all': True, 'dim': (3, 5, 4)} + self.attrs = {'dim': (3, 5, 4)} self.outputs = {'Out': self.inputs['X'].any(axis=self.attrs['dim'])} def test_check_output(self): @@ -1291,11 +1299,17 @@ class TestReduceSumWithNumelOne(OpTest): self.check_grad(['X'], 'Out', check_prim=False) +def reduce_sum_wrapper( + x, axis=None, keepdim=False, reduce_all=True, out_dtype=None, name=None +): + return paddle.sum(x, axis, out_dtype, keepdim, name) + + class TestReduceAll(OpTest): def setUp(self): self.op_type = "reduce_sum" - self.python_api = paddle.sum - self.public_python_api = paddle.sum + self.python_api = reduce_sum_wrapper + self.public_python_api = reduce_sum_wrapper self.prim_op_type = "prim" self.inputs = {'X': np.random.random((100, 1, 1)).astype("float64")} self.attrs = {'reduce_all': True, 'keep_dim': False} @@ -1312,8 +1326,8 @@ class TestReduceAll(OpTest): class TestReduceAllFp32(OpTest): def setUp(self): self.op_type = "reduce_sum" - self.python_api = paddle.sum - self.public_python_api = paddle.sum + self.python_api = reduce_sum_wrapper + self.public_python_api = reduce_sum_wrapper self.prim_op_type = "prim" self.inputs = {'X': np.random.random((100, 1, 1)).astype("float32")} self.attrs = {'reduce_all': True, 'keep_dim': False} @@ -1345,15 +1359,17 @@ class Test1DReduceWithAxes1(OpTest): self.check_grad(['X'], 'Out', check_prim=True) -def reduce_sum_wrapper(x, axis=None, out_dtype=None, keepdim=False, name=None): - return paddle.sum(x, axis, "float64", keepdim, name) +def reduce_sum_wrapper_fp64( + x, axis=None, keepdim=False, reduce_all=True, out_dtype=None, name=None +): + return paddle.sum(x, axis, 'float64', keepdim, name) class TestReduceWithDtype(OpTest): def setUp(self): self.op_type = "reduce_sum" - self.python_api = reduce_sum_wrapper - self.public_python_api = reduce_sum_wrapper + self.python_api = reduce_sum_wrapper_fp64 + self.public_python_api = reduce_sum_wrapper_fp64 self.prim_op_type = "prim" self.inputs = {'X': np.random.random((6, 2, 10)).astype("float64")} self.outputs = {'Out': self.inputs['X'].sum().astype('float64')} @@ -1375,8 +1391,8 @@ class TestReduceWithDtype(OpTest): class TestReduceWithDtype1(TestReduceWithDtype): def setUp(self): self.op_type = "reduce_sum" - self.python_api = reduce_sum_wrapper - self.public_python_api = reduce_sum_wrapper + self.python_api = paddle.sum + self.public_python_api = paddle.sum self.prim_op_type = "prim" self.inputs = {'X': np.random.random((6, 2, 10)).astype("float64")} self.outputs = {'Out': self.inputs['X'].sum(axis=1)} @@ -1401,8 +1417,8 @@ class TestReduceWithDtype2(TestReduceWithDtype): def setUp(self): self.op_type = "reduce_sum" self.prim_op_type = "prim" - self.python_api = reduce_sum_wrapper - self.public_python_api = reduce_sum_wrapper + self.python_api = paddle.sum + self.public_python_api = paddle.sum self.inputs = {'X': np.random.random((6, 2, 10)).astype("float64")} self.outputs = {'Out': self.inputs['X'].sum(axis=1, keepdims=True)} self.attrs = {'dim': [1], 'keep_dim': True} diff --git a/test/legacy_test/test_scatter_op.py b/test/legacy_test/test_scatter_op.py index 34c30e6591d..2a222c9d96a 100644 --- a/test/legacy_test/test_scatter_op.py +++ b/test/legacy_test/test_scatter_op.py @@ -674,13 +674,13 @@ class TestScatterOpFp16(OpTest): ) ref_grad_updates = self.compute_ref_grad_updates() np.testing.assert_allclose( - ref_grad_updates.numpy(), - updates_tensor.grad.numpy(), + ref_grad_updates.numpy(False), + updates_tensor.grad.numpy(False), rtol=1e-5, atol=1e-5, ) np.testing.assert_allclose( - self.ref_dx, x_tensor.grad.numpy(), rtol=1e-5, atol=1e-5 + self.ref_dx, x_tensor.grad.numpy(False), rtol=1e-5, atol=1e-5 ) diff --git a/test/legacy_test/test_seed_op.py b/test/legacy_test/test_seed_op.py index 426dab42148..9127d4cd0ae 100644 --- a/test/legacy_test/test_seed_op.py +++ b/test/legacy_test/test_seed_op.py @@ -28,7 +28,7 @@ class TestSeedOpFixSeed(OpTest): self.op_type = "seed" self.inputs = {} self.attrs = {"seed": 123} - self.outputs = {"Out": np.asarray(123).astype('int')} + self.outputs = {"Out": np.array([123]).astype('int')} def test_check_output(self): self.check_output() @@ -39,7 +39,7 @@ class TestSeedOpDiffSeed(OpTest): self.op_type = "seed" self.inputs = {} self.attrs = {"seed": 0} - self.outputs = {"Out": np.asarray(123).astype('int')} + self.outputs = {"Out": np.array([123]).astype('int')} def test_check_output(self): self.check_output(no_check_set=["Out"]) diff --git a/test/legacy_test/test_segment_ops.py b/test/legacy_test/test_segment_ops.py index ab71a515c4a..d2be362e650 100644 --- a/test/legacy_test/test_segment_ops.py +++ b/test/legacy_test/test_segment_ops.py @@ -218,6 +218,14 @@ class TestSegmentMean(TestSegmentOps): } self.convert_bf16() + def test_check_output(self): + if core.is_compiled_with_cuda(): + self.check_output_with_place(core.CUDAPlace(0)) + # due to CPU kernel not implement calculate 'SummedIds' + # so cannot check 'SummedIds' + del self.outputs['SummedIds'] + self.check_output_with_place(core.CPUPlace()) + class TestSegmentMean2(TestSegmentMean): def prepare(self): diff --git a/test/legacy_test/test_slice_op.py b/test/legacy_test/test_slice_op.py index f43bd4b140d..629ee43be35 100644 --- a/test/legacy_test/test_slice_op.py +++ b/test/legacy_test/test_slice_op.py @@ -148,9 +148,9 @@ class TestSliceOp_decs_dim(OpTest): self.starts = [1, 0, 2] self.ends = [2, 3, 4] self.axes = [0, 1, 2] - self.decrease_axis = [0] + self.decrease_axis = [] self.infer_flags = [1, 1, 1] - self.out = self.input[1, 0:3, 2:4, :] + self.out = self.input[1:2, 0:3, 2:4, :] def test_check_output(self): self.check_output() diff --git a/test/legacy_test/test_squared_l2_norm_op.py b/test/legacy_test/test_squared_l2_norm_op.py index 17586d94f3b..4067acd29c5 100755 --- a/test/legacy_test/test_squared_l2_norm_op.py +++ b/test/legacy_test/test_squared_l2_norm_op.py @@ -81,7 +81,7 @@ class TestL2LossOp(OpTest): X = np.random.uniform(-1, 1, (13, 19)).astype("float32") X[np.abs(X) < self.max_relative_error] = 0.1 self.inputs = {'X': X} - self.outputs = {'Out': np.square(LA.norm(X))} + self.outputs = {'Out': np.array([np.square(LA.norm(X))])} def test_check_output(self): self.check_output() diff --git a/test/legacy_test/test_unbind_op.py b/test/legacy_test/test_unbind_op.py index 763aa2c3f24..80c3db774a7 100644 --- a/test/legacy_test/test_unbind_op.py +++ b/test/legacy_test/test_unbind_op.py @@ -77,7 +77,7 @@ class TestUnbind(unittest.TestCase): np_grad = np.ones(x.shape, np.float32) out.backward() - np.testing.assert_array_equal(x.grad.numpy(), np_grad) + np.testing.assert_array_equal(x.grad.numpy(False), np_grad) class TestLayersUnbind(unittest.TestCase): @@ -105,7 +105,9 @@ class TestUnbindOp(OpTest): pass def outReshape(self): - pass + self.out[0] = self.out[0].reshape((2, 2)) + self.out[1] = self.out[1].reshape((2, 2)) + self.out[2] = self.out[2].reshape((2, 2)) def setAxis(self): pass @@ -209,6 +211,7 @@ class TestUnbindFP16Op(OpTest): self.num = 3 x = np.arange(12).reshape(3, 2, 2).astype(self.dtype) self.out = np.split(x, self.num, self.axis) + self.outReshape() self.inputs = {'X': x} self.attrs = {'axis': self.axis} self.outputs = { @@ -216,6 +219,11 @@ class TestUnbindFP16Op(OpTest): } self.python_out_sig = ['out%d' % i for i in range(len(self.out))] + def outReshape(self): + self.out[0] = self.out[0].reshape((2, 2)) + self.out[1] = self.out[1].reshape((2, 2)) + self.out[2] = self.out[2].reshape((2, 2)) + def get_dtype(self): return np.float16 @@ -233,6 +241,7 @@ class TestUnbindBF16Op(OpTest): self.num = 3 x = np.arange(12).reshape(3, 2, 2).astype(self.dtype) self.out = np.split(x, self.num, self.axis) + self.outReshape() self.inputs = {'X': convert_float_to_uint16(x)} self.attrs = {'axis': self.axis} self.outputs = { @@ -243,6 +252,11 @@ class TestUnbindBF16Op(OpTest): } self.python_out_sig = ['out%d' % i for i in range(len(self.out))] + def outReshape(self): + self.out[0] = self.out[0].reshape((2, 2)) + self.out[1] = self.out[1].reshape((2, 2)) + self.out[2] = self.out[2].reshape((2, 2)) + def get_dtype(self): return np.uint16 @@ -277,7 +291,7 @@ class TestUnbindBool(unittest.TestCase): x = paddle.to_tensor([[True, True], [False, False]]) xs = paddle.unbind(x, axis=0) self.assertEqual(len(xs), 2) - np.testing.assert_array_equal(xs[0].numpy(), [True, True]) + np.testing.assert_array_equal(xs[0].numpy(False), [True, True]) class TestUnbindGradOptionalInput(unittest.TestCase): @@ -290,7 +304,7 @@ class TestUnbindGradOptionalInput(unittest.TestCase): a_grad = a.detach() a_grad[:, 0, :] = 1 - np.testing.assert_array_equal(a.grad.numpy(), a_grad.numpy()) + np.testing.assert_array_equal(a.grad.numpy(False), a_grad.numpy(False)) if __name__ == '__main__': diff --git a/test/xpu/parallel_dygraph_dataparallel_with_pylayer.py b/test/xpu/parallel_dygraph_dataparallel_with_pylayer.py index 2171cb24292..6db6e9e62a5 100644 --- a/test/xpu/parallel_dygraph_dataparallel_with_pylayer.py +++ b/test/xpu/parallel_dygraph_dataparallel_with_pylayer.py @@ -100,8 +100,8 @@ class TestDistTraning(unittest.TestCase): model_b.clear_gradients() def check_acc(self, grad, acc_grad): - grad = grad.numpy() if grad is not None else None - acc_grad = acc_grad.numpy() if acc_grad is not None else None + grad = grad.numpy(False) if grad is not None else None + acc_grad = acc_grad.numpy(False) if acc_grad is not None else None return np.testing.assert_allclose(grad, acc_grad, rtol=1e-6) def broadcast_param(self, param, root): @@ -115,7 +115,9 @@ class TestDistTraning(unittest.TestCase): grad = param._grad_ivar() other_grad = self.broadcast_param(grad.clone(), root=1) if self.trainer_id == 0: - np.testing.assert_allclose(other_grad.numpy(), grad.numpy()) + np.testing.assert_allclose( + other_grad.numpy(False), grad.numpy(False) + ) if __name__ == '__main__': diff --git a/test/xpu/parallel_dygraph_gradient_check.py b/test/xpu/parallel_dygraph_gradient_check.py index 1e8f5fadc85..9de687f524a 100644 --- a/test/xpu/parallel_dygraph_gradient_check.py +++ b/test/xpu/parallel_dygraph_gradient_check.py @@ -113,8 +113,8 @@ class TestDistTraning(unittest.TestCase): def check_acc(self, grad, grad_sum, acc_grad): if grad is not None: - grad_sum = grad_sum + grad.numpy() - acc_grad = acc_grad.numpy() if acc_grad is not None else None + grad_sum = grad_sum + grad.numpy(False) + acc_grad = acc_grad.numpy(False) if acc_grad is not None else None np.testing.assert_allclose(grad_sum, acc_grad, rtol=1e-6) return grad_sum @@ -133,7 +133,9 @@ class TestDistTraning(unittest.TestCase): grad = param._grad_ivar() other_grad = self.broadcast_param(grad.clone(), root=1) if self.trainer_id == 0: - np.testing.assert_allclose(other_grad.numpy(), grad.numpy()) + np.testing.assert_allclose( + other_grad.numpy(False), grad.numpy(False) + ) if __name__ == '__main__': diff --git a/test/xpu/parallel_dygraph_gradient_check_in_eager_mode.py b/test/xpu/parallel_dygraph_gradient_check_in_eager_mode.py index d193d4e88ad..f0e46b2db27 100644 --- a/test/xpu/parallel_dygraph_gradient_check_in_eager_mode.py +++ b/test/xpu/parallel_dygraph_gradient_check_in_eager_mode.py @@ -121,8 +121,8 @@ class TestDistTraning(unittest.TestCase): def check_acc(self, grad, grad_sum, acc_grad): if grad is not None: - grad_sum = grad_sum + grad.numpy() - acc_grad = acc_grad.numpy() if acc_grad is not None else None + grad_sum = grad_sum + grad.numpy(False) + acc_grad = acc_grad.numpy(False) if acc_grad is not None else None np.testing.assert_allclose(grad_sum, acc_grad, rtol=1e-6) return grad_sum @@ -141,7 +141,9 @@ class TestDistTraning(unittest.TestCase): grad = param.grad other_grad = self.broadcast_param(grad, root=1) if self.trainer_id == 0: - np.testing.assert_allclose(other_grad.numpy(), grad.numpy()) + np.testing.assert_allclose( + other_grad.numpy(False), grad.numpy(False) + ) if __name__ == '__main__': -- GitLab