提交 fb08e163 编写于 作者: F fengjiayi

refine memory usage calc

上级 a22309af
......@@ -70,17 +70,26 @@ def memory_usage(program, batch_size):
if not isinstance(program, Program):
raise TypeError(
"Calculating Memory Usage requires Program as its Parameter."
"But you passed in %s" % (type(prgram)))
"But you passed in %s" % (type(program)))
if batch_size <= 0:
raise ValueError("The batch size need to be positive.")
# Get the var_name list of first block and calculate
total_memory = 0.0
for var in six.itervalues(program.global_block().vars):
processed_var_names = set()
for op in program.global_block().ops:
for var_name in op.output_arg_names:
if var_name in processed_var_names:
continue
processed_var_names.add(var_name)
var = program.global_block().vars[var_name]
if var.desc.type() != core.VarDesc.VarType.LOD_TENSOR:
continue
data_count = 1
for x in var.shape:
if x == -1:
data_count *= batch_size
if x < 0:
data_count *= batch_size * (-x)
else:
data_count *= x
var_memory = data_count * dtype_to_size[var.dtype]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册