未验证 提交 4230bd87 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2St]Remove PE logic in @to_static (#50512)

* [Dy2St]Remove PE logic in @to_static

* fix typo

* fix infer_program

* fix typo

* fix op_size
上级 bc731487
...@@ -96,10 +96,7 @@ inline void run_program_ad_func( ...@@ -96,10 +96,7 @@ inline void run_program_ad_func(
grad_node->SetGradOutMeta(x, /*slot id*/ 0); grad_node->SetGradOutMeta(x, /*slot id*/ 0);
grad_node->SetGradOutMeta(params, /*slot id*/ 1); grad_node->SetGradOutMeta(params, /*slot id*/ 1);
bool use_interpretorcore =
PADDLE_GET_CONST(bool, attrs.at("use_interpretorcore"));
VLOG(2) << "clear_no_grad_edges."; VLOG(2) << "clear_no_grad_edges.";
if (use_interpretorcore) {
auto* forward_global_block = PADDLE_GET_CONST( auto* forward_global_block = PADDLE_GET_CONST(
paddle::framework::BlockDesc*, attrs.at("forward_global_block")); paddle::framework::BlockDesc*, attrs.at("forward_global_block"));
auto* backward_global_block = PADDLE_GET_CONST( auto* backward_global_block = PADDLE_GET_CONST(
...@@ -110,12 +107,6 @@ inline void run_program_ad_func( ...@@ -110,12 +107,6 @@ inline void run_program_ad_func(
grad_node.get(), grad_node.get(),
/*slot id*/ 1); /*slot id*/ 1);
} else {
auto* global_block = PADDLE_GET_CONST(paddle::framework::BlockDesc*,
attrs.at("global_block"));
clear_no_grad_edges(params, global_block, grad_node.get(), /*slot id*/ 1);
}
grad_node->SetGradInMeta(deref_out, 0); grad_node->SetGradInMeta(deref_out, 0);
egr::EagerUtils::SetOutRankWithSlot(&p_autograd_outs, 0); egr::EagerUtils::SetOutRankWithSlot(&p_autograd_outs, 0);
......
...@@ -304,10 +304,6 @@ inline void RunProgramAPI( ...@@ -304,10 +304,6 @@ inline void RunProgramAPI(
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
"The OutScope of RunProgramGradOp should only hold one scope.")); "The OutScope of RunProgramGradOp should only hold one scope."));
bool use_interpretorcore =
PADDLE_GET_CONST(bool, attrs.at("use_interpretorcore"));
if (use_interpretorcore) {
VLOG(2) << "RunProgramOp use interpretercore to execute program."; VLOG(2) << "RunProgramOp use interpretercore to execute program.";
paddle::framework::Scope *global_inner_scope = out_scope_vec->front(); paddle::framework::Scope *global_inner_scope = out_scope_vec->front();
...@@ -359,8 +355,8 @@ inline void RunProgramAPI( ...@@ -359,8 +355,8 @@ inline void RunProgramAPI(
details::ShareTensorsIntoScope(x, global_inner_scope); details::ShareTensorsIntoScope(x, global_inner_scope);
details::ShareTensorsIntoScope(params, global_inner_scope); details::ShareTensorsIntoScope(params, global_inner_scope);
// Step 2. create new interpretercore // Step 2. create new interpretercore
interpreter_core = paddle::framework::CreateInterpreterCoreInfoToCache( interpreter_core =
*forward_program, paddle::framework::CreateInterpreterCoreInfoToCache(*forward_program,
place, place,
/*is_grad=*/false, /*is_grad=*/false,
program_id, program_id,
...@@ -425,17 +421,14 @@ inline void RunProgramAPI( ...@@ -425,17 +421,14 @@ inline void RunProgramAPI(
paddle::platform::RecordEvent record_event( paddle::platform::RecordEvent record_event(
"fetch_and_gc", paddle::platform::TracerEventType::UserDefined, 1); "fetch_and_gc", paddle::platform::TracerEventType::UserDefined, 1);
// Get Output // Get Output
details::ShareTensorsFromScopeWithPartialBlock(out, details::ShareTensorsFromScopeWithPartialBlock(
*forward_global_block, out, *forward_global_block, *backward_global_block, global_inner_scope);
*backward_global_block,
global_inner_scope);
details::ShareTensorsFromScopeWithPartialBlock(dout, details::ShareTensorsFromScopeWithPartialBlock(dout,
*forward_global_block, *forward_global_block,
*backward_global_block, *backward_global_block,
global_inner_scope); global_inner_scope);
VLOG(3) << paddle::framework::GenScopeTreeDebugInfo( VLOG(3) << paddle::framework::GenScopeTreeDebugInfo(out_scope_vec->front());
out_scope_vec->front());
if (is_test || !egr::Controller::Instance().HasGrad()) { if (is_test || !egr::Controller::Instance().HasGrad()) {
VLOG(4) << "is test, set this scope can reused"; VLOG(4) << "is test, set this scope can reused";
...@@ -450,80 +443,6 @@ inline void RunProgramAPI( ...@@ -450,80 +443,6 @@ inline void RunProgramAPI(
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (FLAGS_use_mkldnn) paddle::platform::DontClearMKLDNNCache(place); if (FLAGS_use_mkldnn) paddle::platform::DontClearMKLDNNCache(place);
#endif #endif
} else {
VLOG(2) << "RunProgramOp execute with parallel_executor.";
// Step 2. prepare executor and init persistable variables
// NOTE(Aurelius84): While training some models, forward can be called many
// times and then apply backpropagation all at once, such as Reinforcement
// Learning. Tensor data in multi-step training should be saved into single
// scope separately. Otherwise, the gradients can be miscalculated because
// always using the Tensor data of the last step in forward.
paddle::framework::Scope *global_inner_scope = out_scope_vec->front();
VLOG(2) << "The number of sub scopes before forward: "
<< out_scope_vec->front()->kids().size();
paddle::framework::Scope &scope = global_inner_scope->NewScope();
// share input_vars & parameters into scope
details::ShareTensorsIntoScope(x, &scope);
details::ShareTensorsIntoScope(params, &scope);
const auto &place = egr::Controller::Instance().GetExpectedPlace();
auto *global_block = PADDLE_GET_CONST(paddle::framework::BlockDesc *,
attrs.at("global_block"));
auto start_op_index = PADDLE_GET_CONST(int64_t, attrs.at("start_op_index"));
auto end_op_index = PADDLE_GET_CONST(int64_t, attrs.at("end_op_index"));
if (end_op_index > start_op_index) {
auto input_names = details::GetTensorsName(x);
auto output_names = details::GetTensorsName(out);
auto dout_names = details::GetTensorsName(dout);
auto *program = global_block->Program();
auto cache_info =
paddle::framework::GetExecutorInfoFromCache(*program,
place,
start_op_index,
end_op_index,
/*is_grad=*/false,
program_id,
&scope);
auto &parallel_executor = cache_info.first;
// all out_vars are skip_eager_var
auto &skip_eager_delete_vars =
paddle::framework::ExecutorInfoCache::Instance().SkipEagerDeleteVars(
program_id, false);
if (cache_info.second /*is_new_created*/) {
parallel_executor->SkipMemoryReuse(/*scope_idx=*/0, input_names);
skip_eager_delete_vars.insert(skip_eager_delete_vars.end(),
output_names.begin(),
output_names.end());
skip_eager_delete_vars.insert(
skip_eager_delete_vars.end(), dout_names.begin(), dout_names.end());
paddle::framework::details::ParseSafeEagerDeletionSkipVars(
*program, end_op_index, output_names, &skip_eager_delete_vars);
}
// Step 3. run ops
parallel_executor->RunWithoutFetch(skip_eager_delete_vars);
}
// Step 4. Get Output
details::ShareTensorsFromScope(out, *global_block, &scope);
details::ShareTensorsFromScope(dout, *global_block, &scope);
// Debug info: scope info when run end
VLOG(3) << paddle::framework::GenScopeTreeDebugInfo(out_scope_vec->front());
// Step 5. Drop all children scopes while testing.
if (is_test || !egr::Controller::Instance().HasGrad()) {
out_scope_vec->front()->DropKids();
}
VLOG(2) << "The number of sub scopes after forward: "
<< out_scope_vec->front()->kids().size();
#ifdef PADDLE_WITH_MKLDNN
if (FLAGS_use_mkldnn) paddle::platform::DontClearMKLDNNCache(place);
#endif
}
} }
inline void RunProgramGradAPI( inline void RunProgramGradAPI(
...@@ -538,8 +457,6 @@ inline void RunProgramGradAPI( ...@@ -538,8 +457,6 @@ inline void RunProgramGradAPI(
// if all output vars are set to stop_gradient, grad op no need to executed // if all output vars are set to stop_gradient, grad op no need to executed
if (x_grad.empty() && params_grad.empty()) return; if (x_grad.empty() && params_grad.empty()) return;
bool use_interpretorcore =
PADDLE_GET_CONST(bool, attrs.at("use_interpretorcore"));
auto program_id = PADDLE_GET_CONST(int64_t, attrs.at("program_id")); auto program_id = PADDLE_GET_CONST(int64_t, attrs.at("program_id"));
auto *out_scope_vec = &step_scope; auto *out_scope_vec = &step_scope;
...@@ -550,8 +467,6 @@ inline void RunProgramGradAPI( ...@@ -550,8 +467,6 @@ inline void RunProgramGradAPI(
"The OutScope of RunProgramGradOp should only hold one scope.")); "The OutScope of RunProgramGradOp should only hold one scope."));
auto place = egr::Controller::Instance().GetExpectedPlace(); auto place = egr::Controller::Instance().GetExpectedPlace();
if (use_interpretorcore) {
VLOG(2) << "RunProgramGradOp use interpretercore to execute program."; VLOG(2) << "RunProgramGradOp use interpretercore to execute program.";
paddle::framework::Scope *global_inner_scope = out_scope_vec->front(); paddle::framework::Scope *global_inner_scope = out_scope_vec->front();
...@@ -574,8 +489,8 @@ inline void RunProgramGradAPI( ...@@ -574,8 +489,8 @@ inline void RunProgramGradAPI(
1); 1);
VLOG(2) << "No interpretercore cahce, so create a new interpretercore"; VLOG(2) << "No interpretercore cahce, so create a new interpretercore";
details::ShareTensorsIntoScope(out_grad, global_inner_scope); details::ShareTensorsIntoScope(out_grad, global_inner_scope);
interpreter_core = paddle::framework::CreateInterpreterCoreInfoToCache( interpreter_core =
*backward_program, paddle::framework::CreateInterpreterCoreInfoToCache(*backward_program,
place, place,
/*is_grad=*/true, /*is_grad=*/true,
program_id, program_id,
...@@ -589,8 +504,8 @@ inline void RunProgramGradAPI( ...@@ -589,8 +504,8 @@ inline void RunProgramGradAPI(
interpretercore_info_cache.GetMutable(program_id, /*is_grad=*/false) interpretercore_info_cache.GetMutable(program_id, /*is_grad=*/false)
.core_; .core_;
interpreter_core->ShareWorkQueueFrom(fwd_interpreter_core); interpreter_core->ShareWorkQueueFrom(fwd_interpreter_core);
VLOG(4) << "Share workqueue from " << fwd_interpreter_core.get() VLOG(4) << "Share workqueue from " << fwd_interpreter_core.get() << " to "
<< " to " << interpreter_core.get(); << interpreter_core.get();
} }
std::vector<std::string> x_grad_names; std::vector<std::string> x_grad_names;
...@@ -606,8 +521,8 @@ inline void RunProgramGradAPI( ...@@ -606,8 +521,8 @@ inline void RunProgramGradAPI(
// all out_vars are skip_eager_var // all out_vars are skip_eager_var
skip_eager_delete_vars.insert(x_grad_names.begin(), x_grad_names.end()); skip_eager_delete_vars.insert(x_grad_names.begin(), x_grad_names.end());
// initialize skip gc vars by forward_program and backward_program // initialize skip gc vars by forward_program and backward_program
paddle::framework::details::AppendSkipDeletionVars( paddle::framework::details::AppendSkipDeletionVars(param_grad_names,
param_grad_names, &skip_eager_delete_vars); &skip_eager_delete_vars);
interpreter_core->SetSkipGcVars(skip_eager_delete_vars); interpreter_core->SetSkipGcVars(skip_eager_delete_vars);
interpretercore_info_cache.UpdateSkipEagerDeleteVars( interpretercore_info_cache.UpdateSkipEagerDeleteVars(
program_id, /*is_grad=*/true, skip_eager_delete_vars); program_id, /*is_grad=*/true, skip_eager_delete_vars);
...@@ -626,9 +541,8 @@ inline void RunProgramGradAPI( ...@@ -626,9 +541,8 @@ inline void RunProgramGradAPI(
details::ShareTensorsIntoScope(out_grad, global_inner_scope); details::ShareTensorsIntoScope(out_grad, global_inner_scope);
if (interpreter_core->GetVariableScope()->GetMutableScope() != if (interpreter_core->GetVariableScope()->GetMutableScope() !=
global_inner_scope) { global_inner_scope) {
details::BuildScopeByBlock(*interpreter_core.get(), details::BuildScopeByBlock(
*backward_global_block, *interpreter_core.get(), *backward_global_block, global_inner_scope);
global_inner_scope);
interpreter_core->reset_scope(global_inner_scope); interpreter_core->reset_scope(global_inner_scope);
} }
} }
...@@ -639,8 +553,7 @@ inline void RunProgramGradAPI( ...@@ -639,8 +553,7 @@ inline void RunProgramGradAPI(
paddle::platform::TracerEventType::UserDefined, paddle::platform::TracerEventType::UserDefined,
1); 1);
// Debug info: scope info when run end // Debug info: scope info when run end
VLOG(3) << paddle::framework::GenScopeTreeDebugInfo( VLOG(3) << paddle::framework::GenScopeTreeDebugInfo(out_scope_vec->front());
out_scope_vec->front());
interpreter_core->Run({}); interpreter_core->Run({});
} }
...@@ -660,87 +573,6 @@ inline void RunProgramGradAPI( ...@@ -660,87 +573,6 @@ inline void RunProgramGradAPI(
global_inner_scope->SetCanReuesd(true); global_inner_scope->SetCanReuesd(true);
details::GcScope(global_inner_scope); details::GcScope(global_inner_scope);
} }
} else {
VLOG(2) << "RunProgramGradOp use pe to execute program.";
paddle::framework::Scope *global_inner_scope = out_scope_vec->front();
auto sub_scope_num = global_inner_scope->kids().size();
VLOG(2) << "The number of sub scopes before backward: " << sub_scope_num;
PADDLE_ENFORCE_GT(sub_scope_num,
0,
paddle::platform::errors::InvalidArgument(
"The OutScope of RunProgramGradOp should hold at "
"least one sub scope."));
auto &scope = *(global_inner_scope->kids().front());
auto *global_block = PADDLE_GET_CONST(paddle::framework::BlockDesc *,
attrs.at("global_block"));
auto orig_end_op_index =
PADDLE_GET_CONST(int64_t, attrs.at("end_op_index"));
// NOTE: skip `shape` and `fill_constant` op created by
// fluid.backward.gradients, one forward output will generate one `shape`
// and `fill_constant`
int64_t start_op_index = orig_end_op_index + (out_grad.size() * 2);
int64_t end_op_index = global_block->OpSize();
if (end_op_index > start_op_index) {
auto out_grad_names = details::GetTensorsName(out_grad);
// Step 2. prepare executor and scope
auto *program = global_block->Program();
auto cache_info =
paddle::framework::GetExecutorInfoFromCache(*program,
place,
start_op_index,
end_op_index,
/*is_grad*/ true,
program_id,
&scope);
auto &parallel_executor = cache_info.first;
auto &skip_eager_delete_vars =
paddle::framework::ExecutorInfoCache::Instance().SkipEagerDeleteVars(
program_id, true);
if (cache_info.second /*is_new_created*/) {
parallel_executor->SkipMemoryReuse(/*scope_idx=*/0, out_grad_names);
// NOTE: after PR22939 [Add double grad] merged, the grad op maker's
// SetOutput will set to None if the input var stop_gradient=True,
// it will cause an NotFound error when ctx.OutputNames() is called
std::vector<std::string> x_grad_names;
std::vector<std::string> param_grad_names;
if (!x_grad.empty()) {
x_grad_names = details::GetTensorsName(x_grad);
}
if (!params_grad.empty()) {
param_grad_names = details::GetTensorsName(params_grad);
}
skip_eager_delete_vars.insert(skip_eager_delete_vars.end(),
x_grad_names.begin(),
x_grad_names.end());
paddle::framework::details::AppendSkipDeletionVars(
param_grad_names, &skip_eager_delete_vars);
}
details::ShareTensorsIntoScope(out_grad, &scope);
// Debug info: scope info when run end
VLOG(3) << paddle::framework::GenScopeTreeDebugInfo(
out_scope_vec->front());
// Step 3. run ops
parallel_executor->RunWithoutFetch(
/*skip_eager_delete_vars=*/skip_eager_delete_vars);
}
// Step 4. get outputs
details::ShareTensorsFromScope(x_grad, *global_block, &scope);
details::ShareTensorsFromScope(params_grad, *global_block, &scope);
// Step5. drop current scope
global_inner_scope->DeleteScope(&scope);
VLOG(2) << "The number of sub scopes after backward: "
<< global_inner_scope->kids().size();
}
} }
class GradNodeRunProgram : public egr::GradNodeBase { class GradNodeRunProgram : public egr::GradNodeBase {
......
...@@ -23,10 +23,6 @@ from paddle.fluid import backward, core, framework, program_guard ...@@ -23,10 +23,6 @@ from paddle.fluid import backward, core, framework, program_guard
from paddle.fluid.compiler import BuildStrategy from paddle.fluid.compiler import BuildStrategy
from paddle.fluid.dygraph import layers from paddle.fluid.dygraph import layers
from paddle.fluid.dygraph.base import switch_to_static_graph from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.executor import (
_is_dy2st_enable_standalone_executor,
_is_enable_standalone_executor,
)
from paddle.fluid.framework import _apply_pass from paddle.fluid.framework import _apply_pass
from paddle.fluid.layers.utils import _hash_with_id, flatten, pack_sequence_as from paddle.fluid.layers.utils import _hash_with_id, flatten, pack_sequence_as
...@@ -128,14 +124,26 @@ class ProgramInfo: ...@@ -128,14 +124,26 @@ class ProgramInfo:
A helper class to recoder Program information A helper class to recoder Program information
""" """
def __init__(self, mode='infer'): def __init__(self):
self.op_size = { self.op_size = {
'fp32': -1, 'fp32': -1,
'amp': -1, 'amp': -1,
'fp16': -1, 'fp16': -1,
} }
assert mode in ['train', 'infer'] self.programs = {}
self.mode = mode self.mode = "infer"
def __call__(self, key, prog_creator):
"""
Recoder infer program and op size.
"""
assert key in ['fp32', 'amp', 'fp16']
if key not in self.programs:
infer_prog = prog_creator(is_infer_mode=True)
self.programs[key] = infer_prog
self.op_size[key] = infer_prog.desc.block(0).op_size()
return self.programs[key], self.op_size[key]
class PartialProgramLayer: class PartialProgramLayer:
...@@ -176,7 +184,7 @@ class PartialProgramLayer: ...@@ -176,7 +184,7 @@ class PartialProgramLayer:
self._cuda_graph_pool_id = 0 self._cuda_graph_pool_id = 0
# Set default mode to train # Set default mode to train
self.training = True self.training = True
self._infer_info = ProgramInfo(mode='infer') self._infer_info = ProgramInfo()
custom_white_list, custom_black_list = None, None custom_white_list, custom_black_list = None, None
tracer = framework._dygraph_tracer() tracer = framework._dygraph_tracer()
...@@ -191,6 +199,28 @@ class PartialProgramLayer: ...@@ -191,6 +199,28 @@ class PartialProgramLayer:
# program_id -> list(scope) # program_id -> list(scope)
self._scope_cache = {} self._scope_cache = {}
def __call__(self, inputs):
"""
Execute static graph by Interpreter and Return dynamic Tensors.
"""
in_vars, out_vars = self._prepare(inputs)
self._cast_fp16_if_pure_fp16(in_vars)
attrs = self._prepare_attributes()
_legacy_C_ops.run_program(
self._valid_vars(in_vars),
self._valid_vars(self._params),
self._valid_vars(out_vars),
self._create_scope_vec(
program_id=self.program_id, use_scope_cache=True
),
self._double_grads,
self._cuda_graph_vec,
*attrs
)
restored_nest_out = self._restore_out(out_vars)
return self._remove_no_value(restored_nest_out)
def _get_scope(self, program_id=None, use_scope_cache=False): def _get_scope(self, program_id=None, use_scope_cache=False):
if use_scope_cache: if use_scope_cache:
if program_id not in self._scope_cache: if program_id not in self._scope_cache:
...@@ -259,8 +289,9 @@ class PartialProgramLayer: ...@@ -259,8 +289,9 @@ class PartialProgramLayer:
@switch_to_static_graph @switch_to_static_graph
def _create_forward_backward_train_program(self): def _create_forward_backward_train_program(self):
whole_program = self._train_program whole_program = self._train_program
forward_end_op_index = self._infer_info.op_size['fp32'] _, forward_end_op_index = self._infer_info('fp32', self._create_program)
assert forward_end_op_index >= 0 assert forward_end_op_index >= 0
return self._get_forward_backward_program_form( return self._get_forward_backward_program_form(
whole_program, forward_end_op_index whole_program, forward_end_op_index
) )
...@@ -268,8 +299,11 @@ class PartialProgramLayer: ...@@ -268,8 +299,11 @@ class PartialProgramLayer:
@switch_to_static_graph @switch_to_static_graph
def _create_forward_backward_train_amp_program(self): def _create_forward_backward_train_amp_program(self):
whole_program = self._train_amp_program whole_program = self._train_amp_program
forward_end_op_index = self._infer_info.op_size['amp'] _, forward_end_op_index = self._infer_info(
'amp', self._create_amp_program
)
assert forward_end_op_index >= 0 assert forward_end_op_index >= 0
return self._get_forward_backward_program_form( return self._get_forward_backward_program_form(
whole_program, forward_end_op_index whole_program, forward_end_op_index
) )
...@@ -277,8 +311,11 @@ class PartialProgramLayer: ...@@ -277,8 +311,11 @@ class PartialProgramLayer:
@switch_to_static_graph @switch_to_static_graph
def _create_forward_backward_train_pure_fp16_program(self): def _create_forward_backward_train_pure_fp16_program(self):
whole_program = self._train_pure_fp16_program whole_program = self._train_pure_fp16_program
forward_end_op_index = self._infer_info.op_size['fp16'] _, forward_end_op_index = self._infer_info(
'fp16', self._create_pure_fp16_program
)
assert forward_end_op_index >= 0 assert forward_end_op_index >= 0
return self._get_forward_backward_program_form( return self._get_forward_backward_program_form(
whole_program, forward_end_op_index whole_program, forward_end_op_index
) )
...@@ -289,11 +326,8 @@ class PartialProgramLayer: ...@@ -289,11 +326,8 @@ class PartialProgramLayer:
@LazyInitialized @LazyInitialized
def _infer_program(self): def _infer_program(self):
program = self._create_program(is_infer_mode=True) program, op_size = self._infer_info('fp32', self._create_program)
self._infer_info.op_size['fp32'] = program.desc.block(0).op_size() return self._build_infer_program(program, op_size)
return self._build_infer_program(
program, self._infer_info.op_size['fp32']
)
@LazyInitialized @LazyInitialized
def _train_amp_program(self): def _train_amp_program(self):
...@@ -301,11 +335,8 @@ class PartialProgramLayer: ...@@ -301,11 +335,8 @@ class PartialProgramLayer:
@LazyInitialized @LazyInitialized
def _infer_amp_program(self): def _infer_amp_program(self):
program = self._create_amp_program(is_infer_mode=True) program, op_size = self._infer_info('amp', self._create_amp_program)
self._infer_info.op_size['amp'] = program.desc.block(0).op_size() return self._build_infer_program(program, op_size)
return self._build_infer_program(
program, self._infer_info.op_size['amp']
)
@LazyInitialized @LazyInitialized
def _train_pure_fp16_program(self): def _train_pure_fp16_program(self):
...@@ -313,11 +344,10 @@ class PartialProgramLayer: ...@@ -313,11 +344,10 @@ class PartialProgramLayer:
@LazyInitialized @LazyInitialized
def _infer_pure_fp16_program(self): def _infer_pure_fp16_program(self):
program = self._create_pure_fp16_program(is_infer_mode=True) program, op_size = self._infer_info(
self._infer_info.op_size['fp16'] = program.desc.block(0).op_size() 'fp16', self._create_pure_fp16_program
return self._build_infer_program(
program, self._infer_info.op_size['fp16']
) )
return self._build_infer_program(program, op_size)
@LazyInitialized @LazyInitialized
def _train_forward_backward_program(self): def _train_forward_backward_program(self):
...@@ -632,27 +662,24 @@ class PartialProgramLayer: ...@@ -632,27 +662,24 @@ class PartialProgramLayer:
double_grads.append(var_base) double_grads.append(var_base)
return self._valid_vars(double_grads) return self._valid_vars(double_grads)
def _get_end_op_index(self): def _cast_fp16_if_pure_fp16(self, in_vars):
if _in_amp_guard(): if _in_pure_fp16_guard():
infer_program = self._infer_amp_program for i, var in enumerate(in_vars):
elif _in_pure_fp16_guard(): name = var.name
infer_program = self._infer_pure_fp16_program if (
else: self.program.global_block().has_var(name)
infer_program = self._infer_program and self.program.global_block().var(name).dtype
return infer_program.desc.block(0).op_size() == paddle.float16
):
def __call__(self, inputs): in_vars[i] = var.astype('float16')
in_vars, out_vars = self._prepare(inputs) in_vars[i].name = name
self._cast_fp16_if_pure_fp16(in_vars)
def _prepare_attributes(self):
attrs = [ attrs = [
'global_block', 'forward_global_block',
self.program.desc.block(0), self.forward_program.desc.block(0),
'start_op_index', 'backward_global_block',
0, self.backward_program.desc.block(0),
'end_op_index',
self._get_end_op_index(),
'is_test', 'is_test',
not self.training, not self.training,
'program_id', 'program_id',
...@@ -679,57 +706,7 @@ class PartialProgramLayer: ...@@ -679,57 +706,7 @@ class PartialProgramLayer:
self._cuda_graph_pool_id, self._cuda_graph_pool_id,
) )
) )
return attrs
use_interpretorcore = (
_is_enable_standalone_executor()
and _is_dy2st_enable_standalone_executor()
)
attrs.extend(('use_interpretorcore', use_interpretorcore))
if use_interpretorcore:
attrs.extend(
(
'forward_global_block',
self.forward_program.desc.block(0),
'backward_global_block',
self.backward_program.desc.block(0),
)
)
_legacy_C_ops.run_program(
self._valid_vars(in_vars),
self._valid_vars(self._params),
self._valid_vars(out_vars),
self._create_scope_vec(
program_id=self.program_id, use_scope_cache=True
),
self._double_grads,
self._cuda_graph_vec,
*attrs
)
else:
_legacy_C_ops.run_program(
self._valid_vars(in_vars),
self._valid_vars(self._params),
self._valid_vars(out_vars),
self._create_scope_vec(),
self._double_grads,
self._cuda_graph_vec,
*attrs
)
restored_nest_out = self._restore_out(out_vars)
return self._remove_no_value(restored_nest_out)
def _cast_fp16_if_pure_fp16(self, in_vars):
if _in_pure_fp16_guard():
for i, var in enumerate(in_vars):
name = var.name
if (
self.program.global_block().has_var(name)
and self.program.global_block().var(name).dtype
== paddle.float16
):
in_vars[i] = var.astype('float16')
in_vars[i].name = name
@switch_to_static_graph @switch_to_static_graph
def _build_infer_program(self, infer_program, forward_end_op_index): def _build_infer_program(self, infer_program, forward_end_op_index):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册