未验证 提交 6fc15986 编写于 作者: W WangZhen 提交者: GitHub

[OpAttr]Adapt tensor axis for argmin/max (#45453)

* Adapt tensor axis for argmin/max

* Add UT

* Polish UT
上级 5f1a8e46
......@@ -31,6 +31,13 @@ namespace operators {
class ArgMinMaxOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -42,7 +49,8 @@ class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override {
AddInput("X", "Input tensor.");
AddOutput("Out", "Output tensor.");
AddAttr<int64_t>("axis", "The axis in which to compute the arg indics.");
AddAttr<int64_t>("axis", "The axis in which to compute the arg indics.")
.SupportTensor();
AddAttr<bool>("keepdims", "Keep the dim that to reduce.").SetDefault(false);
AddAttr<bool>("flatten",
"Flatten the input value, and search the min or max indices")
......
......@@ -197,7 +197,7 @@
support_trans_dtype : start, end, step
- api : argmax
args : (Tensor x, int64_t axis, bool keepdims, bool flatten, int dtype)
args : (Tensor x, Scalar axis, bool keepdims, bool flatten, int dtype)
output : Tensor(out)
infer_meta :
func : ArgMinMaxInferMeta
......@@ -205,7 +205,7 @@
func : arg_max
- api : argmin
args : (Tensor x, int64_t axis, bool keepdims, bool flatten, int dtype)
args : (Tensor x, Scalar axis, bool keepdims, bool flatten, int dtype)
output : Tensor(out)
infer_meta :
func : ArgMinMaxInferMeta
......
......@@ -121,28 +121,12 @@ void AffineGridInferMeta(const MetaTensor& input,
}
void ArgMinMaxInferMeta(const MetaTensor& x,
int64_t axis,
const Scalar& axis,
bool keepdims,
bool flatten,
int dtype,
MetaTensor* out,
MetaConfig config) {
const auto& x_dims = x.dims();
PADDLE_ENFORCE_GE(
axis,
-x_dims.size(),
phi::errors::InvalidArgument("'axis'(%d) must be greater than or equal to"
" -Rank(X)(%d).",
axis,
-x_dims.size()));
PADDLE_ENFORCE_LT(axis,
x_dims.size(),
phi::errors::InvalidArgument(
"'axis'(%d) must be less than Rank(X)(%d) of Input(X).",
axis,
x_dims.size()));
PADDLE_ENFORCE_EQ(
(dtype < 0 || dtype == 2 || dtype == 3),
true,
......@@ -156,8 +140,45 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
paddle::framework::DataTypeToString(
static_cast<paddle::framework::proto::VarType::Type>(dtype))));
if (!config.is_runtime && axis.FromTensor()) {
std::vector<int64_t> vec;
if (flatten) {
vec = {1};
} else {
if (keepdims) {
vec = std::vector<int64_t>(x.dims().size(), -1);
} else {
vec = std::vector<int64_t>(x.dims().size() - 1, -1);
}
}
out->set_dims(phi::make_ddim(vec));
if (dtype == 2) {
out->set_dtype(DataType::INT32);
} else if (dtype == 3) {
out->set_dtype(DataType::INT64);
}
return;
}
auto int_axis = axis.to<int64_t>();
const auto& x_dims = x.dims();
PADDLE_ENFORCE_GE(
int_axis,
-x_dims.size(),
phi::errors::InvalidArgument("'axis'(%d) must be greater than or equal to"
" -Rank(X)(%d).",
int_axis,
-x_dims.size()));
PADDLE_ENFORCE_LT(int_axis,
x_dims.size(),
phi::errors::InvalidArgument(
"'axis'(%d) must be less than Rank(X)(%d) of Input(X).",
int_axis,
x_dims.size()));
auto x_rank = x_dims.size();
if (axis < 0) axis += x_rank;
if (int_axis < 0) int_axis += x_rank;
if (config.is_runtime) {
if (dtype == paddle::framework::proto::VarType::INT32) {
int64_t all_element_num = 0;
......@@ -165,7 +186,7 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
all_element_num = phi::product(x_dims);
} else {
all_element_num = x_dims[axis];
all_element_num = x_dims[int_axis];
}
PADDLE_ENFORCE_LE(
all_element_num,
......@@ -182,11 +203,11 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
if (flatten) {
vec.emplace_back(static_cast<int64_t>(1));
} else {
for (int64_t i = 0; i < axis; i++) vec.emplace_back(x_dims[i]);
for (int64_t i = 0; i < int_axis; i++) vec.emplace_back(x_dims[i]);
if (keepdims) {
vec.emplace_back(static_cast<int64_t>(1));
}
for (int64_t i = axis + 1; i < x_rank; i++) vec.emplace_back(x_dims[i]);
for (int64_t i = int_axis + 1; i < x_rank; i++) vec.emplace_back(x_dims[i]);
}
out->set_dims(phi::make_ddim(vec));
if (dtype == 2) {
......
......@@ -40,7 +40,7 @@ void AffineGridInferMeta(const MetaTensor& input,
MetaTensor* output);
void ArgMinMaxInferMeta(const MetaTensor& x,
int64_t axis,
const Scalar& axis,
bool keepdims,
bool flatten,
int dtype,
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
......@@ -21,7 +22,7 @@ namespace phi {
template <typename T, typename Context>
void ArgMinKernel(const Context& dev_ctx,
const DenseTensor& x,
int64_t axis,
const Scalar& axis,
bool keepdims,
bool flatten,
int dtype,
......@@ -30,7 +31,7 @@ void ArgMinKernel(const Context& dev_ctx,
template <typename T, typename Context>
void ArgMaxKernel(const Context& dev_ctx,
const DenseTensor& x,
int64_t axis,
const Scalar& axis,
bool keepdims,
bool flatten,
int dtype,
......
......@@ -135,7 +135,7 @@ struct VisitDataArgMinMaxFunctor {
template <typename Context, typename T, ArgMinMaxType EnumArgMinMaxValue>
void ArgMinMaxKernel(const Context& dev_ctx,
const DenseTensor& x,
int64_t axis,
const Scalar& axis,
bool keepdims,
bool flatten,
int dtype,
......@@ -145,19 +145,19 @@ void ArgMinMaxKernel(const Context& dev_ctx,
static_cast<paddle::framework::proto::VarType::Type>(
paddle::framework::proto::VarType::INT64),
VisitDataArgMinMaxFunctor<Context, T, EnumArgMinMaxValue>(
dev_ctx, x, axis, keepdims, flatten, out));
dev_ctx, x, axis.to<int64_t>(), keepdims, flatten, out));
return;
}
paddle::framework::VisitDataTypeTiny(
static_cast<paddle::framework::proto::VarType::Type>(dtype),
VisitDataArgMinMaxFunctor<Context, T, EnumArgMinMaxValue>(
dev_ctx, x, axis, keepdims, flatten, out));
dev_ctx, x, axis.to<int64_t>(), keepdims, flatten, out));
}
template <typename T, typename Context>
void ArgMinKernel(const Context& dev_ctx,
const DenseTensor& x,
int64_t axis,
const Scalar& axis,
bool keepdims,
bool flatten,
int dtype,
......@@ -169,7 +169,7 @@ void ArgMinKernel(const Context& dev_ctx,
template <typename T, typename Context>
void ArgMaxKernel(const Context& dev_ctx,
const DenseTensor& x,
int64_t axis,
const Scalar& axis,
bool keepdims,
bool flatten,
int dtype,
......
......@@ -203,7 +203,7 @@ struct VisitDataCudaArgMinMaxFunctor {
template <typename Context, typename T, class Reducer>
void ArgMinMaxOpCUDAKernel(const Context& dev_ctx,
const DenseTensor& x,
int64_t axis,
const Scalar& axis,
bool keepdims,
bool flatten,
int dtype,
......@@ -213,19 +213,19 @@ void ArgMinMaxOpCUDAKernel(const Context& dev_ctx,
static_cast<paddle::framework::proto::VarType::Type>(
paddle::framework::proto::VarType::INT64),
VisitDataCudaArgMinMaxFunctor<Context, T, Reducer>(
dev_ctx, x, axis, keepdims, flatten, out));
dev_ctx, x, axis.to<int64_t>(), keepdims, flatten, out));
return;
}
paddle::framework::VisitDataTypeTiny(
static_cast<paddle::framework::proto::VarType::Type>(dtype),
VisitDataCudaArgMinMaxFunctor<Context, T, Reducer>(
dev_ctx, x, axis, keepdims, flatten, out));
dev_ctx, x, axis.to<int64_t>(), keepdims, flatten, out));
}
template <typename T, typename Context>
void ArgMinKernel(const Context& dev_ctx,
const DenseTensor& x,
int64_t axis,
const Scalar& axis,
bool keepdims,
bool flatten,
int dtype,
......@@ -237,7 +237,7 @@ void ArgMinKernel(const Context& dev_ctx,
template <typename T, typename Context>
void ArgMaxKernel(const Context& dev_ctx,
const DenseTensor& x,
int64_t axis,
const Scalar& axis,
bool keepdims,
bool flatten,
int dtype,
......
......@@ -14,6 +14,7 @@
from __future__ import print_function
import os
import unittest
import numpy as np
from op_test import OpTest
......@@ -21,6 +22,7 @@ import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import Program, program_guard
from test_attribute_var import UnittestBase
class BaseTestCase(OpTest):
......@@ -235,6 +237,92 @@ class BaseTestComplex2_2(OpTest):
}
class TestArgMaxTensorAxis(UnittestBase):
def init_info(self):
self.shapes = [[2, 3, 4]]
self.x = [np.random.randn(*shape) for shape in self.shapes]
self.save_path = os.path.join(self.temp_dir.name, self.path_prefix())
def test_static(self):
main_prog = Program()
starup_prog = Program()
with program_guard(main_prog, starup_prog):
fc = paddle.nn.Linear(4, 10)
x = paddle.randn([2, 3, 4])
x.stop_gradient = False
feat = fc(x)
out = self.call_func(feat)
sgd = paddle.optimizer.SGD()
sgd.minimize(paddle.mean(paddle.cast(out, 'float32')))
self.assertTrue(self.var_prefix() in str(main_prog))
exe = paddle.static.Executor()
exe.run(starup_prog)
res = exe.run(fetch_list=[feat, out])
paddle.static.save_inference_model(self.save_path, [x], [feat, out],
exe)
gt = np.argmax(res[0], 0)
np.testing.assert_allclose(res[1], gt)
# Test for Inference Predictor
infer_outs = self.infer_prog()
gt = np.argmax(infer_outs[0], 0)
np.testing.assert_allclose(infer_outs[1], gt)
def path_prefix(self):
return 'argmax_tensor_axis'
def var_prefix(self):
return "Var["
def call_func(self, x):
axis = paddle.assign(0)
out = paddle.argmax(x, axis)
return out
class TestArgMinTensorAxis(TestArgMaxTensorAxis):
def test_static(self):
main_prog = Program()
starup_prog = Program()
with program_guard(main_prog, starup_prog):
fc = paddle.nn.Linear(4, 10)
x = paddle.randn([2, 3, 4])
x.stop_gradient = False
feat = fc(x)
feat = paddle.cast(feat, 'int32')
out = self.call_func(feat)
sgd = paddle.optimizer.SGD()
sgd.minimize(paddle.mean(paddle.cast(out, 'float32')))
self.assertTrue(self.var_prefix() in str(main_prog))
exe = paddle.static.Executor()
exe.run(starup_prog)
res = exe.run(fetch_list=[feat, out])
paddle.static.save_inference_model(self.save_path, [x], [feat, out],
exe)
gt = np.argmin(res[0], 1)
np.testing.assert_allclose(np.squeeze(res[1]), gt)
# Test for Inference Predictor
infer_outs = self.infer_prog()
gt = np.argmin(infer_outs[0], 1)
np.testing.assert_allclose(np.squeeze(infer_outs[1]), gt)
def path_prefix(self):
return 'argmin_tensor_axis'
def call_func(self, x):
axis = paddle.assign(1)
out = paddle.argmin(x, axis, keepdim=True)
return out
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
......@@ -162,9 +162,9 @@ def argmax(x, axis=None, keepdim=False, dtype="int64", name=None):
print(out4)
# [[2, 2, 0, 1]]
"""
if axis is not None and not isinstance(axis, int):
if axis is not None and not isinstance(axis, (int, Variable)):
raise TypeError(
"The type of 'axis' must be int or None in argmax, but received %s."
"The type of 'axis' must be int or Tensor or None in argmax, but received %s."
% (type(axis)))
if dtype is None:
......@@ -244,9 +244,9 @@ def argmin(x, axis=None, keepdim=False, dtype="int64", name=None):
print(out4)
# [[1, 1, 1, 2]]
"""
if axis is not None and not isinstance(axis, int):
if axis is not None and not isinstance(axis, (int, Variable)):
raise TypeError(
"The type of 'axis' must be int or None in argmin, but received %s."
"The type of 'axis' must be int or Tensor or None in argmin, but received %s."
% (type(axis)))
if dtype is None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册