From fb08e163cf11e5eec1f44b33168ae7df438d9d32 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Wed, 29 Aug 2018 16:50:22 +0800 Subject: [PATCH] refine memory usage calc --- .../paddle/fluid/contrib/memory_usage_calc.py | 33 ++++++++++++------- 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/python/paddle/fluid/contrib/memory_usage_calc.py b/python/paddle/fluid/contrib/memory_usage_calc.py index 09721e430b..5ffdca82bd 100644 --- a/python/paddle/fluid/contrib/memory_usage_calc.py +++ b/python/paddle/fluid/contrib/memory_usage_calc.py @@ -70,23 +70,32 @@ def memory_usage(program, batch_size): if not isinstance(program, Program): raise TypeError( "Calculating Memory Usage requires Program as its Parameter." - "But you passed in %s" % (type(prgram))) + "But you passed in %s" % (type(program))) if batch_size <= 0: raise ValueError("The batch size need to be positive.") # Get the var_name list of first block and calculate total_memory = 0.0 - for var in six.itervalues(program.global_block().vars): - data_count = 1 - for x in var.shape: - if x == -1: - data_count *= batch_size - else: - data_count *= x - var_memory = data_count * dtype_to_size[var.dtype] - if DEBUG: - print("%s memory usage: %d" % (var.name, var_memory)) - total_memory += var_memory + processed_var_names = set() + for op in program.global_block().ops: + for var_name in op.output_arg_names: + if var_name in processed_var_names: + continue + processed_var_names.add(var_name) + var = program.global_block().vars[var_name] + if var.desc.type() != core.VarDesc.VarType.LOD_TENSOR: + continue + + data_count = 1 + for x in var.shape: + if x < 0: + data_count *= batch_size * (-x) + else: + data_count *= x + var_memory = data_count * dtype_to_size[var.dtype] + if DEBUG: + print("%s memory usage: %d" % (var.name, var_memory)) + total_memory += var_memory if DEBUG: print("total memory usage: %.2f" % (total_memory)) -- GitLab