未验证 提交 ab8c33b1 编写于 作者: H hong 提交者: GitHub

add final state python api (#41252)

上级 99029dc9
...@@ -18,7 +18,7 @@ from ..wrapped_decorator import signature_safe_contextmanager ...@@ -18,7 +18,7 @@ from ..wrapped_decorator import signature_safe_contextmanager
from .layer_function_generator import autodoc, templatedoc from .layer_function_generator import autodoc, templatedoc
from .tensor import assign, cast, fill_constant from .tensor import assign, cast, fill_constant
from .. import core from .. import core
from ..framework import Program, Variable, Operator, _non_static_mode, static_only from ..framework import Program, Variable, Operator, _non_static_mode, static_only, _in_legacy_dygraph, in_dygraph_mode
from ..layer_helper import LayerHelper, unique_name from ..layer_helper import LayerHelper, unique_name
from .nn import logical_and, logical_not, logical_or from .nn import logical_and, logical_not, logical_or
from .utils import assert_same_structure, map_structure, hold_mutable_vars, copy_mutable_vars from .utils import assert_same_structure, map_structure, hold_mutable_vars, copy_mutable_vars
...@@ -3867,7 +3867,9 @@ def is_empty(x, name=None): ...@@ -3867,7 +3867,9 @@ def is_empty(x, name=None):
# - data: [0]) # - data: [0])
""" """
if _non_static_mode(): if in_dygraph_mode():
return _C_ops.final_state_is_empty(x)
if _in_legacy_dygraph():
return _C_ops.is_empty(x) return _C_ops.is_empty(x)
check_variable_and_dtype(x, 'x', ['float32', 'float64', 'int32', 'int64'], check_variable_and_dtype(x, 'x', ['float32', 'float64', 'int32', 'int64'],
......
...@@ -21,7 +21,7 @@ from paddle.utils import deprecated ...@@ -21,7 +21,7 @@ from paddle.utils import deprecated
from . import nn from . import nn
from .layer_function_generator import templatedoc from .layer_function_generator import templatedoc
from ..layer_helper import LayerHelper from ..layer_helper import LayerHelper
from ..framework import Variable, _non_static_mode, static_only from ..framework import Variable, _non_static_mode, static_only, _in_legacy_dygraph
from .. import core from .. import core
from ..data_feeder import check_variable_and_dtype, check_type from ..data_feeder import check_variable_and_dtype, check_type
from ..param_attr import ParamAttr from ..param_attr import ParamAttr
......
...@@ -20,7 +20,7 @@ from __future__ import print_function ...@@ -20,7 +20,7 @@ from __future__ import print_function
import warnings import warnings
from ..layer_helper import LayerHelper from ..layer_helper import LayerHelper
from ..initializer import Normal, Constant from ..initializer import Normal, Constant
from ..framework import Variable, _non_static_mode, _varbase_creator from ..framework import Variable, _non_static_mode, _varbase_creator, _in_legacy_dygraph, in_dygraph_mode
from .. import core from .. import core
from ..param_attr import ParamAttr from ..param_attr import ParamAttr
from . import nn from . import nn
......
...@@ -11674,8 +11674,12 @@ def size(input): ...@@ -11674,8 +11674,12 @@ def size(input):
rank = layers.size(input) # 300 rank = layers.size(input) # 300
""" """
if _non_static_mode(): if in_dygraph_mode():
return _C_ops.final_state_size(input)
if _in_legacy_dygraph():
return _C_ops.size(input) return _C_ops.size(input)
check_variable_and_dtype( check_variable_and_dtype(
input, 'input', input, 'input',
['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], "size") ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], "size")
...@@ -13432,6 +13436,9 @@ def log_loss(input, label, epsilon=1e-4, name=None): ...@@ -13432,6 +13436,9 @@ def log_loss(input, label, epsilon=1e-4, name=None):
prob = paddle.randn((10,1)) prob = paddle.randn((10,1))
cost = F.log_loss(input=prob, label=label) cost = F.log_loss(input=prob, label=label)
""" """
if in_dygraph_mode():
return _C_ops.final_state_log_loss(input, label, epsilon)
helper = LayerHelper('log_loss', **locals()) helper = LayerHelper('log_loss', **locals())
check_variable_and_dtype(input, 'input', ['float32'], 'log_loss') check_variable_and_dtype(input, 'input', ['float32'], 'log_loss')
check_variable_and_dtype(label, 'label', ['float32'], 'log_loss') check_variable_and_dtype(label, 'label', ['float32'], 'log_loss')
...@@ -14447,7 +14454,10 @@ def where(condition): ...@@ -14447,7 +14454,10 @@ def where(condition):
out = layers.where(condition) # [[]] out = layers.where(condition) # [[]]
""" """
if _non_static_mode():
if in_dygraph_mode():
return _C_ops.final_state_where_index(condition)
if _in_legacy_dygraph():
return _C_ops.where_index(condition) return _C_ops.where_index(condition)
helper = LayerHelper("where_index", **locals()) helper = LayerHelper("where_index", **locals())
...@@ -14940,6 +14950,10 @@ def unfold(x, kernel_sizes, strides=1, paddings=0, dilations=1, name=None): ...@@ -14940,6 +14950,10 @@ def unfold(x, kernel_sizes, strides=1, paddings=0, dilations=1, name=None):
"Unexpected type of paddings, it should be either an integer or a list" "Unexpected type of paddings, it should be either an integer or a list"
"of 2 or 4 integers") "of 2 or 4 integers")
if in_dygraph_mode():
return _C_ops.final_state_unfold(x, kernel_sizes, strides, paddings,
dilations)
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op( helper.append_op(
type="unfold", type="unfold",
...@@ -15167,6 +15181,10 @@ def shard_index(input, index_num, nshards, shard_id, ignore_value=-1): ...@@ -15167,6 +15181,10 @@ def shard_index(input, index_num, nshards, shard_id, ignore_value=-1):
print(shard_label) print(shard_label)
# [[-1], [1]] # [[-1], [1]]
""" """
if in_dygraph_mode():
return _C_ops.final_state_shard_index(input, index_num, nshards,
shard_id, ignore_value)
check_variable_and_dtype(input, 'input', ['int64', 'int32'], 'shard_index') check_variable_and_dtype(input, 'input', ['int64', 'int32'], 'shard_index')
op_type = 'shard_index' op_type = 'shard_index'
helper = LayerHelper(op_type, **locals()) helper = LayerHelper(op_type, **locals())
......
...@@ -16,9 +16,10 @@ from __future__ import print_function ...@@ -16,9 +16,10 @@ from __future__ import print_function
import os import os
from .layer_function_generator import generate_layer_fn, generate_activation_fn, generate_inplace_fn, add_sample_code from .layer_function_generator import generate_layer_fn, generate_activation_fn, generate_inplace_fn, add_sample_code
from .. import core from .. import core
from ..framework import convert_np_dtype_to_dtype_, Variable from ..framework import convert_np_dtype_to_dtype_, Variable, in_dygraph_mode
from ..data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype from ..data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype
from paddle.utils import deprecated from paddle.utils import deprecated
from paddle import _C_ops
__deprecated_func_name__ = { __deprecated_func_name__ = {
'tanh_shrink': 'tanhshrink', 'tanh_shrink': 'tanhshrink',
...@@ -794,6 +795,9 @@ _erf_ = generate_layer_fn('erf') ...@@ -794,6 +795,9 @@ _erf_ = generate_layer_fn('erf')
def erf(x, name=None): def erf(x, name=None):
if in_dygraph_mode():
return _C_ops.final_state_erf(x)
locals_var = locals().copy() locals_var = locals().copy()
kwargs = dict() kwargs = dict()
for name, val in locals_var.items(): for name, val in locals_var.items():
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from paddle.fluid.layer_helper import LayerHelper, _non_static_mode from paddle.fluid.layer_helper import LayerHelper, _non_static_mode
from paddle.fluid.data_feeder import check_variable_and_dtype from paddle.fluid.data_feeder import check_variable_and_dtype
from paddle import _C_ops from paddle import _C_ops
from paddle.fluid.framework import _in_legacy_dygraph, in_dygraph_mode
__all__ = [] __all__ = []
...@@ -50,7 +51,9 @@ def segment_sum(data, segment_ids, name=None): ...@@ -50,7 +51,9 @@ def segment_sum(data, segment_ids, name=None):
#Outputs: [[4., 4., 4.], [4., 5., 6.]] #Outputs: [[4., 4., 4.], [4., 5., 6.]]
""" """
if _non_static_mode(): if in_dygraph_mode():
return _C_ops.final_state_segment_pool(data, segment_idsm, "SUM")[0]
if _in_legacy_dygraph():
out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "SUM") out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "SUM")
return out return out
...@@ -104,6 +107,9 @@ def segment_mean(data, segment_ids, name=None): ...@@ -104,6 +107,9 @@ def segment_mean(data, segment_ids, name=None):
#Outputs: [[2., 2., 2.], [4., 5., 6.]] #Outputs: [[2., 2., 2.], [4., 5., 6.]]
""" """
if in_dygraph_mode():
return _C_ops.final_state_segment_pool(data, segment_idsm, "MEAN")[0]
if _non_static_mode(): if _non_static_mode():
out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MEAN") out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MEAN")
return out return out
...@@ -157,6 +163,10 @@ def segment_min(data, segment_ids, name=None): ...@@ -157,6 +163,10 @@ def segment_min(data, segment_ids, name=None):
#Outputs: [[1., 2., 1.], [4., 5., 6.]] #Outputs: [[1., 2., 1.], [4., 5., 6.]]
""" """
if in_dygraph_mode():
return _C_ops.final_state_segment_pool(data, segment_idsm, "MIN")[0]
if _non_static_mode(): if _non_static_mode():
out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MIN") out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MIN")
return out return out
...@@ -210,6 +220,11 @@ def segment_max(data, segment_ids, name=None): ...@@ -210,6 +220,11 @@ def segment_max(data, segment_ids, name=None):
#Outputs: [[3., 2., 3.], [4., 5., 6.]] #Outputs: [[3., 2., 3.], [4., 5., 6.]]
""" """
if in_dygraph_mode():
out, tmp = _C_ops.final_state_segment_pool(data, segment_ids, "MAX")[0]
return out
if _non_static_mode(): if _non_static_mode():
out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MAX") out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MAX")
return out return out
......
...@@ -22,7 +22,7 @@ import numpy as np ...@@ -22,7 +22,7 @@ import numpy as np
from ..fluid.data_feeder import check_variable_and_dtype from ..fluid.data_feeder import check_variable_and_dtype
from ..fluid.layer_helper import LayerHelper from ..fluid.layer_helper import LayerHelper
from ..fluid.framework import core, _varbase_creator, _non_static_mode from ..fluid.framework import core, _varbase_creator, _non_static_mode, _in_legacy_dygraph
import paddle import paddle
from paddle import _C_ops from paddle import _C_ops
...@@ -800,6 +800,7 @@ def accuracy(input, label, k=1, correct=None, total=None, name=None): ...@@ -800,6 +800,7 @@ def accuracy(input, label, k=1, correct=None, total=None, name=None):
topk_out, topk_indices = paddle.topk(input, k=k) topk_out, topk_indices = paddle.topk(input, k=k)
_acc, _, _ = _C_ops.accuracy(topk_out, topk_indices, label, correct, _acc, _, _ = _C_ops.accuracy(topk_out, topk_indices, label, correct,
total) total)
return _acc return _acc
helper = LayerHelper("accuracy", **locals()) helper = LayerHelper("accuracy", **locals())
......
...@@ -784,7 +784,9 @@ def selu(x, ...@@ -784,7 +784,9 @@ def selu(x,
raise ValueError( raise ValueError(
"The alpha must be no less than zero. Received: {}.".format(alpha)) "The alpha must be no less than zero. Received: {}.".format(alpha))
if in_dynamic_mode(): if in_dygraph_mode():
return _C_ops.final_state_selu(x, scale, alpha)
if _in_legacy_dygraph():
return _C_ops.selu(x, 'scale', scale, 'alpha', alpha) return _C_ops.selu(x, 'scale', scale, 'alpha', alpha)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'selu') check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'selu')
...@@ -955,7 +957,12 @@ def softmax(x, axis=-1, dtype=None, name=None): ...@@ -955,7 +957,12 @@ def softmax(x, axis=-1, dtype=None, name=None):
dtype = convert_np_dtype_to_dtype_(dtype) dtype = convert_np_dtype_to_dtype_(dtype)
use_cudnn = True use_cudnn = True
if in_dynamic_mode(): if in_dygraph_mode():
outs_cast = x if dtype is None \
else _C_ops.cast(x, 'in_dtype', x.dtype, 'out_dtype', dtype)
return _C_ops.final_state_softmax(outs_cast, axis)
if _in_legacy_dygraph():
outs_cast = x if dtype is None \ outs_cast = x if dtype is None \
else _C_ops.cast(x, 'in_dtype', x.dtype, 'out_dtype', dtype) else _C_ops.cast(x, 'in_dtype', x.dtype, 'out_dtype', dtype)
return _C_ops.softmax(outs_cast, 'axis', axis, 'use_cudnn', use_cudnn) return _C_ops.softmax(outs_cast, 'axis', axis, 'use_cudnn', use_cudnn)
......
...@@ -1212,8 +1212,12 @@ def cholesky(x, upper=False, name=None): ...@@ -1212,8 +1212,12 @@ def cholesky(x, upper=False, name=None):
# [1.25450498 0.05600871 0.06400121]] # [1.25450498 0.05600871 0.06400121]]
""" """
if paddle.in_dynamic_mode(): if in_dygraph_mode():
return _C_ops.final_state_cholesky(x, upper)
if _in_legacy_dygraph():
return _C_ops.cholesky(x, "upper", upper) return _C_ops.cholesky(x, "upper", upper)
check_variable_and_dtype(x, 'dtype', ['float32', 'float64'], 'cholesky') check_variable_and_dtype(x, 'dtype', ['float32', 'float64'], 'cholesky')
check_type(upper, 'upper', bool, 'cholesky') check_type(upper, 'upper', bool, 'cholesky')
helper = LayerHelper('cholesky', **locals()) helper = LayerHelper('cholesky', **locals())
...@@ -1447,7 +1451,10 @@ def bincount(x, weights=None, minlength=0, name=None): ...@@ -1447,7 +1451,10 @@ def bincount(x, weights=None, minlength=0, name=None):
if x.dtype not in [paddle.int32, paddle.int64]: if x.dtype not in [paddle.int32, paddle.int64]:
raise TypeError("Elements in Input(x) should all be integers") raise TypeError("Elements in Input(x) should all be integers")
if paddle.in_dynamic_mode(): # if in_dygraph_mode():
# return _C_ops.final_state_bincount(x, weights, minlength)
if _in_legacy_dygraph():
return _C_ops.bincount(x, weights, "minlength", minlength) return _C_ops.bincount(x, weights, "minlength", minlength)
helper = LayerHelper('bincount', **locals()) helper = LayerHelper('bincount', **locals())
...@@ -1761,7 +1768,10 @@ def matrix_power(x, n, name=None): ...@@ -1761,7 +1768,10 @@ def matrix_power(x, n, name=None):
# [-7.66666667 , 8. , -1.83333333 ], # [-7.66666667 , 8. , -1.83333333 ],
# [ 1.80555556 , -1.91666667 , 0.44444444 ]] # [ 1.80555556 , -1.91666667 , 0.44444444 ]]
""" """
if paddle.in_dynamic_mode(): if in_dygraph_mode():
return _C_ops.final_state_matrix_power(x, n)
if _in_legacy_dygraph():
return _C_ops.matrix_power(x, "n", n) return _C_ops.matrix_power(x, "n", n)
check_variable_and_dtype(x, 'dtype', ['float32', 'float64'], 'matrix_power') check_variable_and_dtype(x, 'dtype', ['float32', 'float64'], 'matrix_power')
...@@ -2279,7 +2289,10 @@ def eigh(x, UPLO='L', name=None): ...@@ -2279,7 +2289,10 @@ def eigh(x, UPLO='L', name=None):
#[ 0.3826834323650898j , -0.9238795325112867j ]] #[ 0.3826834323650898j , -0.9238795325112867j ]]
""" """
if paddle.in_dynamic_mode(): if in_dygraph_mode():
return _C_ops.final_state_eigh(x, UPLO)
if _in_legacy_dygraph():
return _C_ops.eigh(x, 'UPLO', UPLO) return _C_ops.eigh(x, 'UPLO', UPLO)
def __check_input(x, UPLO): def __check_input(x, UPLO):
...@@ -2749,7 +2762,10 @@ def cholesky_solve(x, y, upper=False, name=None): ...@@ -2749,7 +2762,10 @@ def cholesky_solve(x, y, upper=False, name=None):
print(out) print(out)
# [-2.5, -7, 9.5] # [-2.5, -7, 9.5]
""" """
if paddle.in_dynamic_mode(): if in_dygraph_mode():
return _C_ops.final_state_cholesky_solve(x, y, upper)
if _in_legacy_dygraph():
return _C_ops.cholesky_solve(x, y, 'upper', upper) return _C_ops.cholesky_solve(x, y, 'upper', upper)
helper = LayerHelper("cholesky_solve", **locals()) helper = LayerHelper("cholesky_solve", **locals())
......
...@@ -1762,8 +1762,12 @@ def tile(x, repeat_times, name=None): ...@@ -1762,8 +1762,12 @@ def tile(x, repeat_times, name=None):
np_out = out.numpy() np_out = out.numpy()
# [[1, 2, 3, 1, 2, 3]] # [[1, 2, 3, 1, 2, 3]]
""" """
if paddle.in_dynamic_mode(): if in_dygraph_mode():
return _C_ops.final_state_tile(x, repeat_times)
if _in_legacy_dygraph():
return _C_ops.tile(x, 'repeat_times', repeat_times) return _C_ops.tile(x, 'repeat_times', repeat_times)
check_type(repeat_times, 'repeat_times', (list, tuple, Variable), 'tile') check_type(repeat_times, 'repeat_times', (list, tuple, Variable), 'tile')
if isinstance(repeat_times, Variable): if isinstance(repeat_times, Variable):
assert len(repeat_times.shape) == 1, ( assert len(repeat_times.shape) == 1, (
...@@ -2833,12 +2837,14 @@ def take_along_axis(arr, indices, axis): ...@@ -2833,12 +2837,14 @@ def take_along_axis(arr, indices, axis):
if not broadcast_shape: if not broadcast_shape:
# if indices matrix have larger size than arr, arr should broadcast into indices shape. # if indices matrix have larger size than arr, arr should broadcast into indices shape.
broadcast_shape = indices.shape broadcast_shape = indices.shape
if paddle.in_dynamic_mode(): if _non_static_mode():
indices = paddle.broadcast_to(indices, broadcast_shape) indices = paddle.broadcast_to(indices, broadcast_shape)
broadcast_shape_list = list(broadcast_shape) broadcast_shape_list = list(broadcast_shape)
broadcast_shape_list[axis] = list(arr.shape)[axis] broadcast_shape_list[axis] = list(arr.shape)[axis]
broadcast_shape = tuple(broadcast_shape_list) broadcast_shape = tuple(broadcast_shape_list)
arr = paddle.broadcast_to(arr, broadcast_shape) arr = paddle.broadcast_to(arr, broadcast_shape)
if not _in_legacy_dygraph():
return _C_ops.final_state_take_along_axis(arr, indices, axis)
return _C_ops.take_along_axis(arr, indices, 'Axis', axis) return _C_ops.take_along_axis(arr, indices, 'Axis', axis)
check_variable_and_dtype( check_variable_and_dtype(
arr, 'x', ['float16', 'float32', 'float64', 'int32', 'int64', 'uint8'], arr, 'x', ['float16', 'float32', 'float64', 'int32', 'int64', 'uint8'],
...@@ -2898,12 +2904,15 @@ def put_along_axis(arr, indices, values, axis, reduce='assign'): ...@@ -2898,12 +2904,15 @@ def put_along_axis(arr, indices, values, axis, reduce='assign'):
"`indices` and `arr` must have the same number of dimensions!") "`indices` and `arr` must have the same number of dimensions!")
axis = non_negative_axis(arr, axis) axis = non_negative_axis(arr, axis)
broadcast_shape = infer_broadcast_shape(arr, indices, axis) broadcast_shape = infer_broadcast_shape(arr, indices, axis)
if paddle.in_dynamic_mode(): if _non_static_mode():
values = paddle.to_tensor(values) if not isinstance( values = paddle.to_tensor(values) if not isinstance(
values, paddle.Tensor) else values values, paddle.Tensor) else values
if broadcast_shape: if broadcast_shape:
indices = paddle.broadcast_to(indices, broadcast_shape) indices = paddle.broadcast_to(indices, broadcast_shape)
values = paddle.broadcast_to(values, indices.shape) values = paddle.broadcast_to(values, indices.shape)
if in_dygraph_mode():
return _C_ops.final_state_put_along_axis(arr, indices, values, axis,
reduce)
return _C_ops.put_along_axis(arr, indices, values, "Axis", axis, return _C_ops.put_along_axis(arr, indices, values, "Axis", axis,
"Reduce", reduce) "Reduce", reduce)
......
...@@ -2374,7 +2374,10 @@ def trace(x, offset=0, axis1=0, axis2=1, name=None): ...@@ -2374,7 +2374,10 @@ def trace(x, offset=0, axis1=0, axis2=1, name=None):
"But received axis1 = %d, axis2 = %d\n"%(axis1, axis2) "But received axis1 = %d, axis2 = %d\n"%(axis1, axis2)
__check_input(input, offset, axis1, axis2) __check_input(input, offset, axis1, axis2)
if paddle.in_dynamic_mode(): if in_dygraph_mode():
return _C_ops.final_state_trace( x, offset, axis1, axis2 )
if _in_legacy_dygraph():
return _C_ops.trace(x, 'offset', offset, 'axis1', axis1, 'axis2', axis2) return _C_ops.trace(x, 'offset', offset, 'axis1', axis1, 'axis2', axis2)
inputs = {'Input': [x]} inputs = {'Input': [x]}
...@@ -2597,7 +2600,9 @@ def cumsum(x, axis=None, dtype=None, name=None): ...@@ -2597,7 +2600,9 @@ def cumsum(x, axis=None, dtype=None, name=None):
if dtype is not None and x.dtype != convert_np_dtype_to_dtype_(dtype): if dtype is not None and x.dtype != convert_np_dtype_to_dtype_(dtype):
x = cast(x, dtype) x = cast(x, dtype)
if paddle.in_dynamic_mode(): if in_dygraph_mode():
return _C_ops.final_state_cumsum(x, axis, flatten, False, False)
if _in_legacy_dygraph():
if axis is None: if axis is None:
return _C_ops.cumsum(x, 'flatten', flatten) return _C_ops.cumsum(x, 'flatten', flatten)
else: else:
...@@ -2854,7 +2859,10 @@ def sign(x, name=None): ...@@ -2854,7 +2859,10 @@ def sign(x, name=None):
out = paddle.sign(x=x) out = paddle.sign(x=x)
print(out) # [1.0, 0.0, -1.0, 1.0] print(out) # [1.0, 0.0, -1.0, 1.0]
""" """
if paddle.in_dynamic_mode(): if in_dygraph_mode():
return _C_ops.final_state_sign(x)
if _in_legacy_dygraph():
return _C_ops.sign(x) return _C_ops.sign(x)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'sign') check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'sign')
...@@ -2891,7 +2899,10 @@ def tanh(x, name=None): ...@@ -2891,7 +2899,10 @@ def tanh(x, name=None):
print(out) print(out)
# [-0.37994896 -0.19737532 0.09966799 0.29131261] # [-0.37994896 -0.19737532 0.09966799 0.29131261]
""" """
if paddle.in_dynamic_mode(): if in_dygraph_mode():
return _C_ops.final_state_tanh( x )
if _in_legacy_dygraph():
return _C_ops.tanh(x) return _C_ops.tanh(x)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'tanh') check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'tanh')
...@@ -2933,7 +2944,10 @@ def increment(x, value=1.0, name=None): ...@@ -2933,7 +2944,10 @@ def increment(x, value=1.0, name=None):
# [1.] # [1.]
""" """
if paddle.in_dynamic_mode(): if in_dygraph_mode():
return _C_ops.final_state_increment( x, value)
if _in_legacy_dygraph():
return _C_ops.increment(x, 'step', value) return _C_ops.increment(x, 'step', value)
check_variable_and_dtype(x, 'x', ['float32', 'float64', 'int32', 'int64'], check_variable_and_dtype(x, 'x', ['float32', 'float64', 'int32', 'int64'],
......
...@@ -22,6 +22,7 @@ from ..fluid.layers import utils ...@@ -22,6 +22,7 @@ from ..fluid.layers import utils
import paddle import paddle
from paddle import _C_ops from paddle import _C_ops
from paddle.static import Variable from paddle.static import Variable
from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph
__all__ = [] __all__ = []
...@@ -66,7 +67,10 @@ def bernoulli(x, name=None): ...@@ -66,7 +67,10 @@ def bernoulli(x, name=None):
""" """
if paddle.in_dynamic_mode(): if in_dygraph_mode():
return _C_ops.final_state_bernoulli(x)
if _in_legacy_dygraph():
return _C_ops.bernoulli(x) return _C_ops.bernoulli(x)
check_variable_and_dtype(x, "x", ["float32", "float64"], "bernoulli") check_variable_and_dtype(x, "x", ["float32", "float64"], "bernoulli")
...@@ -174,7 +178,10 @@ def multinomial(x, num_samples=1, replacement=False, name=None): ...@@ -174,7 +178,10 @@ def multinomial(x, num_samples=1, replacement=False, name=None):
assert core.is_compiled_with_rocm() == False, ( assert core.is_compiled_with_rocm() == False, (
"multinomial op is not supported on ROCM yet.") "multinomial op is not supported on ROCM yet.")
if paddle.in_dynamic_mode(): if in_dygraph_mode():
return _C_ops.final_state_multinomial(x, num_samples, replacement)
if _in_legacy_dygraph():
return _C_ops.multinomial(x, 'num_samples', num_samples, 'replacement', return _C_ops.multinomial(x, 'num_samples', num_samples, 'replacement',
replacement) replacement)
......
...@@ -91,7 +91,11 @@ def argsort(x, axis=-1, descending=False, name=None): ...@@ -91,7 +91,11 @@ def argsort(x, axis=-1, descending=False, name=None):
# [1 1 0 2] # [1 1 0 2]
# [0 2 1 1]]] # [0 2 1 1]]]
""" """
if paddle.in_dynamic_mode(): if in_dygraph_mode():
_, ids, = _C_ops.final_state_argsort(x, axis, descending)
return ids
if _in_legacy_dygraph():
_, ids = _C_ops.argsort(x, 'axis', axis, 'descending', descending) _, ids = _C_ops.argsort(x, 'axis', axis, 'descending', descending)
return ids return ids
check_variable_and_dtype( check_variable_and_dtype(
...@@ -171,7 +175,9 @@ def argmax(x, axis=None, keepdim=False, dtype="int64", name=None): ...@@ -171,7 +175,9 @@ def argmax(x, axis=None, keepdim=False, dtype="int64", name=None):
flatten = True flatten = True
axis = 0 axis = 0
if paddle.in_dynamic_mode(): if in_dygraph_mode():
return _C_ops.final_state_argmax(x, axis, keepdim, flatten, var_dtype)
if _in_legacy_dygraph():
out = _C_ops.arg_max(x, 'axis', axis, 'dtype', var_dtype, 'keepdims', out = _C_ops.arg_max(x, 'axis', axis, 'dtype', var_dtype, 'keepdims',
keepdim, 'flatten', flatten) keepdim, 'flatten', flatten)
return out return out
...@@ -251,7 +257,9 @@ def argmin(x, axis=None, keepdim=False, dtype="int64", name=None): ...@@ -251,7 +257,9 @@ def argmin(x, axis=None, keepdim=False, dtype="int64", name=None):
flatten = True flatten = True
axis = 0 axis = 0
if paddle.in_dynamic_mode(): if in_dygraph_mode():
return _C_ops.final_state_argmin(x, axis, keepdim, flatten, var_dtype)
if _in_legacy_dygraph():
out = _C_ops.arg_min(x, 'axis', axis, 'dtype', var_dtype, 'keepdims', out = _C_ops.arg_min(x, 'axis', axis, 'dtype', var_dtype, 'keepdims',
keepdim, 'flatten', flatten) keepdim, 'flatten', flatten)
return out return out
......
...@@ -51,7 +51,6 @@ PD_REGISTER_INFER_META_FN({api.kernel['func'][0]}, phi::{api.infer_meta['func']} ...@@ -51,7 +51,6 @@ PD_REGISTER_INFER_META_FN({api.kernel['func'][0]}, phi::{api.infer_meta['func']}
wrapped_infermeta_name = get_wrapped_infermeta_name(api.api) wrapped_infermeta_name = get_wrapped_infermeta_name(api.api)
args = [] args = []
print("@@@", api.api)
for input_name in api.inputs['names']: for input_name in api.inputs['names']:
if input_name in kernel_params: if input_name in kernel_params:
print("type", api.inputs['input_info']) print("type", api.inputs['input_info'])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册