未验证 提交 cf17ae8a 编写于 作者: J Jiabin Yang 提交者: GitHub

[Eager] Support more final_state code (#44986)

* support more final_state code

* support more final_state code

* fix api error

* fix norm error

* fix pool3d error

* revert pool3d and max_pool_3d_adaptive

* fix code check error

* fix norm problem
上级 7bc57d35
......@@ -2259,7 +2259,11 @@ def pool2d(input,
pool_padding = [0, 0]
pool_padding = update_padding(pool_padding, data_format)
if in_dygraph_mode():
return _C_ops.final_state_pool2d(input, pool_size, pool_stride,
pool_padding, ceil_mode, exclusive,
data_format, pool_type, global_pooling,
False, padding_algorithm)
op_type = 'pool2d'
helper = LayerHelper(op_type, **locals())
dtype = helper.input_dtype()
......
......@@ -1609,12 +1609,12 @@ def linspace(start, stop, num, dtype=None, name=None):
if not isinstance(num, Variable):
with device_guard("cpu"):
tensor_num = fill_constant([1], 'int32', num)
if _in_legacy_dygraph():
return _C_ops.linspace(tensor_start, tensor_stop, tensor_num, 'dtype',
dtype)
if in_dygraph_mode():
return _C_ops.final_state_linspace(tensor_start, tensor_stop,
tensor_num, dtype)
if _in_legacy_dygraph():
return _C_ops.linspace(tensor_start, tensor_stop, tensor_num, 'dtype',
dtype)
helper = LayerHelper("linspace", **locals())
start_dtype = convert_dtype(tensor_start.dtype)
......
......@@ -1615,7 +1615,12 @@ def adaptive_max_pool1d(x, output_size, return_mask=False, name=None):
pool_size = [1] + utils.convert_to_list(output_size, 1, 'pool_size')
x = unsqueeze(x, [2])
if in_dynamic_mode():
if in_dygraph_mode():
pool_out = _C_ops.final_state_max_pool2d_with_index(
x, pool_size, [1, 1], [0, 0], False, True)
return (squeeze(pool_out[0], [2]), squeeze(
pool_out[1], [2])) if return_mask else squeeze(pool_out[0], [2])
if _in_legacy_dygraph():
pool_out = _C_ops.max_pool2d_with_index(x, 'pooling_type', pool_type,
'ksize', pool_size, 'adaptive',
True)
......@@ -1703,8 +1708,11 @@ def adaptive_max_pool2d(x, output_size, return_mask=False, name=None):
output_size[0] = in_h
if output_size[1] == None:
output_size[1] = in_w
if in_dynamic_mode():
if in_dygraph_mode():
pool_out = _C_ops.final_state_max_pool2d_with_index(
x, output_size, [1, 1], [0, 0], False, True)
return pool_out if return_mask else pool_out[0]
if _in_legacy_dygraph():
pool_out = _C_ops.max_pool2d_with_index(x, 'pooling_type', 'max',
'ksize', output_size,
'adaptive', True)
......
......@@ -98,7 +98,10 @@ def linspace(start, stop, num, dtype=None, name=None):
if not isinstance(num, Variable):
with device_guard("cpu"):
tensor_num = fill_constant([1], 'int32', num, force_cpu=True)
if _non_static_mode():
if in_dygraph_mode():
return _C_ops.final_state_linspace(tensor_start, tensor_stop,
tensor_num, dtype)
if _in_legacy_dygraph():
return _C_ops.linspace(tensor_start, tensor_stop, tensor_num, 'dtype',
dtype)
......@@ -1162,7 +1165,14 @@ def diagflat(x, offset=0, name=None):
# [0 0 0 4 0]]
"""
padding_value = 0
if paddle.in_dynamic_mode():
if in_dygraph_mode():
if len(x.shape) == 1:
return _C_ops.final_state_diag(x, offset, padding_value)
else:
y = _C_ops.final_state_flatten(x, 0, -1)
return _C_ops.final_state_diag(y, offset, padding_value)
if _in_legacy_dygraph():
if len(x.shape) == 1:
return _C_ops.diag_v2(x, "offset", offset, "padding_value",
padding_value)
......@@ -1370,7 +1380,14 @@ def empty(shape, dtype=None, name=None):
dtype = convert_dtype(dtype)
if paddle.in_dynamic_mode():
if in_dygraph_mode():
shape = utils.convert_shape_to_list(shape)
out = _C_ops.final_state_empty(shape, convert_np_dtype_to_dtype_(dtype),
_current_expected_place())
out.stop_gradient = True
return out
if _in_legacy_dygraph():
shape = utils.convert_shape_to_list(shape)
out = _C_ops.empty('shape', shape, 'dtype',
convert_np_dtype_to_dtype_(dtype))
......@@ -1437,7 +1454,14 @@ def empty_like(x, dtype=None, name=None):
dtype = x.dtype
dtype = convert_dtype(dtype)
if paddle.in_dynamic_mode():
if in_dygraph_mode():
out = _C_ops.final_state_empty(x.shape,
convert_np_dtype_to_dtype_(dtype),
_current_expected_place())
out.stop_gradient = True
return out
if _in_legacy_dygraph():
out = _C_ops.empty('shape', x.shape, 'dtype',
convert_np_dtype_to_dtype_(dtype))
out.stop_gradient = True
......
......@@ -420,6 +420,17 @@ def norm(x, p='fro', axis=None, keepdim=False, name=None):
keepdim=False,
asvector=False,
name=None):
if in_dygraph_mode():
out = _C_ops.final_state_abs(input)
reduce_all = True if axis == None or axis == [] or asvector == True else False
axis = axis if axis != None and axis != [] else [0]
if reduce_all:
assert (axis == []) or (axis is None)
if porder == np.float64('inf'):
return _C_ops.final_state_max(out, axis, keepdim)
else:
return _C_ops.final_state_min(out, axis, keepdim)
helper = LayerHelper('inf_norm', **locals())
out = helper.create_variable_for_type_inference(
dtype=helper.input_dtype())
......@@ -448,6 +459,13 @@ def norm(x, p='fro', axis=None, keepdim=False, name=None):
NOTE:
This function actually treats the matrix as flattened vector to calculate vector norm instead of matrix norm.
"""
if in_dygraph_mode():
abs_out = _C_ops.final_state_abs(input)
pow_out = _C_ops.final_state_pow(abs_out, porder)
sum_out = _C_ops.final_state_sum(pow_out, axis, None, keepdim)
out = _C_ops.final_state_pow(sum_out, float(1. / porder))
return out
block = LayerHelper('norm', **locals())
out = block.create_variable_for_type_inference(
dtype=block.input_dtype())
......@@ -2588,8 +2606,55 @@ def pinv(x, rcond=1e-15, hermitian=False, name=None):
# one can verify : x * out * x = x ;
# or out * x * out = x ;
"""
if in_dygraph_mode():
if not hermitian:
# combine svd and matmul op
u, s, vt = _C_ops.final_state_svd(x, False)
max_singular_val = _C_ops.final_state_max(s, [-1], True)
rcond = paddle.to_tensor(rcond, dtype=x.dtype)
cutoff = rcond * max_singular_val
y = float('inf')
y = paddle.to_tensor(y, dtype=x.dtype)
if _non_static_mode():
condition = s > cutoff
cond_int = cast(condition, s.dtype)
cond_not_int = cast(logical_not(condition), s.dtype)
out1 = multiply(1 / s, cond_int)
out2 = multiply(1 / y, cond_not_int)
singular = add(out1, out2)
st = _C_ops.final_state_unsqueeze(singular, [-2])
dims = list(range(len(vt.shape)))
perm = dims[:-2] + [dims[-1]] + [dims[-2]]
v = _C_ops.final_state_transpose(vt, perm)
out_1 = v * st
out_2 = _C_ops.final_state_matmul(out_1, u, False, True)
return out_2
else:
# combine eigh and matmul op
s, u = _C_ops.final_state_eigh(x, 'UPLO')
s_abs = paddle.abs(s)
max_singular_val = _C_ops.final_state_max(s_abs, [-1], True)
rcond = paddle.to_tensor(rcond, dtype=s.dtype)
cutoff = rcond * max_singular_val
y = float('inf')
y = paddle.to_tensor(y, dtype=s.dtype)
condition = s_abs > cutoff
cond_int = cast(condition, s.dtype)
cond_not_int = cast(logical_not(condition), s.dtype)
out1 = multiply(1 / s, cond_int)
out2 = multiply(1 / y, cond_not_int)
singular = add(out1, out2)
st = _C_ops.final_state_unsqueeze(singular, [-2])
out_1 = u * st
u_conj = _C_ops.final_state_conj(u)
out_2 = _C_ops.final_state_matmul(out_1, u_conj, False, True)
return out_2
if _in_legacy_dygraph():
if not hermitian:
# combine svd and matmul op
u, s, vt = _C_ops.svd(x, 'full_matrices', False)
......
......@@ -4123,7 +4123,11 @@ def moveaxis(x, source, destination, name=None):
for i in range(len(src_dims)):
perm[dst_dims[i]] = src_dims[i]
if paddle.in_dynamic_mode():
if in_dygraph_mode():
out = _C_ops.final_state_transpose(x, perm)
return out
if _in_legacy_dygraph():
out, _ = _C_ops.transpose2(x, 'axis', perm)
return out
......
......@@ -19,6 +19,8 @@ from ..framework import core
from ..framework import convert_np_dtype_to_dtype_
from ..static import Variable
from ..fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype
from ..fluid.framework import in_dygraph_mode
from .. import _C_ops
__deprecated_func_name__ = {
'tanh_shrink': 'tanhshrink',
......@@ -510,6 +512,9 @@ _erf_ = generate_layer_fn('erf')
def erf(x, name=None):
if in_dygraph_mode():
return _C_ops.final_state_erf(x)
locals_var = locals().copy()
kwargs = dict()
for name, val in locals_var.items():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册