未验证 提交 91af6df8 编写于 作者: V Vvsmile 提交者: GitHub

[Clean Fluid API]Remove API: log (#47966)

* replace log with paddle.log

* replace log with paddle.nn.functional.log

* fix the code style of remove_log

* fix the ImportError of log

* fix the error of modification of the dist_transformer.py

* fix error of Static-Check
上级 7c903ae7
...@@ -242,7 +242,7 @@ class Normal(distribution.Distribution): ...@@ -242,7 +242,7 @@ class Normal(distribution.Distribution):
) )
return paddle.add( return paddle.add(
0.5 + zero_tmp, 0.5 + zero_tmp,
0.5 * math.log(2 * math.pi) + nn.log((self.scale + zero_tmp)), 0.5 * math.log(2 * math.pi) + paddle.log((self.scale + zero_tmp)),
name=name, name=name,
) )
...@@ -260,7 +260,7 @@ class Normal(distribution.Distribution): ...@@ -260,7 +260,7 @@ class Normal(distribution.Distribution):
value = self._check_values_dtype_in_probs(self.loc, value) value = self._check_values_dtype_in_probs(self.loc, value)
var = self.scale * self.scale var = self.scale * self.scale
log_scale = nn.log(self.scale) log_scale = paddle.log(self.scale)
return paddle.subtract( return paddle.subtract(
-1.0 * ((value - self.loc) * (value - self.loc)) / (2.0 * var), -1.0 * ((value - self.loc) * (value - self.loc)) / (2.0 * var),
log_scale + math.log(math.sqrt(2.0 * math.pi)), log_scale + math.log(math.sqrt(2.0 * math.pi)),
...@@ -331,5 +331,5 @@ class Normal(distribution.Distribution): ...@@ -331,5 +331,5 @@ class Normal(distribution.Distribution):
t1 = (self.loc - other.loc) / other.scale t1 = (self.loc - other.loc) / other.scale
t1 = t1 * t1 t1 = t1 * t1
return paddle.add( return paddle.add(
0.5 * var_ratio, 0.5 * (t1 - 1.0 - nn.log(var_ratio)), name=name 0.5 * var_ratio, 0.5 * (t1 - 1.0 - paddle.log(var_ratio)), name=name
) )
...@@ -27,6 +27,8 @@ from paddle.fluid.layers import ( ...@@ -27,6 +27,8 @@ from paddle.fluid.layers import (
nn, nn,
tensor, tensor,
) )
import paddle
from paddle.tensor import random from paddle.tensor import random
...@@ -216,7 +218,7 @@ class Uniform(distribution.Distribution): ...@@ -216,7 +218,7 @@ class Uniform(distribution.Distribution):
if in_dygraph_mode(): if in_dygraph_mode():
lb = _C_ops.cast(lb_bool, value.dtype) lb = _C_ops.cast(lb_bool, value.dtype)
ub = _C_ops.cast(ub_bool, value.dtype) ub = _C_ops.cast(ub_bool, value.dtype)
return nn.log(lb * ub) - nn.log(self.high - self.low) return paddle.log(lb * ub) - paddle.log(self.high - self.low)
if _in_legacy_dygraph(): if _in_legacy_dygraph():
lb = _legacy_C_ops.cast( lb = _legacy_C_ops.cast(
...@@ -225,7 +227,7 @@ class Uniform(distribution.Distribution): ...@@ -225,7 +227,7 @@ class Uniform(distribution.Distribution):
ub = _legacy_C_ops.cast( ub = _legacy_C_ops.cast(
ub_bool, 'in_dtype', ub_bool.dtype, 'out_dtype', value.dtype ub_bool, 'in_dtype', ub_bool.dtype, 'out_dtype', value.dtype
) )
return nn.log(lb * ub) - nn.log(self.high - self.low) return paddle.log(lb * ub) - paddle.log(self.high - self.low)
name = self.name + '_log_prob' name = self.name + '_log_prob'
lb_bool = self.low < value lb_bool = self.low < value
...@@ -233,7 +235,7 @@ class Uniform(distribution.Distribution): ...@@ -233,7 +235,7 @@ class Uniform(distribution.Distribution):
lb = tensor.cast(lb_bool, dtype=value.dtype) lb = tensor.cast(lb_bool, dtype=value.dtype)
ub = tensor.cast(ub_bool, dtype=value.dtype) ub = tensor.cast(ub_bool, dtype=value.dtype)
return paddle.subtract( return paddle.subtract(
nn.log(lb * ub), nn.log(self.high - self.low), name=name paddle.log(lb * ub), paddle.log(self.high - self.low), name=name
) )
def probs(self, value): def probs(self, value):
...@@ -286,4 +288,4 @@ class Uniform(distribution.Distribution): ...@@ -286,4 +288,4 @@ class Uniform(distribution.Distribution):
""" """
name = self.name + '_entropy' name = self.name + '_entropy'
return nn.log(self.high - self.low, name=name) return paddle.log(self.high - self.low, name=name)
...@@ -264,7 +264,7 @@ class Uniform(Distribution): ...@@ -264,7 +264,7 @@ class Uniform(Distribution):
ub_bool = control_flow.less_than(value, self.high) ub_bool = control_flow.less_than(value, self.high)
lb = tensor.cast(lb_bool, dtype=value.dtype) lb = tensor.cast(lb_bool, dtype=value.dtype)
ub = tensor.cast(ub_bool, dtype=value.dtype) ub = tensor.cast(ub_bool, dtype=value.dtype)
return nn.log(lb * ub) - nn.log(self.high - self.low) return paddle.log(lb * ub) - paddle.log(self.high - self.low)
def entropy(self): def entropy(self):
"""Shannon entropy in nats. """Shannon entropy in nats.
...@@ -273,7 +273,7 @@ class Uniform(Distribution): ...@@ -273,7 +273,7 @@ class Uniform(Distribution):
Variable: Shannon entropy of uniform distribution.The data type is float32. Variable: Shannon entropy of uniform distribution.The data type is float32.
""" """
return nn.log(self.high - self.low) return paddle.log(self.high - self.low)
class Normal(Distribution): class Normal(Distribution):
...@@ -412,7 +412,9 @@ class Normal(Distribution): ...@@ -412,7 +412,9 @@ class Normal(Distribution):
self.loc + self.scale, batch_shape, self.loc.dtype, 0.0 self.loc + self.scale, batch_shape, self.loc.dtype, 0.0
) )
return ( return (
0.5 + 0.5 * math.log(2 * math.pi) + nn.log((self.scale + zero_tmp)) 0.5
+ 0.5 * math.log(2 * math.pi)
+ paddle.log((self.scale + zero_tmp))
) )
def log_prob(self, value): def log_prob(self, value):
...@@ -430,7 +432,7 @@ class Normal(Distribution): ...@@ -430,7 +432,7 @@ class Normal(Distribution):
) )
var = self.scale * self.scale var = self.scale * self.scale
log_scale = nn.log(self.scale) log_scale = paddle.log(self.scale)
return ( return (
-1.0 * ((value - self.loc) * (value - self.loc)) / (2.0 * var) -1.0 * ((value - self.loc) * (value - self.loc)) / (2.0 * var)
- log_scale - log_scale
...@@ -454,7 +456,7 @@ class Normal(Distribution): ...@@ -454,7 +456,7 @@ class Normal(Distribution):
var_ratio = var_ratio * var_ratio var_ratio = var_ratio * var_ratio
t1 = (self.loc - other.loc) / other.scale t1 = (self.loc - other.loc) / other.scale
t1 = t1 * t1 t1 = t1 * t1
return 0.5 * (var_ratio + t1 - 1.0 - nn.log(var_ratio)) return 0.5 * (var_ratio + t1 - 1.0 - paddle.log(var_ratio))
class Categorical(Distribution): class Categorical(Distribution):
...@@ -542,7 +544,8 @@ class Categorical(Distribution): ...@@ -542,7 +544,8 @@ class Categorical(Distribution):
other_z = paddle.sum(other_e_logits, axis=-1, keepdim=True) other_z = paddle.sum(other_e_logits, axis=-1, keepdim=True)
prob = e_logits / z prob = e_logits / z
kl = paddle.sum( kl = paddle.sum(
prob * (logits - nn.log(z) - other_logits + nn.log(other_z)), prob
* (logits - paddle.log(z) - other_logits + paddle.log(other_z)),
axis=-1, axis=-1,
keepdim=True, keepdim=True,
) )
...@@ -562,7 +565,7 @@ class Categorical(Distribution): ...@@ -562,7 +565,7 @@ class Categorical(Distribution):
prob = e_logits / z prob = e_logits / z
entropy = -1.0 * paddle.sum( entropy = -1.0 * paddle.sum(
prob * (logits - nn.log(z)), axis=-1, keepdim=True prob * (logits - paddle.log(z)), axis=-1, keepdim=True
) )
return entropy return entropy
...@@ -687,7 +690,7 @@ class MultivariateNormalDiag(Distribution): ...@@ -687,7 +690,7 @@ class MultivariateNormalDiag(Distribution):
""" """
entropy = 0.5 * ( entropy = 0.5 * (
self.scale.shape[0] * (1.0 + math.log(2 * math.pi)) self.scale.shape[0] * (1.0 + math.log(2 * math.pi))
+ nn.log(self._det(self.scale)) + paddle.log(self._det(self.scale))
) )
return entropy return entropy
...@@ -710,7 +713,9 @@ class MultivariateNormalDiag(Distribution): ...@@ -710,7 +713,9 @@ class MultivariateNormalDiag(Distribution):
) )
tri_matmul = nn.matmul(loc_matmul_cov, (other.loc - self.loc)) tri_matmul = nn.matmul(loc_matmul_cov, (other.loc - self.loc))
k = list(self.scale.shape)[0] k = list(self.scale.shape)[0]
ln_cov = nn.log(self._det(other.scale)) - nn.log(self._det(self.scale)) ln_cov = paddle.log(self._det(other.scale)) - paddle.log(
self._det(self.scale)
)
kl = 0.5 * (tr_cov_matmul + tri_matmul - k + ln_cov) kl = 0.5 * (tr_cov_matmul + tri_matmul - k + ln_cov)
return kl return kl
...@@ -97,7 +97,6 @@ __all__ = [ ...@@ -97,7 +97,6 @@ __all__ = [
'resize_trilinear', 'resize_trilinear',
'resize_nearest', 'resize_nearest',
'relu', 'relu',
'log',
'unique', 'unique',
'unique_with_counts', 'unique_with_counts',
'elementwise_add', 'elementwise_add',
...@@ -5246,47 +5245,6 @@ def resize_nearest( ...@@ -5246,47 +5245,6 @@ def resize_nearest(
) )
def log(x, name=None):
r"""
Calculates the natural log of the given input tensor, element-wise.
.. math::
Out = \\ln(x)
Args:
x (Tensor): Input Tensor. Must be one of the following types: float32, float64.
name (str|None): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`
Returns:
Tensor: The natural log of the input Tensor computed element-wise.
Examples:
.. code-block:: python
import paddle
x = [[2,3,4], [7,8,9]]
x = paddle.to_tensor(x, dtype='float32')
res = paddle.log(x)
# [[0.693147, 1.09861, 1.38629], [1.94591, 2.07944, 2.19722]]
"""
if in_dygraph_mode():
return _C_ops.log(x)
if _in_legacy_dygraph():
return _legacy_C_ops.log(x)
check_variable_and_dtype(x, 'x', ['float32', 'float64'], "log")
inputs = {'X': [x]}
helper = LayerHelper('log', **locals())
dtype = helper.input_dtype(input_param_name='x')
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(type="log", inputs={"X": x}, outputs={"Out": out})
return out
@deprecated(since="2.0.0", update_to="paddle.nn.functional.relu") @deprecated(since="2.0.0", update_to="paddle.nn.functional.relu")
def relu(x, name=None): def relu(x, name=None):
""" """
......
...@@ -1304,7 +1304,7 @@ class BeamSearchDecoder(Decoder): ...@@ -1304,7 +1304,7 @@ class BeamSearchDecoder(Decoder):
self.noend_mask_tensor, "float64" self.noend_mask_tensor, "float64"
) )
step_log_probs = nn.log(nn.softmax(logits)) step_log_probs = paddle.log(nn.softmax(logits))
step_log_probs = self._mask_probs(step_log_probs, beam_state.finished) step_log_probs = self._mask_probs(step_log_probs, beam_state.finished)
log_probs = nn.elementwise_add( log_probs = nn.elementwise_add(
x=step_log_probs, y=beam_state.log_probs, axis=0 x=step_log_probs, y=beam_state.log_probs, axis=0
...@@ -3529,8 +3529,8 @@ def beam_search( ...@@ -3529,8 +3529,8 @@ def beam_search(
name='probs', shape=[None, 10000], dtype='float32') name='probs', shape=[None, 10000], dtype='float32')
topk_scores, topk_indices = fluid.layers.topk(probs, k=beam_size) topk_scores, topk_indices = fluid.layers.topk(probs, k=beam_size)
accu_scores = fluid.layers.elementwise_add( accu_scores = fluid.layers.elementwise_add(
x=fluid.layers.log(x=topk_scores), x=paddle.log(x=topk_scores),
y=fluid.layers.reshape(pre_scores, shape=[-1]), y=paddle.reshape(pre_scores, shape=[-1]),
axis=0) axis=0)
selected_ids, selected_scores = fluid.layers.beam_search( selected_ids, selected_scores = fluid.layers.beam_search(
pre_ids=pre_ids, pre_ids=pre_ids,
......
...@@ -1837,7 +1837,7 @@ def fast_decode( ...@@ -1837,7 +1837,7 @@ def fast_decode(
input=layers.softmax(logits), k=beam_size input=layers.softmax(logits), k=beam_size
) )
accu_scores = layers.elementwise_add( accu_scores = layers.elementwise_add(
x=layers.log(topk_scores), x=paddle.log(topk_scores),
y=paddle.reshape(pre_scores, shape=[-1]), y=paddle.reshape(pre_scores, shape=[-1]),
axis=0, axis=0,
) )
......
...@@ -435,9 +435,7 @@ class BaseModel(fluid.dygraph.Layer): ...@@ -435,9 +435,7 @@ class BaseModel(fluid.dygraph.Layer):
cell_outputs = self._split_batch_beams(step_input) cell_outputs = self._split_batch_beams(step_input)
cell_outputs = self.fc(cell_outputs) cell_outputs = self.fc(cell_outputs)
step_log_probs = fluid.layers.log( step_log_probs = paddle.log(fluid.layers.softmax(cell_outputs))
fluid.layers.softmax(cell_outputs)
)
noend_array = [-self.kinf] * self.tar_vocab_size noend_array = [-self.kinf] * self.tar_vocab_size
noend_array[self.beam_end_token] = 0 noend_array[self.beam_end_token] = 0
noend_mask_tensor = to_variable( noend_mask_tensor = to_variable(
......
...@@ -329,13 +329,11 @@ def bmn_loss_func( ...@@ -329,13 +329,11 @@ def bmn_loss_func(
coef_0 = 0.5 * ratio / (ratio - 1) coef_0 = 0.5 * ratio / (ratio - 1)
coef_1 = 0.5 * ratio coef_1 = 0.5 * ratio
epsilon = 0.000001 epsilon = 0.000001
# temp = fluid.layers.log(pred_score + epsilon) # temp = paddle.log(pred_score + epsilon)
loss_pos = paddle.multiply( loss_pos = paddle.multiply(paddle.log(pred_score + epsilon), pmask)
fluid.layers.log(pred_score + epsilon), pmask
)
loss_pos = coef_1 * fluid.layers.reduce_mean(loss_pos) loss_pos = coef_1 * fluid.layers.reduce_mean(loss_pos)
loss_neg = paddle.multiply( loss_neg = paddle.multiply(
fluid.layers.log(1.0 - pred_score + epsilon), (1.0 - pmask) paddle.log(1.0 - pred_score + epsilon), (1.0 - pmask)
) )
loss_neg = coef_0 * fluid.layers.reduce_mean(loss_neg) loss_neg = coef_0 * fluid.layers.reduce_mean(loss_neg)
loss = -1 * (loss_pos + loss_neg) loss = -1 * (loss_pos + loss_neg)
...@@ -400,12 +398,10 @@ def bmn_loss_func( ...@@ -400,12 +398,10 @@ def bmn_loss_func(
coef_0 = 0.5 * ratio / (ratio - 1) coef_0 = 0.5 * ratio / (ratio - 1)
coef_1 = 0.5 * ratio coef_1 = 0.5 * ratio
epsilon = 0.000001 epsilon = 0.000001
loss_pos = paddle.multiply( loss_pos = paddle.multiply(paddle.log(pred_score + epsilon), pmask)
fluid.layers.log(pred_score + epsilon), pmask
)
loss_pos = coef_1 * paddle.sum(loss_pos) loss_pos = coef_1 * paddle.sum(loss_pos)
loss_neg = paddle.multiply( loss_neg = paddle.multiply(
fluid.layers.log(1.0 - pred_score + epsilon), nmask paddle.log(1.0 - pred_score + epsilon), nmask
) )
loss_neg = coef_0 * paddle.sum(loss_neg) loss_neg = coef_0 * paddle.sum(loss_neg)
loss = -1 * (loss_pos + loss_neg) / num_entries loss = -1 * (loss_pos + loss_neg) / num_entries
......
...@@ -122,7 +122,7 @@ def train(args, place, to_static): ...@@ -122,7 +122,7 @@ def train(args, place, to_static):
mask = to_variable(_mask) mask = to_variable(_mask)
mask.stop_gradient = True mask.stop_gradient = True
loss_probs = fluid.layers.log(loss_probs) loss_probs = paddle.log(loss_probs)
loss_probs = paddle.multiply(loss_probs, mask) loss_probs = paddle.multiply(loss_probs, mask)
loss_probs = paddle.sum(loss_probs, axis=-1) loss_probs = paddle.sum(loss_probs, axis=-1)
......
...@@ -845,7 +845,7 @@ class Transformer(Layer): ...@@ -845,7 +845,7 @@ class Transformer(Layer):
) )
caches = map_structure(split_batch_beams, caches) caches = map_structure(split_batch_beams, caches)
step_log_probs = split_batch_beams( step_log_probs = split_batch_beams(
fluid.layers.log(fluid.layers.softmax(logits)) paddle.log(fluid.layers.softmax(logits))
) )
step_log_probs = mask_probs( step_log_probs = mask_probs(
......
...@@ -2417,8 +2417,8 @@ class TestLog(TestActivation): ...@@ -2417,8 +2417,8 @@ class TestLog(TestActivation):
name="in2", shape=[11, 17], append_batch_size=False, dtype="int64" name="in2", shape=[11, 17], append_batch_size=False, dtype="int64"
) )
self.assertRaises(TypeError, fluid.layers.log, in1) self.assertRaises(TypeError, paddle.log, in1)
self.assertRaises(TypeError, fluid.layers.log, in2) self.assertRaises(TypeError, paddle.log, in2)
class TestLog_ZeroDim(TestLog): class TestLog_ZeroDim(TestLog):
......
...@@ -314,7 +314,7 @@ class TestBeamSearchOpError(unittest.TestCase): ...@@ -314,7 +314,7 @@ class TestBeamSearchOpError(unittest.TestCase):
probs = fluid.data(name='probs', shape=[10000], dtype='float32') probs = fluid.data(name='probs', shape=[10000], dtype='float32')
topk_scores, topk_indices = fluid.layers.topk(probs, k=4) topk_scores, topk_indices = fluid.layers.topk(probs, k=4)
accu_scores = fluid.layers.elementwise_add( accu_scores = fluid.layers.elementwise_add(
x=fluid.layers.log(x=topk_scores), x=paddle.log(x=topk_scores),
y=paddle.reshape(pre_scores, shape=[-1]), y=paddle.reshape(pre_scores, shape=[-1]),
axis=0, axis=0,
) )
......
...@@ -71,7 +71,7 @@ class TestImperativeMnist(unittest.TestCase): ...@@ -71,7 +71,7 @@ class TestImperativeMnist(unittest.TestCase):
dy_mask = fluid.dygraph.base.to_variable(mask) dy_mask = fluid.dygraph.base.to_variable(mask)
dy_mask.stop_gradient = True dy_mask.stop_gradient = True
loss_probs = fluid.layers.log(loss_probs) loss_probs = paddle.log(loss_probs)
loss_probs = fluid.layers.elementwise_mul(loss_probs, dy_mask) loss_probs = fluid.layers.elementwise_mul(loss_probs, dy_mask)
loss_probs = paddle.sum(loss_probs, axis=-1) loss_probs = paddle.sum(loss_probs, axis=-1)
...@@ -139,7 +139,7 @@ class TestImperativeMnist(unittest.TestCase): ...@@ -139,7 +139,7 @@ class TestImperativeMnist(unittest.TestCase):
st_loss_probs = policy(st_state) st_loss_probs = policy(st_state)
st_loss_probs = fluid.layers.log(st_loss_probs) st_loss_probs = paddle.log(st_loss_probs)
st_loss_probs = fluid.layers.elementwise_mul(st_loss_probs, st_mask) st_loss_probs = fluid.layers.elementwise_mul(st_loss_probs, st_mask)
st_loss_probs = paddle.sum(st_loss_probs, axis=-1) st_loss_probs = paddle.sum(st_loss_probs, axis=-1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册