diff --git a/paddle/fluid/operators/kldiv_loss_op.h b/paddle/fluid/operators/kldiv_loss_op.h index 369fdb4872b4184d706a5264b58f70f63051fca1..857ecda303c2607b1b6fb9a5d2ec132b335d6c29 100644 --- a/paddle/fluid/operators/kldiv_loss_op.h +++ b/paddle/fluid/operators/kldiv_loss_op.h @@ -72,7 +72,11 @@ class KLDivLossKernel : public framework::OpKernel { loss_t.device(place) = output; } else if ("batchmean" == reduction) { auto output_sum = output.sum(); - loss_t.device(place) = output_sum / output_sum.constant(n); + if (n > 0) { + loss_t.device(place) = output_sum / output_sum.constant(n); + } else { + loss_t.device(place) = output_sum; + } } else if ("mean" == reduction) { loss_t.device(place) = output.mean(); } else if ("sum" == reduction) { diff --git a/python/paddle/fluid/tests/unittests/test_kldiv_loss_op.py b/python/paddle/fluid/tests/unittests/test_kldiv_loss_op.py index 8780727e4cb276a989a8d04d05c6419a4874e7f5..041fe4e9043d60852fcaab42bc233b63b39609ce 100644 --- a/python/paddle/fluid/tests/unittests/test_kldiv_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_kldiv_loss_op.py @@ -24,7 +24,10 @@ def kldiv_loss(x, target, reduction): loss = np.where(target >= 0, output, np.zeros_like(x)) if reduction == "batchmean": - return loss.sum() / x.shape[0] + if len(x.shape) > 0: + return loss.sum() / x.shape[0] + else: + return loss.sum() if reduction == "mean": return loss.mean() if reduction == "sum": @@ -93,6 +96,9 @@ class TestKLDivLossDygraph(unittest.TestCase): def test_kl_loss_batchmean(self): self.run_kl_loss('batchmean') + def test_kl_loss_batchmean_shape(self): + self.run_kl_loss('batchmean', ()) + def test_kl_loss_mean(self): self.run_kl_loss('mean') diff --git a/python/paddle/hapi/model.py b/python/paddle/hapi/model.py index 019e526dfbb3107aa730ad93b82faa9a98042e9b..2836a151ec35698a31f3814d573828853349a151 100644 --- a/python/paddle/hapi/model.py +++ b/python/paddle/hapi/model.py @@ -1868,8 +1868,13 @@ class Model(object): print(params_info) """ - - return summary(self.network, self._inputs, batch_size, dtype) + assert (input_size is not None or self._inputs is not None + ), "'input_size' or 'self._input' must be set" + if input_size is not None: + _input_size = input_size + else: + _input_size = self._inputs + return summary(self.network, _input_size, batch_size, dtype) def _verify_spec(self, specs, is_input=False): out_specs = [] diff --git a/python/paddle/hapi/model_summary.py b/python/paddle/hapi/model_summary.py index ddafcbed8ec87ed01c21cd43053101236459dc90..d388ba62f2a244f84497810739e5fd6b50f669d2 100644 --- a/python/paddle/hapi/model_summary.py +++ b/python/paddle/hapi/model_summary.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings import numpy as np +import numbers import paddle import paddle.nn as nn @@ -86,8 +88,10 @@ def summary(net, input_size, batch_size=None, dtypes=None): elif isinstance(input_size, list): _input_size = [] for item in input_size: + if isinstance(item, int): + item = (item, ) assert isinstance(item, - (list, InputSpec)), 'When input_size is list, \ + (tuple, InputSpec)), 'When input_size is list, \ expect item in input_size is a tuple or InputSpec, but got {}'.format( type(item)) @@ -97,12 +101,19 @@ def summary(net, input_size, batch_size=None, dtypes=None): batch_size = item.shape[0] else: _input_size.append(item) + elif isinstance(input_size, int): + _input_size = (input_size, ) else: _input_size = input_size if batch_size is None: batch_size = -1 + if not paddle.in_dynamic_mode(): + warnings.warn( + "Your model was created in static mode, this may not get correct summary information!" + ) + result, params_info = summary_string(net, _input_size, batch_size, dtypes) print(result) @@ -117,16 +128,16 @@ def summary_string(model, input_size, batch_size=-1, dtypes=None): depth = len(list(model.sublayers())) - def register_hook(module): - def hook(module, input, output): - class_name = str(module.__class__).split(".")[-1].split("'")[0] + def register_hook(layer): + def hook(layer, input, output): + class_name = str(layer.__class__).split(".")[-1].split("'")[0] try: - module_idx = int(module._full_name.split('_')[-1]) + layer_idx = int(layer._full_name.split('_')[-1]) except: - module_idx = len(summary) + layer_idx = len(summary) - m_key = "%s-%i" % (class_name, module_idx + 1) + m_key = "%s-%i" % (class_name, layer_idx + 1) summary[m_key] = OrderedDict() summary[m_key]["input_shape"] = list(input[0].shape) summary[m_key]["input_shape"][0] = batch_size @@ -138,23 +149,50 @@ def summary_string(model, input_size, batch_size=-1, dtypes=None): summary[m_key]["output_shape"][0] = batch_size params = 0 - if hasattr(module, "weight"): - params += np.prod(module.weight.shape) - summary[m_key]["trainable"] = module.weight.trainable or ( - not module.weight.stop_gradient) - if hasattr(module, "bias"): - params += np.prod(module.bias.shape) + + if paddle.in_dynamic_mode(): + layer_state_dict = layer._parameters + else: + layer_state_dict = layer.state_dict() + + for k, v in layer_state_dict.items(): + params += np.prod(v.shape) + + try: + if (getattr(getattr(layer, k), 'trainable')) and ( + not getattr(getattr(layer, k), 'stop_gradient')): + summary[m_key]["trainable"] = True + else: + summary[m_key]["trainable"] = False + except: + summary[m_key]["trainable"] = True + summary[m_key]["nb_params"] = params - if (not isinstance(module, nn.Sequential) and - not isinstance(module, nn.LayerList) and - (not (module == model) or depth < 1)): + if (not isinstance(layer, nn.Sequential) and + not isinstance(layer, nn.LayerList) and + (not (layer == model) or depth < 1)): + + hooks.append(layer.register_forward_post_hook(hook)) + + def _check_input_size(input_sizes): + for input_size in input_sizes: + for item in input_size: + if not isinstance(item, numbers.Number): + raise TypeError( + "Expected item in input size be a number, but got {}". + format(type(item))) - hooks.append(module.register_forward_post_hook(hook)) + if item <= 0: + raise ValueError( + "Expected item in input size greater than zero, but got {}". + format(item)) if isinstance(input_size, tuple): input_size = [input_size] + _check_input_size(input_size) + x = [ paddle.rand( [2] + list(in_size), dtype=dtype) @@ -193,7 +231,12 @@ def summary_string(model, input_size, batch_size=-1, dtypes=None): "{0:,}".format(summary[layer]["nb_params"]), ) total_params += summary[layer]["nb_params"] - total_output += np.prod(summary[layer]["output_shape"]) + try: + total_output += np.prod(summary[layer]["output_shape"]) + except: + for output_shape in summary[layer]["output_shape"]: + total_output += np.prod(output_shape) + if "trainable" in summary[layer]: if summary[layer]["trainable"] == True: trainable_params += summary[layer]["nb_params"] diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 3d5894064c44cb72259472fc638d46b67c5703fc..6c139b0ddbbb996145e3a611839bf5e2e113f3cd 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -780,10 +780,10 @@ def kl_div(input, label, reduction='mean', name=None): input = np.random.uniform(-10, 10, shape).astype('float32') target = np.random.uniform(-10, 10, shape).astype('float32') - # 'batchmean' reduction, loss shape will be [N] + # 'batchmean' reduction, loss shape will be [1] pred_loss = F.kl_div(paddle.to_tensor(input), paddle.to_tensor(target), reduction='batchmean') - # shape=[5] + # shape=[1] # 'mean' reduction, loss shape will be [1] pred_loss = F.kl_div(paddle.to_tensor(input), diff --git a/python/paddle/nn/layer/conv.py b/python/paddle/nn/layer/conv.py index f3985781adb6267780cc974cef7dc3fa8ae46b38..a610693a0a46b7e21d2c6d83716a7bc029677583 100644 --- a/python/paddle/nn/layer/conv.py +++ b/python/paddle/nn/layer/conv.py @@ -1084,7 +1084,7 @@ class ConvTranspose3d(_ConvNd): bias_attr=bias_attr, data_format=data_format) - def forward(self, x, output_size): + def forward(self, x, output_size=None): if output_size is None: output_padding = self.output_padding else: diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index a60e615d5064bf4ef2229dd67193774030383888..271dc9b4e685ce06cdb12ccdcb6bb0704a5ef2a1 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -627,10 +627,13 @@ class KLDivLoss(fluid.dygraph.Layer): $$l(x, y) = y * (\log(y) - x)$$ Parameters: - reduction (str, optional): Indicate how to average the loss, - the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. - If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned; - Default is ``'mean'``. + reduction (Tensor): Indicate how to average the loss, + the candicates are ``'none'`` | ``'batchmean'`` | ``'mean'`` | ``'sum'``. + If `reduction` is ``'mean'``, the reduced mean loss is returned; + If `reduction` is ``'batchmean'``, the sum loss divided by batch size is returned; + if `reduction` is ``'sum'``, the reduced sum loss is returned; + if `reduction` is ``'none'``, no reduction will be apllied. + Default is ``'mean'``. Shape: @@ -654,11 +657,11 @@ class KLDivLoss(fluid.dygraph.Layer): x = np.random.uniform(-10, 10, shape).astype('float32') target = np.random.uniform(-10, 10, shape).astype('float32') - # 'batchmean' reduction, loss shape will be [N] + # 'batchmean' reduction, loss shape will be [1] kldiv_criterion = nn.KLDivLoss(reduction='batchmean') pred_loss = kldiv_criterion(paddle.to_tensor(x), paddle.to_tensor(target)) - # shape=[5] + # shape=[1] # 'mean' reduction, loss shape will be [1] kldiv_criterion = nn.KLDivLoss(reduction='mean') @@ -684,7 +687,7 @@ class KLDivLoss(fluid.dygraph.Layer): self.reduction = reduction def forward(self, input, label): - out = paddle.nn.functional.kl_div(input, label, self.reduction) + out = F.kl_div(input, label, self.reduction) return out diff --git a/python/paddle/tests/test_model.py b/python/paddle/tests/test_model.py index 0267cd5dbc13c7ecc801eeb4d09d42fd84b7a6d0..5c4e98feaa686217bc78ad3915423593ad4fcdce 100644 --- a/python/paddle/tests/test_model.py +++ b/python/paddle/tests/test_model.py @@ -519,6 +519,28 @@ class TestModelFunction(unittest.TestCase): np.testing.assert_allclose(params_info['total_params'], gt_params) print(params_info) + model.summary(input_size=(20)) + model.summary(input_size=[(20)]) + model.summary(input_size=(20), batch_size=2) + + def test_summary_nlp(self): + paddle.enable_static() + nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3) + paddle.summary(nlp_net, (1, 2)) + + def test_summary_error(self): + with self.assertRaises(TypeError): + nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3) + paddle.summary(nlp_net, (1, '2')) + + with self.assertRaises(ValueError): + nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3) + paddle.summary(nlp_net, (-1, -1)) + + paddle.disable_static() + nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3) + paddle.summary(nlp_net, (1, 2)) + def test_export_deploy_model(self): for dynamic in [True, False]: fluid.enable_dygraph() if dynamic else None