From 6fc15986f70f69fd0af581b01ffa63c26e14c95a Mon Sep 17 00:00:00 2001 From: WangZhen <23097963+0x45f@users.noreply.github.com> Date: Tue, 30 Aug 2022 15:53:34 +0800 Subject: [PATCH] [OpAttr]Adapt tensor axis for argmin/max (#45453) * Adapt tensor axis for argmin/max * Add UT * Polish UT --- paddle/fluid/operators/arg_min_max_op_base.h | 10 ++- paddle/phi/api/yaml/legacy_api.yaml | 4 +- paddle/phi/infermeta/unary.cc | 63 ++++++++----- paddle/phi/infermeta/unary.h | 2 +- paddle/phi/kernels/arg_min_max_kernel.h | 5 +- paddle/phi/kernels/cpu/arg_min_max_kernel.cc | 10 +-- paddle/phi/kernels/gpu/arg_min_max_kernel.cu | 10 +-- .../tests/unittests/test_arg_min_max_op.py | 88 +++++++++++++++++++ python/paddle/tensor/search.py | 8 +- 9 files changed, 159 insertions(+), 41 deletions(-) diff --git a/paddle/fluid/operators/arg_min_max_op_base.h b/paddle/fluid/operators/arg_min_max_op_base.h index 577100ae742..0e44fd2fa27 100644 --- a/paddle/fluid/operators/arg_min_max_op_base.h +++ b/paddle/fluid/operators/arg_min_max_op_base.h @@ -31,6 +31,13 @@ namespace operators { class ArgMinMaxOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto input_data_type = + framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } }; class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker { @@ -42,7 +49,8 @@ class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker { void Make() override { AddInput("X", "Input tensor."); AddOutput("Out", "Output tensor."); - AddAttr("axis", "The axis in which to compute the arg indics."); + AddAttr("axis", "The axis in which to compute the arg indics.") + .SupportTensor(); AddAttr("keepdims", "Keep the dim that to reduce.").SetDefault(false); AddAttr("flatten", "Flatten the input value, and search the min or max indices") diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index 4dff5c4ac8c..44096f0ef1e 100755 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -197,7 +197,7 @@ support_trans_dtype : start, end, step - api : argmax - args : (Tensor x, int64_t axis, bool keepdims, bool flatten, int dtype) + args : (Tensor x, Scalar axis, bool keepdims, bool flatten, int dtype) output : Tensor(out) infer_meta : func : ArgMinMaxInferMeta @@ -205,7 +205,7 @@ func : arg_max - api : argmin - args : (Tensor x, int64_t axis, bool keepdims, bool flatten, int dtype) + args : (Tensor x, Scalar axis, bool keepdims, bool flatten, int dtype) output : Tensor(out) infer_meta : func : ArgMinMaxInferMeta diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 56e524b1bbb..b6e2d9eaf57 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -121,28 +121,12 @@ void AffineGridInferMeta(const MetaTensor& input, } void ArgMinMaxInferMeta(const MetaTensor& x, - int64_t axis, + const Scalar& axis, bool keepdims, bool flatten, int dtype, MetaTensor* out, MetaConfig config) { - const auto& x_dims = x.dims(); - - PADDLE_ENFORCE_GE( - axis, - -x_dims.size(), - phi::errors::InvalidArgument("'axis'(%d) must be greater than or equal to" - " -Rank(X)(%d).", - axis, - -x_dims.size())); - PADDLE_ENFORCE_LT(axis, - x_dims.size(), - phi::errors::InvalidArgument( - "'axis'(%d) must be less than Rank(X)(%d) of Input(X).", - axis, - x_dims.size())); - PADDLE_ENFORCE_EQ( (dtype < 0 || dtype == 2 || dtype == 3), true, @@ -156,8 +140,45 @@ void ArgMinMaxInferMeta(const MetaTensor& x, paddle::framework::DataTypeToString( static_cast(dtype)))); + if (!config.is_runtime && axis.FromTensor()) { + std::vector vec; + if (flatten) { + vec = {1}; + } else { + if (keepdims) { + vec = std::vector(x.dims().size(), -1); + } else { + vec = std::vector(x.dims().size() - 1, -1); + } + } + out->set_dims(phi::make_ddim(vec)); + if (dtype == 2) { + out->set_dtype(DataType::INT32); + } else if (dtype == 3) { + out->set_dtype(DataType::INT64); + } + return; + } + + auto int_axis = axis.to(); + const auto& x_dims = x.dims(); + + PADDLE_ENFORCE_GE( + int_axis, + -x_dims.size(), + phi::errors::InvalidArgument("'axis'(%d) must be greater than or equal to" + " -Rank(X)(%d).", + int_axis, + -x_dims.size())); + PADDLE_ENFORCE_LT(int_axis, + x_dims.size(), + phi::errors::InvalidArgument( + "'axis'(%d) must be less than Rank(X)(%d) of Input(X).", + int_axis, + x_dims.size())); + auto x_rank = x_dims.size(); - if (axis < 0) axis += x_rank; + if (int_axis < 0) int_axis += x_rank; if (config.is_runtime) { if (dtype == paddle::framework::proto::VarType::INT32) { int64_t all_element_num = 0; @@ -165,7 +186,7 @@ void ArgMinMaxInferMeta(const MetaTensor& x, all_element_num = phi::product(x_dims); } else { - all_element_num = x_dims[axis]; + all_element_num = x_dims[int_axis]; } PADDLE_ENFORCE_LE( all_element_num, @@ -182,11 +203,11 @@ void ArgMinMaxInferMeta(const MetaTensor& x, if (flatten) { vec.emplace_back(static_cast(1)); } else { - for (int64_t i = 0; i < axis; i++) vec.emplace_back(x_dims[i]); + for (int64_t i = 0; i < int_axis; i++) vec.emplace_back(x_dims[i]); if (keepdims) { vec.emplace_back(static_cast(1)); } - for (int64_t i = axis + 1; i < x_rank; i++) vec.emplace_back(x_dims[i]); + for (int64_t i = int_axis + 1; i < x_rank; i++) vec.emplace_back(x_dims[i]); } out->set_dims(phi::make_ddim(vec)); if (dtype == 2) { diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 9a67066cab2..1e7a65e9be8 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -40,7 +40,7 @@ void AffineGridInferMeta(const MetaTensor& input, MetaTensor* output); void ArgMinMaxInferMeta(const MetaTensor& x, - int64_t axis, + const Scalar& axis, bool keepdims, bool flatten, int dtype, diff --git a/paddle/phi/kernels/arg_min_max_kernel.h b/paddle/phi/kernels/arg_min_max_kernel.h index 917babeef07..258c8f21e05 100644 --- a/paddle/phi/kernels/arg_min_max_kernel.h +++ b/paddle/phi/kernels/arg_min_max_kernel.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include "paddle/phi/common/scalar.h" #include "paddle/phi/core/dense_tensor.h" namespace phi { @@ -21,7 +22,7 @@ namespace phi { template void ArgMinKernel(const Context& dev_ctx, const DenseTensor& x, - int64_t axis, + const Scalar& axis, bool keepdims, bool flatten, int dtype, @@ -30,7 +31,7 @@ void ArgMinKernel(const Context& dev_ctx, template void ArgMaxKernel(const Context& dev_ctx, const DenseTensor& x, - int64_t axis, + const Scalar& axis, bool keepdims, bool flatten, int dtype, diff --git a/paddle/phi/kernels/cpu/arg_min_max_kernel.cc b/paddle/phi/kernels/cpu/arg_min_max_kernel.cc index f4ad830e149..13e401b59d6 100644 --- a/paddle/phi/kernels/cpu/arg_min_max_kernel.cc +++ b/paddle/phi/kernels/cpu/arg_min_max_kernel.cc @@ -135,7 +135,7 @@ struct VisitDataArgMinMaxFunctor { template void ArgMinMaxKernel(const Context& dev_ctx, const DenseTensor& x, - int64_t axis, + const Scalar& axis, bool keepdims, bool flatten, int dtype, @@ -145,19 +145,19 @@ void ArgMinMaxKernel(const Context& dev_ctx, static_cast( paddle::framework::proto::VarType::INT64), VisitDataArgMinMaxFunctor( - dev_ctx, x, axis, keepdims, flatten, out)); + dev_ctx, x, axis.to(), keepdims, flatten, out)); return; } paddle::framework::VisitDataTypeTiny( static_cast(dtype), VisitDataArgMinMaxFunctor( - dev_ctx, x, axis, keepdims, flatten, out)); + dev_ctx, x, axis.to(), keepdims, flatten, out)); } template void ArgMinKernel(const Context& dev_ctx, const DenseTensor& x, - int64_t axis, + const Scalar& axis, bool keepdims, bool flatten, int dtype, @@ -169,7 +169,7 @@ void ArgMinKernel(const Context& dev_ctx, template void ArgMaxKernel(const Context& dev_ctx, const DenseTensor& x, - int64_t axis, + const Scalar& axis, bool keepdims, bool flatten, int dtype, diff --git a/paddle/phi/kernels/gpu/arg_min_max_kernel.cu b/paddle/phi/kernels/gpu/arg_min_max_kernel.cu index b1a40f03e61..13db1853495 100644 --- a/paddle/phi/kernels/gpu/arg_min_max_kernel.cu +++ b/paddle/phi/kernels/gpu/arg_min_max_kernel.cu @@ -203,7 +203,7 @@ struct VisitDataCudaArgMinMaxFunctor { template void ArgMinMaxOpCUDAKernel(const Context& dev_ctx, const DenseTensor& x, - int64_t axis, + const Scalar& axis, bool keepdims, bool flatten, int dtype, @@ -213,19 +213,19 @@ void ArgMinMaxOpCUDAKernel(const Context& dev_ctx, static_cast( paddle::framework::proto::VarType::INT64), VisitDataCudaArgMinMaxFunctor( - dev_ctx, x, axis, keepdims, flatten, out)); + dev_ctx, x, axis.to(), keepdims, flatten, out)); return; } paddle::framework::VisitDataTypeTiny( static_cast(dtype), VisitDataCudaArgMinMaxFunctor( - dev_ctx, x, axis, keepdims, flatten, out)); + dev_ctx, x, axis.to(), keepdims, flatten, out)); } template void ArgMinKernel(const Context& dev_ctx, const DenseTensor& x, - int64_t axis, + const Scalar& axis, bool keepdims, bool flatten, int dtype, @@ -237,7 +237,7 @@ void ArgMinKernel(const Context& dev_ctx, template void ArgMaxKernel(const Context& dev_ctx, const DenseTensor& x, - int64_t axis, + const Scalar& axis, bool keepdims, bool flatten, int dtype, diff --git a/python/paddle/fluid/tests/unittests/test_arg_min_max_op.py b/python/paddle/fluid/tests/unittests/test_arg_min_max_op.py index 6056f8f2106..dc5c0c17a49 100644 --- a/python/paddle/fluid/tests/unittests/test_arg_min_max_op.py +++ b/python/paddle/fluid/tests/unittests/test_arg_min_max_op.py @@ -14,6 +14,7 @@ from __future__ import print_function +import os import unittest import numpy as np from op_test import OpTest @@ -21,6 +22,7 @@ import paddle import paddle.fluid as fluid import paddle.fluid.core as core from paddle.fluid import Program, program_guard +from test_attribute_var import UnittestBase class BaseTestCase(OpTest): @@ -235,6 +237,92 @@ class BaseTestComplex2_2(OpTest): } +class TestArgMaxTensorAxis(UnittestBase): + + def init_info(self): + self.shapes = [[2, 3, 4]] + self.x = [np.random.randn(*shape) for shape in self.shapes] + self.save_path = os.path.join(self.temp_dir.name, self.path_prefix()) + + def test_static(self): + main_prog = Program() + starup_prog = Program() + with program_guard(main_prog, starup_prog): + fc = paddle.nn.Linear(4, 10) + x = paddle.randn([2, 3, 4]) + x.stop_gradient = False + feat = fc(x) + + out = self.call_func(feat) + + sgd = paddle.optimizer.SGD() + sgd.minimize(paddle.mean(paddle.cast(out, 'float32'))) + self.assertTrue(self.var_prefix() in str(main_prog)) + + exe = paddle.static.Executor() + exe.run(starup_prog) + res = exe.run(fetch_list=[feat, out]) + paddle.static.save_inference_model(self.save_path, [x], [feat, out], + exe) + gt = np.argmax(res[0], 0) + np.testing.assert_allclose(res[1], gt) + + # Test for Inference Predictor + infer_outs = self.infer_prog() + gt = np.argmax(infer_outs[0], 0) + np.testing.assert_allclose(infer_outs[1], gt) + + def path_prefix(self): + return 'argmax_tensor_axis' + + def var_prefix(self): + return "Var[" + + def call_func(self, x): + axis = paddle.assign(0) + out = paddle.argmax(x, axis) + return out + + +class TestArgMinTensorAxis(TestArgMaxTensorAxis): + + def test_static(self): + main_prog = Program() + starup_prog = Program() + with program_guard(main_prog, starup_prog): + fc = paddle.nn.Linear(4, 10) + x = paddle.randn([2, 3, 4]) + x.stop_gradient = False + feat = fc(x) + feat = paddle.cast(feat, 'int32') + out = self.call_func(feat) + + sgd = paddle.optimizer.SGD() + sgd.minimize(paddle.mean(paddle.cast(out, 'float32'))) + self.assertTrue(self.var_prefix() in str(main_prog)) + + exe = paddle.static.Executor() + exe.run(starup_prog) + res = exe.run(fetch_list=[feat, out]) + paddle.static.save_inference_model(self.save_path, [x], [feat, out], + exe) + gt = np.argmin(res[0], 1) + np.testing.assert_allclose(np.squeeze(res[1]), gt) + + # Test for Inference Predictor + infer_outs = self.infer_prog() + gt = np.argmin(infer_outs[0], 1) + np.testing.assert_allclose(np.squeeze(infer_outs[1]), gt) + + def path_prefix(self): + return 'argmin_tensor_axis' + + def call_func(self, x): + axis = paddle.assign(1) + out = paddle.argmin(x, axis, keepdim=True) + return out + + if __name__ == '__main__': paddle.enable_static() unittest.main() diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index b740a100358..fd8beb0f933 100644 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -162,9 +162,9 @@ def argmax(x, axis=None, keepdim=False, dtype="int64", name=None): print(out4) # [[2, 2, 0, 1]] """ - if axis is not None and not isinstance(axis, int): + if axis is not None and not isinstance(axis, (int, Variable)): raise TypeError( - "The type of 'axis' must be int or None in argmax, but received %s." + "The type of 'axis' must be int or Tensor or None in argmax, but received %s." % (type(axis))) if dtype is None: @@ -244,9 +244,9 @@ def argmin(x, axis=None, keepdim=False, dtype="int64", name=None): print(out4) # [[1, 1, 1, 2]] """ - if axis is not None and not isinstance(axis, int): + if axis is not None and not isinstance(axis, (int, Variable)): raise TypeError( - "The type of 'axis' must be int or None in argmin, but received %s." + "The type of 'axis' must be int or Tensor or None in argmin, but received %s." % (type(axis))) if dtype is None: -- GitLab