Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
897e5746
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
897e5746
编写于
9月 04, 2020
作者:
W
wawltor
提交者:
GitHub
9月 04, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
cherry-pick from the develop PR#26792, fix the argmin, argmax
cherry-pick from the develop PR#26792, fix the argmin, argmax
上级
2c298d62
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
108 addition
and
56 deletion
+108
-56
paddle/fluid/operators/arg_max_op.cc
paddle/fluid/operators/arg_max_op.cc
+18
-0
paddle/fluid/operators/arg_min_max_op_base.h
paddle/fluid/operators/arg_min_max_op_base.h
+28
-7
paddle/fluid/operators/arg_min_op.cc
paddle/fluid/operators/arg_min_op.cc
+18
-0
python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py
...on/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py
+18
-4
python/paddle/tensor/search.py
python/paddle/tensor/search.py
+26
-45
未找到文件。
paddle/fluid/operators/arg_max_op.cc
浏览文件 @
897e5746
...
...
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/arg_min_max_op_base.h"
REGISTER_OPERATOR
(
...
...
@@ -31,3 +32,20 @@ REGISTER_OP_CPU_KERNEL(
int16_t
>
,
paddle
::
operators
::
ArgMaxKernel
<
paddle
::
platform
::
CPUDeviceContext
,
uint8_t
>
);
REGISTER_OP_VERSION
(
arg_max
)
.
AddCheckpoint
(
R"ROC(
Upgrade argmax add a new attribute [flatten] and modify the attribute of dtype)ROC"
,
paddle
::
framework
::
compatible
::
OpVersionDesc
()
.
NewAttr
(
"flatten"
,
"In order to compute the argmax over the flattened array "
"when the "
"argument `axis` in python API is None."
,
false
)
.
ModifyAttr
(
"dtype"
,
"change the default value of dtype, the older version "
"is -1, means return the int64 indices."
"The new version is 3, return the int64 indices directly."
"And supporting the dtype of -1 in new version."
,
3
));
paddle/fluid/operators/arg_min_max_op_base.h
浏览文件 @
897e5746
...
...
@@ -70,6 +70,8 @@ struct VisitDataArgMinMaxFunctor {
auto
axis
=
ctx
.
Attr
<
int64_t
>
(
"axis"
);
auto
keepdims
=
ctx
.
Attr
<
bool
>
(
"keepdims"
);
const
bool
&
flatten
=
ctx
.
Attr
<
bool
>
(
"flatten"
);
// paddle do not have the scalar tensor, just return the shape [1] tensor
if
(
flatten
)
keepdims
=
true
;
// if flatten, will construct the new dims for the cacluate
framework
::
DDim
x_dims
;
...
...
@@ -164,15 +166,30 @@ class ArgMinMaxOp : public framework::OperatorWithKernel {
platform
::
errors
::
InvalidArgument
(
"'axis'(%d) must be less than Rank(X)(%d)."
,
axis
,
x_dims
.
size
()));
auto
x_rank
=
x_dims
.
size
();
if
(
axis
<
0
)
axis
+=
x_rank
;
if
(
ctx
->
IsRuntime
())
{
const
int
&
dtype
=
ctx
->
Attrs
().
Get
<
int
>
(
"dtype"
);
if
(
dtype
==
framework
::
proto
::
VarType
::
INT32
)
{
int64_t
all_element_num
=
0
;
if
(
flatten
)
{
all_element_num
=
framework
::
product
(
x_dims
);
}
else
{
all_element_num
=
x_dims
[
axis
];
}
PADDLE_ENFORCE_LE
(
all_element_num
,
INT_MAX
,
"The element num of the argmin/argmax input at axis is "
"%d, is larger than int32 maximum value:%d, you must "
"set the dtype of argmin/argmax to 'int64'."
,
all_element_num
,
INT_MAX
);
}
}
std
::
vector
<
int64_t
>
vec
;
if
(
flatten
)
{
// if is flatten, will return the only on element
if
(
keepdims
)
{
vec
.
emplace_back
(
static_cast
<
int64_t
>
(
1
));
}
vec
.
emplace_back
(
static_cast
<
int64_t
>
(
1
));
}
else
{
auto
x_rank
=
x_dims
.
size
();
if
(
axis
<
0
)
axis
+=
x_rank
;
for
(
int64_t
i
=
0
;
i
<
axis
;
i
++
)
vec
.
emplace_back
(
x_dims
[
i
]);
if
(
keepdims
)
{
vec
.
emplace_back
(
static_cast
<
int64_t
>
(
1
));
...
...
@@ -194,10 +211,14 @@ class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput
(
"Out"
,
"Output tensor."
);
AddAttr
<
int64_t
>
(
"axis"
,
"The axis in which to compute the arg indics."
);
AddAttr
<
bool
>
(
"keepdims"
,
"Keep the dim that to reduce."
).
SetDefault
(
false
);
AddAttr
<
int
>
(
"dtype"
,
"Keep the dim that to reduce."
).
SetDefault
(
-
1
);
AddAttr
<
bool
>
(
"flatten"
,
"Flatten the input value, and search the min or max indices"
)
.
SetDefault
(
false
);
AddAttr
<
int
>
(
"dtype"
,
"(int, 3), the dtype of indices, the indices dtype must be "
"int32, int64."
"default dtype is int64, and proto value is 3."
)
.
SetDefault
(
3
);
AddComment
(
string
::
Sprintf
(
R"DOC(
%s Operator.
...
...
paddle/fluid/operators/arg_min_op.cc
浏览文件 @
897e5746
...
...
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/arg_min_max_op_base.h"
REGISTER_OPERATOR
(
...
...
@@ -31,3 +32,20 @@ REGISTER_OP_CPU_KERNEL(
int16_t
>
,
paddle
::
operators
::
ArgMinKernel
<
paddle
::
platform
::
CPUDeviceContext
,
uint8_t
>
);
REGISTER_OP_VERSION
(
arg_min
)
.
AddCheckpoint
(
R"ROC(
Upgrade argmin add a new attribute [flatten] and modify the attribute of dtype)ROC"
,
paddle
::
framework
::
compatible
::
OpVersionDesc
()
.
NewAttr
(
"flatten"
,
"In order to compute the argmin over the flattened array "
"when the "
"argument `axis` in python API is None."
,
false
)
.
ModifyAttr
(
"dtype"
,
"change the default value of dtype, the older version "
"is -1, means return the int64 indices."
"The new version is 3, return the int64 indices directly."
"And supporting the dtype of -1 in new version."
,
3
));
python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py
浏览文件 @
897e5746
...
...
@@ -218,7 +218,7 @@ def create_test_case(op_type):
self
.
assertTrue
(
"test_arg_api"
in
result
.
name
)
def
run_dygraph
(
self
,
place
):
paddle
.
disable_static
()
paddle
.
disable_static
(
place
)
op
=
eval
(
"paddle.%s"
%
(
op_type
))
data_tensor
=
paddle
.
to_tensor
(
self
.
input_data
)
...
...
@@ -240,7 +240,7 @@ def create_test_case(op_type):
#case 4
result_data
=
op
(
data_tensor
,
axis
=-
1
,
keepdim
=
True
)
excepted_data
=
self
.
numpy_op
(
self
.
input_data
,
axis
=-
1
)
excepted_data
=
excepted_data
.
reshape
((
10
))
excepted_data
=
excepted_data
.
reshape
((
10
,
1
))
self
.
assertTrue
((
result_data
.
numpy
()
==
excepted_data
).
all
(),
True
)
#case 5
...
...
@@ -299,14 +299,28 @@ class TestArgMinMaxOpError(unittest.TestCase):
name
=
"test_argmax"
,
shape
=
[
10
],
dtype
=
"float32"
)
output
=
paddle
.
argmax
(
x
=
data
,
dtype
=
"float32"
)
self
.
assertRaises
(
Valu
eError
,
test_argmax_attr_type
)
self
.
assertRaises
(
Typ
eError
,
test_argmax_attr_type
)
def
test_argmin_attr_type
():
data
=
paddle
.
static
.
data
(
name
=
"test_argmax"
,
shape
=
[
10
],
dtype
=
"float32"
)
output
=
paddle
.
argmin
(
x
=
data
,
dtype
=
"float32"
)
self
.
assertRaises
(
ValueError
,
test_argmin_attr_type
)
self
.
assertRaises
(
TypeError
,
test_argmin_attr_type
)
def
test_argmax_axis_type
():
data
=
paddle
.
static
.
data
(
name
=
"test_argmax"
,
shape
=
[
10
],
dtype
=
"float32"
)
output
=
paddle
.
argmax
(
x
=
data
,
axis
=
1.2
)
self
.
assertRaises
(
TypeError
,
test_argmax_axis_type
)
def
test_argmin_axis_type
():
data
=
paddle
.
static
.
data
(
name
=
"test_argmin"
,
shape
=
[
10
],
dtype
=
"float32"
)
output
=
paddle
.
argmin
(
x
=
data
,
axis
=
1.2
)
self
.
assertRaises
(
TypeError
,
test_argmin_axis_type
)
if
__name__
==
'__main__'
:
...
...
python/paddle/tensor/search.py
浏览文件 @
897e5746
...
...
@@ -18,7 +18,6 @@ from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtyp
from
..fluid
import
core
,
layers
# TODO: define searching & indexing functions of a tensor
from
..fluid.layers
import
argmin
#DEFINE_ALIAS
from
..fluid.layers
import
has_inf
#DEFINE_ALIAS
from
..fluid.layers
import
has_nan
#DEFINE_ALIAS
...
...
@@ -123,7 +122,7 @@ def argsort(x, axis=-1, descending=False, name=None):
return
ids
def
argmax
(
x
,
axis
=
None
,
dtype
=
None
,
keepdim
=
False
,
name
=
None
):
def
argmax
(
x
,
axis
=
None
,
keepdim
=
False
,
dtype
=
"int64"
,
name
=
None
):
"""
This OP computes the indices of the max elements of the input tensor's
element along the provided axis.
...
...
@@ -134,10 +133,10 @@ def argmax(x, axis=None, dtype=None, keepdim=False, name=None):
axis(int, optional): Axis to compute indices along. The effective range
is [-R, R), where R is x.ndim. when axis < 0, it works the same way
as axis + R. Default is None, the input `x` will be into the flatten tensor, and selecting the min value index.
dtype(str): Data type of the output tensor which can
be int32, int64. The default value is None, and it will
return the int64 indices.
keepdim(bool, optional): Keep the axis that selecting max. The defalut value is False.
dtype(str|np.dtype, optional): Data type of the output tensor which can
be int32, int64. The default value is 'int64', and it will
return the int64 indices.
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
...
...
@@ -163,48 +162,39 @@ def argmax(x, axis=None, dtype=None, keepdim=False, name=None):
print(out3.numpy())
# [2 3 1]
"""
if
axis
is
not
None
and
not
isinstance
(
axis
,
int
):
raise
TypeError
(
"The type of 'axis' must be int or None in argmax, but received %s."
%
(
type
(
axis
)))
var_dtype
=
convert_np_dtype_to_dtype_
(
dtype
)
check_dtype
(
var_dtype
,
'dtype'
,
[
'int32'
,
'int64'
],
'argmin'
)
flatten
=
False
if
axis
is
None
:
flatten
=
True
axis
=
0
if
in_dygraph_mode
():
if
dtype
!=
None
:
var_dtype
=
convert_np_dtype_to_dtype_
(
dtype
)
out
=
core
.
ops
.
arg_max
(
x
,
'axis'
,
axis
,
'dtype'
,
var_dtype
,
'keepdim'
,
keepdim
,
'flatten'
,
flatten
)
else
:
out
=
core
.
ops
.
arg_max
(
x
,
'axis'
,
axis
,
'keepdim'
,
keepdim
,
'flatten'
,
flatten
)
out
=
core
.
ops
.
arg_max
(
x
,
'axis'
,
axis
,
'dtype'
,
var_dtype
,
'keepdims'
,
keepdim
,
'flatten'
,
flatten
)
return
out
helper
=
LayerHelper
(
"argmax"
,
**
locals
())
check_variable_and_dtype
(
x
,
'x'
,
[
'float32'
,
'float64'
,
'int16'
,
'int32'
,
'int64'
,
'uint8'
],
'paddle.argmax'
)
var_dtype
=
None
attrs
=
{}
if
dtype
is
not
None
:
if
dtype
not
in
[
'int32'
,
'int64'
]:
raise
ValueError
(
"The value of 'dtype' in argmax op must be int32, int64, but received of {}"
.
format
(
dtype
))
var_dtype
=
convert_np_dtype_to_dtype_
(
dtype
)
attrs
[
"dtype"
]
=
var_dtype
else
:
var_dtype
=
VarDesc
.
VarType
.
INT64
out
=
helper
.
create_variable_for_type_inference
(
var_dtype
)
attrs
[
'keepdims'
]
=
keepdim
attrs
[
'axis'
]
=
axis
attrs
[
'flatten'
]
=
flatten
attrs
[
'dtype'
]
=
var_dtype
helper
.
append_op
(
type
=
'arg_max'
,
inputs
=
{
'X'
:
x
},
outputs
=
{
'Out'
:
[
out
]},
attrs
=
attrs
)
out
.
stop_gradient
=
True
return
out
def
argmin
(
x
,
axis
=
None
,
dtype
=
None
,
keepdim
=
False
,
name
=
None
):
def
argmin
(
x
,
axis
=
None
,
keepdim
=
False
,
dtype
=
"int64"
,
name
=
None
):
"""
This OP computes the indices of the min elements of the input tensor's
element along the provided axis.
...
...
@@ -215,10 +205,10 @@ def argmin(x, axis=None, dtype=None, keepdim=False, name=None):
axis(int, optional): Axis to compute indices along. The effective range
is [-R, R), where R is x.ndim. when axis < 0, it works the same way
as axis + R. Default is None, the input `x` will be into the flatten tensor, and selecting the min value index.
keepdim(bool, optional): Keep the axis that selecting min. The defalut value is False.
dtype(str): Data type of the output tensor which can
be int32, int64. The default value is
None
, and it will
be int32, int64. The default value is
'int64'
, and it will
return the int64 indices.
keepdim(bool, optional): Keep the axis that selecting min. The defalut value is False.
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
...
...
@@ -244,41 +234,32 @@ def argmin(x, axis=None, dtype=None, keepdim=False, name=None):
print(out3.numpy())
# [0 0 2]
"""
if
axis
is
not
None
and
not
isinstance
(
axis
,
int
):
raise
TypeError
(
"The type of 'axis' must be int or None in argmin, but received %s."
%
(
type
(
axis
)))
var_dtype
=
convert_np_dtype_to_dtype_
(
dtype
)
check_dtype
(
var_dtype
,
'dtype'
,
[
'int32'
,
'int64'
],
'argmin'
)
flatten
=
False
if
axis
is
None
:
flatten
=
True
axis
=
0
if
in_dygraph_mode
():
if
dtype
!=
None
:
var_dtype
=
convert_np_dtype_to_dtype_
(
dtype
)
out
=
core
.
ops
.
arg_min
(
x
,
'axis'
,
axis
,
'dtype'
,
var_dtype
,
'keepdim'
,
keepdim
,
'flatten'
,
flatten
)
else
:
out
=
core
.
ops
.
arg_min
(
x
,
'axis'
,
axis
,
'keepdim'
,
keepdim
,
'flatten'
,
flatten
)
out
=
core
.
ops
.
arg_min
(
x
,
'axis'
,
axis
,
'dtype'
,
var_dtype
,
'keepdims'
,
keepdim
,
'flatten'
,
flatten
)
return
out
helper
=
LayerHelper
(
"argmin"
,
**
locals
())
check_variable_and_dtype
(
x
,
'x'
,
[
'float32'
,
'float64'
,
'int16'
,
'int32'
,
'int64'
,
'uint8'
],
'paddle.argmin'
)
var_dtype
=
None
attrs
=
{}
if
dtype
is
not
None
:
if
dtype
not
in
[
'int32'
,
'int64'
]:
raise
ValueError
(
"The value of 'dtype' in argmin op must be int32, int64, but received of {}"
.
format
(
dtype
))
var_dtype
=
convert_np_dtype_to_dtype_
(
dtype
)
attrs
[
"dtype"
]
=
var_dtype
else
:
var_dtype
=
VarDesc
.
VarType
.
INT64
out
=
helper
.
create_variable_for_type_inference
(
var_dtype
)
attrs
=
{}
attrs
[
'keepdims'
]
=
keepdim
attrs
[
'axis'
]
=
axis
attrs
[
'flatten'
]
=
flatten
attrs
[
'dtype'
]
=
var_dtype
helper
.
append_op
(
type
=
'arg_min'
,
inputs
=
{
'X'
:
x
},
outputs
=
{
'Out'
:
[
out
]},
attrs
=
attrs
)
out
.
stop_gradient
=
True
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录