未验证 提交 ab8c33b1 编写于 作者: H hong 提交者: GitHub

add final state python api (#41252)

上级 99029dc9
......@@ -18,7 +18,7 @@ from ..wrapped_decorator import signature_safe_contextmanager
from .layer_function_generator import autodoc, templatedoc
from .tensor import assign, cast, fill_constant
from .. import core
from ..framework import Program, Variable, Operator, _non_static_mode, static_only
from ..framework import Program, Variable, Operator, _non_static_mode, static_only, _in_legacy_dygraph, in_dygraph_mode
from ..layer_helper import LayerHelper, unique_name
from .nn import logical_and, logical_not, logical_or
from .utils import assert_same_structure, map_structure, hold_mutable_vars, copy_mutable_vars
......@@ -3867,7 +3867,9 @@ def is_empty(x, name=None):
# - data: [0])
"""
if _non_static_mode():
if in_dygraph_mode():
return _C_ops.final_state_is_empty(x)
if _in_legacy_dygraph():
return _C_ops.is_empty(x)
check_variable_and_dtype(x, 'x', ['float32', 'float64', 'int32', 'int64'],
......
......@@ -21,7 +21,7 @@ from paddle.utils import deprecated
from . import nn
from .layer_function_generator import templatedoc
from ..layer_helper import LayerHelper
from ..framework import Variable, _non_static_mode, static_only
from ..framework import Variable, _non_static_mode, static_only, _in_legacy_dygraph
from .. import core
from ..data_feeder import check_variable_and_dtype, check_type
from ..param_attr import ParamAttr
......
......@@ -20,7 +20,7 @@ from __future__ import print_function
import warnings
from ..layer_helper import LayerHelper
from ..initializer import Normal, Constant
from ..framework import Variable, _non_static_mode, _varbase_creator
from ..framework import Variable, _non_static_mode, _varbase_creator, _in_legacy_dygraph, in_dygraph_mode
from .. import core
from ..param_attr import ParamAttr
from . import nn
......
......@@ -11674,8 +11674,12 @@ def size(input):
rank = layers.size(input) # 300
"""
if _non_static_mode():
if in_dygraph_mode():
return _C_ops.final_state_size(input)
if _in_legacy_dygraph():
return _C_ops.size(input)
check_variable_and_dtype(
input, 'input',
['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], "size")
......@@ -13432,6 +13436,9 @@ def log_loss(input, label, epsilon=1e-4, name=None):
prob = paddle.randn((10,1))
cost = F.log_loss(input=prob, label=label)
"""
if in_dygraph_mode():
return _C_ops.final_state_log_loss(input, label, epsilon)
helper = LayerHelper('log_loss', **locals())
check_variable_and_dtype(input, 'input', ['float32'], 'log_loss')
check_variable_and_dtype(label, 'label', ['float32'], 'log_loss')
......@@ -14447,7 +14454,10 @@ def where(condition):
out = layers.where(condition) # [[]]
"""
if _non_static_mode():
if in_dygraph_mode():
return _C_ops.final_state_where_index(condition)
if _in_legacy_dygraph():
return _C_ops.where_index(condition)
helper = LayerHelper("where_index", **locals())
......@@ -14940,6 +14950,10 @@ def unfold(x, kernel_sizes, strides=1, paddings=0, dilations=1, name=None):
"Unexpected type of paddings, it should be either an integer or a list"
"of 2 or 4 integers")
if in_dygraph_mode():
return _C_ops.final_state_unfold(x, kernel_sizes, strides, paddings,
dilations)
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type="unfold",
......@@ -15167,6 +15181,10 @@ def shard_index(input, index_num, nshards, shard_id, ignore_value=-1):
print(shard_label)
# [[-1], [1]]
"""
if in_dygraph_mode():
return _C_ops.final_state_shard_index(input, index_num, nshards,
shard_id, ignore_value)
check_variable_and_dtype(input, 'input', ['int64', 'int32'], 'shard_index')
op_type = 'shard_index'
helper = LayerHelper(op_type, **locals())
......
......@@ -16,9 +16,10 @@ from __future__ import print_function
import os
from .layer_function_generator import generate_layer_fn, generate_activation_fn, generate_inplace_fn, add_sample_code
from .. import core
from ..framework import convert_np_dtype_to_dtype_, Variable
from ..framework import convert_np_dtype_to_dtype_, Variable, in_dygraph_mode
from ..data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype
from paddle.utils import deprecated
from paddle import _C_ops
__deprecated_func_name__ = {
'tanh_shrink': 'tanhshrink',
......@@ -794,6 +795,9 @@ _erf_ = generate_layer_fn('erf')
def erf(x, name=None):
if in_dygraph_mode():
return _C_ops.final_state_erf(x)
locals_var = locals().copy()
kwargs = dict()
for name, val in locals_var.items():
......
......@@ -15,6 +15,7 @@
from paddle.fluid.layer_helper import LayerHelper, _non_static_mode
from paddle.fluid.data_feeder import check_variable_and_dtype
from paddle import _C_ops
from paddle.fluid.framework import _in_legacy_dygraph, in_dygraph_mode
__all__ = []
......@@ -50,7 +51,9 @@ def segment_sum(data, segment_ids, name=None):
#Outputs: [[4., 4., 4.], [4., 5., 6.]]
"""
if _non_static_mode():
if in_dygraph_mode():
return _C_ops.final_state_segment_pool(data, segment_idsm, "SUM")[0]
if _in_legacy_dygraph():
out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "SUM")
return out
......@@ -104,6 +107,9 @@ def segment_mean(data, segment_ids, name=None):
#Outputs: [[2., 2., 2.], [4., 5., 6.]]
"""
if in_dygraph_mode():
return _C_ops.final_state_segment_pool(data, segment_idsm, "MEAN")[0]
if _non_static_mode():
out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MEAN")
return out
......@@ -157,6 +163,10 @@ def segment_min(data, segment_ids, name=None):
#Outputs: [[1., 2., 1.], [4., 5., 6.]]
"""
if in_dygraph_mode():
return _C_ops.final_state_segment_pool(data, segment_idsm, "MIN")[0]
if _non_static_mode():
out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MIN")
return out
......@@ -210,6 +220,11 @@ def segment_max(data, segment_ids, name=None):
#Outputs: [[3., 2., 3.], [4., 5., 6.]]
"""
if in_dygraph_mode():
out, tmp = _C_ops.final_state_segment_pool(data, segment_ids, "MAX")[0]
return out
if _non_static_mode():
out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MAX")
return out
......
......@@ -22,7 +22,7 @@ import numpy as np
from ..fluid.data_feeder import check_variable_and_dtype
from ..fluid.layer_helper import LayerHelper
from ..fluid.framework import core, _varbase_creator, _non_static_mode
from ..fluid.framework import core, _varbase_creator, _non_static_mode, _in_legacy_dygraph
import paddle
from paddle import _C_ops
......@@ -800,6 +800,7 @@ def accuracy(input, label, k=1, correct=None, total=None, name=None):
topk_out, topk_indices = paddle.topk(input, k=k)
_acc, _, _ = _C_ops.accuracy(topk_out, topk_indices, label, correct,
total)
return _acc
helper = LayerHelper("accuracy", **locals())
......
......@@ -784,7 +784,9 @@ def selu(x,
raise ValueError(
"The alpha must be no less than zero. Received: {}.".format(alpha))
if in_dynamic_mode():
if in_dygraph_mode():
return _C_ops.final_state_selu(x, scale, alpha)
if _in_legacy_dygraph():
return _C_ops.selu(x, 'scale', scale, 'alpha', alpha)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'selu')
......@@ -955,7 +957,12 @@ def softmax(x, axis=-1, dtype=None, name=None):
dtype = convert_np_dtype_to_dtype_(dtype)
use_cudnn = True
if in_dynamic_mode():
if in_dygraph_mode():
outs_cast = x if dtype is None \
else _C_ops.cast(x, 'in_dtype', x.dtype, 'out_dtype', dtype)
return _C_ops.final_state_softmax(outs_cast, axis)
if _in_legacy_dygraph():
outs_cast = x if dtype is None \
else _C_ops.cast(x, 'in_dtype', x.dtype, 'out_dtype', dtype)
return _C_ops.softmax(outs_cast, 'axis', axis, 'use_cudnn', use_cudnn)
......
......@@ -1212,8 +1212,12 @@ def cholesky(x, upper=False, name=None):
# [1.25450498 0.05600871 0.06400121]]
"""
if paddle.in_dynamic_mode():
if in_dygraph_mode():
return _C_ops.final_state_cholesky(x, upper)
if _in_legacy_dygraph():
return _C_ops.cholesky(x, "upper", upper)
check_variable_and_dtype(x, 'dtype', ['float32', 'float64'], 'cholesky')
check_type(upper, 'upper', bool, 'cholesky')
helper = LayerHelper('cholesky', **locals())
......@@ -1447,7 +1451,10 @@ def bincount(x, weights=None, minlength=0, name=None):
if x.dtype not in [paddle.int32, paddle.int64]:
raise TypeError("Elements in Input(x) should all be integers")
if paddle.in_dynamic_mode():
# if in_dygraph_mode():
# return _C_ops.final_state_bincount(x, weights, minlength)
if _in_legacy_dygraph():
return _C_ops.bincount(x, weights, "minlength", minlength)
helper = LayerHelper('bincount', **locals())
......@@ -1761,7 +1768,10 @@ def matrix_power(x, n, name=None):
# [-7.66666667 , 8. , -1.83333333 ],
# [ 1.80555556 , -1.91666667 , 0.44444444 ]]
"""
if paddle.in_dynamic_mode():
if in_dygraph_mode():
return _C_ops.final_state_matrix_power(x, n)
if _in_legacy_dygraph():
return _C_ops.matrix_power(x, "n", n)
check_variable_and_dtype(x, 'dtype', ['float32', 'float64'], 'matrix_power')
......@@ -2279,7 +2289,10 @@ def eigh(x, UPLO='L', name=None):
#[ 0.3826834323650898j , -0.9238795325112867j ]]
"""
if paddle.in_dynamic_mode():
if in_dygraph_mode():
return _C_ops.final_state_eigh(x, UPLO)
if _in_legacy_dygraph():
return _C_ops.eigh(x, 'UPLO', UPLO)
def __check_input(x, UPLO):
......@@ -2749,7 +2762,10 @@ def cholesky_solve(x, y, upper=False, name=None):
print(out)
# [-2.5, -7, 9.5]
"""
if paddle.in_dynamic_mode():
if in_dygraph_mode():
return _C_ops.final_state_cholesky_solve(x, y, upper)
if _in_legacy_dygraph():
return _C_ops.cholesky_solve(x, y, 'upper', upper)
helper = LayerHelper("cholesky_solve", **locals())
......
......@@ -1762,8 +1762,12 @@ def tile(x, repeat_times, name=None):
np_out = out.numpy()
# [[1, 2, 3, 1, 2, 3]]
"""
if paddle.in_dynamic_mode():
if in_dygraph_mode():
return _C_ops.final_state_tile(x, repeat_times)
if _in_legacy_dygraph():
return _C_ops.tile(x, 'repeat_times', repeat_times)
check_type(repeat_times, 'repeat_times', (list, tuple, Variable), 'tile')
if isinstance(repeat_times, Variable):
assert len(repeat_times.shape) == 1, (
......@@ -2833,12 +2837,14 @@ def take_along_axis(arr, indices, axis):
if not broadcast_shape:
# if indices matrix have larger size than arr, arr should broadcast into indices shape.
broadcast_shape = indices.shape
if paddle.in_dynamic_mode():
if _non_static_mode():
indices = paddle.broadcast_to(indices, broadcast_shape)
broadcast_shape_list = list(broadcast_shape)
broadcast_shape_list[axis] = list(arr.shape)[axis]
broadcast_shape = tuple(broadcast_shape_list)
arr = paddle.broadcast_to(arr, broadcast_shape)
if not _in_legacy_dygraph():
return _C_ops.final_state_take_along_axis(arr, indices, axis)
return _C_ops.take_along_axis(arr, indices, 'Axis', axis)
check_variable_and_dtype(
arr, 'x', ['float16', 'float32', 'float64', 'int32', 'int64', 'uint8'],
......@@ -2898,12 +2904,15 @@ def put_along_axis(arr, indices, values, axis, reduce='assign'):
"`indices` and `arr` must have the same number of dimensions!")
axis = non_negative_axis(arr, axis)
broadcast_shape = infer_broadcast_shape(arr, indices, axis)
if paddle.in_dynamic_mode():
if _non_static_mode():
values = paddle.to_tensor(values) if not isinstance(
values, paddle.Tensor) else values
if broadcast_shape:
indices = paddle.broadcast_to(indices, broadcast_shape)
values = paddle.broadcast_to(values, indices.shape)
if in_dygraph_mode():
return _C_ops.final_state_put_along_axis(arr, indices, values, axis,
reduce)
return _C_ops.put_along_axis(arr, indices, values, "Axis", axis,
"Reduce", reduce)
......
......@@ -2374,7 +2374,10 @@ def trace(x, offset=0, axis1=0, axis2=1, name=None):
"But received axis1 = %d, axis2 = %d\n"%(axis1, axis2)
__check_input(input, offset, axis1, axis2)
if paddle.in_dynamic_mode():
if in_dygraph_mode():
return _C_ops.final_state_trace( x, offset, axis1, axis2 )
if _in_legacy_dygraph():
return _C_ops.trace(x, 'offset', offset, 'axis1', axis1, 'axis2', axis2)
inputs = {'Input': [x]}
......@@ -2597,7 +2600,9 @@ def cumsum(x, axis=None, dtype=None, name=None):
if dtype is not None and x.dtype != convert_np_dtype_to_dtype_(dtype):
x = cast(x, dtype)
if paddle.in_dynamic_mode():
if in_dygraph_mode():
return _C_ops.final_state_cumsum(x, axis, flatten, False, False)
if _in_legacy_dygraph():
if axis is None:
return _C_ops.cumsum(x, 'flatten', flatten)
else:
......@@ -2854,7 +2859,10 @@ def sign(x, name=None):
out = paddle.sign(x=x)
print(out) # [1.0, 0.0, -1.0, 1.0]
"""
if paddle.in_dynamic_mode():
if in_dygraph_mode():
return _C_ops.final_state_sign(x)
if _in_legacy_dygraph():
return _C_ops.sign(x)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'sign')
......@@ -2891,7 +2899,10 @@ def tanh(x, name=None):
print(out)
# [-0.37994896 -0.19737532 0.09966799 0.29131261]
"""
if paddle.in_dynamic_mode():
if in_dygraph_mode():
return _C_ops.final_state_tanh( x )
if _in_legacy_dygraph():
return _C_ops.tanh(x)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'tanh')
......@@ -2933,7 +2944,10 @@ def increment(x, value=1.0, name=None):
# [1.]
"""
if paddle.in_dynamic_mode():
if in_dygraph_mode():
return _C_ops.final_state_increment( x, value)
if _in_legacy_dygraph():
return _C_ops.increment(x, 'step', value)
check_variable_and_dtype(x, 'x', ['float32', 'float64', 'int32', 'int64'],
......
......@@ -22,6 +22,7 @@ from ..fluid.layers import utils
import paddle
from paddle import _C_ops
from paddle.static import Variable
from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph
__all__ = []
......@@ -66,7 +67,10 @@ def bernoulli(x, name=None):
"""
if paddle.in_dynamic_mode():
if in_dygraph_mode():
return _C_ops.final_state_bernoulli(x)
if _in_legacy_dygraph():
return _C_ops.bernoulli(x)
check_variable_and_dtype(x, "x", ["float32", "float64"], "bernoulli")
......@@ -174,7 +178,10 @@ def multinomial(x, num_samples=1, replacement=False, name=None):
assert core.is_compiled_with_rocm() == False, (
"multinomial op is not supported on ROCM yet.")
if paddle.in_dynamic_mode():
if in_dygraph_mode():
return _C_ops.final_state_multinomial(x, num_samples, replacement)
if _in_legacy_dygraph():
return _C_ops.multinomial(x, 'num_samples', num_samples, 'replacement',
replacement)
......
......@@ -91,7 +91,11 @@ def argsort(x, axis=-1, descending=False, name=None):
# [1 1 0 2]
# [0 2 1 1]]]
"""
if paddle.in_dynamic_mode():
if in_dygraph_mode():
_, ids, = _C_ops.final_state_argsort(x, axis, descending)
return ids
if _in_legacy_dygraph():
_, ids = _C_ops.argsort(x, 'axis', axis, 'descending', descending)
return ids
check_variable_and_dtype(
......@@ -171,7 +175,9 @@ def argmax(x, axis=None, keepdim=False, dtype="int64", name=None):
flatten = True
axis = 0
if paddle.in_dynamic_mode():
if in_dygraph_mode():
return _C_ops.final_state_argmax(x, axis, keepdim, flatten, var_dtype)
if _in_legacy_dygraph():
out = _C_ops.arg_max(x, 'axis', axis, 'dtype', var_dtype, 'keepdims',
keepdim, 'flatten', flatten)
return out
......@@ -251,7 +257,9 @@ def argmin(x, axis=None, keepdim=False, dtype="int64", name=None):
flatten = True
axis = 0
if paddle.in_dynamic_mode():
if in_dygraph_mode():
return _C_ops.final_state_argmin(x, axis, keepdim, flatten, var_dtype)
if _in_legacy_dygraph():
out = _C_ops.arg_min(x, 'axis', axis, 'dtype', var_dtype, 'keepdims',
keepdim, 'flatten', flatten)
return out
......
......@@ -51,7 +51,6 @@ PD_REGISTER_INFER_META_FN({api.kernel['func'][0]}, phi::{api.infer_meta['func']}
wrapped_infermeta_name = get_wrapped_infermeta_name(api.api)
args = []
print("@@@", api.api)
for input_name in api.inputs['names']:
if input_name in kernel_params:
print("type", api.inputs['input_info'])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册