未验证 提交 4230bd87 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2St]Remove PE logic in @to_static (#50512)

* [Dy2St]Remove PE logic in @to_static

* fix typo

* fix infer_program

* fix typo

* fix op_size
上级 bc731487
......@@ -96,25 +96,16 @@ inline void run_program_ad_func(
grad_node->SetGradOutMeta(x, /*slot id*/ 0);
grad_node->SetGradOutMeta(params, /*slot id*/ 1);
bool use_interpretorcore =
PADDLE_GET_CONST(bool, attrs.at("use_interpretorcore"));
VLOG(2) << "clear_no_grad_edges.";
if (use_interpretorcore) {
auto* forward_global_block = PADDLE_GET_CONST(
paddle::framework::BlockDesc*, attrs.at("forward_global_block"));
auto* backward_global_block = PADDLE_GET_CONST(
paddle::framework::BlockDesc*, attrs.at("backward_global_block"));
clear_no_grad_edges_with_partial_block(params,
forward_global_block,
backward_global_block,
grad_node.get(),
/*slot id*/ 1);
} else {
auto* global_block = PADDLE_GET_CONST(paddle::framework::BlockDesc*,
attrs.at("global_block"));
clear_no_grad_edges(params, global_block, grad_node.get(), /*slot id*/ 1);
}
auto* forward_global_block = PADDLE_GET_CONST(
paddle::framework::BlockDesc*, attrs.at("forward_global_block"));
auto* backward_global_block = PADDLE_GET_CONST(
paddle::framework::BlockDesc*, attrs.at("backward_global_block"));
clear_no_grad_edges_with_partial_block(params,
forward_global_block,
backward_global_block,
grad_node.get(),
/*slot id*/ 1);
grad_node->SetGradInMeta(deref_out, 0);
......
......@@ -23,10 +23,6 @@ from paddle.fluid import backward, core, framework, program_guard
from paddle.fluid.compiler import BuildStrategy
from paddle.fluid.dygraph import layers
from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.executor import (
_is_dy2st_enable_standalone_executor,
_is_enable_standalone_executor,
)
from paddle.fluid.framework import _apply_pass
from paddle.fluid.layers.utils import _hash_with_id, flatten, pack_sequence_as
......@@ -128,14 +124,26 @@ class ProgramInfo:
A helper class to recoder Program information
"""
def __init__(self, mode='infer'):
def __init__(self):
self.op_size = {
'fp32': -1,
'amp': -1,
'fp16': -1,
}
assert mode in ['train', 'infer']
self.mode = mode
self.programs = {}
self.mode = "infer"
def __call__(self, key, prog_creator):
"""
Recoder infer program and op size.
"""
assert key in ['fp32', 'amp', 'fp16']
if key not in self.programs:
infer_prog = prog_creator(is_infer_mode=True)
self.programs[key] = infer_prog
self.op_size[key] = infer_prog.desc.block(0).op_size()
return self.programs[key], self.op_size[key]
class PartialProgramLayer:
......@@ -176,7 +184,7 @@ class PartialProgramLayer:
self._cuda_graph_pool_id = 0
# Set default mode to train
self.training = True
self._infer_info = ProgramInfo(mode='infer')
self._infer_info = ProgramInfo()
custom_white_list, custom_black_list = None, None
tracer = framework._dygraph_tracer()
......@@ -191,6 +199,28 @@ class PartialProgramLayer:
# program_id -> list(scope)
self._scope_cache = {}
def __call__(self, inputs):
"""
Execute static graph by Interpreter and Return dynamic Tensors.
"""
in_vars, out_vars = self._prepare(inputs)
self._cast_fp16_if_pure_fp16(in_vars)
attrs = self._prepare_attributes()
_legacy_C_ops.run_program(
self._valid_vars(in_vars),
self._valid_vars(self._params),
self._valid_vars(out_vars),
self._create_scope_vec(
program_id=self.program_id, use_scope_cache=True
),
self._double_grads,
self._cuda_graph_vec,
*attrs
)
restored_nest_out = self._restore_out(out_vars)
return self._remove_no_value(restored_nest_out)
def _get_scope(self, program_id=None, use_scope_cache=False):
if use_scope_cache:
if program_id not in self._scope_cache:
......@@ -259,8 +289,9 @@ class PartialProgramLayer:
@switch_to_static_graph
def _create_forward_backward_train_program(self):
whole_program = self._train_program
forward_end_op_index = self._infer_info.op_size['fp32']
_, forward_end_op_index = self._infer_info('fp32', self._create_program)
assert forward_end_op_index >= 0
return self._get_forward_backward_program_form(
whole_program, forward_end_op_index
)
......@@ -268,8 +299,11 @@ class PartialProgramLayer:
@switch_to_static_graph
def _create_forward_backward_train_amp_program(self):
whole_program = self._train_amp_program
forward_end_op_index = self._infer_info.op_size['amp']
_, forward_end_op_index = self._infer_info(
'amp', self._create_amp_program
)
assert forward_end_op_index >= 0
return self._get_forward_backward_program_form(
whole_program, forward_end_op_index
)
......@@ -277,8 +311,11 @@ class PartialProgramLayer:
@switch_to_static_graph
def _create_forward_backward_train_pure_fp16_program(self):
whole_program = self._train_pure_fp16_program
forward_end_op_index = self._infer_info.op_size['fp16']
_, forward_end_op_index = self._infer_info(
'fp16', self._create_pure_fp16_program
)
assert forward_end_op_index >= 0
return self._get_forward_backward_program_form(
whole_program, forward_end_op_index
)
......@@ -289,11 +326,8 @@ class PartialProgramLayer:
@LazyInitialized
def _infer_program(self):
program = self._create_program(is_infer_mode=True)
self._infer_info.op_size['fp32'] = program.desc.block(0).op_size()
return self._build_infer_program(
program, self._infer_info.op_size['fp32']
)
program, op_size = self._infer_info('fp32', self._create_program)
return self._build_infer_program(program, op_size)
@LazyInitialized
def _train_amp_program(self):
......@@ -301,11 +335,8 @@ class PartialProgramLayer:
@LazyInitialized
def _infer_amp_program(self):
program = self._create_amp_program(is_infer_mode=True)
self._infer_info.op_size['amp'] = program.desc.block(0).op_size()
return self._build_infer_program(
program, self._infer_info.op_size['amp']
)
program, op_size = self._infer_info('amp', self._create_amp_program)
return self._build_infer_program(program, op_size)
@LazyInitialized
def _train_pure_fp16_program(self):
......@@ -313,11 +344,10 @@ class PartialProgramLayer:
@LazyInitialized
def _infer_pure_fp16_program(self):
program = self._create_pure_fp16_program(is_infer_mode=True)
self._infer_info.op_size['fp16'] = program.desc.block(0).op_size()
return self._build_infer_program(
program, self._infer_info.op_size['fp16']
program, op_size = self._infer_info(
'fp16', self._create_pure_fp16_program
)
return self._build_infer_program(program, op_size)
@LazyInitialized
def _train_forward_backward_program(self):
......@@ -632,27 +662,24 @@ class PartialProgramLayer:
double_grads.append(var_base)
return self._valid_vars(double_grads)
def _get_end_op_index(self):
if _in_amp_guard():
infer_program = self._infer_amp_program
elif _in_pure_fp16_guard():
infer_program = self._infer_pure_fp16_program
else:
infer_program = self._infer_program
return infer_program.desc.block(0).op_size()
def __call__(self, inputs):
in_vars, out_vars = self._prepare(inputs)
self._cast_fp16_if_pure_fp16(in_vars)
def _cast_fp16_if_pure_fp16(self, in_vars):
if _in_pure_fp16_guard():
for i, var in enumerate(in_vars):
name = var.name
if (
self.program.global_block().has_var(name)
and self.program.global_block().var(name).dtype
== paddle.float16
):
in_vars[i] = var.astype('float16')
in_vars[i].name = name
def _prepare_attributes(self):
attrs = [
'global_block',
self.program.desc.block(0),
'start_op_index',
0,
'end_op_index',
self._get_end_op_index(),
'forward_global_block',
self.forward_program.desc.block(0),
'backward_global_block',
self.backward_program.desc.block(0),
'is_test',
not self.training,
'program_id',
......@@ -679,57 +706,7 @@ class PartialProgramLayer:
self._cuda_graph_pool_id,
)
)
use_interpretorcore = (
_is_enable_standalone_executor()
and _is_dy2st_enable_standalone_executor()
)
attrs.extend(('use_interpretorcore', use_interpretorcore))
if use_interpretorcore:
attrs.extend(
(
'forward_global_block',
self.forward_program.desc.block(0),
'backward_global_block',
self.backward_program.desc.block(0),
)
)
_legacy_C_ops.run_program(
self._valid_vars(in_vars),
self._valid_vars(self._params),
self._valid_vars(out_vars),
self._create_scope_vec(
program_id=self.program_id, use_scope_cache=True
),
self._double_grads,
self._cuda_graph_vec,
*attrs
)
else:
_legacy_C_ops.run_program(
self._valid_vars(in_vars),
self._valid_vars(self._params),
self._valid_vars(out_vars),
self._create_scope_vec(),
self._double_grads,
self._cuda_graph_vec,
*attrs
)
restored_nest_out = self._restore_out(out_vars)
return self._remove_no_value(restored_nest_out)
def _cast_fp16_if_pure_fp16(self, in_vars):
if _in_pure_fp16_guard():
for i, var in enumerate(in_vars):
name = var.name
if (
self.program.global_block().has_var(name)
and self.program.global_block().var(name).dtype
== paddle.float16
):
in_vars[i] = var.astype('float16')
in_vars[i].name = name
return attrs
@switch_to_static_graph
def _build_infer_program(self, infer_program, forward_end_op_index):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册