未验证 提交 9415a6af 编写于 作者: Z zhangbo9674 提交者: GitHub

Add build strategy for infer program of dy2st (#49641)

* add build strategy for infer program of dy2st

* refine code

* fix bug
上级 52638c1f
...@@ -347,12 +347,7 @@ class PartialProgramLayer: ...@@ -347,12 +347,7 @@ class PartialProgramLayer:
program = self._train_forward_backward_program program = self._train_forward_backward_program
return program[0] return program[0]
else: else:
if _in_amp_guard(): return self.infer_program
return self._infer_amp_program
elif _in_pure_fp16_guard():
return self._infer_pure_fp16_program
else:
return self._infer_program
@property @property
def backward_program(self): def backward_program(self):
...@@ -637,7 +632,7 @@ class PartialProgramLayer: ...@@ -637,7 +632,7 @@ class PartialProgramLayer:
elif _in_pure_fp16_guard(): elif _in_pure_fp16_guard():
infer_program = self._infer_pure_fp16_program infer_program = self._infer_pure_fp16_program
else: else:
infer_program = self.infer_program infer_program = self._infer_program
return infer_program.desc.block(0).op_size() return infer_program.desc.block(0).op_size()
def __call__(self, inputs): def __call__(self, inputs):
...@@ -750,11 +745,27 @@ class PartialProgramLayer: ...@@ -750,11 +745,27 @@ class PartialProgramLayer:
@property @property
def infer_program(self): def infer_program(self):
if _in_amp_guard(): if _in_amp_guard():
return self._infer_amp_program program = self._infer_amp_program
elif _in_pure_fp16_guard(): elif _in_pure_fp16_guard():
return self._infer_pure_fp16_program program = self._infer_pure_fp16_program
else: else:
return self._infer_program program = self._infer_program
return self._build_infer_program(
program, program.desc.block(0).op_size()
)
@switch_to_static_graph
def _build_infer_program(self, infer_program, forward_end_op_index):
forward_skip_vars = self._parse_skip_gc_vars(infer_program)
builded_infer_program = add_build_strategy_for(
infer_program,
0,
forward_end_op_index,
self._build_strategy,
forward_skip_vars,
)
self._apply_inplace_pass(builded_infer_program, None)
return builded_infer_program
@switch_to_static_graph @switch_to_static_graph
def _get_forward_backward_program_form( def _get_forward_backward_program_form(
...@@ -808,30 +819,32 @@ class PartialProgramLayer: ...@@ -808,30 +819,32 @@ class PartialProgramLayer:
forward_program, backward_program forward_program, backward_program
) )
backward_mem_opt_skip_vars = self._parse_skip_gc_vars(forward_program) backward_mem_opt_skip_vars = self._parse_skip_gc_vars(forward_program)
attrs = { if forward_program:
"use_cuda": use_cuda, attrs = {
"mem_opt_skip_vars": forward_mem_opt_skip_vars, "use_cuda": use_cuda,
"for_partial_block": True, "mem_opt_skip_vars": forward_mem_opt_skip_vars,
} "for_partial_block": True,
_apply_pass( }
forward_program, _apply_pass(
empty_startup_program, forward_program,
"buffer_shared_inplace_pass", empty_startup_program,
attrs, "buffer_shared_inplace_pass",
attr_types, attrs,
) attr_types,
attrs = { )
"use_cuda": use_cuda, if backward_program:
"mem_opt_skip_vars": backward_mem_opt_skip_vars, attrs = {
"for_partial_block": True, "use_cuda": use_cuda,
} "mem_opt_skip_vars": backward_mem_opt_skip_vars,
_apply_pass( "for_partial_block": True,
backward_program, }
empty_startup_program, _apply_pass(
"buffer_shared_inplace_pass", backward_program,
attrs, empty_startup_program,
attr_types, "buffer_shared_inplace_pass",
) attrs,
attr_types,
)
@LazyInitialized @LazyInitialized
def _inout_var_names(self): def _inout_var_names(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册