Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
6b28456e
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6b28456e
编写于
8月 22, 2020
作者:
W
wawltor
提交者:
GitHub
8月 22, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add the argmax, argmin for the api2.0
* add the new api and op for the argmax, argmin
上级
d26ae9ad
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
528 addition
and
193 deletion
+528
-193
paddle/fluid/operators/arg_min_max_op_base.cu.h
paddle/fluid/operators/arg_min_max_op_base.cu.h
+50
-21
paddle/fluid/operators/arg_min_max_op_base.h
paddle/fluid/operators/arg_min_max_op_base.h
+34
-12
python/paddle/fluid/tests/unittests/test_arg_min_max_op.py
python/paddle/fluid/tests/unittests/test_arg_min_max_op.py
+0
-102
python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py
...on/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py
+313
-0
python/paddle/tensor/search.py
python/paddle/tensor/search.py
+131
-58
未找到文件。
paddle/fluid/operators/arg_min_max_op_base.cu.h
浏览文件 @
6b28456e
...
@@ -53,9 +53,9 @@ using Tensor = framework::Tensor;
...
@@ -53,9 +53,9 @@ using Tensor = framework::Tensor;
FIXED_BLOCK_DIM_CASE_BASE
(
3
,
##
__VA_ARGS__
);
FIXED_BLOCK_DIM_CASE_BASE
(
3
,
##
__VA_ARGS__
);
template
<
typename
T
,
typename
IndType
,
class
Reducer
,
size_t
BlockDim
>
template
<
typename
T
,
typename
IndType
,
class
Reducer
,
size_t
BlockDim
>
__global__
void
ArgCUDAKernel
(
const
IndType
height
,
// n * h
__global__
void
ArgCUDAKernel
(
const
int64_t
height
,
// n * h
const
IndType
width
,
// c
const
int64_t
width
,
// c
const
IndType
post_size
,
// h
const
int64_t
post_size
,
// h
const
Reducer
reducer
,
const
T
init
,
const
T
*
in
,
const
Reducer
reducer
,
const
T
init
,
const
T
*
in
,
IndType
*
out
)
{
IndType
*
out
)
{
typedef
cub
::
BlockReduce
<
KeyValuePair
<
int
,
T
>
,
BlockDim
>
BlockReduce
;
typedef
cub
::
BlockReduce
<
KeyValuePair
<
int
,
T
>
,
BlockDim
>
BlockReduce
;
...
@@ -79,10 +79,10 @@ __global__ void ArgCUDAKernel(const IndType height, // n * h
...
@@ -79,10 +79,10 @@ __global__ void ArgCUDAKernel(const IndType height, // n * h
template
<
typename
T
,
typename
IndType
,
class
Reducer
>
template
<
typename
T
,
typename
IndType
,
class
Reducer
>
void
ComputeFullArg
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
Tensor
&
input
,
void
ComputeFullArg
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
Tensor
&
input
,
Tensor
*
indices
,
const
IndType
pre
,
const
IndType
post
,
Tensor
*
indices
,
const
int64_t
pre
,
const
int64_t
post
,
const
IndType
n
)
{
const
int64_t
n
)
{
auto
cu_stream
=
ctx
.
stream
();
auto
cu_stream
=
ctx
.
stream
();
auto
ComputeBlockSize
=
[](
IndType
col
)
{
auto
ComputeBlockSize
=
[](
int64_t
col
)
{
if
(
col
>
512
)
if
(
col
>
512
)
return
1024
;
return
1024
;
else
if
(
col
>
256
)
else
if
(
col
>
256
)
...
@@ -101,10 +101,10 @@ void ComputeFullArg(const platform::CUDADeviceContext& ctx, const Tensor& input,
...
@@ -101,10 +101,10 @@ void ComputeFullArg(const platform::CUDADeviceContext& ctx, const Tensor& input,
return
8
;
return
8
;
};
};
int
max_grid_dimx
=
ctx
.
GetCUDAMaxGridDimSize
().
x
;
int
64_t
max_grid_dimx
=
ctx
.
GetCUDAMaxGridDimSize
().
x
;
int
height
=
pre
*
post
;
int
64_t
height
=
pre
*
post
;
int
width
=
n
;
int
64_t
width
=
n
;
int
grid_size
=
height
<
max_grid_dimx
?
height
:
max_grid_dimx
;
int
64_t
grid_size
=
height
<
max_grid_dimx
?
height
:
max_grid_dimx
;
const
T
*
in_data
=
input
.
data
<
T
>
();
const
T
*
in_data
=
input
.
data
<
T
>
();
IndType
*
out_data
=
indices
->
mutable_data
<
IndType
>
(
ctx
.
GetPlace
());
IndType
*
out_data
=
indices
->
mutable_data
<
IndType
>
(
ctx
.
GetPlace
());
...
@@ -129,31 +129,60 @@ void ComputeFullArg(const platform::CUDADeviceContext& ctx, const Tensor& input,
...
@@ -129,31 +129,60 @@ void ComputeFullArg(const platform::CUDADeviceContext& ctx, const Tensor& input,
}
}
template
<
typename
T
,
class
Reducer
>
template
<
typename
T
,
class
Reducer
>
class
ArgMinMaxOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
struct
VisitDataCudaArgMinMaxFunctor
{
public:
const
framework
::
ExecutionContext
&
ctx
;
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
explicit
VisitDataCudaArgMinMaxFunctor
(
const
framework
::
ExecutionContext
&
ctx
)
:
ctx
(
ctx
)
{}
template
<
typename
IndType
>
void
apply
()
const
{
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
output
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
auto
*
output
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
int
axis
=
ctx
.
Attr
<
int64_t
>
(
"axis"
);
int
axis
=
ctx
.
Attr
<
int64_t
>
(
"axis"
);
auto
in_dims
=
input
->
dims
();
const
bool
&
flatten
=
ctx
.
Attr
<
bool
>
(
"flatten"
);
axis
=
(
axis
<
0
)
?
(
in_dims
.
size
()
+
axis
)
:
axis
;
framework
::
DDim
input_dims
;
if
(
flatten
)
{
input_dims
=
framework
::
make_ddim
({
input
->
numel
()});
// if flatten, the axis just as 0
axis
=
0
;
}
else
{
input_dims
=
input
->
dims
();
if
(
axis
<
0
)
axis
+=
input
->
dims
().
size
();
}
int64_t
numel
=
input
->
numel
();
int64_t
numel
=
input
->
numel
();
int64_t
groups
=
numel
/
in_dims
[
axis
];
int64_t
groups
=
numel
/
in
put
_dims
[
axis
];
int64_t
pre
=
1
;
int64_t
pre
=
1
;
int64_t
post
=
1
;
int64_t
post
=
1
;
int64_t
n
=
in_dims
[
axis
];
int64_t
n
=
in
put
_dims
[
axis
];
for
(
int
i
=
0
;
i
<
axis
;
i
++
)
{
for
(
int
i
=
0
;
i
<
axis
;
i
++
)
{
pre
*=
in_dims
[
i
];
pre
*=
in
put
_dims
[
i
];
}
}
for
(
int
i
=
axis
+
1
;
i
<
in_dims
.
size
();
i
++
)
{
for
(
int
i
=
axis
+
1
;
i
<
in
put
_dims
.
size
();
i
++
)
{
post
*=
in_dims
[
i
];
post
*=
in
put
_dims
[
i
];
}
}
const
auto
&
dev_ctx
=
ctx
.
cuda_device_context
();
const
auto
&
dev_ctx
=
ctx
.
cuda_device_context
();
ComputeFullArg
<
T
,
int64_t
,
Reducer
>
(
dev_ctx
,
*
input
,
output
,
pre
,
post
,
n
);
ComputeFullArg
<
T
,
IndType
,
Reducer
>
(
dev_ctx
,
*
input
,
output
,
pre
,
post
,
n
);
}
};
template
<
typename
T
,
class
Reducer
>
class
ArgMinMaxOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
&
dtype
=
ctx
.
Attr
<
int
>
(
"dtype"
);
if
(
dtype
<
0
)
{
framework
::
VisitDataType
(
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
framework
::
proto
::
VarType
::
INT64
),
VisitDataCudaArgMinMaxFunctor
<
T
,
Reducer
>
(
ctx
));
return
;
}
framework
::
VisitDataType
(
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
dtype
),
VisitDataCudaArgMinMaxFunctor
<
T
,
Reducer
>
(
ctx
));
}
}
};
};
...
...
paddle/fluid/operators/arg_min_max_op_base.h
浏览文件 @
6b28456e
...
@@ -38,8 +38,9 @@ struct ArgMinMaxFunctor {};
...
@@ -38,8 +38,9 @@ struct ArgMinMaxFunctor {};
struct ArgMinMaxFunctor<DeviceContext, T, Tout, Rank, \
struct ArgMinMaxFunctor<DeviceContext, T, Tout, Rank, \
enum_argminmax_value> { \
enum_argminmax_value> { \
void operator()(const DeviceContext& ctx, const framework::LoDTensor& in, \
void operator()(const DeviceContext& ctx, const framework::LoDTensor& in, \
framework::LoDTensor* out, int64_t axis, bool keepdims) { \
framework::LoDTensor* out, framework::DDim x_dims, \
auto in_eigen = framework::EigenTensor<T, Rank>::From(in); \
int64_t axis, bool keepdims) { \
auto in_eigen = framework::EigenTensor<T, Rank>::From(in, x_dims); \
if (keepdims) { \
if (keepdims) { \
auto out_eigen = framework::EigenTensor<Tout, Rank>::From(*out); \
auto out_eigen = framework::EigenTensor<Tout, Rank>::From(*out); \
out_eigen.device(*(ctx.eigen_device())) = \
out_eigen.device(*(ctx.eigen_device())) = \
...
@@ -68,16 +69,26 @@ struct VisitDataArgMinMaxFunctor {
...
@@ -68,16 +69,26 @@ struct VisitDataArgMinMaxFunctor {
out
.
template
mutable_data
<
Tout
>(
ctx
.
GetPlace
());
out
.
template
mutable_data
<
Tout
>(
ctx
.
GetPlace
());
auto
axis
=
ctx
.
Attr
<
int64_t
>
(
"axis"
);
auto
axis
=
ctx
.
Attr
<
int64_t
>
(
"axis"
);
auto
keepdims
=
ctx
.
Attr
<
bool
>
(
"keepdims"
);
auto
keepdims
=
ctx
.
Attr
<
bool
>
(
"keepdims"
);
auto
x_rank
=
x
.
dims
().
size
();
const
bool
&
flatten
=
ctx
.
Attr
<
bool
>
(
"flatten"
);
if
(
axis
<
0
)
axis
+=
x_rank
;
// if flatten, will construct the new dims for the cacluate
framework
::
DDim
x_dims
;
if
(
flatten
)
{
x_dims
=
framework
::
make_ddim
({
x
.
numel
()});
// if flatten, the axis just as 0
axis
=
0
;
}
else
{
x_dims
=
x
.
dims
();
if
(
axis
<
0
)
axis
+=
x_dims
.
size
();
}
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
#define CALL_ARG_MINMAX_FUNCTOR(rank) \
#define CALL_ARG_MINMAX_FUNCTOR(rank) \
ArgMinMaxFunctor<DeviceContext, T, Tout, rank, EnumArgMinMaxValue> \
ArgMinMaxFunctor<DeviceContext, T, Tout, rank, EnumArgMinMaxValue> \
functor##rank; \
functor##rank; \
functor##rank(dev_ctx, x, &out, axis, keepdims)
functor##rank(dev_ctx, x, &out,
x_dims,
axis, keepdims)
switch
(
x
.
dims
()
.
size
())
{
switch
(
x
_dims
.
size
())
{
case
1
:
case
1
:
CALL_ARG_MINMAX_FUNCTOR
(
1
);
CALL_ARG_MINMAX_FUNCTOR
(
1
);
break
;
break
;
...
@@ -141,6 +152,7 @@ class ArgMinMaxOp : public framework::OperatorWithKernel {
...
@@ -141,6 +152,7 @@ class ArgMinMaxOp : public framework::OperatorWithKernel {
const
auto
&
x_dims
=
ctx
->
GetInputDim
(
"X"
);
const
auto
&
x_dims
=
ctx
->
GetInputDim
(
"X"
);
int64_t
axis
=
ctx
->
Attrs
().
Get
<
int64_t
>
(
"axis"
);
int64_t
axis
=
ctx
->
Attrs
().
Get
<
int64_t
>
(
"axis"
);
bool
keepdims
=
ctx
->
Attrs
().
Get
<
bool
>
(
"keepdims"
);
bool
keepdims
=
ctx
->
Attrs
().
Get
<
bool
>
(
"keepdims"
);
const
bool
&
flatten
=
ctx
->
Attrs
().
Get
<
bool
>
(
"flatten"
);
PADDLE_ENFORCE_GE
(
axis
,
-
x_dims
.
size
(),
PADDLE_ENFORCE_GE
(
axis
,
-
x_dims
.
size
(),
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
...
@@ -152,14 +164,21 @@ class ArgMinMaxOp : public framework::OperatorWithKernel {
...
@@ -152,14 +164,21 @@ class ArgMinMaxOp : public framework::OperatorWithKernel {
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"'axis'(%d) must be less than Rank(X)(%d)."
,
axis
,
x_dims
.
size
()));
"'axis'(%d) must be less than Rank(X)(%d)."
,
axis
,
x_dims
.
size
()));
std
::
vector
<
int64_t
>
vec
;
if
(
flatten
)
{
// if is flatten, will return the only on element
if
(
keepdims
)
{
vec
.
emplace_back
(
static_cast
<
int64_t
>
(
1
));
}
}
else
{
auto
x_rank
=
x_dims
.
size
();
auto
x_rank
=
x_dims
.
size
();
if
(
axis
<
0
)
axis
+=
x_rank
;
if
(
axis
<
0
)
axis
+=
x_rank
;
std
::
vector
<
int64_t
>
vec
;
for
(
int64_t
i
=
0
;
i
<
axis
;
i
++
)
vec
.
emplace_back
(
x_dims
[
i
]);
for
(
int64_t
i
=
0
;
i
<
axis
;
i
++
)
vec
.
push_back
(
x_dims
[
i
]);
if
(
keepdims
)
{
if
(
keepdims
)
{
vec
.
push_back
(
static_cast
<
int64_t
>
(
1
));
vec
.
emplace_back
(
static_cast
<
int64_t
>
(
1
));
}
for
(
int64_t
i
=
axis
+
1
;
i
<
x_rank
;
i
++
)
vec
.
emplace_back
(
x_dims
[
i
]);
}
}
for
(
int64_t
i
=
axis
+
1
;
i
<
x_rank
;
i
++
)
vec
.
push_back
(
x_dims
[
i
]);
ctx
->
SetOutputDim
(
"Out"
,
framework
::
make_ddim
(
vec
));
ctx
->
SetOutputDim
(
"Out"
,
framework
::
make_ddim
(
vec
));
}
}
};
};
...
@@ -176,6 +195,9 @@ class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -176,6 +195,9 @@ class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr
<
int64_t
>
(
"axis"
,
"The axis in which to compute the arg indics."
);
AddAttr
<
int64_t
>
(
"axis"
,
"The axis in which to compute the arg indics."
);
AddAttr
<
bool
>
(
"keepdims"
,
"Keep the dim that to reduce."
).
SetDefault
(
false
);
AddAttr
<
bool
>
(
"keepdims"
,
"Keep the dim that to reduce."
).
SetDefault
(
false
);
AddAttr
<
int
>
(
"dtype"
,
"Keep the dim that to reduce."
).
SetDefault
(
-
1
);
AddAttr
<
int
>
(
"dtype"
,
"Keep the dim that to reduce."
).
SetDefault
(
-
1
);
AddAttr
<
bool
>
(
"flatten"
,
"Flatten the input value, and search the min or max indices"
)
.
SetDefault
(
false
);
AddComment
(
string
::
Sprintf
(
R"DOC(
AddComment
(
string
::
Sprintf
(
R"DOC(
%s Operator.
%s Operator.
...
...
python/paddle/fluid/tests/unittests/test_arg_min_max_op.py
浏览文件 @
6b28456e
...
@@ -201,107 +201,5 @@ class BaseTestComplex2_2(OpTest):
...
@@ -201,107 +201,5 @@ class BaseTestComplex2_2(OpTest):
}
}
class
APT_ArgMaxTest
(
unittest
.
TestCase
):
def
test_output_result
(
self
):
with
fluid
.
program_guard
(
fluid
.
Program
()):
data1
=
fluid
.
data
(
name
=
"X"
,
shape
=
[
3
,
4
],
dtype
=
"float32"
)
data2
=
fluid
.
data
(
name
=
"Y"
,
shape
=
[
3
],
dtype
=
"int64"
)
out
=
paddle
.
argmax
(
input
=
data1
,
out
=
data2
)
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
result
=
exe
.
run
(
feed
=
{
"X"
:
np
.
random
.
rand
(
3
,
4
).
astype
(
"float32"
)},
fetch_list
=
[
data2
,
out
])
self
.
assertEqual
((
result
[
0
]
==
result
[
1
]).
all
(),
True
)
def
test_basic
(
self
):
with
fluid
.
program_guard
(
fluid
.
Program
()):
data
=
fluid
.
data
(
name
=
"X"
,
shape
=
[
3
,
4
],
dtype
=
"float32"
)
out
=
paddle
.
argmax
(
input
=
data
)
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
np_input
=
np
.
random
.
rand
(
3
,
4
).
astype
(
"float32"
)
expected_result
=
np
.
argmax
(
np_input
,
axis
=
1
)
result
,
=
exe
.
run
(
feed
=
{
"X"
:
np_input
},
fetch_list
=
[
out
])
self
.
assertEqual
((
result
==
expected_result
).
all
(),
True
)
with
fluid
.
program_guard
(
fluid
.
Program
()):
data
=
fluid
.
data
(
name
=
"X"
,
shape
=
[
3
,
4
],
dtype
=
"float32"
)
out
=
paddle
.
argmax
(
input
=
data
,
axis
=
0
)
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
np_input
=
np
.
random
.
rand
(
3
,
4
).
astype
(
"float32"
)
expected_result
=
np
.
argmax
(
np_input
,
axis
=
0
)
result
=
exe
.
run
(
feed
=
{
"X"
:
np_input
},
fetch_list
=
[
out
])
self
.
assertEqual
((
result
==
expected_result
).
all
(),
True
)
with
fluid
.
program_guard
(
fluid
.
Program
()):
data
=
fluid
.
data
(
name
=
"X"
,
shape
=
[
3
,
4
],
dtype
=
"float32"
)
out
=
paddle
.
argmax
(
input
=
data
,
dtype
=
"int32"
)
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
np_input
=
np
.
random
.
rand
(
3
,
4
).
astype
(
"float32"
)
expected_result
=
np
.
argmax
(
np_input
,
axis
=
1
).
astype
(
np
.
int32
)
result
=
exe
.
run
(
feed
=
{
"X"
:
np_input
},
fetch_list
=
[
out
])
self
.
assertEqual
((
result
==
expected_result
).
all
(),
True
)
with
fluid
.
program_guard
(
fluid
.
Program
()):
data1
=
fluid
.
data
(
name
=
"X"
,
shape
=
[
3
,
4
],
dtype
=
"float32"
)
data2
=
fluid
.
data
(
name
=
"Y"
,
shape
=
[
3
],
dtype
=
"int64"
)
out
=
paddle
.
argmax
(
input
=
data
,
out
=
data2
)
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
result
=
exe
.
run
(
feed
=
{
"X"
:
np
.
random
.
rand
(
3
,
4
).
astype
(
"float32"
)},
fetch_list
=
[
data2
,
out
])
self
.
assertEqual
((
result
[
0
]
==
result
[
1
]).
all
(),
True
)
def
test_name
(
self
):
with
fluid
.
program_guard
(
fluid
.
Program
()):
x
=
fluid
.
data
(
name
=
"x"
,
shape
=
[
100
],
dtype
=
"float32"
)
y_1
=
paddle
.
argmax
(
x
,
name
=
'arg_max_res'
)
self
.
assertEqual
((
'arg_max_res'
in
y_1
.
name
),
True
)
def
test_errors
(
self
):
def
test_dtype1
():
with
fluid
.
program_guard
(
fluid
.
Program
(),
fluid
.
Program
()):
data
=
fluid
.
data
(
name
=
"data"
,
shape
=
[
10
],
dtype
=
"float32"
)
paddle
.
argmax
(
data
,
dtype
=
"float32"
)
self
.
assertRaises
(
TypeError
,
test_dtype1
)
def
test_dtype2
():
with
fluid
.
program_guard
(
fluid
.
Program
(),
fluid
.
Program
()):
data
=
fluid
.
data
(
name
=
"data"
,
shape
=
[
10
],
dtype
=
"float64"
)
paddle
.
argmax
(
data
,
dtype
=
"float32"
)
self
.
assertRaises
(
TypeError
,
test_dtype2
)
class
TestArgMinMaxOpError
(
unittest
.
TestCase
):
def
test_errors
(
self
):
with
program_guard
(
Program
(),
Program
()):
def
test_argmax_x_type
():
x1
=
[
1
,
2
,
3
]
output
=
fluid
.
layers
.
argmax
(
x
=
x1
)
self
.
assertRaises
(
TypeError
,
test_argmax_x_type
)
def
test_argmin_x_type
():
x2
=
[
1
,
2
,
3
]
output
=
fluid
.
layers
.
argmin
(
x
=
x2
)
self
.
assertRaises
(
TypeError
,
test_argmin_x_type
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py
0 → 100644
浏览文件 @
6b28456e
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid.core
as
core
from
paddle.fluid
import
Program
,
program_guard
def
create_kernel_case
(
op_type
,
numpy_op_type
):
class
ArgMinMaxKernelBaseCase
(
OpTest
):
def
initTestCase
(
self
):
self
.
op_type
=
op_type
self
.
numpy_op_type
=
numpy_op_type
self
.
axis
=
0
def
setUp
(
self
):
np
.
random
.
seed
(
123
)
self
.
initTestCase
()
self
.
dims
=
(
4
,
5
,
6
)
self
.
dtype
=
"float64"
self
.
x
=
(
1000
*
np
.
random
.
random
(
self
.
dims
).
astype
(
self
.
dtype
))
self
.
inputs
=
{
'X'
:
self
.
x
}
self
.
attrs
=
{
"axis"
:
self
.
axis
}
self
.
numpy_op
=
eval
(
"np.%s"
%
(
numpy_op_type
))
self
.
outputs
=
{
'Out'
:
self
.
numpy_op
(
self
.
x
,
axis
=
self
.
axis
)}
def
test_check_output
(
self
):
paddle
.
enable_static
()
self
.
check_output
()
class
ArgMinMaxKernelCase0
(
ArgMinMaxKernelBaseCase
):
def
initTestCase
(
self
):
self
.
op_type
=
op_type
self
.
numpy_op_type
=
numpy_op_type
self
.
axis
=
1
class
ArgMinMaxKernelCase1
(
ArgMinMaxKernelBaseCase
):
def
initTestCase
(
self
):
self
.
op_type
=
op_type
self
.
numpy_op_type
=
numpy_op_type
self
.
axis
=
2
class
ArgMinMaxKernelCase2
(
ArgMinMaxKernelBaseCase
):
def
initTestCase
(
self
):
self
.
op_type
=
op_type
self
.
numpy_op_type
=
numpy_op_type
self
.
axis
=
-
1
class
ArgMinMaxKernelCase3
(
ArgMinMaxKernelBaseCase
):
def
initTestCase
(
self
):
self
.
op_type
=
op_type
self
.
numpy_op_type
=
numpy_op_type
self
.
axis
=
-
2
class
ArgMinMaxKernelCase4
(
ArgMinMaxKernelBaseCase
):
def
setUp
(
self
):
self
.
initTestCase
()
self
.
dims
=
(
4
,
5
,
6
)
self
.
dtype
=
"float64"
self
.
x
=
(
1000
*
np
.
random
.
random
(
self
.
dims
).
astype
(
self
.
dtype
))
self
.
inputs
=
{
'X'
:
self
.
x
}
self
.
attrs
=
{
"axis"
:
self
.
axis
,
"keepdims"
:
True
}
self
.
numpy_op
=
eval
(
"np.%s"
%
(
numpy_op_type
))
self
.
outputs
=
{
'Out'
:
self
.
numpy_op
(
self
.
x
,
axis
=
self
.
axis
).
reshape
((
1
,
5
,
6
))
}
class
ArgMinMaxKernelCase5
(
ArgMinMaxKernelBaseCase
):
def
setUp
(
self
):
self
.
initTestCase
()
self
.
dims
=
(
4
)
self
.
dtype
=
"float64"
self
.
x
=
(
1000
*
np
.
random
.
random
(
self
.
dims
).
astype
(
self
.
dtype
))
self
.
inputs
=
{
'X'
:
self
.
x
}
self
.
attrs
=
{
"axis"
:
self
.
axis
,
"flatten"
:
True
}
self
.
numpy_op
=
eval
(
"np.%s"
%
(
numpy_op_type
))
self
.
outputs
=
{
'Out'
:
self
.
numpy_op
(
self
.
x
.
flatten
(),
axis
=
self
.
axis
)
}
class
ArgMinMaxKernelCase6
(
ArgMinMaxKernelBaseCase
):
def
setUp
(
self
):
self
.
initTestCase
()
self
.
dims
=
(
4
)
self
.
dtype
=
"float64"
self
.
x
=
(
1000
*
np
.
random
.
random
(
self
.
dims
).
astype
(
self
.
dtype
))
self
.
inputs
=
{
'X'
:
self
.
x
}
self
.
attrs
=
{
"axis"
:
self
.
axis
,
"flatten"
:
True
,
"keepdims"
:
True
}
self
.
numpy_op
=
eval
(
"np.%s"
%
(
numpy_op_type
))
self
.
outputs
=
{
'Out'
:
np
.
array
(
self
.
numpy_op
(
self
.
x
.
flatten
(),
axis
=
self
.
axis
))
}
cls_name
=
"ArgMinMaxKernelBaseCase_%s"
%
(
op_type
)
ArgMinMaxKernelBaseCase
.
__name__
=
cls_name
globals
()[
cls_name
]
=
ArgMinMaxKernelBaseCase
cls_name
=
"ArgMinMaxKernelCase0_%s"
%
(
op_type
)
ArgMinMaxKernelCase0
.
__name__
=
cls_name
globals
()[
cls_name
]
=
ArgMinMaxKernelCase0
cls_name
=
"ArgMinMaxKernelCase1_%s"
%
(
op_type
)
ArgMinMaxKernelCase1
.
__name__
=
cls_name
globals
()[
cls_name
]
=
ArgMinMaxKernelCase1
cls_name
=
"ArgMinMaxKernelCase2_%s"
%
(
op_type
)
ArgMinMaxKernelCase2
.
__name__
=
cls_name
globals
()[
cls_name
]
=
ArgMinMaxKernelCase2
cls_name
=
"ArgMinMaxKernelCase3_%s"
%
(
op_type
)
ArgMinMaxKernelCase3
.
__name__
=
cls_name
globals
()[
cls_name
]
=
ArgMinMaxKernelCase3
cls_name
=
"ArgMinMaxKernelCase4_%s"
%
(
op_type
)
ArgMinMaxKernelCase4
.
__name__
=
cls_name
globals
()[
cls_name
]
=
ArgMinMaxKernelCase4
cls_name
=
"ArgMinMaxKernelCase5_%s"
%
(
op_type
)
ArgMinMaxKernelCase5
.
__name__
=
cls_name
globals
()[
cls_name
]
=
ArgMinMaxKernelCase5
cls_name
=
"ArgMinMaxKernelCase6_%s"
%
(
op_type
)
ArgMinMaxKernelCase6
.
__name__
=
cls_name
globals
()[
cls_name
]
=
ArgMinMaxKernelCase6
for
op_type
,
numpy_op_type
in
zip
([
'arg_max'
,
'arg_min'
],
[
'argmax'
,
'argmin'
]):
create_kernel_case
(
op_type
,
numpy_op_type
)
def
create_test_case
(
op_type
):
class
ArgMaxMinTestCase
(
unittest
.
TestCase
):
def
setUp
(
self
):
np
.
random
.
seed
(
123
)
self
.
input_data
=
np
.
random
.
rand
(
10
,
10
).
astype
(
"float32"
)
self
.
places
=
[]
self
.
places
.
append
(
fluid
.
CPUPlace
())
if
core
.
is_compiled_with_cuda
():
self
.
places
.
append
(
paddle
.
CUDAPlace
(
0
))
self
.
op
=
eval
(
"paddle.%s"
%
(
op_type
))
self
.
numpy_op
=
eval
(
"np.%s"
%
(
op_type
))
def
run_static
(
self
,
place
):
paddle
.
enable_static
()
with
paddle
.
static
.
program_guard
(
paddle
.
static
.
Program
()):
data_var
=
paddle
.
static
.
data
(
name
=
"data"
,
shape
=
[
10
,
10
],
dtype
=
"float32"
)
op
=
eval
(
"paddle.%s"
%
(
op_type
))
result
=
op
(
data_var
)
exe
=
paddle
.
static
.
Executor
(
place
)
result_data
=
exe
.
run
(
feed
=
{
"data"
:
self
.
input_data
},
fetch_list
=
[
result
])
expected_data
=
self
.
numpy_op
(
self
.
input_data
)
self
.
assertTrue
((
result_data
==
np
.
array
(
expected_data
)).
all
(),
True
)
with
paddle
.
static
.
program_guard
(
paddle
.
static
.
Program
()):
data_var
=
paddle
.
static
.
data
(
name
=
"data"
,
shape
=
[
10
,
10
],
dtype
=
"float32"
)
op
=
eval
(
"paddle.%s"
%
(
op_type
))
result
=
op
(
data_var
,
axis
=
1
)
exe
=
paddle
.
static
.
Executor
(
place
)
result_data
=
exe
.
run
(
feed
=
{
"data"
:
self
.
input_data
},
fetch_list
=
[
result
])
expected_data
=
self
.
numpy_op
(
self
.
input_data
,
axis
=
1
)
self
.
assertTrue
((
result_data
==
expected_data
).
all
(),
True
)
with
paddle
.
static
.
program_guard
(
paddle
.
static
.
Program
()):
data_var
=
paddle
.
static
.
data
(
name
=
"data"
,
shape
=
[
10
,
10
],
dtype
=
"float32"
)
op
=
eval
(
"paddle.%s"
%
(
op_type
))
result
=
op
(
data_var
,
axis
=-
1
)
exe
=
paddle
.
static
.
Executor
(
place
)
result_data
=
exe
.
run
(
feed
=
{
"data"
:
self
.
input_data
},
fetch_list
=
[
result
])
expected_data
=
self
.
numpy_op
(
self
.
input_data
,
axis
=-
1
)
self
.
assertTrue
((
result_data
==
expected_data
).
all
(),
True
)
with
paddle
.
static
.
program_guard
(
paddle
.
static
.
Program
()):
data_var
=
paddle
.
static
.
data
(
name
=
"data"
,
shape
=
[
10
,
10
],
dtype
=
"float32"
)
op
=
eval
(
"paddle.%s"
%
(
op_type
))
result
=
op
(
data_var
,
axis
=-
1
,
keepdim
=
True
)
exe
=
paddle
.
static
.
Executor
(
place
)
result_data
=
exe
.
run
(
feed
=
{
"data"
:
self
.
input_data
},
fetch_list
=
[
result
])
expected_data
=
self
.
numpy_op
(
self
.
input_data
,
axis
=-
1
).
reshape
((
10
,
1
))
self
.
assertTrue
((
result_data
==
expected_data
).
all
(),
True
)
with
paddle
.
static
.
program_guard
(
paddle
.
static
.
Program
()):
op
=
eval
(
"paddle.%s"
%
(
op_type
))
data_var
=
paddle
.
static
.
data
(
name
=
"data"
,
shape
=
[
10
,
10
],
dtype
=
"float32"
)
result
=
op
(
data_var
,
axis
=-
1
,
name
=
"test_arg_api"
)
self
.
assertTrue
(
"test_arg_api"
in
result
.
name
)
def
run_dygraph
(
self
,
place
):
paddle
.
disable_static
()
op
=
eval
(
"paddle.%s"
%
(
op_type
))
data_tensor
=
paddle
.
to_tensor
(
self
.
input_data
)
#case 1
result_data
=
op
(
data_tensor
)
excepted_data
=
self
.
numpy_op
(
self
.
input_data
)
self
.
assertTrue
((
result_data
.
numpy
()
==
excepted_data
).
all
(),
True
)
#case 2
result_data
=
op
(
data_tensor
,
axis
=
1
)
excepted_data
=
self
.
numpy_op
(
self
.
input_data
,
axis
=
1
)
self
.
assertTrue
((
result_data
.
numpy
()
==
excepted_data
).
all
(),
True
)
#case 3
result_data
=
op
(
data_tensor
,
axis
=-
1
)
excepted_data
=
self
.
numpy_op
(
self
.
input_data
,
axis
=-
1
)
self
.
assertTrue
((
result_data
.
numpy
()
==
excepted_data
).
all
(),
True
)
#case 4
result_data
=
op
(
data_tensor
,
axis
=-
1
,
keepdim
=
True
)
excepted_data
=
self
.
numpy_op
(
self
.
input_data
,
axis
=-
1
)
excepted_data
=
excepted_data
.
reshape
((
10
))
self
.
assertTrue
((
result_data
.
numpy
()
==
excepted_data
).
all
(),
True
)
#case 5
result_data
=
op
(
data_tensor
,
axis
=-
1
,
keepdim
=
True
,
dtype
=
"int32"
)
self
.
assertTrue
(
result_data
.
numpy
().
dtype
==
np
.
int32
)
# case for dim 4, 5, 6, for test case coverage
input_data
=
np
.
random
.
rand
(
5
,
5
,
5
,
5
)
excepted_data
=
self
.
numpy_op
(
input_data
,
axis
=
0
)
result_data
=
op
(
paddle
.
to_tensor
(
input_data
),
axis
=
0
)
self
.
assertTrue
((
result_data
.
numpy
()
==
excepted_data
).
all
(),
True
)
input_data
=
np
.
random
.
rand
(
4
,
4
,
4
,
4
,
4
)
excepted_data
=
self
.
numpy_op
(
input_data
,
axis
=
0
)
result_data
=
op
(
paddle
.
to_tensor
(
input_data
),
axis
=
0
)
self
.
assertTrue
((
result_data
.
numpy
()
==
excepted_data
).
all
(),
True
)
input_data
=
np
.
random
.
rand
(
3
,
3
,
3
,
3
,
3
,
3
)
excepted_data
=
self
.
numpy_op
(
input_data
,
axis
=
0
)
result_data
=
op
(
paddle
.
to_tensor
(
input_data
),
axis
=
0
)
self
.
assertTrue
((
result_data
.
numpy
()
==
excepted_data
).
all
(),
True
)
def
test_case
(
self
):
for
place
in
self
.
places
:
self
.
run_static
(
place
)
self
.
run_dygraph
(
place
)
cls_name
=
"ArgMaxMinTestCase_{}"
.
format
(
op_type
)
ArgMaxMinTestCase
.
__name__
=
cls_name
globals
()[
cls_name
]
=
ArgMaxMinTestCase
for
op_type
in
[
'argmin'
,
'argmax'
]:
create_test_case
(
op_type
)
class
TestArgMinMaxOpError
(
unittest
.
TestCase
):
def
test_errors
(
self
):
paddle
.
enable_static
()
with
program_guard
(
Program
(),
Program
()):
def
test_argmax_x_type
():
x1
=
[
1
,
2
,
3
]
output
=
paddle
.
argmax
(
x
=
x1
)
self
.
assertRaises
(
TypeError
,
test_argmax_x_type
)
def
test_argmin_x_type
():
x2
=
[
1
,
2
,
3
]
output
=
paddle
.
argmin
(
x
=
x2
)
self
.
assertRaises
(
TypeError
,
test_argmin_x_type
)
def
test_argmax_attr_type
():
data
=
paddle
.
static
.
data
(
name
=
"test_argmax"
,
shape
=
[
10
],
dtype
=
"float32"
)
output
=
paddle
.
argmax
(
x
=
data
,
dtype
=
"float32"
)
self
.
assertRaises
(
ValueError
,
test_argmax_attr_type
)
def
test_argmin_attr_type
():
data
=
paddle
.
static
.
data
(
name
=
"test_argmax"
,
shape
=
[
10
],
dtype
=
"float32"
)
output
=
paddle
.
argmin
(
x
=
data
,
dtype
=
"float32"
)
self
.
assertRaises
(
ValueError
,
test_argmin_attr_type
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/tensor/search.py
浏览文件 @
6b28456e
...
@@ -125,95 +125,168 @@ def argsort(x, axis=-1, descending=False, name=None):
...
@@ -125,95 +125,168 @@ def argsort(x, axis=-1, descending=False, name=None):
return
ids
return
ids
def
argmax
(
input
,
axis
=
None
,
dtype
=
None
,
out
=
None
,
keepdims
=
False
,
name
=
None
):
def
argmax
(
x
,
axis
=
None
,
dtype
=
None
,
keepdim
=
False
,
name
=
None
):
"""
"""
:alias_main: paddle.argmax
:alias: paddle.argmax,paddle.tensor.argmax,paddle.tensor.search.argmax
This OP computes the indices of the max elements of the input tensor's
This OP computes the indices of the max elements of the input tensor's
element along the provided axis.
element along the provided axis.
Args:
Args:
input(Variable
): An input N-D Tensor with type float32, float64, int16,
x(Tensor
): An input N-D Tensor with type float32, float64, int16,
int32, int64, uint8.
int32, int64, uint8.
axis(int, optional): Axis to compute indices along. The effective range
axis(int, optional): Axis to compute indices along. The effective range
is [-R, R), where R is
Rank(input). when axis<
0, it works the same way
is [-R, R), where R is
x.ndim. when axis <
0, it works the same way
as axis
+R. Default is None, it will use the last dim to select indices of max value
.
as axis
+ R. Default is None, the input `x` will be into the flatten tensor, and selecting the min value index
.
dtype(
np.dtype|core.VarDesc.VarType|
str): Data type of the output tensor which can
dtype(str): Data type of the output tensor which can
be int32, int64. The default value is None, and it will
be int32, int64. The default value is None, and it will
return the int64 indices.
return the int64 indices.
out(Variable, optional): Optional output which can be any created
keepdim(bool, optional): Keep the axis that selecting max. The defalut value is False.
Variable that meets the requirements to store the result of operation.
if out is None, a new Varibale will be create to store the result. Defalut is None.
keepdims(bool, optional): Keep the axis that do the select max.
name(str, optional): The default value is None. Normally there is no
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
refer to :ref:`api_guide_Name`.
Returns:
Returns:
Variable: A Tensor with data type int64.
Tensor, return the tensor of `int32` if set :attr:`dtype` is `int32`, otherwise return the tensor of `int64`
Examples:
Examples:
.. code-block:: python
.. code-block:: python
import paddle
import paddle.fluid as fluid
import numpy as np
import numpy as np
import paddle
in1 = np.array([[[5,8,9,5],
paddle.disable_static()
data = np.array([[5,8,9,5],
[0,0,1,7],
[0,0,1,7],
[6,9,2,4]],
[6,9,2,4]])
[[5,2,4,2],
x = paddle.to_variable(data)
[4,7,7,9],
out1 = paddle.argmax(x)
[1,7,0,6]]])
print(out1.numpy()) # 2
with fluid.dygraph.guard():
out2 = paddle.argmax(x, axis=1)
x = fluid.dygraph.to_variable(in1)
out1 = paddle.argmax(input=x, axis=-1)
out2 = paddle.argmax(input=x, axis=0)
out3 = paddle.argmax(input=x, axis=1)
out4 = paddle.argmax(input=x, axis=2)
out5 = paddle.argmax(input=x, axis=2, keepdims=True)
print(out1.numpy())
# [[2 3 1]
# [0 3 1]]
print(out2.numpy())
print(out2.numpy())
# [[0 0 0 0]
# [2 3 1]
# [1 1 1 1]
out3 = paddle.argmax(x, axis=-1)
# [0 0 0 1]]
print(out3.numpy())
print(out3.numpy())
# [[2 2 0 1]
# [2 3 1]
# [0 1 1 1]]
print(out4.numpy())
# [[2 3 1]
# [0 3 1]]
print(out5.numpy())
#array([[[2],
# [3],
# [1]],
# [[0],
# [3],
# [1]]])
"""
"""
helper
=
LayerHelper
(
"arg_max"
,
**
locals
())
flatten
=
False
if
axis
is
None
:
flatten
=
True
axis
=
0
if
in_dygraph_mode
():
if
dtype
!=
None
:
var_dtype
=
convert_np_dtype_to_dtype_
(
dtype
)
out
=
core
.
ops
.
arg_max
(
x
,
'axis'
,
axis
,
'dtype'
,
var_dtype
,
'keepdim'
,
keepdim
,
'flatten'
,
flatten
)
else
:
out
=
core
.
ops
.
arg_max
(
x
,
'axis'
,
axis
,
'keepdim'
,
keepdim
,
'flatten'
,
flatten
)
return
out
helper
=
LayerHelper
(
"argmax"
,
**
locals
())
check_variable_and_dtype
(
x
,
'x'
,
[
'float32'
,
'float64'
,
'int16'
,
'int32'
,
'int64'
,
'uint8'
],
'paddle.argmax'
)
var_dtype
=
None
var_dtype
=
None
attrs
=
{}
attrs
=
{}
if
dtype
is
not
None
:
if
dtype
is
not
None
:
check_dtype
(
dtype
,
'create data type'
,
[
'int32'
,
'int64'
],
'arg_max'
)
if
dtype
not
in
[
'int32'
,
'int64'
]:
raise
ValueError
(
"The value of 'dtype' in argmax op must be int32, int64, but received of {}"
.
format
(
dtype
))
var_dtype
=
convert_np_dtype_to_dtype_
(
dtype
)
var_dtype
=
convert_np_dtype_to_dtype_
(
dtype
)
attrs
[
"dtype"
]
=
var_dtype
attrs
[
"dtype"
]
=
var_dtype
else
:
else
:
var_dtype
=
VarDesc
.
VarType
.
INT64
var_dtype
=
VarDesc
.
VarType
.
INT64
if
out
is
None
:
out
=
helper
.
create_variable_for_type_inference
(
var_dtype
)
out
=
helper
.
create_variable_for_type_inference
(
var_dtype
)
attrs
[
'keepdims'
]
=
keepdim
attrs
[
'axis'
]
=
axis
attrs
[
'flatten'
]
=
flatten
helper
.
append_op
(
type
=
'arg_max'
,
inputs
=
{
'X'
:
x
},
outputs
=
{
'Out'
:
[
out
]},
attrs
=
attrs
)
out
.
stop_gradient
=
True
return
out
def
argmin
(
x
,
axis
=
None
,
dtype
=
None
,
keepdim
=
False
,
name
=
None
):
"""
This OP computes the indices of the min elements of the input tensor's
element along the provided axis.
Args:
x(Tensor): An input N-D Tensor with type float32, float64, int16,
int32, int64, uint8.
axis(int, optional): Axis to compute indices along. The effective range
is [-R, R), where R is x.ndim. when axis < 0, it works the same way
as axis + R. Default is None, the input `x` will be into the flatten tensor, and selecting the min value index.
dtype(str): Data type of the output tensor which can
be int32, int64. The default value is None, and it will
return the int64 indices.
keepdim(bool, optional): Keep the axis that selecting min. The defalut value is False.
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
Returns:
Tensor, return the tensor of `int32` if set :attr:`dtype` is `int32`, otherwise return the tensor of `int64`
Examples:
.. code-block:: python
import numpy as np
import paddle
paddle.disable_static()
data = np.array([[5,8,9,5],
[0,0,1,7],
[6,9,2,4]])
x = paddle.to_variable(data)
out1 = paddle.argmin(x)
print(out1.numpy()) # 4
out2 = paddle.argmin(x, axis=1)
print(out2.numpy())
# [0 0 2]
out3 = paddle.argmin(x, axis=-1)
print(out3.numpy())
# [0 0 2]
"""
flatten
=
False
if
axis
is
None
:
if
axis
is
None
:
axis
=
-
1
flatten
=
True
attrs
[
'keepdims'
]
=
keepdims
axis
=
0
if
in_dygraph_mode
():
if
dtype
!=
None
:
var_dtype
=
convert_np_dtype_to_dtype_
(
dtype
)
out
=
core
.
ops
.
arg_min
(
x
,
'axis'
,
axis
,
'dtype'
,
var_dtype
,
'keepdim'
,
keepdim
,
'flatten'
,
flatten
)
else
:
out
=
core
.
ops
.
arg_min
(
x
,
'axis'
,
axis
,
'keepdim'
,
keepdim
,
'flatten'
,
flatten
)
return
out
helper
=
LayerHelper
(
"argmin"
,
**
locals
())
check_variable_and_dtype
(
x
,
'x'
,
[
'float32'
,
'float64'
,
'int16'
,
'int32'
,
'int64'
,
'uint8'
],
'paddle.argmin'
)
var_dtype
=
None
attrs
=
{}
if
dtype
is
not
None
:
if
dtype
not
in
[
'int32'
,
'int64'
]:
raise
ValueError
(
"The value of 'dtype' in argmin op must be int32, int64, but received of {}"
.
format
(
dtype
))
var_dtype
=
convert_np_dtype_to_dtype_
(
dtype
)
attrs
[
"dtype"
]
=
var_dtype
else
:
var_dtype
=
VarDesc
.
VarType
.
INT64
out
=
helper
.
create_variable_for_type_inference
(
var_dtype
)
attrs
[
'keepdims'
]
=
keepdim
attrs
[
'axis'
]
=
axis
attrs
[
'axis'
]
=
axis
attrs
[
'flatten'
]
=
flatten
helper
.
append_op
(
helper
.
append_op
(
type
=
'arg_max'
,
type
=
'arg_min'
,
inputs
=
{
'X'
:
x
},
outputs
=
{
'Out'
:
[
out
]},
attrs
=
attrs
)
inputs
=
{
'X'
:
input
},
outputs
=
{
'Out'
:
[
out
]},
attrs
=
attrs
)
out
.
stop_gradient
=
True
out
.
stop_gradient
=
True
return
out
return
out
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录