未验证 提交 18a7e13f 编写于 作者: A Aurelius84 提交者: GitHub

[D2SCinn]Fix self.infer_program always build cinn pass without cache (#49696)

* [D2SCinn]Fix self.infer_program always build cinn pass without cache

* fix infer op size
上级 c79befa0
......@@ -129,6 +129,21 @@ def _change_is_test_status(program, is_test):
return program
class ProgramInfo:
"""
A helper class to recoder Program information
"""
def __init__(self, mode='infer'):
self.op_size = {
'fp32': -1,
'amp': -1,
'fp16': -1,
}
assert mode in ['train', 'infer']
self.mode = mode
class PartialProgramLayer:
"""
PartialProgramLayer wraps all the ops from layers decorated by `@to_static`
......@@ -167,6 +182,7 @@ class PartialProgramLayer:
self._cuda_graph_pool_id = 0
# Set default mode to train
self.training = True
self._infer_info = ProgramInfo(mode='infer')
custom_white_list, custom_black_list = None, None
tracer = framework._dygraph_tracer()
......@@ -251,7 +267,8 @@ class PartialProgramLayer:
@switch_to_static_graph
def _create_forward_backward_train_program(self):
whole_program = self._train_program
forward_end_op_index = self._infer_program.desc.block(0).op_size()
forward_end_op_index = self._infer_info.op_size['fp32']
assert forward_end_op_index >= 0
return self._get_forward_backward_program_form(
whole_program, forward_end_op_index
)
......@@ -259,7 +276,8 @@ class PartialProgramLayer:
@switch_to_static_graph
def _create_forward_backward_train_amp_program(self):
whole_program = self._train_amp_program
forward_end_op_index = self._infer_amp_program.desc.block(0).op_size()
forward_end_op_index = self._infer_info.op_size['amp']
assert forward_end_op_index >= 0
return self._get_forward_backward_program_form(
whole_program, forward_end_op_index
)
......@@ -267,9 +285,8 @@ class PartialProgramLayer:
@switch_to_static_graph
def _create_forward_backward_train_pure_fp16_program(self):
whole_program = self._train_pure_fp16_program
forward_end_op_index = self._infer_pure_fp16_program.desc.block(
0
).op_size()
forward_end_op_index = self._infer_info.op_size['fp16']
assert forward_end_op_index >= 0
return self._get_forward_backward_program_form(
whole_program, forward_end_op_index
)
......@@ -280,7 +297,11 @@ class PartialProgramLayer:
@LazyInitialized
def _infer_program(self):
return self._create_program(is_infer_mode=True)
program = self._create_program(is_infer_mode=True)
self._infer_info.op_size['fp32'] = program.desc.block(0).op_size()
return self._build_infer_program(
program, self._infer_info.op_size['fp32']
)
@LazyInitialized
def _train_amp_program(self):
......@@ -288,7 +309,11 @@ class PartialProgramLayer:
@LazyInitialized
def _infer_amp_program(self):
return self._create_amp_program(is_infer_mode=True)
program = self._create_amp_program(is_infer_mode=True)
self._infer_info.op_size['amp'] = program.desc.block(0).op_size()
return self._build_infer_program(
program, self._infer_info.op_size['amp']
)
@LazyInitialized
def _train_pure_fp16_program(self):
......@@ -296,7 +321,11 @@ class PartialProgramLayer:
@LazyInitialized
def _infer_pure_fp16_program(self):
return self._create_pure_fp16_program(is_infer_mode=True)
program = self._create_pure_fp16_program(is_infer_mode=True)
self._infer_info.op_size['fp16'] = program.desc.block(0).op_size()
return self._build_infer_program(
program, self._infer_info.op_size['fp16']
)
@LazyInitialized
def _train_forward_backward_program(self):
......@@ -317,62 +346,6 @@ class PartialProgramLayer:
program = self._create_forward_backward_train_pure_fp16_program()
return program
@property
def whole_program(self):
if self.training:
if _in_amp_guard():
return self._train_amp_program
elif _in_pure_fp16_guard():
return self._train_pure_fp16_program
else:
return self._train_program
else:
if _in_amp_guard():
return self._infer_amp_program
elif _in_pure_fp16_guard():
return self._infer_pure_fp16_program
else:
return self._infer_program
@property
def forward_program(self):
if self.training:
if _in_amp_guard():
program = self._train_amp_forward_backward_program
return program[0]
elif _in_pure_fp16_guard():
program = self._train_pure_fp16_forward_backward_program
return program[0]
else:
program = self._train_forward_backward_program
return program[0]
else:
return self.infer_program
@property
def backward_program(self):
if self.training:
if _in_amp_guard():
program = self._train_amp_forward_backward_program
return program[1]
elif _in_pure_fp16_guard():
program = self._train_pure_fp16_forward_backward_program
return program[1]
else:
program = self._train_forward_backward_program
return program[1]
else:
"""
Can't just return paddle.static.Program(), because self.backward_program is a property,
whenever we call this method, a tmp Program() object is created and is gc immediatly
after executed the following line in PartialProgramLayer.__call__.
>>> self.backward_program.desc.block(0),
When we access RunProgramAPI, it's possible to get an invalid backward_program address.
"""
return self._empty_backward_program_for_eval
@LazyInitialized
def _train_program_id(self):
program_id = _hash_with_id(self._train_program, self)
......@@ -430,24 +403,43 @@ class PartialProgramLayer:
@LazyInitialized
def _out_grad_names(self):
"""
Parse Out@GARD name from original train and infer program.
"""
names = []
fwd_end_op_index = self._get_end_op_index()
origin_infer_program = self._create_program(is_infer_mode=True)
origin_train_program = self._train_program
fwd_end_op_index = len(origin_infer_program.block(0).ops)
for i in range(
fwd_end_op_index + 1,
min(
fwd_end_op_index + 2 * len(self._outputs.var_ids),
len(self.program.block(0).ops),
len(origin_train_program.block(0).ops),
),
2,
):
op = self.program.block(0).ops[i]
op = origin_train_program.block(0).ops[i]
if op.type == 'fill_constant':
var_name = op.output('Out')[0]
names.append(var_name)
return names
@property
def whole_program_id(self):
def program(self):
"""
Return current train or eval program.
"""
if self.training:
return self.train_program
else:
return self.infer_program
@property
def program_id(self):
"""
Return current train or eval program hash id.
"""
if self.training:
if _in_amp_guard():
return self._train_amp_program_id
......@@ -463,6 +455,59 @@ class PartialProgramLayer:
else:
return self._infer_program_id
@property
def train_program(self):
if _in_amp_guard():
return self._train_amp_program
elif _in_pure_fp16_guard():
return self._train_pure_fp16_program
else:
return self._train_program
@property
def infer_program(self):
if _in_amp_guard():
return self._infer_amp_program
elif _in_pure_fp16_guard():
return self._infer_pure_fp16_program
else:
return self._infer_program
@property
def forward_program(self):
if self.training:
if _in_amp_guard():
progs = self._train_amp_forward_backward_program
elif _in_pure_fp16_guard():
progs = self._train_pure_fp16_forward_backward_program
else:
progs = self._train_forward_backward_program
return progs[0]
else:
return self.infer_program
@property
def backward_program(self):
if self.training:
if _in_amp_guard():
progs = self._train_amp_forward_backward_program
elif _in_pure_fp16_guard():
progs = self._train_pure_fp16_forward_backward_program
else:
progs = self._train_forward_backward_program
return progs[1]
else:
"""
Can't just return paddle.static.Program(), because self.backward_program is a property,
whenever we call this method, a tmp Program() object is created and is gc immediatly
after executed the following line in PartialProgramLayer.__call__.
>>> self.backward_program.desc.block(0),
When we access RunProgramAPI, it's possible to get an invalid backward_program address.
"""
return self._empty_backward_program_for_eval
def _verify_program(self, main_program):
"""
Verify that the program parameter is initialized, prune some unused params,
......@@ -725,35 +770,6 @@ class PartialProgramLayer:
in_vars[i] = var.astype('float16')
in_vars[i].name = name
@property
def program(self):
return self.whole_program
@property
def program_id(self):
return self.whole_program_id
@property
def train_program(self):
if _in_amp_guard():
return self._train_amp_program
elif _in_pure_fp16_guard():
return self._train_pure_fp16_program
else:
return self._train_program
@property
def infer_program(self):
if _in_amp_guard():
program = self._infer_amp_program
elif _in_pure_fp16_guard():
program = self._infer_pure_fp16_program
else:
program = self._infer_program
return self._build_infer_program(
program, program.desc.block(0).op_size()
)
@switch_to_static_graph
def _build_infer_program(self, infer_program, forward_end_op_index):
forward_skip_vars = self._parse_skip_gc_vars(infer_program)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册