未验证 提交 5949f2d7 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] optimize same python api logic (#49473)

* [Eager] optimize same python api logic

* optimize full api

* optimize logic

* optimize logic
上级 1221307b
...@@ -1440,12 +1440,12 @@ def fft_c2c(x, n, axis, norm, forward, name): ...@@ -1440,12 +1440,12 @@ def fft_c2c(x, n, axis, norm, forward, name):
_check_fft_n(n) _check_fft_n(n)
s = [n] s = [n]
x = _resize_fft_input(x, s, axes) x = _resize_fft_input(x, s, axes)
op_type = 'fft_c2c'
check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], op_type)
if in_dygraph_mode(): if in_dygraph_mode():
out = _C_ops.fft_c2c(x, axes, norm, forward) out = _C_ops.fft_c2c(x, axes, norm, forward)
else: else:
op_type = 'fft_c2c'
check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], op_type)
inputs = { inputs = {
'X': [x], 'X': [x],
} }
...@@ -1472,12 +1472,13 @@ def fft_r2c(x, n, axis, norm, forward, onesided, name): ...@@ -1472,12 +1472,13 @@ def fft_r2c(x, n, axis, norm, forward, onesided, name):
_check_fft_n(n) _check_fft_n(n)
s = [n] s = [n]
x = _resize_fft_input(x, s, axes) x = _resize_fft_input(x, s, axes)
op_type = 'fft_r2c'
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], op_type)
if in_dygraph_mode(): if in_dygraph_mode():
out = _C_ops.fft_r2c(x, axes, norm, forward, onesided) out = _C_ops.fft_r2c(x, axes, norm, forward, onesided)
else: else:
op_type = 'fft_r2c'
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], op_type
)
inputs = { inputs = {
'X': [x], 'X': [x],
} }
...@@ -1513,8 +1514,6 @@ def fft_c2r(x, n, axis, norm, forward, name): ...@@ -1513,8 +1514,6 @@ def fft_c2r(x, n, axis, norm, forward, name):
_check_fft_n(n) _check_fft_n(n)
s = [n // 2 + 1] s = [n // 2 + 1]
x = _resize_fft_input(x, s, axes) x = _resize_fft_input(x, s, axes)
op_type = 'fft_c2r'
check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], op_type)
if in_dygraph_mode(): if in_dygraph_mode():
if n is not None: if n is not None:
...@@ -1522,6 +1521,8 @@ def fft_c2r(x, n, axis, norm, forward, name): ...@@ -1522,6 +1521,8 @@ def fft_c2r(x, n, axis, norm, forward, name):
else: else:
out = _C_ops.fft_c2r(x, axes, norm, forward, 0) out = _C_ops.fft_c2r(x, axes, norm, forward, 0)
else: else:
op_type = 'fft_c2r'
check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], op_type)
inputs = { inputs = {
'X': [x], 'X': [x],
} }
...@@ -1572,12 +1573,12 @@ def fftn_c2c(x, s, axes, norm, forward, name): ...@@ -1572,12 +1573,12 @@ def fftn_c2c(x, s, axes, norm, forward, name):
if s is not None: if s is not None:
x = _resize_fft_input(x, s, axes) x = _resize_fft_input(x, s, axes)
op_type = 'fft_c2c'
check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], op_type)
if in_dygraph_mode(): if in_dygraph_mode():
out = _C_ops.fft_c2c(x, axes, norm, forward) out = _C_ops.fft_c2c(x, axes, norm, forward)
else: else:
op_type = 'fft_c2c'
check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], op_type)
inputs = { inputs = {
'X': [x], 'X': [x],
} }
...@@ -1623,12 +1624,13 @@ def fftn_r2c(x, s, axes, norm, forward, onesided, name): ...@@ -1623,12 +1624,13 @@ def fftn_r2c(x, s, axes, norm, forward, onesided, name):
if s is not None: if s is not None:
x = _resize_fft_input(x, s, axes) x = _resize_fft_input(x, s, axes)
op_type = 'fft_r2c'
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], op_type)
if in_dygraph_mode(): if in_dygraph_mode():
out = _C_ops.fft_r2c(x, axes, norm, forward, onesided) out = _C_ops.fft_r2c(x, axes, norm, forward, onesided)
else: else:
op_type = 'fft_r2c'
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], op_type
)
inputs = { inputs = {
'X': [x], 'X': [x],
} }
...@@ -1686,15 +1688,14 @@ def fftn_c2r(x, s, axes, norm, forward, name): ...@@ -1686,15 +1688,14 @@ def fftn_c2r(x, s, axes, norm, forward, name):
fft_input_shape[-1] = fft_input_shape[-1] // 2 + 1 fft_input_shape[-1] = fft_input_shape[-1] // 2 + 1
x = _resize_fft_input(x, fft_input_shape, axes) x = _resize_fft_input(x, fft_input_shape, axes)
op_type = 'fft_c2r'
check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], op_type)
if in_dygraph_mode(): if in_dygraph_mode():
if s is not None: if s is not None:
out = _C_ops.fft_c2r(x, axes, norm, forward, s[-1]) out = _C_ops.fft_c2r(x, axes, norm, forward, s[-1])
else: else:
out = _C_ops.fft_c2r(x, axes, norm, forward, 0) out = _C_ops.fft_c2r(x, axes, norm, forward, 0)
else: else:
op_type = 'fft_c2r'
check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], op_type)
inputs = { inputs = {
'X': [x], 'X': [x],
} }
......
...@@ -530,16 +530,6 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): ...@@ -530,16 +530,6 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
data5 = fluid.layers.fill_constant(shape=[2,1], value=val, dtype='float32') #data5=[[2.0],[2.0]] data5 = fluid.layers.fill_constant(shape=[2,1], value=val, dtype='float32') #data5=[[2.0],[2.0]]
""" """
attrs = {'force_cpu': force_cpu}
dtype = convert_dtype(dtype)
if not isinstance(value, Variable):
if dtype in ['uint8', 'int16', 'int32', 'int64']:
attrs['str_value'] = str(int(value))
attrs['value'] = int(value)
else:
attrs['str_value'] = str(float(value))
attrs['value'] = float(value)
if in_dygraph_mode(): if in_dygraph_mode():
place = _current_expected_place() place = _current_expected_place()
if force_cpu: if force_cpu:
...@@ -561,6 +551,16 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): ...@@ -561,6 +551,16 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
out.stop_gradient = True out.stop_gradient = True
return out return out
else: else:
attrs = {'force_cpu': force_cpu}
dtype = convert_dtype(dtype)
if not isinstance(value, Variable):
if dtype in ['uint8', 'int16', 'int32', 'int64']:
attrs['str_value'] = str(int(value))
attrs['value'] = int(value)
else:
attrs['str_value'] = str(float(value))
attrs['value'] = float(value)
helper = LayerHelper("fill_constant", **locals()) helper = LayerHelper("fill_constant", **locals())
inputs = {} inputs = {}
if isinstance(value, Variable): if isinstance(value, Variable):
......
...@@ -1026,8 +1026,8 @@ def eye(num_rows, num_columns=None, dtype=None, name=None): ...@@ -1026,8 +1026,8 @@ def eye(num_rows, num_columns=None, dtype=None, name=None):
_check_attr(num_rows, "num_rows") _check_attr(num_rows, "num_rows")
if dtype is None: if dtype is None:
dtype = 'float32' dtype = core.VarDesc.VarType.FP32
if not isinstance(dtype, core.VarDesc.VarType): elif not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype) dtype = convert_np_dtype_to_dtype_(dtype)
if num_columns is not None: if num_columns is not None:
_check_attr(num_columns, "num_columns") _check_attr(num_columns, "num_columns")
...@@ -1181,14 +1181,6 @@ def arange(start=0, end=None, step=1, dtype=None, name=None): ...@@ -1181,14 +1181,6 @@ def arange(start=0, end=None, step=1, dtype=None, name=None):
end = start end = start
start = 0 start = 0
out_shape = None
if (
not isinstance(start, Variable)
and not isinstance(end, Variable)
and not isinstance(step, Variable)
):
out_shape = [int(math.ceil((end - start) / step))]
if not isinstance(dtype, core.VarDesc.VarType): if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype) dtype = convert_np_dtype_to_dtype_(dtype)
...@@ -1220,6 +1212,13 @@ def arange(start=0, end=None, step=1, dtype=None, name=None): ...@@ -1220,6 +1212,13 @@ def arange(start=0, end=None, step=1, dtype=None, name=None):
'range/arange', 'range/arange',
) )
helper = LayerHelper('range', **locals()) helper = LayerHelper('range', **locals())
out_shape = None
if (
not isinstance(start, Variable)
and not isinstance(end, Variable)
and not isinstance(step, Variable)
):
out_shape = [int(math.ceil((end - start) / step))]
out = helper.create_variable_for_type_inference(dtype, shape=out_shape) out = helper.create_variable_for_type_inference(dtype, shape=out_shape)
helper.append_op( helper.append_op(
type='range', type='range',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册