未验证 提交 cf17ae8a 编写于 作者: J Jiabin Yang 提交者: GitHub

[Eager] Support more final_state code (#44986)

* support more final_state code

* support more final_state code

* fix api error

* fix norm error

* fix pool3d error

* revert pool3d and max_pool_3d_adaptive

* fix code check error

* fix norm problem
上级 7bc57d35
...@@ -2259,7 +2259,11 @@ def pool2d(input, ...@@ -2259,7 +2259,11 @@ def pool2d(input,
pool_padding = [0, 0] pool_padding = [0, 0]
pool_padding = update_padding(pool_padding, data_format) pool_padding = update_padding(pool_padding, data_format)
if in_dygraph_mode():
return _C_ops.final_state_pool2d(input, pool_size, pool_stride,
pool_padding, ceil_mode, exclusive,
data_format, pool_type, global_pooling,
False, padding_algorithm)
op_type = 'pool2d' op_type = 'pool2d'
helper = LayerHelper(op_type, **locals()) helper = LayerHelper(op_type, **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
......
...@@ -1609,12 +1609,12 @@ def linspace(start, stop, num, dtype=None, name=None): ...@@ -1609,12 +1609,12 @@ def linspace(start, stop, num, dtype=None, name=None):
if not isinstance(num, Variable): if not isinstance(num, Variable):
with device_guard("cpu"): with device_guard("cpu"):
tensor_num = fill_constant([1], 'int32', num) tensor_num = fill_constant([1], 'int32', num)
if _in_legacy_dygraph():
return _C_ops.linspace(tensor_start, tensor_stop, tensor_num, 'dtype',
dtype)
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.final_state_linspace(tensor_start, tensor_stop, return _C_ops.final_state_linspace(tensor_start, tensor_stop,
tensor_num, dtype) tensor_num, dtype)
if _in_legacy_dygraph():
return _C_ops.linspace(tensor_start, tensor_stop, tensor_num, 'dtype',
dtype)
helper = LayerHelper("linspace", **locals()) helper = LayerHelper("linspace", **locals())
start_dtype = convert_dtype(tensor_start.dtype) start_dtype = convert_dtype(tensor_start.dtype)
......
...@@ -1615,7 +1615,12 @@ def adaptive_max_pool1d(x, output_size, return_mask=False, name=None): ...@@ -1615,7 +1615,12 @@ def adaptive_max_pool1d(x, output_size, return_mask=False, name=None):
pool_size = [1] + utils.convert_to_list(output_size, 1, 'pool_size') pool_size = [1] + utils.convert_to_list(output_size, 1, 'pool_size')
x = unsqueeze(x, [2]) x = unsqueeze(x, [2])
if in_dynamic_mode(): if in_dygraph_mode():
pool_out = _C_ops.final_state_max_pool2d_with_index(
x, pool_size, [1, 1], [0, 0], False, True)
return (squeeze(pool_out[0], [2]), squeeze(
pool_out[1], [2])) if return_mask else squeeze(pool_out[0], [2])
if _in_legacy_dygraph():
pool_out = _C_ops.max_pool2d_with_index(x, 'pooling_type', pool_type, pool_out = _C_ops.max_pool2d_with_index(x, 'pooling_type', pool_type,
'ksize', pool_size, 'adaptive', 'ksize', pool_size, 'adaptive',
True) True)
...@@ -1703,8 +1708,11 @@ def adaptive_max_pool2d(x, output_size, return_mask=False, name=None): ...@@ -1703,8 +1708,11 @@ def adaptive_max_pool2d(x, output_size, return_mask=False, name=None):
output_size[0] = in_h output_size[0] = in_h
if output_size[1] == None: if output_size[1] == None:
output_size[1] = in_w output_size[1] = in_w
if in_dygraph_mode():
if in_dynamic_mode(): pool_out = _C_ops.final_state_max_pool2d_with_index(
x, output_size, [1, 1], [0, 0], False, True)
return pool_out if return_mask else pool_out[0]
if _in_legacy_dygraph():
pool_out = _C_ops.max_pool2d_with_index(x, 'pooling_type', 'max', pool_out = _C_ops.max_pool2d_with_index(x, 'pooling_type', 'max',
'ksize', output_size, 'ksize', output_size,
'adaptive', True) 'adaptive', True)
......
...@@ -98,7 +98,10 @@ def linspace(start, stop, num, dtype=None, name=None): ...@@ -98,7 +98,10 @@ def linspace(start, stop, num, dtype=None, name=None):
if not isinstance(num, Variable): if not isinstance(num, Variable):
with device_guard("cpu"): with device_guard("cpu"):
tensor_num = fill_constant([1], 'int32', num, force_cpu=True) tensor_num = fill_constant([1], 'int32', num, force_cpu=True)
if _non_static_mode(): if in_dygraph_mode():
return _C_ops.final_state_linspace(tensor_start, tensor_stop,
tensor_num, dtype)
if _in_legacy_dygraph():
return _C_ops.linspace(tensor_start, tensor_stop, tensor_num, 'dtype', return _C_ops.linspace(tensor_start, tensor_stop, tensor_num, 'dtype',
dtype) dtype)
...@@ -1162,7 +1165,14 @@ def diagflat(x, offset=0, name=None): ...@@ -1162,7 +1165,14 @@ def diagflat(x, offset=0, name=None):
# [0 0 0 4 0]] # [0 0 0 4 0]]
""" """
padding_value = 0 padding_value = 0
if paddle.in_dynamic_mode(): if in_dygraph_mode():
if len(x.shape) == 1:
return _C_ops.final_state_diag(x, offset, padding_value)
else:
y = _C_ops.final_state_flatten(x, 0, -1)
return _C_ops.final_state_diag(y, offset, padding_value)
if _in_legacy_dygraph():
if len(x.shape) == 1: if len(x.shape) == 1:
return _C_ops.diag_v2(x, "offset", offset, "padding_value", return _C_ops.diag_v2(x, "offset", offset, "padding_value",
padding_value) padding_value)
...@@ -1370,7 +1380,14 @@ def empty(shape, dtype=None, name=None): ...@@ -1370,7 +1380,14 @@ def empty(shape, dtype=None, name=None):
dtype = convert_dtype(dtype) dtype = convert_dtype(dtype)
if paddle.in_dynamic_mode(): if in_dygraph_mode():
shape = utils.convert_shape_to_list(shape)
out = _C_ops.final_state_empty(shape, convert_np_dtype_to_dtype_(dtype),
_current_expected_place())
out.stop_gradient = True
return out
if _in_legacy_dygraph():
shape = utils.convert_shape_to_list(shape) shape = utils.convert_shape_to_list(shape)
out = _C_ops.empty('shape', shape, 'dtype', out = _C_ops.empty('shape', shape, 'dtype',
convert_np_dtype_to_dtype_(dtype)) convert_np_dtype_to_dtype_(dtype))
...@@ -1437,7 +1454,14 @@ def empty_like(x, dtype=None, name=None): ...@@ -1437,7 +1454,14 @@ def empty_like(x, dtype=None, name=None):
dtype = x.dtype dtype = x.dtype
dtype = convert_dtype(dtype) dtype = convert_dtype(dtype)
if paddle.in_dynamic_mode(): if in_dygraph_mode():
out = _C_ops.final_state_empty(x.shape,
convert_np_dtype_to_dtype_(dtype),
_current_expected_place())
out.stop_gradient = True
return out
if _in_legacy_dygraph():
out = _C_ops.empty('shape', x.shape, 'dtype', out = _C_ops.empty('shape', x.shape, 'dtype',
convert_np_dtype_to_dtype_(dtype)) convert_np_dtype_to_dtype_(dtype))
out.stop_gradient = True out.stop_gradient = True
......
...@@ -420,6 +420,17 @@ def norm(x, p='fro', axis=None, keepdim=False, name=None): ...@@ -420,6 +420,17 @@ def norm(x, p='fro', axis=None, keepdim=False, name=None):
keepdim=False, keepdim=False,
asvector=False, asvector=False,
name=None): name=None):
if in_dygraph_mode():
out = _C_ops.final_state_abs(input)
reduce_all = True if axis == None or axis == [] or asvector == True else False
axis = axis if axis != None and axis != [] else [0]
if reduce_all:
assert (axis == []) or (axis is None)
if porder == np.float64('inf'):
return _C_ops.final_state_max(out, axis, keepdim)
else:
return _C_ops.final_state_min(out, axis, keepdim)
helper = LayerHelper('inf_norm', **locals()) helper = LayerHelper('inf_norm', **locals())
out = helper.create_variable_for_type_inference( out = helper.create_variable_for_type_inference(
dtype=helper.input_dtype()) dtype=helper.input_dtype())
...@@ -448,6 +459,13 @@ def norm(x, p='fro', axis=None, keepdim=False, name=None): ...@@ -448,6 +459,13 @@ def norm(x, p='fro', axis=None, keepdim=False, name=None):
NOTE: NOTE:
This function actually treats the matrix as flattened vector to calculate vector norm instead of matrix norm. This function actually treats the matrix as flattened vector to calculate vector norm instead of matrix norm.
""" """
if in_dygraph_mode():
abs_out = _C_ops.final_state_abs(input)
pow_out = _C_ops.final_state_pow(abs_out, porder)
sum_out = _C_ops.final_state_sum(pow_out, axis, None, keepdim)
out = _C_ops.final_state_pow(sum_out, float(1. / porder))
return out
block = LayerHelper('norm', **locals()) block = LayerHelper('norm', **locals())
out = block.create_variable_for_type_inference( out = block.create_variable_for_type_inference(
dtype=block.input_dtype()) dtype=block.input_dtype())
...@@ -2588,8 +2606,55 @@ def pinv(x, rcond=1e-15, hermitian=False, name=None): ...@@ -2588,8 +2606,55 @@ def pinv(x, rcond=1e-15, hermitian=False, name=None):
# one can verify : x * out * x = x ; # one can verify : x * out * x = x ;
# or out * x * out = x ; # or out * x * out = x ;
""" """
if in_dygraph_mode():
if not hermitian:
# combine svd and matmul op
u, s, vt = _C_ops.final_state_svd(x, False)
max_singular_val = _C_ops.final_state_max(s, [-1], True)
rcond = paddle.to_tensor(rcond, dtype=x.dtype)
cutoff = rcond * max_singular_val
y = float('inf')
y = paddle.to_tensor(y, dtype=x.dtype)
if _non_static_mode(): condition = s > cutoff
cond_int = cast(condition, s.dtype)
cond_not_int = cast(logical_not(condition), s.dtype)
out1 = multiply(1 / s, cond_int)
out2 = multiply(1 / y, cond_not_int)
singular = add(out1, out2)
st = _C_ops.final_state_unsqueeze(singular, [-2])
dims = list(range(len(vt.shape)))
perm = dims[:-2] + [dims[-1]] + [dims[-2]]
v = _C_ops.final_state_transpose(vt, perm)
out_1 = v * st
out_2 = _C_ops.final_state_matmul(out_1, u, False, True)
return out_2
else:
# combine eigh and matmul op
s, u = _C_ops.final_state_eigh(x, 'UPLO')
s_abs = paddle.abs(s)
max_singular_val = _C_ops.final_state_max(s_abs, [-1], True)
rcond = paddle.to_tensor(rcond, dtype=s.dtype)
cutoff = rcond * max_singular_val
y = float('inf')
y = paddle.to_tensor(y, dtype=s.dtype)
condition = s_abs > cutoff
cond_int = cast(condition, s.dtype)
cond_not_int = cast(logical_not(condition), s.dtype)
out1 = multiply(1 / s, cond_int)
out2 = multiply(1 / y, cond_not_int)
singular = add(out1, out2)
st = _C_ops.final_state_unsqueeze(singular, [-2])
out_1 = u * st
u_conj = _C_ops.final_state_conj(u)
out_2 = _C_ops.final_state_matmul(out_1, u_conj, False, True)
return out_2
if _in_legacy_dygraph():
if not hermitian: if not hermitian:
# combine svd and matmul op # combine svd and matmul op
u, s, vt = _C_ops.svd(x, 'full_matrices', False) u, s, vt = _C_ops.svd(x, 'full_matrices', False)
......
...@@ -4123,7 +4123,11 @@ def moveaxis(x, source, destination, name=None): ...@@ -4123,7 +4123,11 @@ def moveaxis(x, source, destination, name=None):
for i in range(len(src_dims)): for i in range(len(src_dims)):
perm[dst_dims[i]] = src_dims[i] perm[dst_dims[i]] = src_dims[i]
if paddle.in_dynamic_mode(): if in_dygraph_mode():
out = _C_ops.final_state_transpose(x, perm)
return out
if _in_legacy_dygraph():
out, _ = _C_ops.transpose2(x, 'axis', perm) out, _ = _C_ops.transpose2(x, 'axis', perm)
return out return out
......
...@@ -19,6 +19,8 @@ from ..framework import core ...@@ -19,6 +19,8 @@ from ..framework import core
from ..framework import convert_np_dtype_to_dtype_ from ..framework import convert_np_dtype_to_dtype_
from ..static import Variable from ..static import Variable
from ..fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype from ..fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype
from ..fluid.framework import in_dygraph_mode
from .. import _C_ops
__deprecated_func_name__ = { __deprecated_func_name__ = {
'tanh_shrink': 'tanhshrink', 'tanh_shrink': 'tanhshrink',
...@@ -510,6 +512,9 @@ _erf_ = generate_layer_fn('erf') ...@@ -510,6 +512,9 @@ _erf_ = generate_layer_fn('erf')
def erf(x, name=None): def erf(x, name=None):
if in_dygraph_mode():
return _C_ops.final_state_erf(x)
locals_var = locals().copy() locals_var = locals().copy()
kwargs = dict() kwargs = dict()
for name, val in locals_var.items(): for name, val in locals_var.items():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册