diff --git a/python/paddle/jit/dy2static/partial_program.py b/python/paddle/jit/dy2static/partial_program.py index ea41f5626f392de7e1e16a6e105a31c3900c6c80..1e4ab1b3110019d94b58b017b1dfb9bc1e6e8aa6 100644 --- a/python/paddle/jit/dy2static/partial_program.py +++ b/python/paddle/jit/dy2static/partial_program.py @@ -347,12 +347,7 @@ class PartialProgramLayer: program = self._train_forward_backward_program return program[0] else: - if _in_amp_guard(): - return self._infer_amp_program - elif _in_pure_fp16_guard(): - return self._infer_pure_fp16_program - else: - return self._infer_program + return self.infer_program @property def backward_program(self): @@ -637,7 +632,7 @@ class PartialProgramLayer: elif _in_pure_fp16_guard(): infer_program = self._infer_pure_fp16_program else: - infer_program = self.infer_program + infer_program = self._infer_program return infer_program.desc.block(0).op_size() def __call__(self, inputs): @@ -750,11 +745,27 @@ class PartialProgramLayer: @property def infer_program(self): if _in_amp_guard(): - return self._infer_amp_program + program = self._infer_amp_program elif _in_pure_fp16_guard(): - return self._infer_pure_fp16_program + program = self._infer_pure_fp16_program else: - return self._infer_program + program = self._infer_program + return self._build_infer_program( + program, program.desc.block(0).op_size() + ) + + @switch_to_static_graph + def _build_infer_program(self, infer_program, forward_end_op_index): + forward_skip_vars = self._parse_skip_gc_vars(infer_program) + builded_infer_program = add_build_strategy_for( + infer_program, + 0, + forward_end_op_index, + self._build_strategy, + forward_skip_vars, + ) + self._apply_inplace_pass(builded_infer_program, None) + return builded_infer_program @switch_to_static_graph def _get_forward_backward_program_form( @@ -808,30 +819,32 @@ class PartialProgramLayer: forward_program, backward_program ) backward_mem_opt_skip_vars = self._parse_skip_gc_vars(forward_program) - attrs = { - "use_cuda": use_cuda, - "mem_opt_skip_vars": forward_mem_opt_skip_vars, - "for_partial_block": True, - } - _apply_pass( - forward_program, - empty_startup_program, - "buffer_shared_inplace_pass", - attrs, - attr_types, - ) - attrs = { - "use_cuda": use_cuda, - "mem_opt_skip_vars": backward_mem_opt_skip_vars, - "for_partial_block": True, - } - _apply_pass( - backward_program, - empty_startup_program, - "buffer_shared_inplace_pass", - attrs, - attr_types, - ) + if forward_program: + attrs = { + "use_cuda": use_cuda, + "mem_opt_skip_vars": forward_mem_opt_skip_vars, + "for_partial_block": True, + } + _apply_pass( + forward_program, + empty_startup_program, + "buffer_shared_inplace_pass", + attrs, + attr_types, + ) + if backward_program: + attrs = { + "use_cuda": use_cuda, + "mem_opt_skip_vars": backward_mem_opt_skip_vars, + "for_partial_block": True, + } + _apply_pass( + backward_program, + empty_startup_program, + "buffer_shared_inplace_pass", + attrs, + attr_types, + ) @LazyInitialized def _inout_var_names(self):