提交 fb08e163 编写于 作者: F fengjiayi

refine memory usage calc

上级 a22309af
...@@ -70,23 +70,32 @@ def memory_usage(program, batch_size): ...@@ -70,23 +70,32 @@ def memory_usage(program, batch_size):
if not isinstance(program, Program): if not isinstance(program, Program):
raise TypeError( raise TypeError(
"Calculating Memory Usage requires Program as its Parameter." "Calculating Memory Usage requires Program as its Parameter."
"But you passed in %s" % (type(prgram))) "But you passed in %s" % (type(program)))
if batch_size <= 0: if batch_size <= 0:
raise ValueError("The batch size need to be positive.") raise ValueError("The batch size need to be positive.")
# Get the var_name list of first block and calculate # Get the var_name list of first block and calculate
total_memory = 0.0 total_memory = 0.0
for var in six.itervalues(program.global_block().vars): processed_var_names = set()
data_count = 1 for op in program.global_block().ops:
for x in var.shape: for var_name in op.output_arg_names:
if x == -1: if var_name in processed_var_names:
data_count *= batch_size continue
else: processed_var_names.add(var_name)
data_count *= x var = program.global_block().vars[var_name]
var_memory = data_count * dtype_to_size[var.dtype] if var.desc.type() != core.VarDesc.VarType.LOD_TENSOR:
if DEBUG: continue
print("%s memory usage: %d" % (var.name, var_memory))
total_memory += var_memory data_count = 1
for x in var.shape:
if x < 0:
data_count *= batch_size * (-x)
else:
data_count *= x
var_memory = data_count * dtype_to_size[var.dtype]
if DEBUG:
print("%s memory usage: %d" % (var.name, var_memory))
total_memory += var_memory
if DEBUG: if DEBUG:
print("total memory usage: %.2f" % (total_memory)) print("total memory usage: %.2f" % (total_memory))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册