未验证 提交 5949f2d7 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] optimize same python api logic (#49473)

* [Eager] optimize same python api logic

* optimize full api

* optimize logic

* optimize logic
上级 1221307b
......@@ -1440,12 +1440,12 @@ def fft_c2c(x, n, axis, norm, forward, name):
_check_fft_n(n)
s = [n]
x = _resize_fft_input(x, s, axes)
op_type = 'fft_c2c'
check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], op_type)
if in_dygraph_mode():
out = _C_ops.fft_c2c(x, axes, norm, forward)
else:
op_type = 'fft_c2c'
check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], op_type)
inputs = {
'X': [x],
}
......@@ -1472,12 +1472,13 @@ def fft_r2c(x, n, axis, norm, forward, onesided, name):
_check_fft_n(n)
s = [n]
x = _resize_fft_input(x, s, axes)
op_type = 'fft_r2c'
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], op_type)
if in_dygraph_mode():
out = _C_ops.fft_r2c(x, axes, norm, forward, onesided)
else:
op_type = 'fft_r2c'
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], op_type
)
inputs = {
'X': [x],
}
......@@ -1513,8 +1514,6 @@ def fft_c2r(x, n, axis, norm, forward, name):
_check_fft_n(n)
s = [n // 2 + 1]
x = _resize_fft_input(x, s, axes)
op_type = 'fft_c2r'
check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], op_type)
if in_dygraph_mode():
if n is not None:
......@@ -1522,6 +1521,8 @@ def fft_c2r(x, n, axis, norm, forward, name):
else:
out = _C_ops.fft_c2r(x, axes, norm, forward, 0)
else:
op_type = 'fft_c2r'
check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], op_type)
inputs = {
'X': [x],
}
......@@ -1572,12 +1573,12 @@ def fftn_c2c(x, s, axes, norm, forward, name):
if s is not None:
x = _resize_fft_input(x, s, axes)
op_type = 'fft_c2c'
check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], op_type)
if in_dygraph_mode():
out = _C_ops.fft_c2c(x, axes, norm, forward)
else:
op_type = 'fft_c2c'
check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], op_type)
inputs = {
'X': [x],
}
......@@ -1623,12 +1624,13 @@ def fftn_r2c(x, s, axes, norm, forward, onesided, name):
if s is not None:
x = _resize_fft_input(x, s, axes)
op_type = 'fft_r2c'
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], op_type)
if in_dygraph_mode():
out = _C_ops.fft_r2c(x, axes, norm, forward, onesided)
else:
op_type = 'fft_r2c'
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], op_type
)
inputs = {
'X': [x],
}
......@@ -1686,15 +1688,14 @@ def fftn_c2r(x, s, axes, norm, forward, name):
fft_input_shape[-1] = fft_input_shape[-1] // 2 + 1
x = _resize_fft_input(x, fft_input_shape, axes)
op_type = 'fft_c2r'
check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], op_type)
if in_dygraph_mode():
if s is not None:
out = _C_ops.fft_c2r(x, axes, norm, forward, s[-1])
else:
out = _C_ops.fft_c2r(x, axes, norm, forward, 0)
else:
op_type = 'fft_c2r'
check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], op_type)
inputs = {
'X': [x],
}
......
......@@ -530,16 +530,6 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
data5 = fluid.layers.fill_constant(shape=[2,1], value=val, dtype='float32') #data5=[[2.0],[2.0]]
"""
attrs = {'force_cpu': force_cpu}
dtype = convert_dtype(dtype)
if not isinstance(value, Variable):
if dtype in ['uint8', 'int16', 'int32', 'int64']:
attrs['str_value'] = str(int(value))
attrs['value'] = int(value)
else:
attrs['str_value'] = str(float(value))
attrs['value'] = float(value)
if in_dygraph_mode():
place = _current_expected_place()
if force_cpu:
......@@ -561,6 +551,16 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
out.stop_gradient = True
return out
else:
attrs = {'force_cpu': force_cpu}
dtype = convert_dtype(dtype)
if not isinstance(value, Variable):
if dtype in ['uint8', 'int16', 'int32', 'int64']:
attrs['str_value'] = str(int(value))
attrs['value'] = int(value)
else:
attrs['str_value'] = str(float(value))
attrs['value'] = float(value)
helper = LayerHelper("fill_constant", **locals())
inputs = {}
if isinstance(value, Variable):
......
......@@ -1026,8 +1026,8 @@ def eye(num_rows, num_columns=None, dtype=None, name=None):
_check_attr(num_rows, "num_rows")
if dtype is None:
dtype = 'float32'
if not isinstance(dtype, core.VarDesc.VarType):
dtype = core.VarDesc.VarType.FP32
elif not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
if num_columns is not None:
_check_attr(num_columns, "num_columns")
......@@ -1181,14 +1181,6 @@ def arange(start=0, end=None, step=1, dtype=None, name=None):
end = start
start = 0
out_shape = None
if (
not isinstance(start, Variable)
and not isinstance(end, Variable)
and not isinstance(step, Variable)
):
out_shape = [int(math.ceil((end - start) / step))]
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
......@@ -1220,6 +1212,13 @@ def arange(start=0, end=None, step=1, dtype=None, name=None):
'range/arange',
)
helper = LayerHelper('range', **locals())
out_shape = None
if (
not isinstance(start, Variable)
and not isinstance(end, Variable)
and not isinstance(step, Variable)
):
out_shape = [int(math.ceil((end - start) / step))]
out = helper.create_variable_for_type_inference(dtype, shape=out_shape)
helper.append_op(
type='range',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册