未验证 提交 9415a6af 编写于 作者: Z zhangbo9674 提交者: GitHub

Add build strategy for infer program of dy2st (#49641)

* add build strategy for infer program of dy2st

* refine code

* fix bug
上级 52638c1f
......@@ -347,12 +347,7 @@ class PartialProgramLayer:
program = self._train_forward_backward_program
return program[0]
else:
if _in_amp_guard():
return self._infer_amp_program
elif _in_pure_fp16_guard():
return self._infer_pure_fp16_program
else:
return self._infer_program
return self.infer_program
@property
def backward_program(self):
......@@ -637,7 +632,7 @@ class PartialProgramLayer:
elif _in_pure_fp16_guard():
infer_program = self._infer_pure_fp16_program
else:
infer_program = self.infer_program
infer_program = self._infer_program
return infer_program.desc.block(0).op_size()
def __call__(self, inputs):
......@@ -750,11 +745,27 @@ class PartialProgramLayer:
@property
def infer_program(self):
if _in_amp_guard():
return self._infer_amp_program
program = self._infer_amp_program
elif _in_pure_fp16_guard():
return self._infer_pure_fp16_program
program = self._infer_pure_fp16_program
else:
return self._infer_program
program = self._infer_program
return self._build_infer_program(
program, program.desc.block(0).op_size()
)
@switch_to_static_graph
def _build_infer_program(self, infer_program, forward_end_op_index):
forward_skip_vars = self._parse_skip_gc_vars(infer_program)
builded_infer_program = add_build_strategy_for(
infer_program,
0,
forward_end_op_index,
self._build_strategy,
forward_skip_vars,
)
self._apply_inplace_pass(builded_infer_program, None)
return builded_infer_program
@switch_to_static_graph
def _get_forward_backward_program_form(
......@@ -808,6 +819,7 @@ class PartialProgramLayer:
forward_program, backward_program
)
backward_mem_opt_skip_vars = self._parse_skip_gc_vars(forward_program)
if forward_program:
attrs = {
"use_cuda": use_cuda,
"mem_opt_skip_vars": forward_mem_opt_skip_vars,
......@@ -820,6 +832,7 @@ class PartialProgramLayer:
attrs,
attr_types,
)
if backward_program:
attrs = {
"use_cuda": use_cuda,
"mem_opt_skip_vars": backward_mem_opt_skip_vars,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册