未验证 提交 4d7d6612 编写于 作者: L LielinJiang 提交者: GitHub

Fix kl and summary bug (#27132)

* fix summary rnn

* fix kl_div bug when input shape is [1] and reduction is batchmean
上级 13804ed8
...@@ -72,7 +72,11 @@ class KLDivLossKernel : public framework::OpKernel<T> { ...@@ -72,7 +72,11 @@ class KLDivLossKernel : public framework::OpKernel<T> {
loss_t.device(place) = output; loss_t.device(place) = output;
} else if ("batchmean" == reduction) { } else if ("batchmean" == reduction) {
auto output_sum = output.sum(); auto output_sum = output.sum();
loss_t.device(place) = output_sum / output_sum.constant(n); if (n > 0) {
loss_t.device(place) = output_sum / output_sum.constant(n);
} else {
loss_t.device(place) = output_sum;
}
} else if ("mean" == reduction) { } else if ("mean" == reduction) {
loss_t.device(place) = output.mean(); loss_t.device(place) = output.mean();
} else if ("sum" == reduction) { } else if ("sum" == reduction) {
......
...@@ -24,7 +24,10 @@ def kldiv_loss(x, target, reduction): ...@@ -24,7 +24,10 @@ def kldiv_loss(x, target, reduction):
loss = np.where(target >= 0, output, np.zeros_like(x)) loss = np.where(target >= 0, output, np.zeros_like(x))
if reduction == "batchmean": if reduction == "batchmean":
return loss.sum() / x.shape[0] if len(x.shape) > 0:
return loss.sum() / x.shape[0]
else:
return loss.sum()
if reduction == "mean": if reduction == "mean":
return loss.mean() return loss.mean()
if reduction == "sum": if reduction == "sum":
...@@ -93,6 +96,9 @@ class TestKLDivLossDygraph(unittest.TestCase): ...@@ -93,6 +96,9 @@ class TestKLDivLossDygraph(unittest.TestCase):
def test_kl_loss_batchmean(self): def test_kl_loss_batchmean(self):
self.run_kl_loss('batchmean') self.run_kl_loss('batchmean')
def test_kl_loss_batchmean_shape(self):
self.run_kl_loss('batchmean', ())
def test_kl_loss_mean(self): def test_kl_loss_mean(self):
self.run_kl_loss('mean') self.run_kl_loss('mean')
......
...@@ -12,7 +12,9 @@ ...@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import warnings
import numpy as np import numpy as np
import numbers
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
...@@ -107,6 +109,11 @@ def summary(net, input_size, batch_size=None, dtypes=None): ...@@ -107,6 +109,11 @@ def summary(net, input_size, batch_size=None, dtypes=None):
if batch_size is None: if batch_size is None:
batch_size = -1 batch_size = -1
if not paddle.in_dynamic_mode():
warnings.warn(
"Your model was created in static mode, this may not get correct summary information!"
)
result, params_info = summary_string(net, _input_size, batch_size, dtypes) result, params_info = summary_string(net, _input_size, batch_size, dtypes)
print(result) print(result)
...@@ -121,16 +128,16 @@ def summary_string(model, input_size, batch_size=-1, dtypes=None): ...@@ -121,16 +128,16 @@ def summary_string(model, input_size, batch_size=-1, dtypes=None):
depth = len(list(model.sublayers())) depth = len(list(model.sublayers()))
def register_hook(module): def register_hook(layer):
def hook(module, input, output): def hook(layer, input, output):
class_name = str(module.__class__).split(".")[-1].split("'")[0] class_name = str(layer.__class__).split(".")[-1].split("'")[0]
try: try:
module_idx = int(module._full_name.split('_')[-1]) layer_idx = int(layer._full_name.split('_')[-1])
except: except:
module_idx = len(summary) layer_idx = len(summary)
m_key = "%s-%i" % (class_name, module_idx + 1) m_key = "%s-%i" % (class_name, layer_idx + 1)
summary[m_key] = OrderedDict() summary[m_key] = OrderedDict()
summary[m_key]["input_shape"] = list(input[0].shape) summary[m_key]["input_shape"] = list(input[0].shape)
summary[m_key]["input_shape"][0] = batch_size summary[m_key]["input_shape"][0] = batch_size
...@@ -142,23 +149,50 @@ def summary_string(model, input_size, batch_size=-1, dtypes=None): ...@@ -142,23 +149,50 @@ def summary_string(model, input_size, batch_size=-1, dtypes=None):
summary[m_key]["output_shape"][0] = batch_size summary[m_key]["output_shape"][0] = batch_size
params = 0 params = 0
if hasattr(module, "weight") and hasattr(module.weight, "shape"):
params += np.prod(module.weight.shape) if paddle.in_dynamic_mode():
summary[m_key]["trainable"] = module.weight.trainable or ( layer_state_dict = layer._parameters
not module.weight.stop_gradient) else:
if hasattr(module, "bias") and hasattr(module.bias, "shape"): layer_state_dict = layer.state_dict()
params += np.prod(module.bias.shape)
for k, v in layer_state_dict.items():
params += np.prod(v.shape)
try:
if (getattr(getattr(layer, k), 'trainable')) and (
not getattr(getattr(layer, k), 'stop_gradient')):
summary[m_key]["trainable"] = True
else:
summary[m_key]["trainable"] = False
except:
summary[m_key]["trainable"] = True
summary[m_key]["nb_params"] = params summary[m_key]["nb_params"] = params
if (not isinstance(module, nn.Sequential) and if (not isinstance(layer, nn.Sequential) and
not isinstance(module, nn.LayerList) and not isinstance(layer, nn.LayerList) and
(not (module == model) or depth < 1)): (not (layer == model) or depth < 1)):
hooks.append(layer.register_forward_post_hook(hook))
def _check_input_size(input_sizes):
for input_size in input_sizes:
for item in input_size:
if not isinstance(item, numbers.Number):
raise TypeError(
"Expected item in input size be a number, but got {}".
format(type(item)))
hooks.append(module.register_forward_post_hook(hook)) if item <= 0:
raise ValueError(
"Expected item in input size greater than zero, but got {}".
format(item))
if isinstance(input_size, tuple): if isinstance(input_size, tuple):
input_size = [input_size] input_size = [input_size]
_check_input_size(input_size)
x = [ x = [
paddle.rand( paddle.rand(
[2] + list(in_size), dtype=dtype) [2] + list(in_size), dtype=dtype)
...@@ -197,7 +231,12 @@ def summary_string(model, input_size, batch_size=-1, dtypes=None): ...@@ -197,7 +231,12 @@ def summary_string(model, input_size, batch_size=-1, dtypes=None):
"{0:,}".format(summary[layer]["nb_params"]), ) "{0:,}".format(summary[layer]["nb_params"]), )
total_params += summary[layer]["nb_params"] total_params += summary[layer]["nb_params"]
total_output += np.prod(summary[layer]["output_shape"]) try:
total_output += np.prod(summary[layer]["output_shape"])
except:
for output_shape in summary[layer]["output_shape"]:
total_output += np.prod(output_shape)
if "trainable" in summary[layer]: if "trainable" in summary[layer]:
if summary[layer]["trainable"] == True: if summary[layer]["trainable"] == True:
trainable_params += summary[layer]["nb_params"] trainable_params += summary[layer]["nb_params"]
......
...@@ -780,10 +780,10 @@ def kl_div(input, label, reduction='mean', name=None): ...@@ -780,10 +780,10 @@ def kl_div(input, label, reduction='mean', name=None):
input = np.random.uniform(-10, 10, shape).astype('float32') input = np.random.uniform(-10, 10, shape).astype('float32')
target = np.random.uniform(-10, 10, shape).astype('float32') target = np.random.uniform(-10, 10, shape).astype('float32')
# 'batchmean' reduction, loss shape will be [N] # 'batchmean' reduction, loss shape will be [1]
pred_loss = F.kl_div(paddle.to_tensor(input), pred_loss = F.kl_div(paddle.to_tensor(input),
paddle.to_tensor(target), reduction='batchmean') paddle.to_tensor(target), reduction='batchmean')
# shape=[5] # shape=[1]
# 'mean' reduction, loss shape will be [1] # 'mean' reduction, loss shape will be [1]
pred_loss = F.kl_div(paddle.to_tensor(input), pred_loss = F.kl_div(paddle.to_tensor(input),
......
...@@ -627,10 +627,13 @@ class KLDivLoss(fluid.dygraph.Layer): ...@@ -627,10 +627,13 @@ class KLDivLoss(fluid.dygraph.Layer):
$$l(x, y) = y * (\log(y) - x)$$ $$l(x, y) = y * (\log(y) - x)$$
Parameters: Parameters:
reduction (str, optional): Indicate how to average the loss, reduction (Tensor): Indicate how to average the loss,
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. the candicates are ``'none'`` | ``'batchmean'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned; If `reduction` is ``'mean'``, the reduced mean loss is returned;
Default is ``'mean'``. If `reduction` is ``'batchmean'``, the sum loss divided by batch size is returned;
if `reduction` is ``'sum'``, the reduced sum loss is returned;
if `reduction` is ``'none'``, no reduction will be apllied.
Default is ``'mean'``.
Shape: Shape:
...@@ -654,11 +657,11 @@ class KLDivLoss(fluid.dygraph.Layer): ...@@ -654,11 +657,11 @@ class KLDivLoss(fluid.dygraph.Layer):
x = np.random.uniform(-10, 10, shape).astype('float32') x = np.random.uniform(-10, 10, shape).astype('float32')
target = np.random.uniform(-10, 10, shape).astype('float32') target = np.random.uniform(-10, 10, shape).astype('float32')
# 'batchmean' reduction, loss shape will be [N] # 'batchmean' reduction, loss shape will be [1]
kldiv_criterion = nn.KLDivLoss(reduction='batchmean') kldiv_criterion = nn.KLDivLoss(reduction='batchmean')
pred_loss = kldiv_criterion(paddle.to_tensor(x), pred_loss = kldiv_criterion(paddle.to_tensor(x),
paddle.to_tensor(target)) paddle.to_tensor(target))
# shape=[5] # shape=[1]
# 'mean' reduction, loss shape will be [1] # 'mean' reduction, loss shape will be [1]
kldiv_criterion = nn.KLDivLoss(reduction='mean') kldiv_criterion = nn.KLDivLoss(reduction='mean')
...@@ -684,7 +687,7 @@ class KLDivLoss(fluid.dygraph.Layer): ...@@ -684,7 +687,7 @@ class KLDivLoss(fluid.dygraph.Layer):
self.reduction = reduction self.reduction = reduction
def forward(self, input, label): def forward(self, input, label):
out = paddle.nn.functional.kl_div(input, label, self.reduction) out = F.kl_div(input, label, self.reduction)
return out return out
......
...@@ -523,6 +523,24 @@ class TestModelFunction(unittest.TestCase): ...@@ -523,6 +523,24 @@ class TestModelFunction(unittest.TestCase):
model.summary(input_size=[(20)]) model.summary(input_size=[(20)])
model.summary(input_size=(20), batch_size=2) model.summary(input_size=(20), batch_size=2)
def test_summary_nlp(self):
paddle.enable_static()
nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3)
paddle.summary(nlp_net, (1, 2))
def test_summary_error(self):
with self.assertRaises(TypeError):
nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3)
paddle.summary(nlp_net, (1, '2'))
with self.assertRaises(ValueError):
nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3)
paddle.summary(nlp_net, (-1, -1))
paddle.disable_static()
nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3)
paddle.summary(nlp_net, (1, 2))
def test_export_deploy_model(self): def test_export_deploy_model(self):
for dynamic in [True, False]: for dynamic in [True, False]:
fluid.enable_dygraph() if dynamic else None fluid.enable_dygraph() if dynamic else None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册