未验证 提交 9415a6af 编写于 作者: Z zhangbo9674 提交者: GitHub

Add build strategy for infer program of dy2st (#49641)

* add build strategy for infer program of dy2st

* refine code

* fix bug
上级 52638c1f
...@@ -347,12 +347,7 @@ class PartialProgramLayer: ...@@ -347,12 +347,7 @@ class PartialProgramLayer:
program = self._train_forward_backward_program program = self._train_forward_backward_program
return program[0] return program[0]
else: else:
if _in_amp_guard(): return self.infer_program
return self._infer_amp_program
elif _in_pure_fp16_guard():
return self._infer_pure_fp16_program
else:
return self._infer_program
@property @property
def backward_program(self): def backward_program(self):
...@@ -637,7 +632,7 @@ class PartialProgramLayer: ...@@ -637,7 +632,7 @@ class PartialProgramLayer:
elif _in_pure_fp16_guard(): elif _in_pure_fp16_guard():
infer_program = self._infer_pure_fp16_program infer_program = self._infer_pure_fp16_program
else: else:
infer_program = self.infer_program infer_program = self._infer_program
return infer_program.desc.block(0).op_size() return infer_program.desc.block(0).op_size()
def __call__(self, inputs): def __call__(self, inputs):
...@@ -750,11 +745,27 @@ class PartialProgramLayer: ...@@ -750,11 +745,27 @@ class PartialProgramLayer:
@property @property
def infer_program(self): def infer_program(self):
if _in_amp_guard(): if _in_amp_guard():
return self._infer_amp_program program = self._infer_amp_program
elif _in_pure_fp16_guard(): elif _in_pure_fp16_guard():
return self._infer_pure_fp16_program program = self._infer_pure_fp16_program
else: else:
return self._infer_program program = self._infer_program
return self._build_infer_program(
program, program.desc.block(0).op_size()
)
@switch_to_static_graph
def _build_infer_program(self, infer_program, forward_end_op_index):
forward_skip_vars = self._parse_skip_gc_vars(infer_program)
builded_infer_program = add_build_strategy_for(
infer_program,
0,
forward_end_op_index,
self._build_strategy,
forward_skip_vars,
)
self._apply_inplace_pass(builded_infer_program, None)
return builded_infer_program
@switch_to_static_graph @switch_to_static_graph
def _get_forward_backward_program_form( def _get_forward_backward_program_form(
...@@ -808,6 +819,7 @@ class PartialProgramLayer: ...@@ -808,6 +819,7 @@ class PartialProgramLayer:
forward_program, backward_program forward_program, backward_program
) )
backward_mem_opt_skip_vars = self._parse_skip_gc_vars(forward_program) backward_mem_opt_skip_vars = self._parse_skip_gc_vars(forward_program)
if forward_program:
attrs = { attrs = {
"use_cuda": use_cuda, "use_cuda": use_cuda,
"mem_opt_skip_vars": forward_mem_opt_skip_vars, "mem_opt_skip_vars": forward_mem_opt_skip_vars,
...@@ -820,6 +832,7 @@ class PartialProgramLayer: ...@@ -820,6 +832,7 @@ class PartialProgramLayer:
attrs, attrs,
attr_types, attr_types,
) )
if backward_program:
attrs = { attrs = {
"use_cuda": use_cuda, "use_cuda": use_cuda,
"mem_opt_skip_vars": backward_mem_opt_skip_vars, "mem_opt_skip_vars": backward_mem_opt_skip_vars,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册