Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
ab8c33b1
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ab8c33b1
编写于
4月 01, 2022
作者:
H
hong
提交者:
GitHub
4月 01, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add final state python api (#41252)
上级
99029dc9
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
130 addition
and
30 deletion
+130
-30
python/paddle/fluid/layers/control_flow.py
python/paddle/fluid/layers/control_flow.py
+4
-2
python/paddle/fluid/layers/loss.py
python/paddle/fluid/layers/loss.py
+1
-1
python/paddle/fluid/layers/metric_op.py
python/paddle/fluid/layers/metric_op.py
+1
-1
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+20
-2
python/paddle/fluid/layers/ops.py
python/paddle/fluid/layers/ops.py
+5
-1
python/paddle/incubate/tensor/math.py
python/paddle/incubate/tensor/math.py
+16
-1
python/paddle/metric/metrics.py
python/paddle/metric/metrics.py
+2
-1
python/paddle/nn/functional/activation.py
python/paddle/nn/functional/activation.py
+9
-2
python/paddle/tensor/linalg.py
python/paddle/tensor/linalg.py
+21
-5
python/paddle/tensor/manipulation.py
python/paddle/tensor/manipulation.py
+12
-3
python/paddle/tensor/math.py
python/paddle/tensor/math.py
+19
-5
python/paddle/tensor/random.py
python/paddle/tensor/random.py
+9
-2
python/paddle/tensor/search.py
python/paddle/tensor/search.py
+11
-3
python/paddle/utils/code_gen/wrapped_infermeta_gen.py
python/paddle/utils/code_gen/wrapped_infermeta_gen.py
+0
-1
未找到文件。
python/paddle/fluid/layers/control_flow.py
浏览文件 @
ab8c33b1
...
...
@@ -18,7 +18,7 @@ from ..wrapped_decorator import signature_safe_contextmanager
from
.layer_function_generator
import
autodoc
,
templatedoc
from
.tensor
import
assign
,
cast
,
fill_constant
from
..
import
core
from
..framework
import
Program
,
Variable
,
Operator
,
_non_static_mode
,
static_only
from
..framework
import
Program
,
Variable
,
Operator
,
_non_static_mode
,
static_only
,
_in_legacy_dygraph
,
in_dygraph_mode
from
..layer_helper
import
LayerHelper
,
unique_name
from
.nn
import
logical_and
,
logical_not
,
logical_or
from
.utils
import
assert_same_structure
,
map_structure
,
hold_mutable_vars
,
copy_mutable_vars
...
...
@@ -3867,7 +3867,9 @@ def is_empty(x, name=None):
# - data: [0])
"""
if
_non_static_mode
():
if
in_dygraph_mode
():
return
_C_ops
.
final_state_is_empty
(
x
)
if
_in_legacy_dygraph
():
return
_C_ops
.
is_empty
(
x
)
check_variable_and_dtype
(
x
,
'x'
,
[
'float32'
,
'float64'
,
'int32'
,
'int64'
],
...
...
python/paddle/fluid/layers/loss.py
浏览文件 @
ab8c33b1
...
...
@@ -21,7 +21,7 @@ from paddle.utils import deprecated
from
.
import
nn
from
.layer_function_generator
import
templatedoc
from
..layer_helper
import
LayerHelper
from
..framework
import
Variable
,
_non_static_mode
,
static_only
from
..framework
import
Variable
,
_non_static_mode
,
static_only
,
_in_legacy_dygraph
from
..
import
core
from
..data_feeder
import
check_variable_and_dtype
,
check_type
from
..param_attr
import
ParamAttr
...
...
python/paddle/fluid/layers/metric_op.py
浏览文件 @
ab8c33b1
...
...
@@ -20,7 +20,7 @@ from __future__ import print_function
import
warnings
from
..layer_helper
import
LayerHelper
from
..initializer
import
Normal
,
Constant
from
..framework
import
Variable
,
_non_static_mode
,
_varbase_creator
from
..framework
import
Variable
,
_non_static_mode
,
_varbase_creator
,
_in_legacy_dygraph
,
in_dygraph_mode
from
..
import
core
from
..param_attr
import
ParamAttr
from
.
import
nn
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
ab8c33b1
...
...
@@ -11674,8 +11674,12 @@ def size(input):
rank = layers.size(input) # 300
"""
if _non_static_mode():
if in_dygraph_mode():
return _C_ops.final_state_size(input)
if _in_legacy_dygraph():
return _C_ops.size(input)
check_variable_and_dtype(
input, 'input',
['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], "size")
...
...
@@ -13432,6 +13436,9 @@ def log_loss(input, label, epsilon=1e-4, name=None):
prob = paddle.randn((10,1))
cost = F.log_loss(input=prob, label=label)
"""
if in_dygraph_mode():
return _C_ops.final_state_log_loss(input, label, epsilon)
helper = LayerHelper('log_loss', **locals())
check_variable_and_dtype(input, 'input', ['float32'], 'log_loss')
check_variable_and_dtype(label, 'label', ['float32'], 'log_loss')
...
...
@@ -14447,7 +14454,10 @@ def where(condition):
out = layers.where(condition) # [[]]
"""
if _non_static_mode():
if in_dygraph_mode():
return _C_ops.final_state_where_index(condition)
if _in_legacy_dygraph():
return _C_ops.where_index(condition)
helper = LayerHelper("where_index", **locals())
...
...
@@ -14940,6 +14950,10 @@ def unfold(x, kernel_sizes, strides=1, paddings=0, dilations=1, name=None):
"Unexpected type of paddings, it should be either an integer or a list"
"of 2 or 4 integers")
if in_dygraph_mode():
return _C_ops.final_state_unfold(x, kernel_sizes, strides, paddings,
dilations)
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type="unfold",
...
...
@@ -15167,6 +15181,10 @@ def shard_index(input, index_num, nshards, shard_id, ignore_value=-1):
print(shard_label)
# [[-1], [1]]
"""
if in_dygraph_mode():
return _C_ops.final_state_shard_index(input, index_num, nshards,
shard_id, ignore_value)
check_variable_and_dtype(input, 'input', ['int64', 'int32'], 'shard_index')
op_type = 'shard_index'
helper = LayerHelper(op_type, **locals())
...
...
python/paddle/fluid/layers/ops.py
浏览文件 @
ab8c33b1
...
...
@@ -16,9 +16,10 @@ from __future__ import print_function
import
os
from
.layer_function_generator
import
generate_layer_fn
,
generate_activation_fn
,
generate_inplace_fn
,
add_sample_code
from
..
import
core
from
..framework
import
convert_np_dtype_to_dtype_
,
Variable
from
..framework
import
convert_np_dtype_to_dtype_
,
Variable
,
in_dygraph_mode
from
..data_feeder
import
convert_dtype
,
check_variable_and_dtype
,
check_type
,
check_dtype
from
paddle.utils
import
deprecated
from
paddle
import
_C_ops
__deprecated_func_name__
=
{
'tanh_shrink'
:
'tanhshrink'
,
...
...
@@ -794,6 +795,9 @@ _erf_ = generate_layer_fn('erf')
def
erf
(
x
,
name
=
None
):
if
in_dygraph_mode
():
return
_C_ops
.
final_state_erf
(
x
)
locals_var
=
locals
().
copy
()
kwargs
=
dict
()
for
name
,
val
in
locals_var
.
items
():
...
...
python/paddle/incubate/tensor/math.py
浏览文件 @
ab8c33b1
...
...
@@ -15,6 +15,7 @@
from
paddle.fluid.layer_helper
import
LayerHelper
,
_non_static_mode
from
paddle.fluid.data_feeder
import
check_variable_and_dtype
from
paddle
import
_C_ops
from
paddle.fluid.framework
import
_in_legacy_dygraph
,
in_dygraph_mode
__all__
=
[]
...
...
@@ -50,7 +51,9 @@ def segment_sum(data, segment_ids, name=None):
#Outputs: [[4., 4., 4.], [4., 5., 6.]]
"""
if
_non_static_mode
():
if
in_dygraph_mode
():
return
_C_ops
.
final_state_segment_pool
(
data
,
segment_idsm
,
"SUM"
)[
0
]
if
_in_legacy_dygraph
():
out
,
tmp
=
_C_ops
.
segment_pool
(
data
,
segment_ids
,
'pooltype'
,
"SUM"
)
return
out
...
...
@@ -104,6 +107,9 @@ def segment_mean(data, segment_ids, name=None):
#Outputs: [[2., 2., 2.], [4., 5., 6.]]
"""
if
in_dygraph_mode
():
return
_C_ops
.
final_state_segment_pool
(
data
,
segment_idsm
,
"MEAN"
)[
0
]
if
_non_static_mode
():
out
,
tmp
=
_C_ops
.
segment_pool
(
data
,
segment_ids
,
'pooltype'
,
"MEAN"
)
return
out
...
...
@@ -157,6 +163,10 @@ def segment_min(data, segment_ids, name=None):
#Outputs: [[1., 2., 1.], [4., 5., 6.]]
"""
if
in_dygraph_mode
():
return
_C_ops
.
final_state_segment_pool
(
data
,
segment_idsm
,
"MIN"
)[
0
]
if
_non_static_mode
():
out
,
tmp
=
_C_ops
.
segment_pool
(
data
,
segment_ids
,
'pooltype'
,
"MIN"
)
return
out
...
...
@@ -210,6 +220,11 @@ def segment_max(data, segment_ids, name=None):
#Outputs: [[3., 2., 3.], [4., 5., 6.]]
"""
if
in_dygraph_mode
():
out
,
tmp
=
_C_ops
.
final_state_segment_pool
(
data
,
segment_ids
,
"MAX"
)[
0
]
return
out
if
_non_static_mode
():
out
,
tmp
=
_C_ops
.
segment_pool
(
data
,
segment_ids
,
'pooltype'
,
"MAX"
)
return
out
...
...
python/paddle/metric/metrics.py
浏览文件 @
ab8c33b1
...
...
@@ -22,7 +22,7 @@ import numpy as np
from
..fluid.data_feeder
import
check_variable_and_dtype
from
..fluid.layer_helper
import
LayerHelper
from
..fluid.framework
import
core
,
_varbase_creator
,
_non_static_mode
from
..fluid.framework
import
core
,
_varbase_creator
,
_non_static_mode
,
_in_legacy_dygraph
import
paddle
from
paddle
import
_C_ops
...
...
@@ -800,6 +800,7 @@ def accuracy(input, label, k=1, correct=None, total=None, name=None):
topk_out
,
topk_indices
=
paddle
.
topk
(
input
,
k
=
k
)
_acc
,
_
,
_
=
_C_ops
.
accuracy
(
topk_out
,
topk_indices
,
label
,
correct
,
total
)
return
_acc
helper
=
LayerHelper
(
"accuracy"
,
**
locals
())
...
...
python/paddle/nn/functional/activation.py
浏览文件 @
ab8c33b1
...
...
@@ -784,7 +784,9 @@ def selu(x,
raise
ValueError
(
"The alpha must be no less than zero. Received: {}."
.
format
(
alpha
))
if
in_dynamic_mode
():
if
in_dygraph_mode
():
return
_C_ops
.
final_state_selu
(
x
,
scale
,
alpha
)
if
_in_legacy_dygraph
():
return
_C_ops
.
selu
(
x
,
'scale'
,
scale
,
'alpha'
,
alpha
)
check_variable_and_dtype
(
x
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'selu'
)
...
...
@@ -955,7 +957,12 @@ def softmax(x, axis=-1, dtype=None, name=None):
dtype
=
convert_np_dtype_to_dtype_
(
dtype
)
use_cudnn
=
True
if
in_dynamic_mode
():
if
in_dygraph_mode
():
outs_cast
=
x
if
dtype
is
None
\
else
_C_ops
.
cast
(
x
,
'in_dtype'
,
x
.
dtype
,
'out_dtype'
,
dtype
)
return
_C_ops
.
final_state_softmax
(
outs_cast
,
axis
)
if
_in_legacy_dygraph
():
outs_cast
=
x
if
dtype
is
None
\
else
_C_ops
.
cast
(
x
,
'in_dtype'
,
x
.
dtype
,
'out_dtype'
,
dtype
)
return
_C_ops
.
softmax
(
outs_cast
,
'axis'
,
axis
,
'use_cudnn'
,
use_cudnn
)
...
...
python/paddle/tensor/linalg.py
浏览文件 @
ab8c33b1
...
...
@@ -1212,8 +1212,12 @@ def cholesky(x, upper=False, name=None):
# [1.25450498 0.05600871 0.06400121]]
"""
if
paddle
.
in_dynamic_mode
():
if
in_dygraph_mode
():
return
_C_ops
.
final_state_cholesky
(
x
,
upper
)
if
_in_legacy_dygraph
():
return
_C_ops
.
cholesky
(
x
,
"upper"
,
upper
)
check_variable_and_dtype
(
x
,
'dtype'
,
[
'float32'
,
'float64'
],
'cholesky'
)
check_type
(
upper
,
'upper'
,
bool
,
'cholesky'
)
helper
=
LayerHelper
(
'cholesky'
,
**
locals
())
...
...
@@ -1447,7 +1451,10 @@ def bincount(x, weights=None, minlength=0, name=None):
if
x
.
dtype
not
in
[
paddle
.
int32
,
paddle
.
int64
]:
raise
TypeError
(
"Elements in Input(x) should all be integers"
)
if
paddle
.
in_dynamic_mode
():
# if in_dygraph_mode():
# return _C_ops.final_state_bincount(x, weights, minlength)
if
_in_legacy_dygraph
():
return
_C_ops
.
bincount
(
x
,
weights
,
"minlength"
,
minlength
)
helper
=
LayerHelper
(
'bincount'
,
**
locals
())
...
...
@@ -1761,7 +1768,10 @@ def matrix_power(x, n, name=None):
# [-7.66666667 , 8. , -1.83333333 ],
# [ 1.80555556 , -1.91666667 , 0.44444444 ]]
"""
if
paddle
.
in_dynamic_mode
():
if
in_dygraph_mode
():
return
_C_ops
.
final_state_matrix_power
(
x
,
n
)
if
_in_legacy_dygraph
():
return
_C_ops
.
matrix_power
(
x
,
"n"
,
n
)
check_variable_and_dtype
(
x
,
'dtype'
,
[
'float32'
,
'float64'
],
'matrix_power'
)
...
...
@@ -2279,7 +2289,10 @@ def eigh(x, UPLO='L', name=None):
#[ 0.3826834323650898j , -0.9238795325112867j ]]
"""
if
paddle
.
in_dynamic_mode
():
if
in_dygraph_mode
():
return
_C_ops
.
final_state_eigh
(
x
,
UPLO
)
if
_in_legacy_dygraph
():
return
_C_ops
.
eigh
(
x
,
'UPLO'
,
UPLO
)
def
__check_input
(
x
,
UPLO
):
...
...
@@ -2749,7 +2762,10 @@ def cholesky_solve(x, y, upper=False, name=None):
print(out)
# [-2.5, -7, 9.5]
"""
if
paddle
.
in_dynamic_mode
():
if
in_dygraph_mode
():
return
_C_ops
.
final_state_cholesky_solve
(
x
,
y
,
upper
)
if
_in_legacy_dygraph
():
return
_C_ops
.
cholesky_solve
(
x
,
y
,
'upper'
,
upper
)
helper
=
LayerHelper
(
"cholesky_solve"
,
**
locals
())
...
...
python/paddle/tensor/manipulation.py
浏览文件 @
ab8c33b1
...
...
@@ -1762,8 +1762,12 @@ def tile(x, repeat_times, name=None):
np_out = out.numpy()
# [[1, 2, 3, 1, 2, 3]]
"""
if
paddle
.
in_dynamic_mode
():
if
in_dygraph_mode
():
return
_C_ops
.
final_state_tile
(
x
,
repeat_times
)
if
_in_legacy_dygraph
():
return
_C_ops
.
tile
(
x
,
'repeat_times'
,
repeat_times
)
check_type
(
repeat_times
,
'repeat_times'
,
(
list
,
tuple
,
Variable
),
'tile'
)
if
isinstance
(
repeat_times
,
Variable
):
assert
len
(
repeat_times
.
shape
)
==
1
,
(
...
...
@@ -2833,12 +2837,14 @@ def take_along_axis(arr, indices, axis):
if
not
broadcast_shape
:
# if indices matrix have larger size than arr, arr should broadcast into indices shape.
broadcast_shape
=
indices
.
shape
if
paddle
.
in_dynam
ic_mode
():
if
_non_stat
ic_mode
():
indices
=
paddle
.
broadcast_to
(
indices
,
broadcast_shape
)
broadcast_shape_list
=
list
(
broadcast_shape
)
broadcast_shape_list
[
axis
]
=
list
(
arr
.
shape
)[
axis
]
broadcast_shape
=
tuple
(
broadcast_shape_list
)
arr
=
paddle
.
broadcast_to
(
arr
,
broadcast_shape
)
if
not
_in_legacy_dygraph
():
return
_C_ops
.
final_state_take_along_axis
(
arr
,
indices
,
axis
)
return
_C_ops
.
take_along_axis
(
arr
,
indices
,
'Axis'
,
axis
)
check_variable_and_dtype
(
arr
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'uint8'
],
...
...
@@ -2898,12 +2904,15 @@ def put_along_axis(arr, indices, values, axis, reduce='assign'):
"`indices` and `arr` must have the same number of dimensions!"
)
axis
=
non_negative_axis
(
arr
,
axis
)
broadcast_shape
=
infer_broadcast_shape
(
arr
,
indices
,
axis
)
if
paddle
.
in_dynam
ic_mode
():
if
_non_stat
ic_mode
():
values
=
paddle
.
to_tensor
(
values
)
if
not
isinstance
(
values
,
paddle
.
Tensor
)
else
values
if
broadcast_shape
:
indices
=
paddle
.
broadcast_to
(
indices
,
broadcast_shape
)
values
=
paddle
.
broadcast_to
(
values
,
indices
.
shape
)
if
in_dygraph_mode
():
return
_C_ops
.
final_state_put_along_axis
(
arr
,
indices
,
values
,
axis
,
reduce
)
return
_C_ops
.
put_along_axis
(
arr
,
indices
,
values
,
"Axis"
,
axis
,
"Reduce"
,
reduce
)
...
...
python/paddle/tensor/math.py
浏览文件 @
ab8c33b1
...
...
@@ -2374,7 +2374,10 @@ def trace(x, offset=0, axis1=0, axis2=1, name=None):
"But received axis1 = %d, axis2 = %d
\n
"
%
(
axis1
,
axis2
)
__check_input
(
input
,
offset
,
axis1
,
axis2
)
if
paddle
.
in_dynamic_mode
():
if
in_dygraph_mode
():
return
_C_ops
.
final_state_trace
(
x
,
offset
,
axis1
,
axis2
)
if
_in_legacy_dygraph
():
return
_C_ops
.
trace
(
x
,
'offset'
,
offset
,
'axis1'
,
axis1
,
'axis2'
,
axis2
)
inputs
=
{
'Input'
:
[
x
]}
...
...
@@ -2597,7 +2600,9 @@ def cumsum(x, axis=None, dtype=None, name=None):
if
dtype
is
not
None
and
x
.
dtype
!=
convert_np_dtype_to_dtype_
(
dtype
):
x
=
cast
(
x
,
dtype
)
if
paddle
.
in_dynamic_mode
():
if
in_dygraph_mode
():
return
_C_ops
.
final_state_cumsum
(
x
,
axis
,
flatten
,
False
,
False
)
if
_in_legacy_dygraph
():
if
axis
is
None
:
return
_C_ops
.
cumsum
(
x
,
'flatten'
,
flatten
)
else
:
...
...
@@ -2854,7 +2859,10 @@ def sign(x, name=None):
out = paddle.sign(x=x)
print(out) # [1.0, 0.0, -1.0, 1.0]
"""
if
paddle
.
in_dynamic_mode
():
if
in_dygraph_mode
():
return
_C_ops
.
final_state_sign
(
x
)
if
_in_legacy_dygraph
():
return
_C_ops
.
sign
(
x
)
check_variable_and_dtype
(
x
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'sign'
)
...
...
@@ -2891,7 +2899,10 @@ def tanh(x, name=None):
print(out)
# [-0.37994896 -0.19737532 0.09966799 0.29131261]
"""
if
paddle
.
in_dynamic_mode
():
if
in_dygraph_mode
():
return
_C_ops
.
final_state_tanh
(
x
)
if
_in_legacy_dygraph
():
return
_C_ops
.
tanh
(
x
)
check_variable_and_dtype
(
x
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'tanh'
)
...
...
@@ -2933,7 +2944,10 @@ def increment(x, value=1.0, name=None):
# [1.]
"""
if
paddle
.
in_dynamic_mode
():
if
in_dygraph_mode
():
return
_C_ops
.
final_state_increment
(
x
,
value
)
if
_in_legacy_dygraph
():
return
_C_ops
.
increment
(
x
,
'step'
,
value
)
check_variable_and_dtype
(
x
,
'x'
,
[
'float32'
,
'float64'
,
'int32'
,
'int64'
],
...
...
python/paddle/tensor/random.py
浏览文件 @
ab8c33b1
...
...
@@ -22,6 +22,7 @@ from ..fluid.layers import utils
import
paddle
from
paddle
import
_C_ops
from
paddle.static
import
Variable
from
paddle.fluid.framework
import
in_dygraph_mode
,
_in_legacy_dygraph
__all__
=
[]
...
...
@@ -66,7 +67,10 @@ def bernoulli(x, name=None):
"""
if
paddle
.
in_dynamic_mode
():
if
in_dygraph_mode
():
return
_C_ops
.
final_state_bernoulli
(
x
)
if
_in_legacy_dygraph
():
return
_C_ops
.
bernoulli
(
x
)
check_variable_and_dtype
(
x
,
"x"
,
[
"float32"
,
"float64"
],
"bernoulli"
)
...
...
@@ -174,7 +178,10 @@ def multinomial(x, num_samples=1, replacement=False, name=None):
assert
core
.
is_compiled_with_rocm
()
==
False
,
(
"multinomial op is not supported on ROCM yet."
)
if
paddle
.
in_dynamic_mode
():
if
in_dygraph_mode
():
return
_C_ops
.
final_state_multinomial
(
x
,
num_samples
,
replacement
)
if
_in_legacy_dygraph
():
return
_C_ops
.
multinomial
(
x
,
'num_samples'
,
num_samples
,
'replacement'
,
replacement
)
...
...
python/paddle/tensor/search.py
浏览文件 @
ab8c33b1
...
...
@@ -91,7 +91,11 @@ def argsort(x, axis=-1, descending=False, name=None):
# [1 1 0 2]
# [0 2 1 1]]]
"""
if
paddle
.
in_dynamic_mode
():
if
in_dygraph_mode
():
_
,
ids
,
=
_C_ops
.
final_state_argsort
(
x
,
axis
,
descending
)
return
ids
if
_in_legacy_dygraph
():
_
,
ids
=
_C_ops
.
argsort
(
x
,
'axis'
,
axis
,
'descending'
,
descending
)
return
ids
check_variable_and_dtype
(
...
...
@@ -171,7 +175,9 @@ def argmax(x, axis=None, keepdim=False, dtype="int64", name=None):
flatten
=
True
axis
=
0
if
paddle
.
in_dynamic_mode
():
if
in_dygraph_mode
():
return
_C_ops
.
final_state_argmax
(
x
,
axis
,
keepdim
,
flatten
,
var_dtype
)
if
_in_legacy_dygraph
():
out
=
_C_ops
.
arg_max
(
x
,
'axis'
,
axis
,
'dtype'
,
var_dtype
,
'keepdims'
,
keepdim
,
'flatten'
,
flatten
)
return
out
...
...
@@ -251,7 +257,9 @@ def argmin(x, axis=None, keepdim=False, dtype="int64", name=None):
flatten
=
True
axis
=
0
if
paddle
.
in_dynamic_mode
():
if
in_dygraph_mode
():
return
_C_ops
.
final_state_argmin
(
x
,
axis
,
keepdim
,
flatten
,
var_dtype
)
if
_in_legacy_dygraph
():
out
=
_C_ops
.
arg_min
(
x
,
'axis'
,
axis
,
'dtype'
,
var_dtype
,
'keepdims'
,
keepdim
,
'flatten'
,
flatten
)
return
out
...
...
python/paddle/utils/code_gen/wrapped_infermeta_gen.py
浏览文件 @
ab8c33b1
...
...
@@ -51,7 +51,6 @@ PD_REGISTER_INFER_META_FN({api.kernel['func'][0]}, phi::{api.infer_meta['func']}
wrapped_infermeta_name
=
get_wrapped_infermeta_name
(
api
.
api
)
args
=
[]
print
(
"@@@"
,
api
.
api
)
for
input_name
in
api
.
inputs
[
'names'
]:
if
input_name
in
kernel_params
:
print
(
"type"
,
api
.
inputs
[
'input_info'
])
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录