未验证 提交 1797f3db 编写于 作者: Q QI JUN 提交者: GitHub

Refine memory optimization transpiler (#7394)

* add update graph method for memory optimization transpiler to avoid rebuild graph everytime

* clean code

* reset var desc if hit cache
上级 0e544775
......@@ -236,6 +236,9 @@ class Variable(object):
__repr__ = __str__
def set_desc(self, input):
self.desc = input
@property
def persistable(self):
return self.desc.persistable()
......
......@@ -3,6 +3,17 @@ import framework
from framework import Program, default_main_program, Parameter, Variable
import backward
from backward import _rename_arg_
from . import core
dtype_to_size = {
core.DataType.FP16: 2,
core.DataType.FP32: 4,
core.DataType.FP64: 8,
core.DataType.INT16: 2,
core.DataType.INT32: 4,
core.DataType.INT64: 8,
core.DataType.BOOL: 1
}
class ControlFlowGraph(object):
......@@ -28,18 +39,33 @@ class ControlFlowGraph(object):
block_size = program_desc.num_blocks()
# TODO(qijun) handle Program with if/while operators
self.global_block = program_desc.block(0)
self.op_size = self.global_block.op_size()
self.global_block_desc = program_desc.block(0)
self.op_size = self.global_block_desc.op_size()
op_node_connections = [(i, i + 1) for i in range(self.op_size - 1)]
self._add_connections(op_node_connections)
self.ops = [self.global_block.op(i) for i in range(self.op_size)]
self.ops = [self.global_block_desc.op(i) for i in range(self.op_size)]
for i in range(self.op_size):
self._uses[i].update(self.ops[i].input_arg_names())
self._defs[i].update(self.ops[i].output_arg_names())
def _update_graph(self, old_name, new_name, begin_idx=0):
for i in range(begin_idx, self.op_size):
if old_name in self._uses[i]:
self._uses[i].remove(old_name)
self._uses[i].add(new_name)
if old_name in self._defs[i]:
self._defs[i].remove(old_name)
self._defs[i].add(new_name)
if old_name in self._live_in[i]:
self._live_in[i].remove(old_name)
self._live_out[i].add(new_name)
if old_name in self._live_out[i]:
self._live_out[i].remove(old_name)
self._live_out[i].add(new_name)
def _reach_fixed_point(self, live_in, live_out):
if len(live_in) != len(self._live_in):
return False
......@@ -79,30 +105,45 @@ class ControlFlowGraph(object):
self.pool = []
for i in range(self.op_size):
if self.pool:
out_pair = [(x, self.global_block.var(str(x)).shape())
out_pair = [(x, self.global_block_desc.var(str(x)).shape())
for x in self._defs[i]]
for x, x_shape in out_pair:
if not self.global_block_desc.var(str(x)).persistable():
for index, cache_pair in enumerate(self.pool):
cache_var = cache_pair[0]
cache_shape = cache_pair[1]
if x_shape == cache_shape:
x_dtype = self.global_block_desc.var(str(
x)).dtype()
cache_dtype = self.global_block_desc.var(
str(cache_var)).dtype()
# TODO(qijun): actually, we should compare dtype_to_size[x_dtype]
# and dtype_to_size[cache_dtype]
if x_dtype == cache_dtype:
print(
"Hit Cache !!!! cache pool index is %d, var name is %s, cached var name is %s, var shape is %s "
% (index, x, cache_var, str(cache_shape)))
%
(index, x, cache_var, str(cache_shape)))
self.pool.pop(index)
_rename_arg_(self.ops, x, cache_var, begin_idx=i)
self._dataflow_analyze()
_rename_arg_(
self.ops, x, cache_var, begin_idx=i)
self._program.current_block().var(str(
x)).desc = self.global_block_desc.var(
str(cache_var))
self._update_graph(
x, cache_var, begin_idx=i)
break
in_diff, out_diff = self._get_diff(self._live_in[i],
self._live_out[i])
can_optimize = filter(
lambda x: not self.global_block.var(str(x)).persistable(),
lambda x: not self.global_block_desc.var(str(x)).persistable(),
in_diff)
if can_optimize:
for var_name in can_optimize:
self.pool.append((
var_name, self.global_block.var(str(var_name)).shape()))
self.pool.append(
(var_name,
self.global_block_desc.var(str(var_name)).shape()))
def get_program(self):
return self._program
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册