未验证 提交 03b1e4be 编写于 作者: F fengjiayi 提交者: GitHub

Merge pull request #13047 from JiayiFeng/dev_mem_analyse

refine memory usage calc
......@@ -70,23 +70,37 @@ def memory_usage(program, batch_size):
if not isinstance(program, Program):
raise TypeError(
"Calculating Memory Usage requires Program as its Parameter."
"But you passed in %s" % (type(prgram)))
"But you passed in %s" % (type(program)))
if batch_size <= 0:
raise ValueError("The batch size need to be positive.")
# Get the var_name list of first block and calculate
total_memory = 0.0
for var in six.itervalues(program.global_block().vars):
data_count = 1
for x in var.shape:
if x == -1:
data_count *= batch_size
else:
data_count *= x
var_memory = data_count * dtype_to_size[var.dtype]
if DEBUG:
print("%s memory usage: %d" % (var.name, var_memory))
total_memory += var_memory
processed_var_names = set()
for op in program.global_block().ops:
for var_name in op.output_arg_names:
if var_name in processed_var_names:
continue
processed_var_names.add(var_name)
var = program.global_block().vars[var_name]
if var.desc.type() != core.VarDesc.VarType.LOD_TENSOR:
continue
data_count = 1
neg_dim_count = 0
for x in var.shape:
if x < 0:
if neg_dim_count >= 1:
raise ValueError("Var %s has more than one negtive dim."
% (var_name))
neg_dim_count += 1
data_count *= batch_size * (-x)
else:
data_count *= x
var_memory = data_count * dtype_to_size[var.dtype]
if DEBUG:
print("%s memory usage: %d" % (var.name, var_memory))
total_memory += var_memory
if DEBUG:
print("total memory usage: %.2f" % (total_memory))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册