未验证 提交 1712e212 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] polish several ops (#49612)

* [Eager] polish several ops

* rm useless code
上级 a6bd6957
...@@ -68,9 +68,6 @@ def pairwise_distance(x, y, p=2.0, epsilon=1e-6, keepdim=False, name=None): ...@@ -68,9 +68,6 @@ def pairwise_distance(x, y, p=2.0, epsilon=1e-6, keepdim=False, name=None):
# [4.99999860, 4.99999860]) # [4.99999860, 4.99999860])
""" """
check_type(p, 'porder', (float, int), 'PairwiseDistance')
check_type(epsilon, 'epsilon', (float), 'PairwiseDistance')
check_type(keepdim, 'keepdim', (bool), 'PairwiseDistance')
if in_dygraph_mode(): if in_dygraph_mode():
sub = _C_ops.subtract(x, y) sub = _C_ops.subtract(x, y)
# p_norm op has not uesd epsilon, so change it to the following. # p_norm op has not uesd epsilon, so change it to the following.
...@@ -82,6 +79,10 @@ def pairwise_distance(x, y, p=2.0, epsilon=1e-6, keepdim=False, name=None): ...@@ -82,6 +79,10 @@ def pairwise_distance(x, y, p=2.0, epsilon=1e-6, keepdim=False, name=None):
return _C_ops.p_norm(sub, p, -1, 0.0, keepdim, False) return _C_ops.p_norm(sub, p, -1, 0.0, keepdim, False)
else: else:
check_type(p, 'porder', (float, int), 'PairwiseDistance')
check_type(epsilon, 'epsilon', (float), 'PairwiseDistance')
check_type(keepdim, 'keepdim', (bool), 'PairwiseDistance')
check_variable_and_dtype( check_variable_and_dtype(
x, 'x', ['float32', 'float64'], 'PairwiseDistance' x, 'x', ['float32', 'float64'], 'PairwiseDistance'
) )
......
...@@ -307,9 +307,6 @@ def stft( ...@@ -307,9 +307,6 @@ def stft(
y1 = stft(x, n_fft=512, center=False, onesided=False) # [8, 512, 372] y1 = stft(x, n_fft=512, center=False, onesided=False) # [8, 512, 372]
""" """
check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'complex64', 'complex128'], 'stft'
)
x_rank = len(x.shape) x_rank = len(x.shape)
assert x_rank in [ assert x_rank in [
......
...@@ -873,7 +873,7 @@ def ones(shape, dtype=None, name=None): ...@@ -873,7 +873,7 @@ def ones(shape, dtype=None, name=None):
# [1. 1.]] # [1. 1.]]
""" """
if dtype is None: if dtype is None:
dtype = 'float32' dtype = core.VarDesc.VarType.FP32
return fill_constant(value=1.0, shape=shape, dtype=dtype, name=name) return fill_constant(value=1.0, shape=shape, dtype=dtype, name=name)
......
...@@ -1541,14 +1541,6 @@ def flatten(x, start_axis=0, stop_axis=-1, name=None): ...@@ -1541,14 +1541,6 @@ def flatten(x, start_axis=0, stop_axis=-1, name=None):
if not (isinstance(x, Variable)): if not (isinstance(x, Variable)):
raise ValueError("The input x should be a Tensor") raise ValueError("The input x should be a Tensor")
if not paddle.in_dynamic_mode():
check_variable_and_dtype(
x,
'x',
['float32', 'float64', 'int8', 'int16', 'int32', 'int64', 'uint8'],
'flatten',
)
x_dim = len(x.shape) x_dim = len(x.shape)
if x_dim == 0: if x_dim == 0:
if not (isinstance(start_axis, int)) or start_axis not in [0, -1]: if not (isinstance(start_axis, int)) or start_axis not in [0, -1]:
...@@ -1586,6 +1578,12 @@ def flatten(x, start_axis=0, stop_axis=-1, name=None): ...@@ -1586,6 +1578,12 @@ def flatten(x, start_axis=0, stop_axis=-1, name=None):
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.flatten(x, start_axis, stop_axis) return _C_ops.flatten(x, start_axis, stop_axis)
else: else:
check_variable_and_dtype(
x,
'x',
['float32', 'float64', 'int8', 'int16', 'int32', 'int64', 'uint8'],
'flatten',
)
helper = LayerHelper('flatten', **locals()) helper = LayerHelper('flatten', **locals())
out = helper.create_variable_for_type_inference(x.dtype) out = helper.create_variable_for_type_inference(x.dtype)
x_shape = helper.create_variable_for_type_inference(x.dtype) x_shape = helper.create_variable_for_type_inference(x.dtype)
......
...@@ -213,6 +213,8 @@ def scale(x, scale=1.0, bias=0.0, bias_after_scale=True, act=None, name=None): ...@@ -213,6 +213,8 @@ def scale(x, scale=1.0, bias=0.0, bias_after_scale=True, act=None, name=None):
""" """
if in_dygraph_mode(): if in_dygraph_mode():
if act is None:
return _C_ops.scale(x, scale, float(bias), bias_after_scale)
out = _C_ops.scale(x, scale, float(bias), bias_after_scale) out = _C_ops.scale(x, scale, float(bias), bias_after_scale)
return dygraph_utils._append_activation_in_dygraph(out, act) return dygraph_utils._append_activation_in_dygraph(out, act)
else: else:
...@@ -495,9 +497,12 @@ def _elementwise_op_in_dygraph( ...@@ -495,9 +497,12 @@ def _elementwise_op_in_dygraph(
OP_NAMEMAPPING[op_name] if not is_inplace(op_name) else op_name, OP_NAMEMAPPING[op_name] if not is_inplace(op_name) else op_name,
) )
out = op(x, y) out = op(x, y)
return dygraph_utils._append_activation_in_dygraph( if act is None:
out, act, use_mkldnn=use_mkldnn return out
) else:
return dygraph_utils._append_activation_in_dygraph(
out, act, use_mkldnn=use_mkldnn
)
def _elementwise_op(helper): def _elementwise_op(helper):
...@@ -4209,7 +4214,6 @@ def lerp(x, y, weight, name=None): ...@@ -4209,7 +4214,6 @@ def lerp(x, y, weight, name=None):
""" """
if in_dygraph_mode(): if in_dygraph_mode():
check_type(weight, 'weight', (float, paddle.Tensor, Variable), 'lerp')
if isinstance(weight, float): if isinstance(weight, float):
weight = paddle.to_tensor(weight, dtype=x.dtype) weight = paddle.to_tensor(weight, dtype=x.dtype)
......
...@@ -789,8 +789,8 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None): ...@@ -789,8 +789,8 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None):
high = low high = low
low = 0 low = 0
if dtype is None: if dtype is None:
dtype = 'int64' dtype = core.VarDesc.VarType.INT64
if not isinstance(dtype, core.VarDesc.VarType): elif not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype) dtype = convert_np_dtype_to_dtype_(dtype)
if in_dygraph_mode(): if in_dygraph_mode():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册