Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
6fc15986
P
Paddle
项目概览
PaddlePaddle
/
Paddle
接近 2 年 前同步成功
通知
2323
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6fc15986
编写于
8月 30, 2022
作者:
W
WangZhen
提交者:
GitHub
8月 30, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[OpAttr]Adapt tensor axis for argmin/max (#45453)
* Adapt tensor axis for argmin/max * Add UT * Polish UT
上级
5f1a8e46
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
159 addition
and
41 deletion
+159
-41
paddle/fluid/operators/arg_min_max_op_base.h
paddle/fluid/operators/arg_min_max_op_base.h
+9
-1
paddle/phi/api/yaml/legacy_api.yaml
paddle/phi/api/yaml/legacy_api.yaml
+2
-2
paddle/phi/infermeta/unary.cc
paddle/phi/infermeta/unary.cc
+42
-21
paddle/phi/infermeta/unary.h
paddle/phi/infermeta/unary.h
+1
-1
paddle/phi/kernels/arg_min_max_kernel.h
paddle/phi/kernels/arg_min_max_kernel.h
+3
-2
paddle/phi/kernels/cpu/arg_min_max_kernel.cc
paddle/phi/kernels/cpu/arg_min_max_kernel.cc
+5
-5
paddle/phi/kernels/gpu/arg_min_max_kernel.cu
paddle/phi/kernels/gpu/arg_min_max_kernel.cu
+5
-5
python/paddle/fluid/tests/unittests/test_arg_min_max_op.py
python/paddle/fluid/tests/unittests/test_arg_min_max_op.py
+88
-0
python/paddle/tensor/search.py
python/paddle/tensor/search.py
+4
-4
未找到文件。
paddle/fluid/operators/arg_min_max_op_base.h
浏览文件 @
6fc15986
...
@@ -31,6 +31,13 @@ namespace operators {
...
@@ -31,6 +31,13 @@ namespace operators {
class
ArgMinMaxOp
:
public
framework
::
OperatorWithKernel
{
class
ArgMinMaxOp
:
public
framework
::
OperatorWithKernel
{
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
input_data_type
=
framework
::
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
);
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
());
}
};
};
class
BaseArgMinMaxOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
BaseArgMinMaxOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
...
@@ -42,7 +49,8 @@ class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -42,7 +49,8 @@ class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker {
void
Make
()
override
{
void
Make
()
override
{
AddInput
(
"X"
,
"Input tensor."
);
AddInput
(
"X"
,
"Input tensor."
);
AddOutput
(
"Out"
,
"Output tensor."
);
AddOutput
(
"Out"
,
"Output tensor."
);
AddAttr
<
int64_t
>
(
"axis"
,
"The axis in which to compute the arg indics."
);
AddAttr
<
int64_t
>
(
"axis"
,
"The axis in which to compute the arg indics."
)
.
SupportTensor
();
AddAttr
<
bool
>
(
"keepdims"
,
"Keep the dim that to reduce."
).
SetDefault
(
false
);
AddAttr
<
bool
>
(
"keepdims"
,
"Keep the dim that to reduce."
).
SetDefault
(
false
);
AddAttr
<
bool
>
(
"flatten"
,
AddAttr
<
bool
>
(
"flatten"
,
"Flatten the input value, and search the min or max indices"
)
"Flatten the input value, and search the min or max indices"
)
...
...
paddle/phi/api/yaml/legacy_api.yaml
浏览文件 @
6fc15986
...
@@ -197,7 +197,7 @@
...
@@ -197,7 +197,7 @@
support_trans_dtype
:
start, end, step
support_trans_dtype
:
start, end, step
-
api
:
argmax
-
api
:
argmax
args
:
(Tensor x,
int64_t
axis, bool keepdims, bool flatten, int dtype)
args
:
(Tensor x,
Scalar
axis, bool keepdims, bool flatten, int dtype)
output
:
Tensor(out)
output
:
Tensor(out)
infer_meta
:
infer_meta
:
func
:
ArgMinMaxInferMeta
func
:
ArgMinMaxInferMeta
...
@@ -205,7 +205,7 @@
...
@@ -205,7 +205,7 @@
func
:
arg_max
func
:
arg_max
-
api
:
argmin
-
api
:
argmin
args
:
(Tensor x,
int64_t
axis, bool keepdims, bool flatten, int dtype)
args
:
(Tensor x,
Scalar
axis, bool keepdims, bool flatten, int dtype)
output
:
Tensor(out)
output
:
Tensor(out)
infer_meta
:
infer_meta
:
func
:
ArgMinMaxInferMeta
func
:
ArgMinMaxInferMeta
...
...
paddle/phi/infermeta/unary.cc
浏览文件 @
6fc15986
...
@@ -121,28 +121,12 @@ void AffineGridInferMeta(const MetaTensor& input,
...
@@ -121,28 +121,12 @@ void AffineGridInferMeta(const MetaTensor& input,
}
}
void
ArgMinMaxInferMeta
(
const
MetaTensor
&
x
,
void
ArgMinMaxInferMeta
(
const
MetaTensor
&
x
,
int64_t
axis
,
const
Scalar
&
axis
,
bool
keepdims
,
bool
keepdims
,
bool
flatten
,
bool
flatten
,
int
dtype
,
int
dtype
,
MetaTensor
*
out
,
MetaTensor
*
out
,
MetaConfig
config
)
{
MetaConfig
config
)
{
const
auto
&
x_dims
=
x
.
dims
();
PADDLE_ENFORCE_GE
(
axis
,
-
x_dims
.
size
(),
phi
::
errors
::
InvalidArgument
(
"'axis'(%d) must be greater than or equal to"
" -Rank(X)(%d)."
,
axis
,
-
x_dims
.
size
()));
PADDLE_ENFORCE_LT
(
axis
,
x_dims
.
size
(),
phi
::
errors
::
InvalidArgument
(
"'axis'(%d) must be less than Rank(X)(%d) of Input(X)."
,
axis
,
x_dims
.
size
()));
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
(
dtype
<
0
||
dtype
==
2
||
dtype
==
3
),
(
dtype
<
0
||
dtype
==
2
||
dtype
==
3
),
true
,
true
,
...
@@ -156,8 +140,45 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
...
@@ -156,8 +140,45 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
paddle
::
framework
::
DataTypeToString
(
paddle
::
framework
::
DataTypeToString
(
static_cast
<
paddle
::
framework
::
proto
::
VarType
::
Type
>
(
dtype
))));
static_cast
<
paddle
::
framework
::
proto
::
VarType
::
Type
>
(
dtype
))));
if
(
!
config
.
is_runtime
&&
axis
.
FromTensor
())
{
std
::
vector
<
int64_t
>
vec
;
if
(
flatten
)
{
vec
=
{
1
};
}
else
{
if
(
keepdims
)
{
vec
=
std
::
vector
<
int64_t
>
(
x
.
dims
().
size
(),
-
1
);
}
else
{
vec
=
std
::
vector
<
int64_t
>
(
x
.
dims
().
size
()
-
1
,
-
1
);
}
}
out
->
set_dims
(
phi
::
make_ddim
(
vec
));
if
(
dtype
==
2
)
{
out
->
set_dtype
(
DataType
::
INT32
);
}
else
if
(
dtype
==
3
)
{
out
->
set_dtype
(
DataType
::
INT64
);
}
return
;
}
auto
int_axis
=
axis
.
to
<
int64_t
>
();
const
auto
&
x_dims
=
x
.
dims
();
PADDLE_ENFORCE_GE
(
int_axis
,
-
x_dims
.
size
(),
phi
::
errors
::
InvalidArgument
(
"'axis'(%d) must be greater than or equal to"
" -Rank(X)(%d)."
,
int_axis
,
-
x_dims
.
size
()));
PADDLE_ENFORCE_LT
(
int_axis
,
x_dims
.
size
(),
phi
::
errors
::
InvalidArgument
(
"'axis'(%d) must be less than Rank(X)(%d) of Input(X)."
,
int_axis
,
x_dims
.
size
()));
auto
x_rank
=
x_dims
.
size
();
auto
x_rank
=
x_dims
.
size
();
if
(
axis
<
0
)
axis
+=
x_rank
;
if
(
int_axis
<
0
)
int_
axis
+=
x_rank
;
if
(
config
.
is_runtime
)
{
if
(
config
.
is_runtime
)
{
if
(
dtype
==
paddle
::
framework
::
proto
::
VarType
::
INT32
)
{
if
(
dtype
==
paddle
::
framework
::
proto
::
VarType
::
INT32
)
{
int64_t
all_element_num
=
0
;
int64_t
all_element_num
=
0
;
...
@@ -165,7 +186,7 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
...
@@ -165,7 +186,7 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
all_element_num
=
phi
::
product
(
x_dims
);
all_element_num
=
phi
::
product
(
x_dims
);
}
else
{
}
else
{
all_element_num
=
x_dims
[
axis
];
all_element_num
=
x_dims
[
int_
axis
];
}
}
PADDLE_ENFORCE_LE
(
PADDLE_ENFORCE_LE
(
all_element_num
,
all_element_num
,
...
@@ -182,11 +203,11 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
...
@@ -182,11 +203,11 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
if
(
flatten
)
{
if
(
flatten
)
{
vec
.
emplace_back
(
static_cast
<
int64_t
>
(
1
));
vec
.
emplace_back
(
static_cast
<
int64_t
>
(
1
));
}
else
{
}
else
{
for
(
int64_t
i
=
0
;
i
<
axis
;
i
++
)
vec
.
emplace_back
(
x_dims
[
i
]);
for
(
int64_t
i
=
0
;
i
<
int_
axis
;
i
++
)
vec
.
emplace_back
(
x_dims
[
i
]);
if
(
keepdims
)
{
if
(
keepdims
)
{
vec
.
emplace_back
(
static_cast
<
int64_t
>
(
1
));
vec
.
emplace_back
(
static_cast
<
int64_t
>
(
1
));
}
}
for
(
int64_t
i
=
axis
+
1
;
i
<
x_rank
;
i
++
)
vec
.
emplace_back
(
x_dims
[
i
]);
for
(
int64_t
i
=
int_
axis
+
1
;
i
<
x_rank
;
i
++
)
vec
.
emplace_back
(
x_dims
[
i
]);
}
}
out
->
set_dims
(
phi
::
make_ddim
(
vec
));
out
->
set_dims
(
phi
::
make_ddim
(
vec
));
if
(
dtype
==
2
)
{
if
(
dtype
==
2
)
{
...
...
paddle/phi/infermeta/unary.h
浏览文件 @
6fc15986
...
@@ -40,7 +40,7 @@ void AffineGridInferMeta(const MetaTensor& input,
...
@@ -40,7 +40,7 @@ void AffineGridInferMeta(const MetaTensor& input,
MetaTensor
*
output
);
MetaTensor
*
output
);
void
ArgMinMaxInferMeta
(
const
MetaTensor
&
x
,
void
ArgMinMaxInferMeta
(
const
MetaTensor
&
x
,
int64_t
axis
,
const
Scalar
&
axis
,
bool
keepdims
,
bool
keepdims
,
bool
flatten
,
bool
flatten
,
int
dtype
,
int
dtype
,
...
...
paddle/phi/kernels/arg_min_max_kernel.h
浏览文件 @
6fc15986
...
@@ -14,6 +14,7 @@ limitations under the License. */
...
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/dense_tensor.h"
namespace
phi
{
namespace
phi
{
...
@@ -21,7 +22,7 @@ namespace phi {
...
@@ -21,7 +22,7 @@ namespace phi {
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
ArgMinKernel
(
const
Context
&
dev_ctx
,
void
ArgMinKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
x
,
int64_t
axis
,
const
Scalar
&
axis
,
bool
keepdims
,
bool
keepdims
,
bool
flatten
,
bool
flatten
,
int
dtype
,
int
dtype
,
...
@@ -30,7 +31,7 @@ void ArgMinKernel(const Context& dev_ctx,
...
@@ -30,7 +31,7 @@ void ArgMinKernel(const Context& dev_ctx,
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
ArgMaxKernel
(
const
Context
&
dev_ctx
,
void
ArgMaxKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
x
,
int64_t
axis
,
const
Scalar
&
axis
,
bool
keepdims
,
bool
keepdims
,
bool
flatten
,
bool
flatten
,
int
dtype
,
int
dtype
,
...
...
paddle/phi/kernels/cpu/arg_min_max_kernel.cc
浏览文件 @
6fc15986
...
@@ -135,7 +135,7 @@ struct VisitDataArgMinMaxFunctor {
...
@@ -135,7 +135,7 @@ struct VisitDataArgMinMaxFunctor {
template
<
typename
Context
,
typename
T
,
ArgMinMaxType
EnumArgMinMaxValue
>
template
<
typename
Context
,
typename
T
,
ArgMinMaxType
EnumArgMinMaxValue
>
void
ArgMinMaxKernel
(
const
Context
&
dev_ctx
,
void
ArgMinMaxKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
x
,
int64_t
axis
,
const
Scalar
&
axis
,
bool
keepdims
,
bool
keepdims
,
bool
flatten
,
bool
flatten
,
int
dtype
,
int
dtype
,
...
@@ -145,19 +145,19 @@ void ArgMinMaxKernel(const Context& dev_ctx,
...
@@ -145,19 +145,19 @@ void ArgMinMaxKernel(const Context& dev_ctx,
static_cast
<
paddle
::
framework
::
proto
::
VarType
::
Type
>
(
static_cast
<
paddle
::
framework
::
proto
::
VarType
::
Type
>
(
paddle
::
framework
::
proto
::
VarType
::
INT64
),
paddle
::
framework
::
proto
::
VarType
::
INT64
),
VisitDataArgMinMaxFunctor
<
Context
,
T
,
EnumArgMinMaxValue
>
(
VisitDataArgMinMaxFunctor
<
Context
,
T
,
EnumArgMinMaxValue
>
(
dev_ctx
,
x
,
axis
,
keepdims
,
flatten
,
out
));
dev_ctx
,
x
,
axis
.
to
<
int64_t
>
()
,
keepdims
,
flatten
,
out
));
return
;
return
;
}
}
paddle
::
framework
::
VisitDataTypeTiny
(
paddle
::
framework
::
VisitDataTypeTiny
(
static_cast
<
paddle
::
framework
::
proto
::
VarType
::
Type
>
(
dtype
),
static_cast
<
paddle
::
framework
::
proto
::
VarType
::
Type
>
(
dtype
),
VisitDataArgMinMaxFunctor
<
Context
,
T
,
EnumArgMinMaxValue
>
(
VisitDataArgMinMaxFunctor
<
Context
,
T
,
EnumArgMinMaxValue
>
(
dev_ctx
,
x
,
axis
,
keepdims
,
flatten
,
out
));
dev_ctx
,
x
,
axis
.
to
<
int64_t
>
()
,
keepdims
,
flatten
,
out
));
}
}
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
ArgMinKernel
(
const
Context
&
dev_ctx
,
void
ArgMinKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
x
,
int64_t
axis
,
const
Scalar
&
axis
,
bool
keepdims
,
bool
keepdims
,
bool
flatten
,
bool
flatten
,
int
dtype
,
int
dtype
,
...
@@ -169,7 +169,7 @@ void ArgMinKernel(const Context& dev_ctx,
...
@@ -169,7 +169,7 @@ void ArgMinKernel(const Context& dev_ctx,
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
ArgMaxKernel
(
const
Context
&
dev_ctx
,
void
ArgMaxKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
x
,
int64_t
axis
,
const
Scalar
&
axis
,
bool
keepdims
,
bool
keepdims
,
bool
flatten
,
bool
flatten
,
int
dtype
,
int
dtype
,
...
...
paddle/phi/kernels/gpu/arg_min_max_kernel.cu
浏览文件 @
6fc15986
...
@@ -203,7 +203,7 @@ struct VisitDataCudaArgMinMaxFunctor {
...
@@ -203,7 +203,7 @@ struct VisitDataCudaArgMinMaxFunctor {
template
<
typename
Context
,
typename
T
,
class
Reducer
>
template
<
typename
Context
,
typename
T
,
class
Reducer
>
void
ArgMinMaxOpCUDAKernel
(
const
Context
&
dev_ctx
,
void
ArgMinMaxOpCUDAKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
x
,
int64_t
axis
,
const
Scalar
&
axis
,
bool
keepdims
,
bool
keepdims
,
bool
flatten
,
bool
flatten
,
int
dtype
,
int
dtype
,
...
@@ -213,19 +213,19 @@ void ArgMinMaxOpCUDAKernel(const Context& dev_ctx,
...
@@ -213,19 +213,19 @@ void ArgMinMaxOpCUDAKernel(const Context& dev_ctx,
static_cast
<
paddle
::
framework
::
proto
::
VarType
::
Type
>
(
static_cast
<
paddle
::
framework
::
proto
::
VarType
::
Type
>
(
paddle
::
framework
::
proto
::
VarType
::
INT64
),
paddle
::
framework
::
proto
::
VarType
::
INT64
),
VisitDataCudaArgMinMaxFunctor
<
Context
,
T
,
Reducer
>
(
VisitDataCudaArgMinMaxFunctor
<
Context
,
T
,
Reducer
>
(
dev_ctx
,
x
,
axis
,
keepdims
,
flatten
,
out
));
dev_ctx
,
x
,
axis
.
to
<
int64_t
>
()
,
keepdims
,
flatten
,
out
));
return
;
return
;
}
}
paddle
::
framework
::
VisitDataTypeTiny
(
paddle
::
framework
::
VisitDataTypeTiny
(
static_cast
<
paddle
::
framework
::
proto
::
VarType
::
Type
>
(
dtype
),
static_cast
<
paddle
::
framework
::
proto
::
VarType
::
Type
>
(
dtype
),
VisitDataCudaArgMinMaxFunctor
<
Context
,
T
,
Reducer
>
(
VisitDataCudaArgMinMaxFunctor
<
Context
,
T
,
Reducer
>
(
dev_ctx
,
x
,
axis
,
keepdims
,
flatten
,
out
));
dev_ctx
,
x
,
axis
.
to
<
int64_t
>
()
,
keepdims
,
flatten
,
out
));
}
}
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
ArgMinKernel
(
const
Context
&
dev_ctx
,
void
ArgMinKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
x
,
int64_t
axis
,
const
Scalar
&
axis
,
bool
keepdims
,
bool
keepdims
,
bool
flatten
,
bool
flatten
,
int
dtype
,
int
dtype
,
...
@@ -237,7 +237,7 @@ void ArgMinKernel(const Context& dev_ctx,
...
@@ -237,7 +237,7 @@ void ArgMinKernel(const Context& dev_ctx,
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
ArgMaxKernel
(
const
Context
&
dev_ctx
,
void
ArgMaxKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
x
,
int64_t
axis
,
const
Scalar
&
axis
,
bool
keepdims
,
bool
keepdims
,
bool
flatten
,
bool
flatten
,
int
dtype
,
int
dtype
,
...
...
python/paddle/fluid/tests/unittests/test_arg_min_max_op.py
浏览文件 @
6fc15986
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
from
__future__
import
print_function
from
__future__
import
print_function
import
os
import
unittest
import
unittest
import
numpy
as
np
import
numpy
as
np
from
op_test
import
OpTest
from
op_test
import
OpTest
...
@@ -21,6 +22,7 @@ import paddle
...
@@ -21,6 +22,7 @@ import paddle
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
import
paddle.fluid.core
as
core
import
paddle.fluid.core
as
core
from
paddle.fluid
import
Program
,
program_guard
from
paddle.fluid
import
Program
,
program_guard
from
test_attribute_var
import
UnittestBase
class
BaseTestCase
(
OpTest
):
class
BaseTestCase
(
OpTest
):
...
@@ -235,6 +237,92 @@ class BaseTestComplex2_2(OpTest):
...
@@ -235,6 +237,92 @@ class BaseTestComplex2_2(OpTest):
}
}
class
TestArgMaxTensorAxis
(
UnittestBase
):
def
init_info
(
self
):
self
.
shapes
=
[[
2
,
3
,
4
]]
self
.
x
=
[
np
.
random
.
randn
(
*
shape
)
for
shape
in
self
.
shapes
]
self
.
save_path
=
os
.
path
.
join
(
self
.
temp_dir
.
name
,
self
.
path_prefix
())
def
test_static
(
self
):
main_prog
=
Program
()
starup_prog
=
Program
()
with
program_guard
(
main_prog
,
starup_prog
):
fc
=
paddle
.
nn
.
Linear
(
4
,
10
)
x
=
paddle
.
randn
([
2
,
3
,
4
])
x
.
stop_gradient
=
False
feat
=
fc
(
x
)
out
=
self
.
call_func
(
feat
)
sgd
=
paddle
.
optimizer
.
SGD
()
sgd
.
minimize
(
paddle
.
mean
(
paddle
.
cast
(
out
,
'float32'
)))
self
.
assertTrue
(
self
.
var_prefix
()
in
str
(
main_prog
))
exe
=
paddle
.
static
.
Executor
()
exe
.
run
(
starup_prog
)
res
=
exe
.
run
(
fetch_list
=
[
feat
,
out
])
paddle
.
static
.
save_inference_model
(
self
.
save_path
,
[
x
],
[
feat
,
out
],
exe
)
gt
=
np
.
argmax
(
res
[
0
],
0
)
np
.
testing
.
assert_allclose
(
res
[
1
],
gt
)
# Test for Inference Predictor
infer_outs
=
self
.
infer_prog
()
gt
=
np
.
argmax
(
infer_outs
[
0
],
0
)
np
.
testing
.
assert_allclose
(
infer_outs
[
1
],
gt
)
def
path_prefix
(
self
):
return
'argmax_tensor_axis'
def
var_prefix
(
self
):
return
"Var["
def
call_func
(
self
,
x
):
axis
=
paddle
.
assign
(
0
)
out
=
paddle
.
argmax
(
x
,
axis
)
return
out
class
TestArgMinTensorAxis
(
TestArgMaxTensorAxis
):
def
test_static
(
self
):
main_prog
=
Program
()
starup_prog
=
Program
()
with
program_guard
(
main_prog
,
starup_prog
):
fc
=
paddle
.
nn
.
Linear
(
4
,
10
)
x
=
paddle
.
randn
([
2
,
3
,
4
])
x
.
stop_gradient
=
False
feat
=
fc
(
x
)
feat
=
paddle
.
cast
(
feat
,
'int32'
)
out
=
self
.
call_func
(
feat
)
sgd
=
paddle
.
optimizer
.
SGD
()
sgd
.
minimize
(
paddle
.
mean
(
paddle
.
cast
(
out
,
'float32'
)))
self
.
assertTrue
(
self
.
var_prefix
()
in
str
(
main_prog
))
exe
=
paddle
.
static
.
Executor
()
exe
.
run
(
starup_prog
)
res
=
exe
.
run
(
fetch_list
=
[
feat
,
out
])
paddle
.
static
.
save_inference_model
(
self
.
save_path
,
[
x
],
[
feat
,
out
],
exe
)
gt
=
np
.
argmin
(
res
[
0
],
1
)
np
.
testing
.
assert_allclose
(
np
.
squeeze
(
res
[
1
]),
gt
)
# Test for Inference Predictor
infer_outs
=
self
.
infer_prog
()
gt
=
np
.
argmin
(
infer_outs
[
0
],
1
)
np
.
testing
.
assert_allclose
(
np
.
squeeze
(
infer_outs
[
1
]),
gt
)
def
path_prefix
(
self
):
return
'argmin_tensor_axis'
def
call_func
(
self
,
x
):
axis
=
paddle
.
assign
(
1
)
out
=
paddle
.
argmin
(
x
,
axis
,
keepdim
=
True
)
return
out
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
paddle
.
enable_static
()
paddle
.
enable_static
()
unittest
.
main
()
unittest
.
main
()
python/paddle/tensor/search.py
浏览文件 @
6fc15986
...
@@ -162,9 +162,9 @@ def argmax(x, axis=None, keepdim=False, dtype="int64", name=None):
...
@@ -162,9 +162,9 @@ def argmax(x, axis=None, keepdim=False, dtype="int64", name=None):
print(out4)
print(out4)
# [[2, 2, 0, 1]]
# [[2, 2, 0, 1]]
"""
"""
if
axis
is
not
None
and
not
isinstance
(
axis
,
int
):
if
axis
is
not
None
and
not
isinstance
(
axis
,
(
int
,
Variable
)
):
raise
TypeError
(
raise
TypeError
(
"The type of 'axis' must be int or None in argmax, but received %s."
"The type of 'axis' must be int or
Tensor or
None in argmax, but received %s."
%
(
type
(
axis
)))
%
(
type
(
axis
)))
if
dtype
is
None
:
if
dtype
is
None
:
...
@@ -244,9 +244,9 @@ def argmin(x, axis=None, keepdim=False, dtype="int64", name=None):
...
@@ -244,9 +244,9 @@ def argmin(x, axis=None, keepdim=False, dtype="int64", name=None):
print(out4)
print(out4)
# [[1, 1, 1, 2]]
# [[1, 1, 1, 2]]
"""
"""
if
axis
is
not
None
and
not
isinstance
(
axis
,
int
):
if
axis
is
not
None
and
not
isinstance
(
axis
,
(
int
,
Variable
)
):
raise
TypeError
(
raise
TypeError
(
"The type of 'axis' must be int or None in argmin, but received %s."
"The type of 'axis' must be int or
Tensor or
None in argmin, but received %s."
%
(
type
(
axis
)))
%
(
type
(
axis
)))
if
dtype
is
None
:
if
dtype
is
None
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录