未验证 提交 6b28456e 编写于 作者: W wawltor 提交者: GitHub

add the argmax, argmin for the api2.0

* add the new api and op for the argmax, argmin
上级 d26ae9ad
...@@ -53,9 +53,9 @@ using Tensor = framework::Tensor; ...@@ -53,9 +53,9 @@ using Tensor = framework::Tensor;
FIXED_BLOCK_DIM_CASE_BASE(3, ##__VA_ARGS__); FIXED_BLOCK_DIM_CASE_BASE(3, ##__VA_ARGS__);
template <typename T, typename IndType, class Reducer, size_t BlockDim> template <typename T, typename IndType, class Reducer, size_t BlockDim>
__global__ void ArgCUDAKernel(const IndType height, // n * h __global__ void ArgCUDAKernel(const int64_t height, // n * h
const IndType width, // c const int64_t width, // c
const IndType post_size, // h const int64_t post_size, // h
const Reducer reducer, const T init, const T* in, const Reducer reducer, const T init, const T* in,
IndType* out) { IndType* out) {
typedef cub::BlockReduce<KeyValuePair<int, T>, BlockDim> BlockReduce; typedef cub::BlockReduce<KeyValuePair<int, T>, BlockDim> BlockReduce;
...@@ -79,10 +79,10 @@ __global__ void ArgCUDAKernel(const IndType height, // n * h ...@@ -79,10 +79,10 @@ __global__ void ArgCUDAKernel(const IndType height, // n * h
template <typename T, typename IndType, class Reducer> template <typename T, typename IndType, class Reducer>
void ComputeFullArg(const platform::CUDADeviceContext& ctx, const Tensor& input, void ComputeFullArg(const platform::CUDADeviceContext& ctx, const Tensor& input,
Tensor* indices, const IndType pre, const IndType post, Tensor* indices, const int64_t pre, const int64_t post,
const IndType n) { const int64_t n) {
auto cu_stream = ctx.stream(); auto cu_stream = ctx.stream();
auto ComputeBlockSize = [](IndType col) { auto ComputeBlockSize = [](int64_t col) {
if (col > 512) if (col > 512)
return 1024; return 1024;
else if (col > 256) else if (col > 256)
...@@ -101,10 +101,10 @@ void ComputeFullArg(const platform::CUDADeviceContext& ctx, const Tensor& input, ...@@ -101,10 +101,10 @@ void ComputeFullArg(const platform::CUDADeviceContext& ctx, const Tensor& input,
return 8; return 8;
}; };
int max_grid_dimx = ctx.GetCUDAMaxGridDimSize().x; int64_t max_grid_dimx = ctx.GetCUDAMaxGridDimSize().x;
int height = pre * post; int64_t height = pre * post;
int width = n; int64_t width = n;
int grid_size = height < max_grid_dimx ? height : max_grid_dimx; int64_t grid_size = height < max_grid_dimx ? height : max_grid_dimx;
const T* in_data = input.data<T>(); const T* in_data = input.data<T>();
IndType* out_data = indices->mutable_data<IndType>(ctx.GetPlace()); IndType* out_data = indices->mutable_data<IndType>(ctx.GetPlace());
...@@ -129,31 +129,60 @@ void ComputeFullArg(const platform::CUDADeviceContext& ctx, const Tensor& input, ...@@ -129,31 +129,60 @@ void ComputeFullArg(const platform::CUDADeviceContext& ctx, const Tensor& input,
} }
template <typename T, class Reducer> template <typename T, class Reducer>
class ArgMinMaxOpCUDAKernel : public framework::OpKernel<T> { struct VisitDataCudaArgMinMaxFunctor {
public: const framework::ExecutionContext& ctx;
void Compute(const framework::ExecutionContext& ctx) const override {
explicit VisitDataCudaArgMinMaxFunctor(const framework::ExecutionContext& ctx)
: ctx(ctx) {}
template <typename IndType>
void apply() const {
auto* input = ctx.Input<Tensor>("X"); auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out"); auto* output = ctx.Output<Tensor>("Out");
int axis = ctx.Attr<int64_t>("axis"); int axis = ctx.Attr<int64_t>("axis");
auto in_dims = input->dims(); const bool& flatten = ctx.Attr<bool>("flatten");
axis = (axis < 0) ? (in_dims.size() + axis) : axis;
framework::DDim input_dims;
if (flatten) {
input_dims = framework::make_ddim({input->numel()});
// if flatten, the axis just as 0
axis = 0;
} else {
input_dims = input->dims();
if (axis < 0) axis += input->dims().size();
}
int64_t numel = input->numel(); int64_t numel = input->numel();
int64_t groups = numel / in_dims[axis]; int64_t groups = numel / input_dims[axis];
int64_t pre = 1; int64_t pre = 1;
int64_t post = 1; int64_t post = 1;
int64_t n = in_dims[axis]; int64_t n = input_dims[axis];
for (int i = 0; i < axis; i++) { for (int i = 0; i < axis; i++) {
pre *= in_dims[i]; pre *= input_dims[i];
} }
for (int i = axis + 1; i < in_dims.size(); i++) { for (int i = axis + 1; i < input_dims.size(); i++) {
post *= in_dims[i]; post *= input_dims[i];
} }
const auto& dev_ctx = ctx.cuda_device_context(); const auto& dev_ctx = ctx.cuda_device_context();
ComputeFullArg<T, int64_t, Reducer>(dev_ctx, *input, output, pre, post, n); ComputeFullArg<T, IndType, Reducer>(dev_ctx, *input, output, pre, post, n);
}
};
template <typename T, class Reducer>
class ArgMinMaxOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dtype = ctx.Attr<int>("dtype");
if (dtype < 0) {
framework::VisitDataType(static_cast<framework::proto::VarType::Type>(
framework::proto::VarType::INT64),
VisitDataCudaArgMinMaxFunctor<T, Reducer>(ctx));
return;
}
framework::VisitDataType(
static_cast<framework::proto::VarType::Type>(dtype),
VisitDataCudaArgMinMaxFunctor<T, Reducer>(ctx));
} }
}; };
......
...@@ -38,8 +38,9 @@ struct ArgMinMaxFunctor {}; ...@@ -38,8 +38,9 @@ struct ArgMinMaxFunctor {};
struct ArgMinMaxFunctor<DeviceContext, T, Tout, Rank, \ struct ArgMinMaxFunctor<DeviceContext, T, Tout, Rank, \
enum_argminmax_value> { \ enum_argminmax_value> { \
void operator()(const DeviceContext& ctx, const framework::LoDTensor& in, \ void operator()(const DeviceContext& ctx, const framework::LoDTensor& in, \
framework::LoDTensor* out, int64_t axis, bool keepdims) { \ framework::LoDTensor* out, framework::DDim x_dims, \
auto in_eigen = framework::EigenTensor<T, Rank>::From(in); \ int64_t axis, bool keepdims) { \
auto in_eigen = framework::EigenTensor<T, Rank>::From(in, x_dims); \
if (keepdims) { \ if (keepdims) { \
auto out_eigen = framework::EigenTensor<Tout, Rank>::From(*out); \ auto out_eigen = framework::EigenTensor<Tout, Rank>::From(*out); \
out_eigen.device(*(ctx.eigen_device())) = \ out_eigen.device(*(ctx.eigen_device())) = \
...@@ -68,16 +69,26 @@ struct VisitDataArgMinMaxFunctor { ...@@ -68,16 +69,26 @@ struct VisitDataArgMinMaxFunctor {
out.template mutable_data<Tout>(ctx.GetPlace()); out.template mutable_data<Tout>(ctx.GetPlace());
auto axis = ctx.Attr<int64_t>("axis"); auto axis = ctx.Attr<int64_t>("axis");
auto keepdims = ctx.Attr<bool>("keepdims"); auto keepdims = ctx.Attr<bool>("keepdims");
auto x_rank = x.dims().size(); const bool& flatten = ctx.Attr<bool>("flatten");
if (axis < 0) axis += x_rank;
// if flatten, will construct the new dims for the cacluate
framework::DDim x_dims;
if (flatten) {
x_dims = framework::make_ddim({x.numel()});
// if flatten, the axis just as 0
axis = 0;
} else {
x_dims = x.dims();
if (axis < 0) axis += x_dims.size();
}
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
#define CALL_ARG_MINMAX_FUNCTOR(rank) \ #define CALL_ARG_MINMAX_FUNCTOR(rank) \
ArgMinMaxFunctor<DeviceContext, T, Tout, rank, EnumArgMinMaxValue> \ ArgMinMaxFunctor<DeviceContext, T, Tout, rank, EnumArgMinMaxValue> \
functor##rank; \ functor##rank; \
functor##rank(dev_ctx, x, &out, axis, keepdims) functor##rank(dev_ctx, x, &out, x_dims, axis, keepdims)
switch (x.dims().size()) { switch (x_dims.size()) {
case 1: case 1:
CALL_ARG_MINMAX_FUNCTOR(1); CALL_ARG_MINMAX_FUNCTOR(1);
break; break;
...@@ -141,6 +152,7 @@ class ArgMinMaxOp : public framework::OperatorWithKernel { ...@@ -141,6 +152,7 @@ class ArgMinMaxOp : public framework::OperatorWithKernel {
const auto& x_dims = ctx->GetInputDim("X"); const auto& x_dims = ctx->GetInputDim("X");
int64_t axis = ctx->Attrs().Get<int64_t>("axis"); int64_t axis = ctx->Attrs().Get<int64_t>("axis");
bool keepdims = ctx->Attrs().Get<bool>("keepdims"); bool keepdims = ctx->Attrs().Get<bool>("keepdims");
const bool& flatten = ctx->Attrs().Get<bool>("flatten");
PADDLE_ENFORCE_GE(axis, -x_dims.size(), PADDLE_ENFORCE_GE(axis, -x_dims.size(),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -152,14 +164,21 @@ class ArgMinMaxOp : public framework::OperatorWithKernel { ...@@ -152,14 +164,21 @@ class ArgMinMaxOp : public framework::OperatorWithKernel {
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"'axis'(%d) must be less than Rank(X)(%d).", axis, x_dims.size())); "'axis'(%d) must be less than Rank(X)(%d).", axis, x_dims.size()));
auto x_rank = x_dims.size();
if (axis < 0) axis += x_rank;
std::vector<int64_t> vec; std::vector<int64_t> vec;
for (int64_t i = 0; i < axis; i++) vec.push_back(x_dims[i]); if (flatten) {
if (keepdims) { // if is flatten, will return the only on element
vec.push_back(static_cast<int64_t>(1)); if (keepdims) {
vec.emplace_back(static_cast<int64_t>(1));
}
} else {
auto x_rank = x_dims.size();
if (axis < 0) axis += x_rank;
for (int64_t i = 0; i < axis; i++) vec.emplace_back(x_dims[i]);
if (keepdims) {
vec.emplace_back(static_cast<int64_t>(1));
}
for (int64_t i = axis + 1; i < x_rank; i++) vec.emplace_back(x_dims[i]);
} }
for (int64_t i = axis + 1; i < x_rank; i++) vec.push_back(x_dims[i]);
ctx->SetOutputDim("Out", framework::make_ddim(vec)); ctx->SetOutputDim("Out", framework::make_ddim(vec));
} }
}; };
...@@ -176,6 +195,9 @@ class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -176,6 +195,9 @@ class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int64_t>("axis", "The axis in which to compute the arg indics."); AddAttr<int64_t>("axis", "The axis in which to compute the arg indics.");
AddAttr<bool>("keepdims", "Keep the dim that to reduce.").SetDefault(false); AddAttr<bool>("keepdims", "Keep the dim that to reduce.").SetDefault(false);
AddAttr<int>("dtype", "Keep the dim that to reduce.").SetDefault(-1); AddAttr<int>("dtype", "Keep the dim that to reduce.").SetDefault(-1);
AddAttr<bool>("flatten",
"Flatten the input value, and search the min or max indices")
.SetDefault(false);
AddComment(string::Sprintf(R"DOC( AddComment(string::Sprintf(R"DOC(
%s Operator. %s Operator.
......
...@@ -201,107 +201,5 @@ class BaseTestComplex2_2(OpTest): ...@@ -201,107 +201,5 @@ class BaseTestComplex2_2(OpTest):
} }
class APT_ArgMaxTest(unittest.TestCase):
def test_output_result(self):
with fluid.program_guard(fluid.Program()):
data1 = fluid.data(name="X", shape=[3, 4], dtype="float32")
data2 = fluid.data(name="Y", shape=[3], dtype="int64")
out = paddle.argmax(input=data1, out=data2)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
result = exe.run(
feed={"X": np.random.rand(3, 4).astype("float32")},
fetch_list=[data2, out])
self.assertEqual((result[0] == result[1]).all(), True)
def test_basic(self):
with fluid.program_guard(fluid.Program()):
data = fluid.data(name="X", shape=[3, 4], dtype="float32")
out = paddle.argmax(input=data)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
np_input = np.random.rand(3, 4).astype("float32")
expected_result = np.argmax(np_input, axis=1)
result, = exe.run(feed={"X": np_input}, fetch_list=[out])
self.assertEqual((result == expected_result).all(), True)
with fluid.program_guard(fluid.Program()):
data = fluid.data(name="X", shape=[3, 4], dtype="float32")
out = paddle.argmax(input=data, axis=0)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
np_input = np.random.rand(3, 4).astype("float32")
expected_result = np.argmax(np_input, axis=0)
result = exe.run(feed={"X": np_input}, fetch_list=[out])
self.assertEqual((result == expected_result).all(), True)
with fluid.program_guard(fluid.Program()):
data = fluid.data(name="X", shape=[3, 4], dtype="float32")
out = paddle.argmax(input=data, dtype="int32")
place = fluid.CPUPlace()
exe = fluid.Executor(place)
np_input = np.random.rand(3, 4).astype("float32")
expected_result = np.argmax(np_input, axis=1).astype(np.int32)
result = exe.run(feed={"X": np_input}, fetch_list=[out])
self.assertEqual((result == expected_result).all(), True)
with fluid.program_guard(fluid.Program()):
data1 = fluid.data(name="X", shape=[3, 4], dtype="float32")
data2 = fluid.data(name="Y", shape=[3], dtype="int64")
out = paddle.argmax(input=data, out=data2)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
result = exe.run(
feed={"X": np.random.rand(3, 4).astype("float32")},
fetch_list=[data2, out])
self.assertEqual((result[0] == result[1]).all(), True)
def test_name(self):
with fluid.program_guard(fluid.Program()):
x = fluid.data(name="x", shape=[100], dtype="float32")
y_1 = paddle.argmax(x, name='arg_max_res')
self.assertEqual(('arg_max_res' in y_1.name), True)
def test_errors(self):
def test_dtype1():
with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data(name="data", shape=[10], dtype="float32")
paddle.argmax(data, dtype="float32")
self.assertRaises(TypeError, test_dtype1)
def test_dtype2():
with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data(name="data", shape=[10], dtype="float64")
paddle.argmax(data, dtype="float32")
self.assertRaises(TypeError, test_dtype2)
class TestArgMinMaxOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
def test_argmax_x_type():
x1 = [1, 2, 3]
output = fluid.layers.argmax(x=x1)
self.assertRaises(TypeError, test_argmax_x_type)
def test_argmin_x_type():
x2 = [1, 2, 3]
output = fluid.layers.argmin(x=x2)
self.assertRaises(TypeError, test_argmin_x_type)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import Program, program_guard
def create_kernel_case(op_type, numpy_op_type):
class ArgMinMaxKernelBaseCase(OpTest):
def initTestCase(self):
self.op_type = op_type
self.numpy_op_type = numpy_op_type
self.axis = 0
def setUp(self):
np.random.seed(123)
self.initTestCase()
self.dims = (4, 5, 6)
self.dtype = "float64"
self.x = (1000 * np.random.random(self.dims).astype(self.dtype))
self.inputs = {'X': self.x}
self.attrs = {"axis": self.axis}
self.numpy_op = eval("np.%s" % (numpy_op_type))
self.outputs = {'Out': self.numpy_op(self.x, axis=self.axis)}
def test_check_output(self):
paddle.enable_static()
self.check_output()
class ArgMinMaxKernelCase0(ArgMinMaxKernelBaseCase):
def initTestCase(self):
self.op_type = op_type
self.numpy_op_type = numpy_op_type
self.axis = 1
class ArgMinMaxKernelCase1(ArgMinMaxKernelBaseCase):
def initTestCase(self):
self.op_type = op_type
self.numpy_op_type = numpy_op_type
self.axis = 2
class ArgMinMaxKernelCase2(ArgMinMaxKernelBaseCase):
def initTestCase(self):
self.op_type = op_type
self.numpy_op_type = numpy_op_type
self.axis = -1
class ArgMinMaxKernelCase3(ArgMinMaxKernelBaseCase):
def initTestCase(self):
self.op_type = op_type
self.numpy_op_type = numpy_op_type
self.axis = -2
class ArgMinMaxKernelCase4(ArgMinMaxKernelBaseCase):
def setUp(self):
self.initTestCase()
self.dims = (4, 5, 6)
self.dtype = "float64"
self.x = (1000 * np.random.random(self.dims).astype(self.dtype))
self.inputs = {'X': self.x}
self.attrs = {"axis": self.axis, "keepdims": True}
self.numpy_op = eval("np.%s" % (numpy_op_type))
self.outputs = {
'Out': self.numpy_op(
self.x, axis=self.axis).reshape((1, 5, 6))
}
class ArgMinMaxKernelCase5(ArgMinMaxKernelBaseCase):
def setUp(self):
self.initTestCase()
self.dims = (4)
self.dtype = "float64"
self.x = (1000 * np.random.random(self.dims).astype(self.dtype))
self.inputs = {'X': self.x}
self.attrs = {"axis": self.axis, "flatten": True}
self.numpy_op = eval("np.%s" % (numpy_op_type))
self.outputs = {
'Out': self.numpy_op(
self.x.flatten(), axis=self.axis)
}
class ArgMinMaxKernelCase6(ArgMinMaxKernelBaseCase):
def setUp(self):
self.initTestCase()
self.dims = (4)
self.dtype = "float64"
self.x = (1000 * np.random.random(self.dims).astype(self.dtype))
self.inputs = {'X': self.x}
self.attrs = {"axis": self.axis, "flatten": True, "keepdims": True}
self.numpy_op = eval("np.%s" % (numpy_op_type))
self.outputs = {
'Out':
np.array(self.numpy_op(
self.x.flatten(), axis=self.axis))
}
cls_name = "ArgMinMaxKernelBaseCase_%s" % (op_type)
ArgMinMaxKernelBaseCase.__name__ = cls_name
globals()[cls_name] = ArgMinMaxKernelBaseCase
cls_name = "ArgMinMaxKernelCase0_%s" % (op_type)
ArgMinMaxKernelCase0.__name__ = cls_name
globals()[cls_name] = ArgMinMaxKernelCase0
cls_name = "ArgMinMaxKernelCase1_%s" % (op_type)
ArgMinMaxKernelCase1.__name__ = cls_name
globals()[cls_name] = ArgMinMaxKernelCase1
cls_name = "ArgMinMaxKernelCase2_%s" % (op_type)
ArgMinMaxKernelCase2.__name__ = cls_name
globals()[cls_name] = ArgMinMaxKernelCase2
cls_name = "ArgMinMaxKernelCase3_%s" % (op_type)
ArgMinMaxKernelCase3.__name__ = cls_name
globals()[cls_name] = ArgMinMaxKernelCase3
cls_name = "ArgMinMaxKernelCase4_%s" % (op_type)
ArgMinMaxKernelCase4.__name__ = cls_name
globals()[cls_name] = ArgMinMaxKernelCase4
cls_name = "ArgMinMaxKernelCase5_%s" % (op_type)
ArgMinMaxKernelCase5.__name__ = cls_name
globals()[cls_name] = ArgMinMaxKernelCase5
cls_name = "ArgMinMaxKernelCase6_%s" % (op_type)
ArgMinMaxKernelCase6.__name__ = cls_name
globals()[cls_name] = ArgMinMaxKernelCase6
for op_type, numpy_op_type in zip(['arg_max', 'arg_min'], ['argmax', 'argmin']):
create_kernel_case(op_type, numpy_op_type)
def create_test_case(op_type):
class ArgMaxMinTestCase(unittest.TestCase):
def setUp(self):
np.random.seed(123)
self.input_data = np.random.rand(10, 10).astype("float32")
self.places = []
self.places.append(fluid.CPUPlace())
if core.is_compiled_with_cuda():
self.places.append(paddle.CUDAPlace(0))
self.op = eval("paddle.%s" % (op_type))
self.numpy_op = eval("np.%s" % (op_type))
def run_static(self, place):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
data_var = paddle.static.data(
name="data", shape=[10, 10], dtype="float32")
op = eval("paddle.%s" % (op_type))
result = op(data_var)
exe = paddle.static.Executor(place)
result_data = exe.run(feed={"data": self.input_data},
fetch_list=[result])
expected_data = self.numpy_op(self.input_data)
self.assertTrue((result_data == np.array(expected_data)).all(),
True)
with paddle.static.program_guard(paddle.static.Program()):
data_var = paddle.static.data(
name="data", shape=[10, 10], dtype="float32")
op = eval("paddle.%s" % (op_type))
result = op(data_var, axis=1)
exe = paddle.static.Executor(place)
result_data = exe.run(feed={"data": self.input_data},
fetch_list=[result])
expected_data = self.numpy_op(self.input_data, axis=1)
self.assertTrue((result_data == expected_data).all(), True)
with paddle.static.program_guard(paddle.static.Program()):
data_var = paddle.static.data(
name="data", shape=[10, 10], dtype="float32")
op = eval("paddle.%s" % (op_type))
result = op(data_var, axis=-1)
exe = paddle.static.Executor(place)
result_data = exe.run(feed={"data": self.input_data},
fetch_list=[result])
expected_data = self.numpy_op(self.input_data, axis=-1)
self.assertTrue((result_data == expected_data).all(), True)
with paddle.static.program_guard(paddle.static.Program()):
data_var = paddle.static.data(
name="data", shape=[10, 10], dtype="float32")
op = eval("paddle.%s" % (op_type))
result = op(data_var, axis=-1, keepdim=True)
exe = paddle.static.Executor(place)
result_data = exe.run(feed={"data": self.input_data},
fetch_list=[result])
expected_data = self.numpy_op(
self.input_data, axis=-1).reshape((10, 1))
self.assertTrue((result_data == expected_data).all(), True)
with paddle.static.program_guard(paddle.static.Program()):
op = eval("paddle.%s" % (op_type))
data_var = paddle.static.data(
name="data", shape=[10, 10], dtype="float32")
result = op(data_var, axis=-1, name="test_arg_api")
self.assertTrue("test_arg_api" in result.name)
def run_dygraph(self, place):
paddle.disable_static()
op = eval("paddle.%s" % (op_type))
data_tensor = paddle.to_tensor(self.input_data)
#case 1
result_data = op(data_tensor)
excepted_data = self.numpy_op(self.input_data)
self.assertTrue((result_data.numpy() == excepted_data).all(), True)
#case 2
result_data = op(data_tensor, axis=1)
excepted_data = self.numpy_op(self.input_data, axis=1)
self.assertTrue((result_data.numpy() == excepted_data).all(), True)
#case 3
result_data = op(data_tensor, axis=-1)
excepted_data = self.numpy_op(self.input_data, axis=-1)
self.assertTrue((result_data.numpy() == excepted_data).all(), True)
#case 4
result_data = op(data_tensor, axis=-1, keepdim=True)
excepted_data = self.numpy_op(self.input_data, axis=-1)
excepted_data = excepted_data.reshape((10))
self.assertTrue((result_data.numpy() == excepted_data).all(), True)
#case 5
result_data = op(data_tensor, axis=-1, keepdim=True, dtype="int32")
self.assertTrue(result_data.numpy().dtype == np.int32)
# case for dim 4, 5, 6, for test case coverage
input_data = np.random.rand(5, 5, 5, 5)
excepted_data = self.numpy_op(input_data, axis=0)
result_data = op(paddle.to_tensor(input_data), axis=0)
self.assertTrue((result_data.numpy() == excepted_data).all(), True)
input_data = np.random.rand(4, 4, 4, 4, 4)
excepted_data = self.numpy_op(input_data, axis=0)
result_data = op(paddle.to_tensor(input_data), axis=0)
self.assertTrue((result_data.numpy() == excepted_data).all(), True)
input_data = np.random.rand(3, 3, 3, 3, 3, 3)
excepted_data = self.numpy_op(input_data, axis=0)
result_data = op(paddle.to_tensor(input_data), axis=0)
self.assertTrue((result_data.numpy() == excepted_data).all(), True)
def test_case(self):
for place in self.places:
self.run_static(place)
self.run_dygraph(place)
cls_name = "ArgMaxMinTestCase_{}".format(op_type)
ArgMaxMinTestCase.__name__ = cls_name
globals()[cls_name] = ArgMaxMinTestCase
for op_type in ['argmin', 'argmax']:
create_test_case(op_type)
class TestArgMinMaxOpError(unittest.TestCase):
def test_errors(self):
paddle.enable_static()
with program_guard(Program(), Program()):
def test_argmax_x_type():
x1 = [1, 2, 3]
output = paddle.argmax(x=x1)
self.assertRaises(TypeError, test_argmax_x_type)
def test_argmin_x_type():
x2 = [1, 2, 3]
output = paddle.argmin(x=x2)
self.assertRaises(TypeError, test_argmin_x_type)
def test_argmax_attr_type():
data = paddle.static.data(
name="test_argmax", shape=[10], dtype="float32")
output = paddle.argmax(x=data, dtype="float32")
self.assertRaises(ValueError, test_argmax_attr_type)
def test_argmin_attr_type():
data = paddle.static.data(
name="test_argmax", shape=[10], dtype="float32")
output = paddle.argmin(x=data, dtype="float32")
self.assertRaises(ValueError, test_argmin_attr_type)
if __name__ == '__main__':
unittest.main()
...@@ -125,95 +125,168 @@ def argsort(x, axis=-1, descending=False, name=None): ...@@ -125,95 +125,168 @@ def argsort(x, axis=-1, descending=False, name=None):
return ids return ids
def argmax(input, axis=None, dtype=None, out=None, keepdims=False, name=None): def argmax(x, axis=None, dtype=None, keepdim=False, name=None):
""" """
:alias_main: paddle.argmax
:alias: paddle.argmax,paddle.tensor.argmax,paddle.tensor.search.argmax
This OP computes the indices of the max elements of the input tensor's This OP computes the indices of the max elements of the input tensor's
element along the provided axis. element along the provided axis.
Args: Args:
input(Variable): An input N-D Tensor with type float32, float64, int16, x(Tensor): An input N-D Tensor with type float32, float64, int16,
int32, int64, uint8. int32, int64, uint8.
axis(int, optional): Axis to compute indices along. The effective range axis(int, optional): Axis to compute indices along. The effective range
is [-R, R), where R is Rank(input). when axis<0, it works the same way is [-R, R), where R is x.ndim. when axis < 0, it works the same way
as axis+R. Default is None, it will use the last dim to select indices of max value. as axis + R. Default is None, the input `x` will be into the flatten tensor, and selecting the min value index.
dtype(np.dtype|core.VarDesc.VarType|str): Data type of the output tensor which can dtype(str): Data type of the output tensor which can
be int32, int64. The default value is None, and it will be int32, int64. The default value is None, and it will
return the int64 indices. return the int64 indices.
out(Variable, optional): Optional output which can be any created keepdim(bool, optional): Keep the axis that selecting max. The defalut value is False.
Variable that meets the requirements to store the result of operation.
if out is None, a new Varibale will be create to store the result. Defalut is None.
keepdims(bool, optional): Keep the axis that do the select max.
name(str, optional): The default value is None. Normally there is no name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`. refer to :ref:`api_guide_Name`.
Returns: Returns:
Variable: A Tensor with data type int64. Tensor, return the tensor of `int32` if set :attr:`dtype` is `int32`, otherwise return the tensor of `int64`
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle
import paddle.fluid as fluid
import numpy as np import numpy as np
import paddle
in1 = np.array([[[5,8,9,5], paddle.disable_static()
[0,0,1,7], data = np.array([[5,8,9,5],
[6,9,2,4]], [0,0,1,7],
[[5,2,4,2], [6,9,2,4]])
[4,7,7,9], x = paddle.to_variable(data)
[1,7,0,6]]]) out1 = paddle.argmax(x)
with fluid.dygraph.guard(): print(out1.numpy()) # 2
x = fluid.dygraph.to_variable(in1) out2 = paddle.argmax(x, axis=1)
out1 = paddle.argmax(input=x, axis=-1) print(out2.numpy())
out2 = paddle.argmax(input=x, axis=0) # [2 3 1]
out3 = paddle.argmax(input=x, axis=1) out3 = paddle.argmax(x, axis=-1)
out4 = paddle.argmax(input=x, axis=2) print(out3.numpy())
out5 = paddle.argmax(input=x, axis=2, keepdims=True) # [2 3 1]
print(out1.numpy())
# [[2 3 1]
# [0 3 1]]
print(out2.numpy())
# [[0 0 0 0]
# [1 1 1 1]
# [0 0 0 1]]
print(out3.numpy())
# [[2 2 0 1]
# [0 1 1 1]]
print(out4.numpy())
# [[2 3 1]
# [0 3 1]]
print(out5.numpy())
#array([[[2],
# [3],
# [1]],
# [[0],
# [3],
# [1]]])
""" """
helper = LayerHelper("arg_max", **locals()) flatten = False
if axis is None:
flatten = True
axis = 0
if in_dygraph_mode():
if dtype != None:
var_dtype = convert_np_dtype_to_dtype_(dtype)
out = core.ops.arg_max(x, 'axis', axis, 'dtype', var_dtype,
'keepdim', keepdim, 'flatten', flatten)
else:
out = core.ops.arg_max(x, 'axis', axis, 'keepdim', keepdim,
'flatten', flatten)
return out
helper = LayerHelper("argmax", **locals())
check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'int16', 'int32', 'int64', 'uint8'],
'paddle.argmax')
var_dtype = None var_dtype = None
attrs = {} attrs = {}
if dtype is not None: if dtype is not None:
check_dtype(dtype, 'create data type', ['int32', 'int64'], 'arg_max') if dtype not in ['int32', 'int64']:
raise ValueError(
"The value of 'dtype' in argmax op must be int32, int64, but received of {}".
format(dtype))
var_dtype = convert_np_dtype_to_dtype_(dtype) var_dtype = convert_np_dtype_to_dtype_(dtype)
attrs["dtype"] = var_dtype attrs["dtype"] = var_dtype
else: else:
var_dtype = VarDesc.VarType.INT64 var_dtype = VarDesc.VarType.INT64
if out is None:
out = helper.create_variable_for_type_inference(var_dtype) out = helper.create_variable_for_type_inference(var_dtype)
attrs['keepdims'] = keepdim
attrs['axis'] = axis
attrs['flatten'] = flatten
helper.append_op(
type='arg_max', inputs={'X': x}, outputs={'Out': [out]}, attrs=attrs)
out.stop_gradient = True
return out
def argmin(x, axis=None, dtype=None, keepdim=False, name=None):
"""
This OP computes the indices of the min elements of the input tensor's
element along the provided axis.
Args:
x(Tensor): An input N-D Tensor with type float32, float64, int16,
int32, int64, uint8.
axis(int, optional): Axis to compute indices along. The effective range
is [-R, R), where R is x.ndim. when axis < 0, it works the same way
as axis + R. Default is None, the input `x` will be into the flatten tensor, and selecting the min value index.
dtype(str): Data type of the output tensor which can
be int32, int64. The default value is None, and it will
return the int64 indices.
keepdim(bool, optional): Keep the axis that selecting min. The defalut value is False.
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
Returns:
Tensor, return the tensor of `int32` if set :attr:`dtype` is `int32`, otherwise return the tensor of `int64`
Examples:
.. code-block:: python
import numpy as np
import paddle
paddle.disable_static()
data = np.array([[5,8,9,5],
[0,0,1,7],
[6,9,2,4]])
x = paddle.to_variable(data)
out1 = paddle.argmin(x)
print(out1.numpy()) # 4
out2 = paddle.argmin(x, axis=1)
print(out2.numpy())
# [0 0 2]
out3 = paddle.argmin(x, axis=-1)
print(out3.numpy())
# [0 0 2]
"""
flatten = False
if axis is None: if axis is None:
axis = -1 flatten = True
attrs['keepdims'] = keepdims axis = 0
if in_dygraph_mode():
if dtype != None:
var_dtype = convert_np_dtype_to_dtype_(dtype)
out = core.ops.arg_min(x, 'axis', axis, 'dtype', var_dtype,
'keepdim', keepdim, 'flatten', flatten)
else:
out = core.ops.arg_min(x, 'axis', axis, 'keepdim', keepdim,
'flatten', flatten)
return out
helper = LayerHelper("argmin", **locals())
check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'int16', 'int32', 'int64', 'uint8'],
'paddle.argmin')
var_dtype = None
attrs = {}
if dtype is not None:
if dtype not in ['int32', 'int64']:
raise ValueError(
"The value of 'dtype' in argmin op must be int32, int64, but received of {}".
format(dtype))
var_dtype = convert_np_dtype_to_dtype_(dtype)
attrs["dtype"] = var_dtype
else:
var_dtype = VarDesc.VarType.INT64
out = helper.create_variable_for_type_inference(var_dtype)
attrs['keepdims'] = keepdim
attrs['axis'] = axis attrs['axis'] = axis
attrs['flatten'] = flatten
helper.append_op( helper.append_op(
type='arg_max', type='arg_min', inputs={'X': x}, outputs={'Out': [out]}, attrs=attrs)
inputs={'X': input},
outputs={'Out': [out]},
attrs=attrs)
out.stop_gradient = True out.stop_gradient = True
return out return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册