diff --git a/python/paddle/fluid/contrib/memory_usage_calc.py b/python/paddle/fluid/contrib/memory_usage_calc.py index 09721e430b7e5bb6b9891d5272ca54475baf6157..baa14a573fcfdfa943af1e995f687c74e9fb4d07 100644 --- a/python/paddle/fluid/contrib/memory_usage_calc.py +++ b/python/paddle/fluid/contrib/memory_usage_calc.py @@ -70,23 +70,37 @@ def memory_usage(program, batch_size): if not isinstance(program, Program): raise TypeError( "Calculating Memory Usage requires Program as its Parameter." - "But you passed in %s" % (type(prgram))) + "But you passed in %s" % (type(program))) if batch_size <= 0: raise ValueError("The batch size need to be positive.") # Get the var_name list of first block and calculate total_memory = 0.0 - for var in six.itervalues(program.global_block().vars): - data_count = 1 - for x in var.shape: - if x == -1: - data_count *= batch_size - else: - data_count *= x - var_memory = data_count * dtype_to_size[var.dtype] - if DEBUG: - print("%s memory usage: %d" % (var.name, var_memory)) - total_memory += var_memory + processed_var_names = set() + for op in program.global_block().ops: + for var_name in op.output_arg_names: + if var_name in processed_var_names: + continue + processed_var_names.add(var_name) + var = program.global_block().vars[var_name] + if var.desc.type() != core.VarDesc.VarType.LOD_TENSOR: + continue + + data_count = 1 + neg_dim_count = 0 + for x in var.shape: + if x < 0: + if neg_dim_count >= 1: + raise ValueError("Var %s has more than one negtive dim." + % (var_name)) + neg_dim_count += 1 + data_count *= batch_size * (-x) + else: + data_count *= x + var_memory = data_count * dtype_to_size[var.dtype] + if DEBUG: + print("%s memory usage: %d" % (var.name, var_memory)) + total_memory += var_memory if DEBUG: print("total memory usage: %.2f" % (total_memory))