未验证 提交 1712e212 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] polish several ops (#49612)

* [Eager] polish several ops

* rm useless code
上级 a6bd6957
......@@ -68,9 +68,6 @@ def pairwise_distance(x, y, p=2.0, epsilon=1e-6, keepdim=False, name=None):
# [4.99999860, 4.99999860])
"""
check_type(p, 'porder', (float, int), 'PairwiseDistance')
check_type(epsilon, 'epsilon', (float), 'PairwiseDistance')
check_type(keepdim, 'keepdim', (bool), 'PairwiseDistance')
if in_dygraph_mode():
sub = _C_ops.subtract(x, y)
# p_norm op has not uesd epsilon, so change it to the following.
......@@ -82,6 +79,10 @@ def pairwise_distance(x, y, p=2.0, epsilon=1e-6, keepdim=False, name=None):
return _C_ops.p_norm(sub, p, -1, 0.0, keepdim, False)
else:
check_type(p, 'porder', (float, int), 'PairwiseDistance')
check_type(epsilon, 'epsilon', (float), 'PairwiseDistance')
check_type(keepdim, 'keepdim', (bool), 'PairwiseDistance')
check_variable_and_dtype(
x, 'x', ['float32', 'float64'], 'PairwiseDistance'
)
......
......@@ -307,9 +307,6 @@ def stft(
y1 = stft(x, n_fft=512, center=False, onesided=False) # [8, 512, 372]
"""
check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'complex64', 'complex128'], 'stft'
)
x_rank = len(x.shape)
assert x_rank in [
......
......@@ -873,7 +873,7 @@ def ones(shape, dtype=None, name=None):
# [1. 1.]]
"""
if dtype is None:
dtype = 'float32'
dtype = core.VarDesc.VarType.FP32
return fill_constant(value=1.0, shape=shape, dtype=dtype, name=name)
......
......@@ -1541,14 +1541,6 @@ def flatten(x, start_axis=0, stop_axis=-1, name=None):
if not (isinstance(x, Variable)):
raise ValueError("The input x should be a Tensor")
if not paddle.in_dynamic_mode():
check_variable_and_dtype(
x,
'x',
['float32', 'float64', 'int8', 'int16', 'int32', 'int64', 'uint8'],
'flatten',
)
x_dim = len(x.shape)
if x_dim == 0:
if not (isinstance(start_axis, int)) or start_axis not in [0, -1]:
......@@ -1586,6 +1578,12 @@ def flatten(x, start_axis=0, stop_axis=-1, name=None):
if in_dygraph_mode():
return _C_ops.flatten(x, start_axis, stop_axis)
else:
check_variable_and_dtype(
x,
'x',
['float32', 'float64', 'int8', 'int16', 'int32', 'int64', 'uint8'],
'flatten',
)
helper = LayerHelper('flatten', **locals())
out = helper.create_variable_for_type_inference(x.dtype)
x_shape = helper.create_variable_for_type_inference(x.dtype)
......
......@@ -213,6 +213,8 @@ def scale(x, scale=1.0, bias=0.0, bias_after_scale=True, act=None, name=None):
"""
if in_dygraph_mode():
if act is None:
return _C_ops.scale(x, scale, float(bias), bias_after_scale)
out = _C_ops.scale(x, scale, float(bias), bias_after_scale)
return dygraph_utils._append_activation_in_dygraph(out, act)
else:
......@@ -495,6 +497,9 @@ def _elementwise_op_in_dygraph(
OP_NAMEMAPPING[op_name] if not is_inplace(op_name) else op_name,
)
out = op(x, y)
if act is None:
return out
else:
return dygraph_utils._append_activation_in_dygraph(
out, act, use_mkldnn=use_mkldnn
)
......@@ -4209,7 +4214,6 @@ def lerp(x, y, weight, name=None):
"""
if in_dygraph_mode():
check_type(weight, 'weight', (float, paddle.Tensor, Variable), 'lerp')
if isinstance(weight, float):
weight = paddle.to_tensor(weight, dtype=x.dtype)
......
......@@ -789,8 +789,8 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None):
high = low
low = 0
if dtype is None:
dtype = 'int64'
if not isinstance(dtype, core.VarDesc.VarType):
dtype = core.VarDesc.VarType.INT64
elif not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
if in_dygraph_mode():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册