未验证 提交 6fc15986 编写于 作者: W WangZhen 提交者: GitHub

[OpAttr]Adapt tensor axis for argmin/max (#45453)

* Adapt tensor axis for argmin/max

* Add UT

* Polish UT
上级 5f1a8e46
...@@ -31,6 +31,13 @@ namespace operators { ...@@ -31,6 +31,13 @@ namespace operators {
class ArgMinMaxOp : public framework::OperatorWithKernel { class ArgMinMaxOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
}; };
class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker { class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -42,7 +49,8 @@ class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -42,7 +49,8 @@ class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override { void Make() override {
AddInput("X", "Input tensor."); AddInput("X", "Input tensor.");
AddOutput("Out", "Output tensor."); AddOutput("Out", "Output tensor.");
AddAttr<int64_t>("axis", "The axis in which to compute the arg indics."); AddAttr<int64_t>("axis", "The axis in which to compute the arg indics.")
.SupportTensor();
AddAttr<bool>("keepdims", "Keep the dim that to reduce.").SetDefault(false); AddAttr<bool>("keepdims", "Keep the dim that to reduce.").SetDefault(false);
AddAttr<bool>("flatten", AddAttr<bool>("flatten",
"Flatten the input value, and search the min or max indices") "Flatten the input value, and search the min or max indices")
......
...@@ -197,7 +197,7 @@ ...@@ -197,7 +197,7 @@
support_trans_dtype : start, end, step support_trans_dtype : start, end, step
- api : argmax - api : argmax
args : (Tensor x, int64_t axis, bool keepdims, bool flatten, int dtype) args : (Tensor x, Scalar axis, bool keepdims, bool flatten, int dtype)
output : Tensor(out) output : Tensor(out)
infer_meta : infer_meta :
func : ArgMinMaxInferMeta func : ArgMinMaxInferMeta
...@@ -205,7 +205,7 @@ ...@@ -205,7 +205,7 @@
func : arg_max func : arg_max
- api : argmin - api : argmin
args : (Tensor x, int64_t axis, bool keepdims, bool flatten, int dtype) args : (Tensor x, Scalar axis, bool keepdims, bool flatten, int dtype)
output : Tensor(out) output : Tensor(out)
infer_meta : infer_meta :
func : ArgMinMaxInferMeta func : ArgMinMaxInferMeta
......
...@@ -121,28 +121,12 @@ void AffineGridInferMeta(const MetaTensor& input, ...@@ -121,28 +121,12 @@ void AffineGridInferMeta(const MetaTensor& input,
} }
void ArgMinMaxInferMeta(const MetaTensor& x, void ArgMinMaxInferMeta(const MetaTensor& x,
int64_t axis, const Scalar& axis,
bool keepdims, bool keepdims,
bool flatten, bool flatten,
int dtype, int dtype,
MetaTensor* out, MetaTensor* out,
MetaConfig config) { MetaConfig config) {
const auto& x_dims = x.dims();
PADDLE_ENFORCE_GE(
axis,
-x_dims.size(),
phi::errors::InvalidArgument("'axis'(%d) must be greater than or equal to"
" -Rank(X)(%d).",
axis,
-x_dims.size()));
PADDLE_ENFORCE_LT(axis,
x_dims.size(),
phi::errors::InvalidArgument(
"'axis'(%d) must be less than Rank(X)(%d) of Input(X).",
axis,
x_dims.size()));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
(dtype < 0 || dtype == 2 || dtype == 3), (dtype < 0 || dtype == 2 || dtype == 3),
true, true,
...@@ -156,8 +140,45 @@ void ArgMinMaxInferMeta(const MetaTensor& x, ...@@ -156,8 +140,45 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
paddle::framework::DataTypeToString( paddle::framework::DataTypeToString(
static_cast<paddle::framework::proto::VarType::Type>(dtype)))); static_cast<paddle::framework::proto::VarType::Type>(dtype))));
if (!config.is_runtime && axis.FromTensor()) {
std::vector<int64_t> vec;
if (flatten) {
vec = {1};
} else {
if (keepdims) {
vec = std::vector<int64_t>(x.dims().size(), -1);
} else {
vec = std::vector<int64_t>(x.dims().size() - 1, -1);
}
}
out->set_dims(phi::make_ddim(vec));
if (dtype == 2) {
out->set_dtype(DataType::INT32);
} else if (dtype == 3) {
out->set_dtype(DataType::INT64);
}
return;
}
auto int_axis = axis.to<int64_t>();
const auto& x_dims = x.dims();
PADDLE_ENFORCE_GE(
int_axis,
-x_dims.size(),
phi::errors::InvalidArgument("'axis'(%d) must be greater than or equal to"
" -Rank(X)(%d).",
int_axis,
-x_dims.size()));
PADDLE_ENFORCE_LT(int_axis,
x_dims.size(),
phi::errors::InvalidArgument(
"'axis'(%d) must be less than Rank(X)(%d) of Input(X).",
int_axis,
x_dims.size()));
auto x_rank = x_dims.size(); auto x_rank = x_dims.size();
if (axis < 0) axis += x_rank; if (int_axis < 0) int_axis += x_rank;
if (config.is_runtime) { if (config.is_runtime) {
if (dtype == paddle::framework::proto::VarType::INT32) { if (dtype == paddle::framework::proto::VarType::INT32) {
int64_t all_element_num = 0; int64_t all_element_num = 0;
...@@ -165,7 +186,7 @@ void ArgMinMaxInferMeta(const MetaTensor& x, ...@@ -165,7 +186,7 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
all_element_num = phi::product(x_dims); all_element_num = phi::product(x_dims);
} else { } else {
all_element_num = x_dims[axis]; all_element_num = x_dims[int_axis];
} }
PADDLE_ENFORCE_LE( PADDLE_ENFORCE_LE(
all_element_num, all_element_num,
...@@ -182,11 +203,11 @@ void ArgMinMaxInferMeta(const MetaTensor& x, ...@@ -182,11 +203,11 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
if (flatten) { if (flatten) {
vec.emplace_back(static_cast<int64_t>(1)); vec.emplace_back(static_cast<int64_t>(1));
} else { } else {
for (int64_t i = 0; i < axis; i++) vec.emplace_back(x_dims[i]); for (int64_t i = 0; i < int_axis; i++) vec.emplace_back(x_dims[i]);
if (keepdims) { if (keepdims) {
vec.emplace_back(static_cast<int64_t>(1)); vec.emplace_back(static_cast<int64_t>(1));
} }
for (int64_t i = axis + 1; i < x_rank; i++) vec.emplace_back(x_dims[i]); for (int64_t i = int_axis + 1; i < x_rank; i++) vec.emplace_back(x_dims[i]);
} }
out->set_dims(phi::make_ddim(vec)); out->set_dims(phi::make_ddim(vec));
if (dtype == 2) { if (dtype == 2) {
......
...@@ -40,7 +40,7 @@ void AffineGridInferMeta(const MetaTensor& input, ...@@ -40,7 +40,7 @@ void AffineGridInferMeta(const MetaTensor& input,
MetaTensor* output); MetaTensor* output);
void ArgMinMaxInferMeta(const MetaTensor& x, void ArgMinMaxInferMeta(const MetaTensor& x,
int64_t axis, const Scalar& axis,
bool keepdims, bool keepdims,
bool flatten, bool flatten,
int dtype, int dtype,
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
namespace phi { namespace phi {
...@@ -21,7 +22,7 @@ namespace phi { ...@@ -21,7 +22,7 @@ namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void ArgMinKernel(const Context& dev_ctx, void ArgMinKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
int64_t axis, const Scalar& axis,
bool keepdims, bool keepdims,
bool flatten, bool flatten,
int dtype, int dtype,
...@@ -30,7 +31,7 @@ void ArgMinKernel(const Context& dev_ctx, ...@@ -30,7 +31,7 @@ void ArgMinKernel(const Context& dev_ctx,
template <typename T, typename Context> template <typename T, typename Context>
void ArgMaxKernel(const Context& dev_ctx, void ArgMaxKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
int64_t axis, const Scalar& axis,
bool keepdims, bool keepdims,
bool flatten, bool flatten,
int dtype, int dtype,
......
...@@ -135,7 +135,7 @@ struct VisitDataArgMinMaxFunctor { ...@@ -135,7 +135,7 @@ struct VisitDataArgMinMaxFunctor {
template <typename Context, typename T, ArgMinMaxType EnumArgMinMaxValue> template <typename Context, typename T, ArgMinMaxType EnumArgMinMaxValue>
void ArgMinMaxKernel(const Context& dev_ctx, void ArgMinMaxKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
int64_t axis, const Scalar& axis,
bool keepdims, bool keepdims,
bool flatten, bool flatten,
int dtype, int dtype,
...@@ -145,19 +145,19 @@ void ArgMinMaxKernel(const Context& dev_ctx, ...@@ -145,19 +145,19 @@ void ArgMinMaxKernel(const Context& dev_ctx,
static_cast<paddle::framework::proto::VarType::Type>( static_cast<paddle::framework::proto::VarType::Type>(
paddle::framework::proto::VarType::INT64), paddle::framework::proto::VarType::INT64),
VisitDataArgMinMaxFunctor<Context, T, EnumArgMinMaxValue>( VisitDataArgMinMaxFunctor<Context, T, EnumArgMinMaxValue>(
dev_ctx, x, axis, keepdims, flatten, out)); dev_ctx, x, axis.to<int64_t>(), keepdims, flatten, out));
return; return;
} }
paddle::framework::VisitDataTypeTiny( paddle::framework::VisitDataTypeTiny(
static_cast<paddle::framework::proto::VarType::Type>(dtype), static_cast<paddle::framework::proto::VarType::Type>(dtype),
VisitDataArgMinMaxFunctor<Context, T, EnumArgMinMaxValue>( VisitDataArgMinMaxFunctor<Context, T, EnumArgMinMaxValue>(
dev_ctx, x, axis, keepdims, flatten, out)); dev_ctx, x, axis.to<int64_t>(), keepdims, flatten, out));
} }
template <typename T, typename Context> template <typename T, typename Context>
void ArgMinKernel(const Context& dev_ctx, void ArgMinKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
int64_t axis, const Scalar& axis,
bool keepdims, bool keepdims,
bool flatten, bool flatten,
int dtype, int dtype,
...@@ -169,7 +169,7 @@ void ArgMinKernel(const Context& dev_ctx, ...@@ -169,7 +169,7 @@ void ArgMinKernel(const Context& dev_ctx,
template <typename T, typename Context> template <typename T, typename Context>
void ArgMaxKernel(const Context& dev_ctx, void ArgMaxKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
int64_t axis, const Scalar& axis,
bool keepdims, bool keepdims,
bool flatten, bool flatten,
int dtype, int dtype,
......
...@@ -203,7 +203,7 @@ struct VisitDataCudaArgMinMaxFunctor { ...@@ -203,7 +203,7 @@ struct VisitDataCudaArgMinMaxFunctor {
template <typename Context, typename T, class Reducer> template <typename Context, typename T, class Reducer>
void ArgMinMaxOpCUDAKernel(const Context& dev_ctx, void ArgMinMaxOpCUDAKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
int64_t axis, const Scalar& axis,
bool keepdims, bool keepdims,
bool flatten, bool flatten,
int dtype, int dtype,
...@@ -213,19 +213,19 @@ void ArgMinMaxOpCUDAKernel(const Context& dev_ctx, ...@@ -213,19 +213,19 @@ void ArgMinMaxOpCUDAKernel(const Context& dev_ctx,
static_cast<paddle::framework::proto::VarType::Type>( static_cast<paddle::framework::proto::VarType::Type>(
paddle::framework::proto::VarType::INT64), paddle::framework::proto::VarType::INT64),
VisitDataCudaArgMinMaxFunctor<Context, T, Reducer>( VisitDataCudaArgMinMaxFunctor<Context, T, Reducer>(
dev_ctx, x, axis, keepdims, flatten, out)); dev_ctx, x, axis.to<int64_t>(), keepdims, flatten, out));
return; return;
} }
paddle::framework::VisitDataTypeTiny( paddle::framework::VisitDataTypeTiny(
static_cast<paddle::framework::proto::VarType::Type>(dtype), static_cast<paddle::framework::proto::VarType::Type>(dtype),
VisitDataCudaArgMinMaxFunctor<Context, T, Reducer>( VisitDataCudaArgMinMaxFunctor<Context, T, Reducer>(
dev_ctx, x, axis, keepdims, flatten, out)); dev_ctx, x, axis.to<int64_t>(), keepdims, flatten, out));
} }
template <typename T, typename Context> template <typename T, typename Context>
void ArgMinKernel(const Context& dev_ctx, void ArgMinKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
int64_t axis, const Scalar& axis,
bool keepdims, bool keepdims,
bool flatten, bool flatten,
int dtype, int dtype,
...@@ -237,7 +237,7 @@ void ArgMinKernel(const Context& dev_ctx, ...@@ -237,7 +237,7 @@ void ArgMinKernel(const Context& dev_ctx,
template <typename T, typename Context> template <typename T, typename Context>
void ArgMaxKernel(const Context& dev_ctx, void ArgMaxKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
int64_t axis, const Scalar& axis,
bool keepdims, bool keepdims,
bool flatten, bool flatten,
int dtype, int dtype,
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
from __future__ import print_function from __future__ import print_function
import os
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
...@@ -21,6 +22,7 @@ import paddle ...@@ -21,6 +22,7 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid import Program, program_guard from paddle.fluid import Program, program_guard
from test_attribute_var import UnittestBase
class BaseTestCase(OpTest): class BaseTestCase(OpTest):
...@@ -235,6 +237,92 @@ class BaseTestComplex2_2(OpTest): ...@@ -235,6 +237,92 @@ class BaseTestComplex2_2(OpTest):
} }
class TestArgMaxTensorAxis(UnittestBase):
def init_info(self):
self.shapes = [[2, 3, 4]]
self.x = [np.random.randn(*shape) for shape in self.shapes]
self.save_path = os.path.join(self.temp_dir.name, self.path_prefix())
def test_static(self):
main_prog = Program()
starup_prog = Program()
with program_guard(main_prog, starup_prog):
fc = paddle.nn.Linear(4, 10)
x = paddle.randn([2, 3, 4])
x.stop_gradient = False
feat = fc(x)
out = self.call_func(feat)
sgd = paddle.optimizer.SGD()
sgd.minimize(paddle.mean(paddle.cast(out, 'float32')))
self.assertTrue(self.var_prefix() in str(main_prog))
exe = paddle.static.Executor()
exe.run(starup_prog)
res = exe.run(fetch_list=[feat, out])
paddle.static.save_inference_model(self.save_path, [x], [feat, out],
exe)
gt = np.argmax(res[0], 0)
np.testing.assert_allclose(res[1], gt)
# Test for Inference Predictor
infer_outs = self.infer_prog()
gt = np.argmax(infer_outs[0], 0)
np.testing.assert_allclose(infer_outs[1], gt)
def path_prefix(self):
return 'argmax_tensor_axis'
def var_prefix(self):
return "Var["
def call_func(self, x):
axis = paddle.assign(0)
out = paddle.argmax(x, axis)
return out
class TestArgMinTensorAxis(TestArgMaxTensorAxis):
def test_static(self):
main_prog = Program()
starup_prog = Program()
with program_guard(main_prog, starup_prog):
fc = paddle.nn.Linear(4, 10)
x = paddle.randn([2, 3, 4])
x.stop_gradient = False
feat = fc(x)
feat = paddle.cast(feat, 'int32')
out = self.call_func(feat)
sgd = paddle.optimizer.SGD()
sgd.minimize(paddle.mean(paddle.cast(out, 'float32')))
self.assertTrue(self.var_prefix() in str(main_prog))
exe = paddle.static.Executor()
exe.run(starup_prog)
res = exe.run(fetch_list=[feat, out])
paddle.static.save_inference_model(self.save_path, [x], [feat, out],
exe)
gt = np.argmin(res[0], 1)
np.testing.assert_allclose(np.squeeze(res[1]), gt)
# Test for Inference Predictor
infer_outs = self.infer_prog()
gt = np.argmin(infer_outs[0], 1)
np.testing.assert_allclose(np.squeeze(infer_outs[1]), gt)
def path_prefix(self):
return 'argmin_tensor_axis'
def call_func(self, x):
axis = paddle.assign(1)
out = paddle.argmin(x, axis, keepdim=True)
return out
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static() paddle.enable_static()
unittest.main() unittest.main()
...@@ -162,9 +162,9 @@ def argmax(x, axis=None, keepdim=False, dtype="int64", name=None): ...@@ -162,9 +162,9 @@ def argmax(x, axis=None, keepdim=False, dtype="int64", name=None):
print(out4) print(out4)
# [[2, 2, 0, 1]] # [[2, 2, 0, 1]]
""" """
if axis is not None and not isinstance(axis, int): if axis is not None and not isinstance(axis, (int, Variable)):
raise TypeError( raise TypeError(
"The type of 'axis' must be int or None in argmax, but received %s." "The type of 'axis' must be int or Tensor or None in argmax, but received %s."
% (type(axis))) % (type(axis)))
if dtype is None: if dtype is None:
...@@ -244,9 +244,9 @@ def argmin(x, axis=None, keepdim=False, dtype="int64", name=None): ...@@ -244,9 +244,9 @@ def argmin(x, axis=None, keepdim=False, dtype="int64", name=None):
print(out4) print(out4)
# [[1, 1, 1, 2]] # [[1, 1, 1, 2]]
""" """
if axis is not None and not isinstance(axis, int): if axis is not None and not isinstance(axis, (int, Variable)):
raise TypeError( raise TypeError(
"The type of 'axis' must be int or None in argmin, but received %s." "The type of 'axis' must be int or Tensor or None in argmin, but received %s."
% (type(axis))) % (type(axis)))
if dtype is None: if dtype is None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册