diff --git a/paddle/fluid/eager/to_static/run_program_op_func.h b/paddle/fluid/eager/to_static/run_program_op_func.h index fb8d64e377bc796ca07374d9381399259cf09b68..33da489fd47b168d31a018cc90cd3baade44f43f 100644 --- a/paddle/fluid/eager/to_static/run_program_op_func.h +++ b/paddle/fluid/eager/to_static/run_program_op_func.h @@ -38,6 +38,22 @@ static void clear_no_grad_edges( } } +static void clear_no_grad_edges_with_partial_block( + const std::vector& params, + const paddle::framework::BlockDesc* forward_block_desc, + const paddle::framework::BlockDesc* backward_block_desc, + egr::GradNodeBase* grad_node, + size_t slot_id) { + for (size_t i = 0; i < params.size(); ++i) { + auto p_grad_name = paddle::framework::GradVarName(params[i].name()); + if (!forward_block_desc->HasVar(p_grad_name) && + !backward_block_desc->HasVar(p_grad_name)) { + VLOG(1) << "clear edge of " << p_grad_name; + grad_node->MutableOutputMeta()[slot_id][i].GetMutableEdge().Clear(); + } + } +} + inline void run_program_dygraph_function( const std::vector& x, const std::vector& params, @@ -85,9 +101,26 @@ inline void run_program_dygraph_function( // Set Grad out rank as same as fwd input and set stop gradient to bwd grad_node->SetGradOutMeta(x, /*slot id*/ 0); grad_node->SetGradOutMeta(params, /*slot id*/ 1); - auto* global_block = PADDLE_GET_CONST(paddle::framework::BlockDesc*, - attrs.at("global_block")); - clear_no_grad_edges(params, global_block, grad_node.get(), /*slot id*/ 1); + + bool use_interpretorcore = + PADDLE_GET_CONST(bool, attrs.at("use_interpretorcore")); + VLOG(2) << "clear_no_grad_edges."; + if (use_interpretorcore) { + auto* forward_global_block = PADDLE_GET_CONST( + paddle::framework::BlockDesc*, attrs.at("forward_global_block")); + auto* backward_global_block = PADDLE_GET_CONST( + paddle::framework::BlockDesc*, attrs.at("backward_global_block")); + clear_no_grad_edges_with_partial_block(params, + forward_global_block, + backward_global_block, + grad_node.get(), + /*slot id*/ 1); + + } else { + auto* global_block = PADDLE_GET_CONST(paddle::framework::BlockDesc*, + attrs.at("global_block")); + clear_no_grad_edges(params, global_block, grad_node.get(), /*slot id*/ 1); + } grad_node->SetGradInMeta(deref_out, 0); diff --git a/paddle/fluid/eager/to_static/run_program_op_node.h b/paddle/fluid/eager/to_static/run_program_op_node.h index 50f9d0b58ca32b4a6e1dc9be8933dfaba8c1b205..822e6563ca45343c100623c9f67993479235d560 100644 --- a/paddle/fluid/eager/to_static/run_program_op_node.h +++ b/paddle/fluid/eager/to_static/run_program_op_node.h @@ -17,6 +17,7 @@ #include "paddle/fluid/eager/api/utils/global_utils.h" #include "paddle/fluid/eager/grad_node_info.h" #include "paddle/fluid/eager/tensor_wrapper.h" +#include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/operators/run_program_op.h" #include "paddle/fluid/platform/enforce.h" @@ -181,6 +182,70 @@ static void ShareTensorsFromScope( } } +static void ShareTensorsFromScopeWithPartialBlock( + const std::vector &tensors, + const paddle::framework::BlockDesc &forward_global_block, + const paddle::framework::BlockDesc &backward_global_block, + paddle::framework::Scope *scope) { + for (size_t i = 0; i < tensors.size(); ++i) { + auto &name = tensors[i]->name(); + if (name == paddle::framework::kEmptyVarName || name == "Fake_var" || + (!forward_global_block.HasVar(name) && + !backward_global_block.HasVar(name))) { + VLOG(2) << "find tensor name is " << name << ", skip it!"; + continue; + } + auto *var = scope->FindVar(name); + PADDLE_ENFORCE_NOT_NULL( + var, + paddle::platform::errors::NotFound("The output tensor %s is not in " + "RunProgram(Grad)Op'" + "s internal scope.", + name)); + CheckOutputVarStatus(*var, *tensors[i]); + // share tensor + if (var->IsType()) { + auto &src_tensor = var->Get(); + auto *dst_tensor = const_cast( + dynamic_cast(tensors[i]->impl().get())); + VLOG(2) << "share " << name << " from scope"; + *dst_tensor = src_tensor; + } else if (var->IsType()) { + auto &src_tensor = var->Get(); + auto *dst_tensor = const_cast( + dynamic_cast(tensors[i]->impl().get())); + *dst_tensor = src_tensor; + } + } +} + +static void BuildScopeByBlock( + const paddle::framework::InterpreterCore &interpreter_core, + const paddle::framework::BlockDesc &block, + paddle::framework::Scope *scope) { + for (auto &var_desc : block.AllVars()) { + auto var_name = var_desc->Name(); + if (var_name == paddle::framework::kEmptyVarName) { + continue; + } + if (!scope->FindLocalVar(var_name)) { + auto *ptr = scope->Var(var_name); + InitializeVariable(ptr, var_desc->GetType()); + VLOG(2) << "Initialize Block Variable " << var_name; + } + } + auto &data_transfer_added_vars = + interpreter_core.GetVariableScope()->DataTransferAddedVars(); + for (size_t i = 0; i < data_transfer_added_vars.size(); i++) { + auto *ptr = scope->Var(data_transfer_added_vars[i].first); + InitializeVariable(ptr, + static_cast( + data_transfer_added_vars[i].second)); + VLOG(2) << "Initialize Transfer Added Variable " + << data_transfer_added_vars[i].first; + } +} + } // namespace details inline void RunProgramAPI( @@ -191,8 +256,6 @@ inline void RunProgramAPI( std::vector &dout, // NOLINT const paddle::framework::AttributeMap &attrs) { VLOG(2) << "RunProgramOpKernel Compute"; - auto start_op_index = PADDLE_GET_CONST(int64_t, attrs.at("start_op_index")); - auto end_op_index = PADDLE_GET_CONST(int64_t, attrs.at("end_op_index")); // In the original run_program OP, the default value of the is_test // attribute is false, we should check if there is is_test parameter // in attrs @@ -201,6 +264,7 @@ inline void RunProgramAPI( is_test = PADDLE_GET_CONST(bool, attrs.at("is_test")); } auto program_id = PADDLE_GET_CONST(int64_t, attrs.at("program_id")); + auto place = egr::Controller::Instance().GetExpectedPlace(); // NOTE(chenweihang): In order not to add new variable type, use vector // here. Originally, here can use scope directly. @@ -210,7 +274,6 @@ inline void RunProgramAPI( 1, paddle::platform::errors::InvalidArgument( "The OutScope of RunProgramGradOp should only hold one scope.")); - // Step 2. prepare executor and init persistable variables // NOTE(Aurelius84): While training some models, forward can be called many // times and then apply backpropagation all at once, such as Reinforcement @@ -222,62 +285,151 @@ inline void RunProgramAPI( << out_scope_vec->front()->kids().size(); paddle::framework::Scope &scope = global_inner_scope->NewScope(); - // share input_vars & parameters into scope - details::ShareTensorsIntoScope(x, &scope); - details::ShareTensorsIntoScope(params, &scope); + bool use_interpretorcore = + PADDLE_GET_CONST(bool, attrs.at("use_interpretorcore")); - auto *global_block = PADDLE_GET_CONST(paddle::framework::BlockDesc *, - attrs.at("global_block")); - const auto &place = egr::Controller::Instance().GetExpectedPlace(); + if (use_interpretorcore) { + VLOG(2) << "RunProgramOp use interpretercore to execute program."; - if (end_op_index > start_op_index) { auto input_names = details::GetTensorsName(x); auto output_names = details::GetTensorsName(out); auto dout_names = details::GetTensorsName(dout); - auto *program = global_block->Program(); - - auto cache_info = - paddle::framework::GetExecutorInfoFromCache(*program, - place, - start_op_index, - end_op_index, - /*is_grad=*/false, - program_id, - &scope); - auto ¶llel_executor = cache_info.first; - // all out_vars are skip_eager_var - auto &skip_eager_delete_vars = - paddle::framework::ExecutorInfoCache::Instance().SkipEagerDeleteVars( - program_id, false); - if (cache_info.second /*is_new_created*/) { - parallel_executor->SkipMemoryReuse(/*scope_idx=*/0, input_names); - skip_eager_delete_vars.insert(skip_eager_delete_vars.end(), - output_names.begin(), - output_names.end()); - skip_eager_delete_vars.insert( - skip_eager_delete_vars.end(), dout_names.begin(), dout_names.end()); - paddle::framework::details::ParseSafeEagerDeletionSkipVars( - *program, end_op_index, output_names, &skip_eager_delete_vars); + + auto *forward_global_block = PADDLE_GET_CONST( + paddle::framework::BlockDesc *, attrs.at("forward_global_block")); + auto *backward_global_block = PADDLE_GET_CONST( + paddle::framework::BlockDesc *, attrs.at("backward_global_block")); + auto *forward_program = forward_global_block->Program(); + auto *backward_program = backward_global_block->Program(); + + auto &interpretercore_info_cache = + paddle::framework::InterpreterCoreInfoCache::Instance(); + + if (!interpretercore_info_cache.Has(program_id, /*is_grad=*/false)) { + VLOG(2) << "No interpretercore cahce, so create a new interpretercore"; + // Step 1. share input_vars & parameters into scope + details::ShareTensorsIntoScope(x, &scope); + details::ShareTensorsIntoScope(params, &scope); + // Step 2. create new interpretercore + auto interpreter_core = + paddle::framework::CreateInterpreterCoreInfoToCache( + *forward_program, place, /*is_grad=*/false, program_id, &scope); + // Step 3. get all eager gc vars + std::set skip_eager_delete_vars = + paddle::framework::details::ParseSafeEagerDeletionSkipVarsSet( + *backward_program); + // all out_vars are skip_eager_var + skip_eager_delete_vars.insert(output_names.begin(), output_names.end()); + skip_eager_delete_vars.insert(dout_names.begin(), dout_names.end()); + // update interpretercore skip_gc_var + interpreter_core->SetSkipGcVars(skip_eager_delete_vars); + interpretercore_info_cache.UpdateSkipEagerDeleteVars( + program_id, false, skip_eager_delete_vars); + VLOG(2) << "Get skip GC vars size is: " << skip_eager_delete_vars.size(); + // Step 4. interpretercore run + if (forward_global_block->OpSize() > 0) { + interpreter_core->Run({}); + } + // Step 5. Get Output + details::ShareTensorsFromScopeWithPartialBlock( + out, *forward_global_block, *backward_global_block, &scope); + details::ShareTensorsFromScopeWithPartialBlock( + dout, *forward_global_block, *backward_global_block, &scope); + } else { + VLOG(2) << "Get interpretercore cahce by program:" << program_id; + // Step 1. get cache interpretercore + auto &cached_value = + interpretercore_info_cache.GetMutable(program_id, /*is_grad=*/false); + auto &interpreter_core = cached_value.core_; + // Step 2. update scope for cache interpretercore + details::ShareTensorsIntoScope(x, &scope); + details::ShareTensorsIntoScope(params, &scope); + details::BuildScopeByBlock( + *interpreter_core.get(), *forward_global_block, &scope); + interpreter_core->reset_scope(&scope); + // Step 3. interpretercore run + if (forward_global_block->OpSize() > 0) { + interpreter_core->Run({}); + } + // Step 4. Get Output + details::ShareTensorsFromScopeWithPartialBlock( + out, *forward_global_block, *backward_global_block, &scope); + details::ShareTensorsFromScopeWithPartialBlock( + dout, *forward_global_block, *backward_global_block, &scope); } + VLOG(3) << paddle::framework::GenScopeTreeDebugInfo(out_scope_vec->front()); - // Step 3. run ops - parallel_executor->RunWithoutFetch(skip_eager_delete_vars); - } - // Step 4. Get Output - details::ShareTensorsFromScope(out, *global_block, &scope); - details::ShareTensorsFromScope(dout, *global_block, &scope); - - // Debug info: scope info when run end - VLOG(3) << paddle::framework::GenScopeTreeDebugInfo(out_scope_vec->front()); - // Step 5. Drop all children scopes while testing. - if (is_test) { - out_scope_vec->front()->DropKids(); - } - VLOG(2) << "The number of sub scopes after forward: " - << out_scope_vec->front()->kids().size(); + if (is_test) { + VLOG(1) << "is test, after forward, drop kids"; + out_scope_vec->front()->DropKids(); + } + VLOG(2) << "The number of sub scopes after forward: " + << out_scope_vec->front()->kids().size(); +#ifdef PADDLE_WITH_MKLDNN + if (FLAGS_use_mkldnn) paddle::platform::DontClearMKLDNNCache(place); +#endif + } else { + VLOG(2) << "RunProgramOp execute with parallel_executor."; + // share input_vars & parameters into scope + details::ShareTensorsIntoScope(x, &scope); + details::ShareTensorsIntoScope(params, &scope); + + const auto &place = egr::Controller::Instance().GetExpectedPlace(); + + auto *global_block = PADDLE_GET_CONST(paddle::framework::BlockDesc *, + attrs.at("global_block")); + auto start_op_index = PADDLE_GET_CONST(int64_t, attrs.at("start_op_index")); + auto end_op_index = PADDLE_GET_CONST(int64_t, attrs.at("end_op_index")); + + if (end_op_index > start_op_index) { + auto input_names = details::GetTensorsName(x); + auto output_names = details::GetTensorsName(out); + auto dout_names = details::GetTensorsName(dout); + auto *program = global_block->Program(); + + auto cache_info = + paddle::framework::GetExecutorInfoFromCache(*program, + place, + start_op_index, + end_op_index, + /*is_grad=*/false, + program_id, + &scope); + auto ¶llel_executor = cache_info.first; + // all out_vars are skip_eager_var + auto &skip_eager_delete_vars = + paddle::framework::ExecutorInfoCache::Instance().SkipEagerDeleteVars( + program_id, false); + if (cache_info.second /*is_new_created*/) { + parallel_executor->SkipMemoryReuse(/*scope_idx=*/0, input_names); + skip_eager_delete_vars.insert(skip_eager_delete_vars.end(), + output_names.begin(), + output_names.end()); + skip_eager_delete_vars.insert( + skip_eager_delete_vars.end(), dout_names.begin(), dout_names.end()); + paddle::framework::details::ParseSafeEagerDeletionSkipVars( + *program, end_op_index, output_names, &skip_eager_delete_vars); + } + + // Step 3. run ops + parallel_executor->RunWithoutFetch(skip_eager_delete_vars); + } + // Step 4. Get Output + details::ShareTensorsFromScope(out, *global_block, &scope); + details::ShareTensorsFromScope(dout, *global_block, &scope); + + // Debug info: scope info when run end + VLOG(3) << paddle::framework::GenScopeTreeDebugInfo(out_scope_vec->front()); + // Step 5. Drop all children scopes while testing. + if (is_test) { + out_scope_vec->front()->DropKids(); + } + VLOG(2) << "The number of sub scopes after forward: " + << out_scope_vec->front()->kids().size(); #ifdef PADDLE_WITH_MKLDNN - if (FLAGS_use_mkldnn) paddle::platform::DontClearMKLDNNCache(place); + if (FLAGS_use_mkldnn) paddle::platform::DontClearMKLDNNCache(place); #endif + } } inline void RunProgramGradAPI( @@ -292,16 +444,9 @@ inline void RunProgramGradAPI( // if all output vars are set to stop_gradient, grad op no need to executed if (x_grad.empty() && params_grad.empty()) return; - auto *global_block = PADDLE_GET_CONST(paddle::framework::BlockDesc *, - attrs.at("global_block")); - auto orig_end_op_index = PADDLE_GET_CONST(int64_t, attrs.at("end_op_index")); - + bool use_interpretorcore = + PADDLE_GET_CONST(bool, attrs.at("use_interpretorcore")); auto program_id = PADDLE_GET_CONST(int64_t, attrs.at("program_id")); - // NOTE: skip `shape` and `fill_constant` op created by - // fluid.backward.gradients, one forward output will generate one `shape` - // and `fill_constant` - int64_t start_op_index = orig_end_op_index + (out_grad.size() * 2); - int64_t end_op_index = global_block->OpSize(); auto *out_scope_vec = &step_scope; PADDLE_ENFORCE_EQ( @@ -309,7 +454,6 @@ inline void RunProgramGradAPI( 1, paddle::platform::errors::InvalidArgument( "The OutScope of RunProgramGradOp should only hold one scope.")); - paddle::framework::Scope *global_inner_scope = out_scope_vec->front(); auto sub_scope_num = global_inner_scope->kids().size(); VLOG(2) << "The number of sub scopes before backward: " << sub_scope_num; @@ -320,13 +464,19 @@ inline void RunProgramGradAPI( "least one sub scope.")); auto &scope = *(global_inner_scope->kids().front()); - const auto &place = egr::Controller::Instance().GetExpectedPlace(); + auto place = egr::Controller::Instance().GetExpectedPlace(); + + if (use_interpretorcore) { + VLOG(2) << "RunProgramGradOp use interpretercore to execute program."; + + auto *forward_global_block = PADDLE_GET_CONST( + paddle::framework::BlockDesc *, attrs.at("forward_global_block")); + auto *backward_global_block = PADDLE_GET_CONST( + paddle::framework::BlockDesc *, attrs.at("backward_global_block")); + auto *backward_program = backward_global_block->Program(); - if (end_op_index > start_op_index) { auto out_grad_names = details::GetTensorsName(out_grad); - // NOTE: after PR22939 [Add double grad] merged, the grad op maker's - // SetOutput will set to None if the input var stop_gradient=True, - // it will cause an NotFound error when ctx.OutputNames() is called + std::vector x_grad_names; std::vector param_grad_names; if (!x_grad.empty()) { @@ -336,48 +486,130 @@ inline void RunProgramGradAPI( param_grad_names = details::GetTensorsName(params_grad); } - // Step 2. prepare executor and scope - auto *program = global_block->Program(); - auto cache_info = - paddle::framework::GetExecutorInfoFromCache(*program, - place, - start_op_index, - end_op_index, - /*is_grad*/ true, - program_id, - &scope); - auto ¶llel_executor = cache_info.first; - - auto &skip_eager_delete_vars = - paddle::framework::ExecutorInfoCache::Instance().SkipEagerDeleteVars( - program_id, true); - if (cache_info.second /*is_new_created*/) { - parallel_executor->SkipMemoryReuse(/*scope_idx=*/0, out_grad_names); - - skip_eager_delete_vars.insert(skip_eager_delete_vars.end(), - x_grad_names.begin(), - x_grad_names.end()); + auto &interpretercore_info_cache = + paddle::framework::InterpreterCoreInfoCache::Instance(); + if (!interpretercore_info_cache.Has(program_id, /*is_grad=*/true)) { + VLOG(2) << "No interpretercore cahce, so create a new interpretercore"; + details::ShareTensorsIntoScope(out_grad, &scope); + auto interpreter_core = + paddle::framework::CreateInterpreterCoreInfoToCache( + *backward_program, place, /*is_grad=*/true, program_id, &scope); + + // get all eager gc vars + std::set skip_eager_delete_vars; + // all out_vars are skip_eager_var + skip_eager_delete_vars.insert(x_grad_names.begin(), x_grad_names.end()); + // initialize skip gc vars by forward_program and backward_program paddle::framework::details::AppendSkipDeletionVars( param_grad_names, &skip_eager_delete_vars); + interpreter_core->SetSkipGcVars(skip_eager_delete_vars); + interpretercore_info_cache.UpdateSkipEagerDeleteVars( + program_id, /*is_grad=*/true, skip_eager_delete_vars); + VLOG(2) << "Get skip GC vars size is: " << skip_eager_delete_vars.size(); + if (backward_global_block->OpSize() > 0) { + // Debug info: scope info when run end + VLOG(3) << paddle::framework::GenScopeTreeDebugInfo( + out_scope_vec->front()); + interpreter_core->Run({}); + } + } else { + VLOG(2) << "Get interpretercore cahce by program:" << program_id; + auto &cached_value = + interpretercore_info_cache.GetMutable(program_id, /*is_grad=*/true); + auto &interpreter_core = cached_value.core_; + // update scope + details::ShareTensorsIntoScope(out_grad, &scope); + details::BuildScopeByBlock( + *interpreter_core.get(), *backward_global_block, &scope); + interpreter_core->reset_scope(&scope); + + if (backward_global_block->OpSize() > 0) { + // Debug info: scope info when run end + VLOG(3) << paddle::framework::GenScopeTreeDebugInfo( + out_scope_vec->front()); + interpreter_core->Run({}); + } } + // Step 4. get outputs + details::ShareTensorsFromScopeWithPartialBlock( + x_grad, *forward_global_block, *backward_global_block, &scope); + details::ShareTensorsFromScopeWithPartialBlock( + params_grad, *forward_global_block, *backward_global_block, &scope); + + // Step5. drop current scope + global_inner_scope->DeleteScope(&scope); + VLOG(2) << "The number of sub scopes after backward: " + << global_inner_scope->kids().size(); + } else { + auto *global_block = PADDLE_GET_CONST(paddle::framework::BlockDesc *, + attrs.at("global_block")); + auto orig_end_op_index = + PADDLE_GET_CONST(int64_t, attrs.at("end_op_index")); + + // NOTE: skip `shape` and `fill_constant` op created by + // fluid.backward.gradients, one forward output will generate one `shape` + // and `fill_constant` + int64_t start_op_index = orig_end_op_index + (out_grad.size() * 2); + int64_t end_op_index = global_block->OpSize(); + + if (end_op_index > start_op_index) { + auto out_grad_names = details::GetTensorsName(out_grad); + // NOTE: after PR22939 [Add double grad] merged, the grad op maker's + // SetOutput will set to None if the input var stop_gradient=True, + // it will cause an NotFound error when ctx.OutputNames() is called + std::vector x_grad_names; + std::vector param_grad_names; + if (!x_grad.empty()) { + x_grad_names = details::GetTensorsName(x_grad); + } + if (!params_grad.empty()) { + param_grad_names = details::GetTensorsName(params_grad); + } - details::ShareTensorsIntoScope(out_grad, &scope); - // Debug info: scope info when run end - VLOG(3) << paddle::framework::GenScopeTreeDebugInfo(out_scope_vec->front()); + // Step 2. prepare executor and scope + auto *program = global_block->Program(); + auto cache_info = + paddle::framework::GetExecutorInfoFromCache(*program, + place, + start_op_index, + end_op_index, + /*is_grad*/ true, + program_id, + &scope); + auto ¶llel_executor = cache_info.first; + + auto &skip_eager_delete_vars = + paddle::framework::ExecutorInfoCache::Instance().SkipEagerDeleteVars( + program_id, true); + if (cache_info.second /*is_new_created*/) { + parallel_executor->SkipMemoryReuse(/*scope_idx=*/0, out_grad_names); + + skip_eager_delete_vars.insert(skip_eager_delete_vars.end(), + x_grad_names.begin(), + x_grad_names.end()); + paddle::framework::details::AppendSkipDeletionVars( + param_grad_names, &skip_eager_delete_vars); + } - // Step 3. run ops - parallel_executor->RunWithoutFetch( - /*skip_eager_delete_vars=*/skip_eager_delete_vars); - } + details::ShareTensorsIntoScope(out_grad, &scope); + // Debug info: scope info when run end + VLOG(3) << paddle::framework::GenScopeTreeDebugInfo( + out_scope_vec->front()); + + // Step 3. run ops + parallel_executor->RunWithoutFetch( + /*skip_eager_delete_vars=*/skip_eager_delete_vars); + } - // Step 4. get outputs - details::ShareTensorsFromScope(x_grad, *global_block, &scope); - details::ShareTensorsFromScope(params_grad, *global_block, &scope); + // Step 4. get outputs + details::ShareTensorsFromScope(x_grad, *global_block, &scope); + details::ShareTensorsFromScope(params_grad, *global_block, &scope); - // Step5. drop current scope - global_inner_scope->DeleteScope(&scope); - VLOG(2) << "The number of sub scopes after backward: " - << global_inner_scope->kids().size(); + // Step5. drop current scope + global_inner_scope->DeleteScope(&scope); + VLOG(2) << "The number of sub scopes after backward: " + << global_inner_scope->kids().size(); + } } class GradNodeRunProgram : public egr::GradNodeBase { diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 2dd2b5162f74837f6443337b07b86abd38134265..a3eb067426f5636bbdd6da5a9012d7eebb6ae33f 100755 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -1004,7 +1004,7 @@ cc_library( cc_library( executor_cache SRCS executor_cache.cc - DEPS parallel_executor) + DEPS parallel_executor standalone_executor) if(WITH_PSCORE) get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS) if(WITH_HETERPS) diff --git a/paddle/fluid/framework/executor_cache.cc b/paddle/fluid/framework/executor_cache.cc index de3ca2131aa50a6edfec8625e526650c1d484899..1ce9db6294050bd9b8ef9f2f088fa9c3ee1d7a2a 100644 --- a/paddle/fluid/framework/executor_cache.cc +++ b/paddle/fluid/framework/executor_cache.cc @@ -137,6 +137,58 @@ void ParseSafeEagerDeletionSkipVars( VLOG(3) << "Found skip_eager_delete_vars: " << skip_eager_delete_vars->size(); } +void AppendSkipDeletionVars(const std::vector &append_vars, + std::set *all_vars) { + for (auto &var : append_vars) { + all_vars->insert(var); + } +} + +std::set ParseSafeEagerDeletionSkipVarsSet( + const ProgramDesc &backward_program) { + std::set skip_eager_delete_vars; + auto backward_ops = backward_program.Block(0).AllOps(); + auto &op_info_map = OpInfoMap::Instance(); + std::unordered_set op_outputs; + std::unordered_set op_inputs; + std::unordered_set no_need_buffer_ins; + for (size_t i = 0; i < backward_ops.size(); ++i) { + framework::OpDesc *op = backward_ops[i]; + if (op->Type() == "share_buffer") { + VLOG(1) << "skip share_buffer op"; + continue; + } + // NOTE: skip NoNeedBufferVars of grad_op and GC its memory in advance. + auto &op_info = op_info_map.Get(op->Type()); + auto &inferer = op_info.NoNeedBufferVarsInferer(); + no_need_buffer_ins.clear(); + if (inferer != nullptr) { + no_need_buffer_ins = + inferer(op->Inputs(), op->Outputs(), op->GetAttrMap()); + } + for (auto &in_names : op->Inputs()) { + if (no_need_buffer_ins.count(in_names.first) == 0) { + for (auto &in_name : in_names.second) { + op_inputs.emplace(in_name); + } + } else { + VLOG(2) << op->Type() << " has no_need_buffer_in: " << in_names.first + << " , skip it."; + } + } + for (const std::string &out_arg_name : op->OutputArgumentNames()) { + op_outputs.emplace(out_arg_name); + } + } + for (const std::string &var_name : op_inputs) { + if (op_outputs.find(var_name) == op_outputs.end()) { + VLOG(1) << "skip eager var: " << var_name; + skip_eager_delete_vars.insert(var_name); + } + } + VLOG(1) << "Found skip_eager_delete_vars: " << skip_eager_delete_vars.size(); + return skip_eager_delete_vars; +} } // namespace details // C++11 removes the need for manual locking. Concurrent execution shall wait if @@ -225,5 +277,33 @@ CacheInfo GetExecutorInfoFromCache(const ProgramDesc &program_desc, } } +InterpreterCoreInfoCache &InterpreterCoreInfoCache::Instance() { + static InterpreterCoreInfoCache g_info_cache; + return g_info_cache; +} + +std::shared_ptr CreateInterpreterCoreInfoToCache( + const ProgramDesc &program_desc, + const platform::Place &place, + bool is_grad, + int64_t program_id, + framework::Scope *scope) { + auto &interpretercore_info_cache = + framework::InterpreterCoreInfoCache::Instance(); + if (interpretercore_info_cache.Size() > 4u /* max_cached_size*/) { + interpretercore_info_cache.Finalize(); + } + auto core = std::make_shared( + place, + program_desc.Block(0), + /*skip_gc_vars=*/std::set(), + scope, + /*used_for_jit=*/true); + auto &cached_value = + interpretercore_info_cache.GetMutable(program_id, is_grad); + cached_value.core_ = core; + return core; +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/executor_cache.h b/paddle/fluid/framework/executor_cache.h index 26e5771542dde84eb97f607f43986102c129f388..196bfd22b1e3d2c4a1ba634514390d2d59c04873 100644 --- a/paddle/fluid/framework/executor_cache.h +++ b/paddle/fluid/framework/executor_cache.h @@ -23,6 +23,7 @@ #include #include +#include "paddle/fluid/framework/new_executor/interpretercore.h" #include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/parallel_executor.h" #include "paddle/fluid/framework/program_desc.h" @@ -45,6 +46,12 @@ void ParseSafeEagerDeletionSkipVars( const std::vector& output_var_names, std::vector* skip_eager_delete_vars); +void AppendSkipDeletionVars(const std::vector& append_vars, + std::set* all_vars); + +std::set ParseSafeEagerDeletionSkipVarsSet( + const ProgramDesc& backward_program); + } // namespace details class ExecutorInfo { @@ -147,5 +154,73 @@ PEAndGraphPair CreateFixOrderExecutorInfo(const ProgramDesc& program_desc, int64_t end_op_index, framework::Scope* scope); +class InterpreterCoreInfo { + public: + struct CacheValue { + std::shared_ptr core_{nullptr}; + std::set skip_eager_delete_vars_; + }; + + bool IsAvailable(bool is_grad) { + const auto& core = is_grad ? backward_info_.core_ : forward_info_.core_; + return core != nullptr; + } + + CacheValue& GetMutable(bool is_grad) { + return is_grad ? backward_info_ : forward_info_; + } + + private: + CacheValue forward_info_; + CacheValue backward_info_; +}; + +class InterpreterCoreInfoCache { + public: + static InterpreterCoreInfoCache& Instance(); + + bool Has(int64_t program_id, bool is_grad) { + return info_map_.find(program_id) != info_map_.end() && + info_map_[program_id].IsAvailable(is_grad); + } + + InterpreterCoreInfo::CacheValue& GetMutable(int64_t program_id, + bool is_grad) { + return info_map_[program_id].GetMutable(is_grad); + } + + void UpdateSkipEagerDeleteVars(int64_t program_id, + bool is_grad, + const std::set& skip_vars) { + auto& cached_value = GetMutable(program_id, is_grad); + cached_value.skip_eager_delete_vars_ = std::move(skip_vars); + } + + std::set& GetSkipEagerDeleteVars(int64_t program_id, + bool is_grad) { + auto& cached_value = GetMutable(program_id, is_grad); + return cached_value.skip_eager_delete_vars_; + } + + size_t Size() const { return info_map_.size(); } + + void Finalize() { + // NOTE(Aurelius84): DO NOT perform finalize in destructor + // to avoid problems caused by destructor order of static + // object. + info_map_.clear(); + } + + private: + std::unordered_map info_map_; +}; + +std::shared_ptr CreateInterpreterCoreInfoToCache( + const ProgramDesc& program_desc, + const platform::Place& place, + bool is_grad, + int64_t program_id, + framework::Scope* scope); + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/executor_gc_helper.cc b/paddle/fluid/framework/executor_gc_helper.cc index 2ab57ef77ed32218a5b592a3decf8e4be8242ff4..3c24ad58d7ac1016f5851bbbbd0306baa4790332 100644 --- a/paddle/fluid/framework/executor_gc_helper.cc +++ b/paddle/fluid/framework/executor_gc_helper.cc @@ -202,23 +202,34 @@ static std::vector> CreateOpsFromBlock( } std::vector>> GetEagerDeletionCleanVars( - const ProgramDesc &origin_program, - const std::vector &skip_vars) { + const ProgramDesc &program, const std::vector &skip_vars) { + return GetEagerDeletionCleanVarsForPartial(program, skip_vars, false); +} + +std::vector>> +GetEagerDeletionCleanVarsForPartial(const ProgramDesc &origin_program, + const std::vector &skip_vars, + const bool &for_partial_block) { ProgramDesc program{origin_program}; size_t block_num = program.Size(); PADDLE_ENFORCE_GE(block_num, 1, platform::errors::PermissionDenied( "Program should have at least one block")); - - // prepare safe GCs on sub block ops - auto global_block_ops = CreateOpsFromBlock(program.Block(0)); - operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp( - program, 0, global_block_ops); - operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( - program, 0, global_block_ops); - operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp( - program, 0, global_block_ops); + // Note(zhangbo): For dygraph2static inplace policy, origin_program is a + // partial program(only include forward or backward), and control flow op's + // attr skip_eager_deletion_vars has been updated at graph->program before + // calling this function. + if (!for_partial_block) { + // prepare safe GCs on sub block ops + auto global_block_ops = CreateOpsFromBlock(program.Block(0)); + operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp( + program, 0, global_block_ops); + operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( + program, 0, global_block_ops); + operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp( + program, 0, global_block_ops); + } // find the skip vars on each block std::vector> skip_vars_on_each_block(block_num); diff --git a/paddle/fluid/framework/executor_gc_helper.h b/paddle/fluid/framework/executor_gc_helper.h index 11bbbd9723f41ec7aee6183779dd91285162f141..902dd6f3de5f997c98b04162bd7269259fce2a0b 100644 --- a/paddle/fluid/framework/executor_gc_helper.h +++ b/paddle/fluid/framework/executor_gc_helper.h @@ -71,5 +71,11 @@ void DeleteUnusedTensors( std::vector>> GetEagerDeletionCleanVars( const ProgramDesc &program, const std::vector &skip_vars = {}); +std::vector>> +GetEagerDeletionCleanVarsForPartial( + const ProgramDesc &program, + const std::vector &skip_vars = {}, + const bool &for_partial_block = false); + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_helper.cc b/paddle/fluid/framework/ir/graph_helper.cc index a7bf131805dc143ac967f8ac1cd97fbfe3fdc9a0..f800a1eba89e1af29f272c49394952fd085bc0f2 100644 --- a/paddle/fluid/framework/ir/graph_helper.cc +++ b/paddle/fluid/framework/ir/graph_helper.cc @@ -497,8 +497,38 @@ static OpDesc *ReplaceScaleLossGradOp(const Node &node, OpDesc *desc) { return desc; } +void UpdateControlOpSkipEagerDeletionVars(const Node &node, + const Graph &graph, + const size_t graph_idx, + const std::string &control_type) { + // Node(zhangbo): SkipEagerDeletionVars pass policy for control flow class op: + // 1) if op is in main_block: SkipEagerDeletionVars information will be + // writted into Graph OpNode which wrapped by OpHandleBase; 2) if op is in + // sub_block: SkipEagerDeletionVars information will be writted into graph's + // OriginProgram OpDesc. Please refer to + // FindAllConditionalBlockAndConditionalBlockGradOp in + // "paddle/fluid/operators/controlflow/conditional_block_op_helper.cc" + if (graph_idx != 0) { + auto origin_program = graph.OriginProgram(); + auto &block = origin_program.Block(graph_idx); + for (size_t j = 0; j < block.OpSize(); ++j) { + auto *op = block.Op(j); + if (op->Type() == control_type && + op->HasAttr("skip_eager_deletion_vars")) { + if (op->InputArgumentNames() == node.Op()->InputArgumentNames() && + op->OutputArgumentNames() == node.Op()->OutputArgumentNames()) { + node.Op()->SetAttr("skip_eager_deletion_vars", + op->GetAttr("skip_eager_deletion_vars")); + } + } + } + } +} + static void GetGraphOpDesc(const std::vector &nodes, - std::vector *ops) { + std::vector *ops, + const Graph &graph, + const size_t graph_idx) { auto is_fused_opt = [](Node *n) -> bool { auto op_type = n->Op()->Type(); auto is_opt = @@ -524,7 +554,6 @@ static void GetGraphOpDesc(const std::vector &nodes, ReplaceScaleLossGradOp(*n, &desc); } else if (n->Op()) { VLOG(4) << "convert op node to desc " << n->Op()->Type(); - VLOG(4) << n->ToString(); if (is_fused_opt(n)) { OpDesc depend_desc(n->Op()->Block()); @@ -543,7 +572,15 @@ static void GetGraphOpDesc(const std::vector &nodes, ops->emplace_back(depend_desc); VLOG(4) << "add depend op"; } + if (n->Name() == "while" || n->Name() == "while_grad" || + n->Name() == "conditional_block" || + n->Name() == "conditional_block_grad" || n->Name() == "recurrent" || + n->Name() == "recurrent_grad") { + VLOG(1) << "Update control op attr: skip_eager_deletion_vars"; + UpdateControlOpSkipEagerDeletionVars(*n, graph, graph_idx, n->Name()); + } ops->emplace_back(*n->Op()); + VLOG(4) << n->ToString(); } // delete no OpDesc op } @@ -563,7 +600,8 @@ static void GetGraphVarDesc(const Graph &graph, static void GraphToBlock(const Graph &graph, proto::BlockDesc *block, - const SortKind *sort_kind) { + const SortKind *sort_kind, + const size_t graph_idx) { // Remove the unneeded variables after memory optimization. std::unordered_set vars2remove; if (graph.Has(kGraphToProgramVarsToRemove)) { @@ -607,7 +645,7 @@ static void GraphToBlock(const Graph &graph, } std::vector ops; - GetGraphOpDesc(nodes, &ops); + GetGraphOpDesc(nodes, &ops, graph, graph_idx); for (auto &op : ops) { RemoveControlDepInputAndOuput(&op); @@ -633,7 +671,10 @@ void GraphToProgram(const Graph &graph, block->set_idx(kRootBlockIndex); if (FLAGS_convert_all_blocks) { - GraphToBlock(*graph.GetSubGraph(kRootBlockIndex), block, sort_kind); + GraphToBlock(*graph.GetSubGraph(kRootBlockIndex), + block, + sort_kind, + graph.GetSubGraph(kRootBlockIndex)->GetBlockId()); VLOG(3) << "Graph to program need convert " << graph.SubGraphsSize() << " sub graph"; @@ -644,10 +685,13 @@ void GraphToProgram(const Graph &graph, block = program_pb.add_blocks(); block->set_idx(idx); block->set_parent_idx(kRootBlockIndex); - GraphToBlock(*graph.GetSubGraph(idx), block, sort_kind); + GraphToBlock(*graph.GetSubGraph(idx), + block, + sort_kind, + graph.GetSubGraph(idx)->GetBlockId()); } } else { - GraphToBlock(graph, block, sort_kind); + GraphToBlock(graph, block, sort_kind, graph.GetBlockId()); } program->CopyFrom(program_pb); diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/buffer_shared_inplace_op_pass.cc b/paddle/fluid/framework/ir/memory_optimize_pass/buffer_shared_inplace_op_pass.cc index 91c02a9d8628cbe753ccd5e550250419148ce645..e89c98f15f6f828c28be97e0d3e673cb26149976 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/buffer_shared_inplace_op_pass.cc +++ b/paddle/fluid/framework/ir/memory_optimize_pass/buffer_shared_inplace_op_pass.cc @@ -167,14 +167,15 @@ static std::string GetFirstVarName(const OpDesc &op, static std::vector>> GetInplaceVars(const BlockDesc &block, bool use_cuda, - const std::vector &skip_vars) { + const std::vector &skip_vars, + const bool &for_partial_block) { PADDLE_ENFORCE_EQ( block.ID(), 0, platform::errors::Unimplemented("Inplace can only perform in block 0.")); // only take block 0 gc_vars - const auto op_gc_vars = - GetEagerDeletionCleanVars(*block.Program(), skip_vars)[0]; + const auto op_gc_vars = GetEagerDeletionCleanVarsForPartial( + *block.Program(), skip_vars, for_partial_block)[0]; const auto all_ops = block.AllOps(); PADDLE_ENFORCE_EQ(op_gc_vars.size(), all_ops.size(), @@ -267,9 +268,14 @@ void BufferSharedInplaceOpPass::ApplyImpl(ProgramDesc *main_program, ProgramDesc *startup_program) const { bool use_cuda = Get(kUseCuda); auto skip_vars = Get>("mem_opt_skip_vars"); + bool for_partial_block = false; + if (Has("for_partial_block")) { + for_partial_block = Get("for_partial_block"); + } auto *block = main_program->MutableBlock(0); - auto inplace_vars = GetInplaceVars(*block, use_cuda, skip_vars); + auto inplace_vars = + GetInplaceVars(*block, use_cuda, skip_vars, for_partial_block); PADDLE_ENFORCE_EQ(inplace_vars.size(), block->OpSize(), platform::errors::PermissionDenied( diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/conditional_block_op_eager_deletion_pass.cc b/paddle/fluid/framework/ir/memory_optimize_pass/conditional_block_op_eager_deletion_pass.cc index f253098edbe07599550d2eb4fb232ceae7b5264d..7dad8a115475b85bc54bbf9d82b12e6f86f3e946 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/conditional_block_op_eager_deletion_pass.cc +++ b/paddle/fluid/framework/ir/memory_optimize_pass/conditional_block_op_eager_deletion_pass.cc @@ -75,6 +75,22 @@ class ConditionalOpEagerDeletionPass : public Pass { operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp( graph->OriginProgram(), ifelse_ops, ifelse_grad_ops); } + + for (auto op_hander : all_ops) { + auto *compute_op = + dynamic_cast(op_hander); + if (compute_op == nullptr) continue; + if (compute_op->Name() == "conditional_block" || + compute_op->Name() == "conditional_block_grad") { + ir::Node *op_node = op_hander->Node(); + auto *op_base = compute_op->GetOp(); + if (op_base->Attrs().count("skip_eager_deletion_vars")) { + op_node->Op()->SetAttr( + "skip_eager_deletion_vars", + op_base->Attrs().at("skip_eager_deletion_vars")); + } + } + } } }; diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/recurrent_op_eager_deletion_pass.cc b/paddle/fluid/framework/ir/memory_optimize_pass/recurrent_op_eager_deletion_pass.cc index b1fdb5e2160e00ebfdca537a62fd76358bab8b09..399ad4a3ca52317c7fbaab2542a5d8ccdd4d1330 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/recurrent_op_eager_deletion_pass.cc +++ b/paddle/fluid/framework/ir/memory_optimize_pass/recurrent_op_eager_deletion_pass.cc @@ -43,6 +43,21 @@ void RecurrentOpEagerDeletionPass::ApplyImpl(Graph *graph) const { PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp( graph->OriginProgram(), &op_pair); } + + auto all_ops = ir::FilterByNodeWrapper(*graph); + for (auto op_hander : all_ops) { + auto *compute_op = dynamic_cast(op_hander); + if (compute_op == nullptr) continue; + if (compute_op->Name() == "recurrent" || + compute_op->Name() == "recurrent_grad") { + ir::Node *op_node = op_hander->Node(); + auto *op_base = compute_op->GetOp(); + if (op_base->Attrs().count("skip_eager_deletion_vars")) { + op_node->Op()->SetAttr("skip_eager_deletion_vars", + op_base->Attrs().at("skip_eager_deletion_vars")); + } + } + } } // Returns a std::unordered_map mapping from the device id to recurrent op and diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/while_op_eager_deletion_pass.cc b/paddle/fluid/framework/ir/memory_optimize_pass/while_op_eager_deletion_pass.cc index c381ef33e74fe418a1d1221c8234f947214d97bf..42f395da7c8a8582f0b1bbdc31f183d4af20c70b 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/while_op_eager_deletion_pass.cc +++ b/paddle/fluid/framework/ir/memory_optimize_pass/while_op_eager_deletion_pass.cc @@ -87,6 +87,21 @@ class WhileOpEagerDeletionPass : public ir::Pass { operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( graph->OriginProgram(), while_ops, while_grad_ops); } + + for (auto op_hander : all_ops) { + auto *compute_op = + dynamic_cast(op_hander); + if (compute_op == nullptr) continue; + if (compute_op->Name() == "while" || compute_op->Name() == "while_grad") { + ir::Node *op_node = op_hander->Node(); + auto *op_base = compute_op->GetOp(); + if (op_base->Attrs().count("skip_eager_deletion_vars")) { + op_node->Op()->SetAttr( + "skip_eager_deletion_vars", + op_base->Attrs().at("skip_eager_deletion_vars")); + } + } + } } }; diff --git a/paddle/fluid/framework/new_executor/data_transfer.cc b/paddle/fluid/framework/new_executor/data_transfer.cc index 6d0f79119459e598597baef25d6060f2cbce5687..f2b424d055e47602816c00ea59207e59a271a0c7 100644 --- a/paddle/fluid/framework/new_executor/data_transfer.cc +++ b/paddle/fluid/framework/new_executor/data_transfer.cc @@ -245,6 +245,8 @@ std::shared_ptr TransferLayout(const std::string& var_name, VLOG(3) << "Create Variable " << *new_var_name << " locally, which pointer is " << ptr << "Variable Type " << var_type; + var_scope->MutableDataTransferAddedVars().push_back( + std::make_pair(*new_var_name, var_type)); var_scope->AddVar(*new_var_name, nullptr); // 2. Construct VariableNameMap @@ -288,10 +290,11 @@ std::shared_ptr TransferDtype(const std::string& var_name, auto* ptr = local_scope->Var(*new_var_name); auto var_type = local_scope->FindVar(var_name)->Type(); InitializeVariable(ptr, static_cast(var_type)); - VLOG(3) << "Create Variable " << *new_var_name << " locally, which pointer is " << ptr << "Variable Type " << var_type; + var_scope->MutableDataTransferAddedVars().push_back( + std::make_pair(*new_var_name, var_type)); var_scope->AddVar(*new_var_name, nullptr); // 2. Construct VariableNameMap @@ -328,7 +331,7 @@ std::shared_ptr TransferDevice(const std::string& var_name, *new_var_name = var_name + "_device_" + src_place.DebugString() + "_" + dst_place.DebugString(); - if (local_scope->FindVar(*new_var_name) && + if (var_scope->HasVar(*new_var_name) && IsTensorOfVarInitialized(local_scope->FindVar(*new_var_name))) { // already has same var VLOG(4) << "Use cached variable: " << *new_var_name; @@ -341,6 +344,8 @@ std::shared_ptr TransferDevice(const std::string& var_name, VLOG(3) << "Create Variable " << *new_var_name << " locally, which pointer is " << ptr << "Variable Type " << var_type; + var_scope->MutableDataTransferAddedVars().push_back( + std::make_pair(*new_var_name, var_type)); var_scope->AddVar(*new_var_name, nullptr); // 2. Construct VariableNameMap diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index f57a99e84cce653807a9836c384c20582d5d7717..caf10c62b5bdcbe944b6aa494a277aee606008d5 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -53,12 +53,14 @@ static constexpr size_t kDeviceNumThreads = 1; InterpreterCore::InterpreterCore(const platform::Place& place, const BlockDesc& block, const std::set& skip_gc_vars, - framework::Scope* scope) + framework::Scope* scope, + bool used_for_jit) : place_(place), block_(block), skip_gc_vars_(skip_gc_vars), var_scope_(scope), - stream_analyzer_(place) { + stream_analyzer_(place), + used_for_jit_(used_for_jit) { VLOG(4) << "InterpreterCore(): " << this << " on " << place_; is_build_ = false; @@ -67,6 +69,10 @@ InterpreterCore::InterpreterCore(const platform::Place& place, completion_notifier_ = main_thread_blocker_.RegisterEvent(kTaskCompletion); create_local_scope_ = FLAGS_new_executor_use_local_scope; + + if (used_for_jit_) { + create_local_scope_ = false; + } VLOG(4) << "create_local_scope_ is " << create_local_scope_; if (create_local_scope_) { @@ -85,7 +91,6 @@ InterpreterCore::InterpreterCore(const platform::Place& place, InterpreterCore::~InterpreterCore() { // cancle gc's thread gc_.reset(nullptr); - async_work_queue_.reset(); VLOG(4) << "~InterpreterCore(): " << this << " on " << place_; @@ -184,7 +189,8 @@ paddle::framework::FetchList InterpreterCore::Run( platform::AttachPointerHashToMKLDNNKey(this, place_); #endif if (!is_build_) { - paddle::framework::interpreter::build_variable_scope(block_, &var_scope_); + paddle::framework::interpreter::build_variable_scope( + block_, &var_scope_, create_local_scope_); std::vector op_func_nodes; paddle::framework::interpreter::build_op_func_list(place_, @@ -192,12 +198,12 @@ paddle::framework::FetchList InterpreterCore::Run( skip_gc_vars_, &op_func_nodes, &var_scope_, - create_local_scope_); + create_local_scope_, + used_for_jit_); is_build_ = true; SetFeedVarsInplaceSkip(feed_names); // convert vec func_list to graph Convert(&op_func_nodes); - } else { // For the program that only run once, it is no need to // create work_queue, so the async_work_queue_ is created @@ -219,7 +225,9 @@ paddle::framework::FetchList InterpreterCore::Run( ClearLoDTensorArrayInLocalScope(); } // return Fetch Tensors - auto* fetch_var = local_scope_->FindVar(interpreter::kFetchVarName); + Scope* inner_scope = + create_local_scope_ ? local_scope_ : var_scope_.GetMutableScope(); + auto* fetch_var = inner_scope->FindVar(interpreter::kFetchVarName); if (fetch_var) { return std::move(*fetch_var->GetMutable()); } else { @@ -231,6 +239,31 @@ void InterpreterCore::SetCopyProgram(std::shared_ptr prog) { copy_program_ = prog; } +void InterpreterCore::SetSkipGcVars(const std::set& skip_gc_vars) { + PADDLE_ENFORCE_EQ( + skip_gc_vars_.empty(), + true, + platform::errors::PreconditionNotMet( + "Skip_gc_vars_ can only be initialized once, now skip_gc_vars_ is " + "not empty, do not call SetSkipGcVars method repeatedly.")); + skip_gc_vars_ = skip_gc_vars; +} + +const VariableScope* InterpreterCore::GetVariableScope() const { + return &var_scope_; +} + +void InterpreterCore::reset_scope(Scope* new_scope) { + var_scope_.SetScope(new_scope); + auto& var_list = var_scope_.MutableVarList(); + for (size_t i = 0; i < var_list.size(); i++) { + var_list[i] = new_scope->FindVar(var_scope_.GetNameById(i)); + } + for (size_t i = 0; i < vec_instruction_.size(); ++i) { + BuildAndCacheInstructionCtx(&vec_instruction_[i]); + } +} + void InterpreterCore::ShareWorkQueueFrom(std::shared_ptr src) { async_work_queue_ = src->GetWorkQueue(); VLOG(8) << "Share AsyncWorkQueue from InterpreterCore(" << &src @@ -262,14 +295,15 @@ std::shared_ptr InterpreterCore::GetWorkQueue() { } void InterpreterCore::BuildAndCacheInstructionCtx(Instruction* instr_node) { + Scope* inner_scope = + create_local_scope_ ? local_scope_ : var_scope_.GetMutableScope(); VariableValueMap ins_map; for (auto& var_name_item : instr_node->Inputs()) { std::vector input_vars; input_vars.reserve(var_name_item.second.size()); for (auto& id : var_name_item.second) { - input_vars.emplace_back( - local_scope_->FindVar(var_scope_.GetNameById(id))); + input_vars.emplace_back(inner_scope->FindVar(var_scope_.GetNameById(id))); } ins_map.emplace(var_name_item.first, std::move(input_vars)); } @@ -280,7 +314,7 @@ void InterpreterCore::BuildAndCacheInstructionCtx(Instruction* instr_node) { out_vars.reserve(var_name_item.second.size()); for (auto& id : var_name_item.second) { - out_vars.emplace_back(local_scope_->FindVar(var_scope_.GetNameById(id))); + out_vars.emplace_back(inner_scope->FindVar(var_scope_.GetNameById(id))); } outs_map.emplace(var_name_item.first, std::move(out_vars)); } @@ -319,6 +353,9 @@ void InterpreterCore::BuildInplace() { } } + Scope* local_scope = create_local_scope_ ? var_scope_.GetMutableLocalScope() + : var_scope_.GetMutableScope(); + for (size_t i = 0; i < vec_instruction_.size(); ++i) { auto& instr = vec_instruction_[i]; auto* op_base = instr.OpBase(); @@ -348,8 +385,8 @@ void InterpreterCore::BuildInplace() { var_scope_.GetNameById(iter->second[0]); const std::string& outvar_name = var_scope_.GetNameById(iterout->second[0]); - auto invar = local_scope_->FindVar(invar_name); - auto outvar = local_scope_->FindVar(outvar_name); + auto invar = local_scope->FindVar(invar_name); + auto outvar = local_scope->FindVar(outvar_name); if (invar && outvar && invar->IsType() && outvar->IsType() && @@ -410,15 +447,12 @@ void InterpreterCore::Convert( auto* dev_ctx_ = stream_analyzer_.ParseDeviceContext(op_func_node); vec_instruction_.emplace_back(op_idx, std::move(op_func_node), *dev_ctx_); } - BuildOperatorDependences(); - // calculate last_live_ops_ for (size_t op_idx = 0; op_idx < op_nums; ++op_idx) { auto& instr = vec_instruction_[op_idx]; OpInOutInfo info; std::set gc_check_inputs; - for (auto& item : instr.Inputs()) { for (auto id : item.second) { if (id == kEmptyVarIndex) { @@ -439,10 +473,11 @@ void InterpreterCore::Convert( } } } - for (auto var_id : gc_check_inputs) { + Scope* inner_scope = + create_local_scope_ ? local_scope_ : var_scope_.GetMutableScope(); paddle::framework::Variable* var = - local_scope_->FindVar(var_scope_.GetNameById(var_id)); + inner_scope->FindVar(var_scope_.GetNameById(var_id)); if (var->IsType() || var->IsType() || var->IsType()) { last_live_ops_[var_id].insert(op_idx); @@ -453,7 +488,6 @@ void InterpreterCore::Convert( } } } - for (size_t i = 0; i < vec_instruction_.size(); ++i) { // checkout output for (auto& item : vec_instruction_[i].Outputs()) { @@ -464,7 +498,6 @@ void InterpreterCore::Convert( } } } - // clear the last_live_ops list for all vars in skip_gc_vars for (const std::string& skip_gc_var : skip_gc_vars_) { int var_id = var_scope_.GetIdByName(skip_gc_var); @@ -561,7 +594,6 @@ void InterpreterCore::BuildSkipShareLoDInfo() { void InterpreterCore::RunInstruction(const Instruction& instr_node) { auto* op = instr_node.OpBase(); auto place = instr_node.DeviceContext().GetPlace(); - VLOG(4) << "Start run " << place << " " << op->DebugStringEx(local_scope_); Scope* local_scope = create_local_scope_ ? var_scope_.GetMutableLocalScope() : var_scope_.GetMutableScope(); @@ -602,7 +634,6 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { *(instr_node.InnerRuntimeContext())); } } - if (op_with_kernel != nullptr && FLAGS_new_executor_use_inplace) { // TODO(xiongkun03) Does operator base support inplace ? for (auto& pair : instr_node.InplaceInfo()) { @@ -1009,7 +1040,6 @@ void InterpreterCore::Prepare( "but received %d != %d", feed_names.size(), feed_tensors.size())); - auto FeedInput = [&] { VLOG(4) << "Feed inputs"; for (size_t i = 0; i < feed_names.size(); ++i) { @@ -1035,7 +1065,8 @@ void InterpreterCore::Prepare( skip_gc_vars_, &op_func_nodes, &var_scope_, - create_local_scope_); + create_local_scope_, + used_for_jit_); is_build_ = true; SetFeedVarsInplaceSkip(feed_names); // convert vec func_list to graph diff --git a/paddle/fluid/framework/new_executor/interpretercore.h b/paddle/fluid/framework/new_executor/interpretercore.h index a7efa1349e8f12e57de481fd3e3d33a1f4bc93aa..2d60b0231a5c08e6ae58f4fdafa779fb01421835 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.h +++ b/paddle/fluid/framework/new_executor/interpretercore.h @@ -41,7 +41,8 @@ class InterpreterCore { InterpreterCore(const platform::Place& place, const BlockDesc& block, const std::set& skip_gc_vars, - Scope* scope); + Scope* scope, + bool used_for_jit = false); ~InterpreterCore(); @@ -59,6 +60,12 @@ class InterpreterCore { void SetCopyProgram(std::shared_ptr prog); + void SetSkipGcVars(const std::set& skip_gc_vars); + + const VariableScope* GetVariableScope() const; + + void reset_scope(Scope* new_scope); + private: bool BuildInplaceCheckVarIsOnlyInput(size_t var_index); @@ -103,9 +110,9 @@ class InterpreterCore { bool is_build_; - const platform::Place& place_; + platform::Place place_; const BlockDesc& block_; // not owned - const std::set skip_gc_vars_; + std::set skip_gc_vars_; interpreter::DependencyBuilder dependency_builder_; @@ -144,6 +151,8 @@ class InterpreterCore { std::future> atomic_deps_; std::future> atomic_var_ref_; + + bool used_for_jit_{false}; }; std::shared_ptr CreateInterpreterCore( diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.cc b/paddle/fluid/framework/new_executor/interpretercore_util.cc index 38ef5fd70ce873f79fe4e495b4192061c24dba3e..27606ca2b0c2daad8d38405de6e7c2493cfd2ce5 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.cc +++ b/paddle/fluid/framework/new_executor/interpretercore_util.cc @@ -387,10 +387,8 @@ void deal_operator_base(const platform::Place& place, PADDLE_THROW( platform::errors::Fatal("Unsupported current place %s", place)); } - op_func_node->kernel_func_ = nullptr; op_base->Run(*local_scope, place); // Run without data transformer. - std::unordered_set no_data_transform_index; for (auto& it : op_func_node->input_index) { for (auto& id : it.second) { @@ -407,7 +405,8 @@ void build_op_func_list(const platform::Place& place, const std::set& skip_gc_vars, std::vector* vec_func_list, VariableScope* var_scope, - bool use_local_scope) { + bool use_local_scope, + bool used_for_jit) { Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope() : var_scope->GetMutableScope(); std::vector> @@ -415,19 +414,21 @@ void build_op_func_list(const platform::Place& place, bool flag_log_is_printed = false; // Step 1: create all ops for current block. create_all_ops(block, &ops_unique); - // If gc is enabled and block size > 1 - const ProgramDesc& main_program = *block.Program(); - operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp( - main_program, block.ID(), ops_unique); - operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( - main_program, block.ID(), ops_unique); - operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp( - main_program, block.ID(), ops_unique); + + if (!used_for_jit) { + // If gc is enabled and block size > 1 + const ProgramDesc& main_program = *block.Program(); + operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp( + main_program, block.ID(), ops_unique); + operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( + main_program, block.ID(), ops_unique); + operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp( + main_program, block.ID(), ops_unique); + } #ifdef PADDLE_WITH_MKLDNN platform::RegisterModelLayout(ops_unique, place); #endif - // its elements will be moved to vec_func_list std::vector> ops; for (auto& op_unique : ops_unique) { @@ -484,157 +485,187 @@ void build_op_func_list(const platform::Place& place, } #endif - if (dynamic_cast(op) == nullptr) { - // op is not a operatorwithkernel, so direcly run OperatorBase::Run() - deal_operator_base(place, var_scope, ops[i], &op_func_node, local_scope); - } else { - auto op_with_kernel = const_cast( - static_cast(op)); - // construct RuntimeContext and analysis KernelType - RuntimeContext runtime_context({}, {}); - runtime_context.inputs.swap(ins_map); - runtime_context.outputs.swap(outs_map); - - Scope scope, *runtime_scope = &scope; - // NOTE(Ruibiao): We do not encourage directly using scope in OP kernel. - // But some OPs do have such behavior (e.g., cinn_launch OP). Here special - // treatment for them. - if (op_with_kernel->Type() == "cinn_launch") { - VLOG(6) << "OP(" << op_with_kernel->Type() - << ") use scope in kernel, " - "so pass a real scope to " - "ExecutionContext"; - runtime_scope = local_scope; - } + try { + if (dynamic_cast(op) == nullptr) { + // op is not a operatorwithkernel, so direcly run OperatorBase::Run() + deal_operator_base( + place, var_scope, ops[i], &op_func_node, local_scope); + VLOG(4) << "deal_operator_base"; + } else { + VLOG(4) << "OP is not null"; + auto op_with_kernel = const_cast( + static_cast(op)); + VLOG(4) << "get op_with_kernel"; + // construct RuntimeContext and analysis KernelType + RuntimeContext runtime_context({}, {}); + runtime_context.inputs.swap(ins_map); + runtime_context.outputs.swap(outs_map); + VLOG(4) << "get RuntimeContext"; + + Scope scope, *runtime_scope = &scope; + // NOTE(Ruibiao): We do not encourage directly using scope in OP kernel. + // But some OPs do have such behavior (e.g., cinn_launch OP). Here + // special treatment for them. + if (op_with_kernel->Type() == "cinn_launch") { + VLOG(6) << "OP(" << op_with_kernel->Type() + << ") use scope in kernel, " + "so pass a real scope to " + "ExecutionContext"; + runtime_scope = local_scope; + } - auto& pool = platform::DeviceContextPool::Instance(); - auto* dev_ctx = pool.Get(place); - auto exec_ctx = ExecutionContext( - *op_with_kernel, *runtime_scope, *dev_ctx, runtime_context); - auto expected_kernel_key = - op_with_kernel->GetExpectedKernelType(exec_ctx); - // change device by the device_guard() - apply_device_guard(op, place, &expected_kernel_key); - VLOG(4) << "expected_kernel_key : " << expected_kernel_key; - - // step 2. select op kernel - auto run_phi_kernel = false; - if (phi::KernelFactory::Instance().HasCompatiblePhiKernel( - op_with_kernel->Type())) { - auto phi_kernel_key = op_with_kernel->ChoosePhiKernel(exec_ctx); - auto phi_kernel_name = op_with_kernel->PhiKernelSignature()->name; - - if (op_with_kernel->PhiKernel()->IsValid()) { - run_phi_kernel = true; - } else { - if (!op_with_kernel->SupportsKernelType(expected_kernel_key)) { - auto phi_cpu_kernel_key = FallBackToCpu( - expected_kernel_key, phi_kernel_key, *op_with_kernel); - op_with_kernel->ResetPhiKernel( - new phi::Kernel(phi::KernelFactory::Instance().SelectKernel( - phi_kernel_name, phi_cpu_kernel_key))); - if (op_with_kernel->PhiKernel()->IsValid()) { - VLOG(6) << "Static mode PrepareImpl - kernel name: " - << phi_kernel_name - << " | kernel key: " << phi_cpu_kernel_key - << " | kernel: " << *(op_with_kernel->PhiKernel()); - op_with_kernel->ResetKernelType(new OpKernelType( - TransPhiKernelKeyToOpKernelType(phi_cpu_kernel_key))); - run_phi_kernel = true; + auto& pool = platform::DeviceContextPool::Instance(); + auto* dev_ctx = pool.Get(place); + VLOG(4) << "get dev_ctx"; + auto exec_ctx = ExecutionContext( + *op_with_kernel, *runtime_scope, *dev_ctx, runtime_context); + VLOG(4) << "get exec_ctx"; + auto expected_kernel_key = + op_with_kernel->GetExpectedKernelType(exec_ctx); + VLOG(4) << "get expected_kernel_key"; + // change device by the device_guard() + apply_device_guard(op, place, &expected_kernel_key); + VLOG(4) << "expected_kernel_key : " << expected_kernel_key; + + // step 2. select op kernel + auto run_phi_kernel = false; + if (phi::KernelFactory::Instance().HasCompatiblePhiKernel( + op_with_kernel->Type())) { + auto phi_kernel_key = op_with_kernel->ChoosePhiKernel(exec_ctx); + auto phi_kernel_name = op_with_kernel->PhiKernelSignature()->name; + + if (op_with_kernel->PhiKernel()->IsValid()) { + run_phi_kernel = true; + } else { + if (!op_with_kernel->SupportsKernelType(expected_kernel_key)) { + auto phi_cpu_kernel_key = FallBackToCpu( + expected_kernel_key, phi_kernel_key, *op_with_kernel); + op_with_kernel->ResetPhiKernel( + new phi::Kernel(phi::KernelFactory::Instance().SelectKernel( + phi_kernel_name, phi_cpu_kernel_key))); + if (op_with_kernel->PhiKernel()->IsValid()) { + VLOG(6) << "Static mode PrepareImpl - kernel name: " + << phi_kernel_name + << " | kernel key: " << phi_cpu_kernel_key + << " | kernel: " << *(op_with_kernel->PhiKernel()); + op_with_kernel->ResetKernelType(new OpKernelType( + TransPhiKernelKeyToOpKernelType(phi_cpu_kernel_key))); + run_phi_kernel = true; + } } } } - } - if (!run_phi_kernel) { - op_with_kernel->ChooseKernel(exec_ctx); - op_func_node.kernel_func_ = *op_with_kernel->kernel_func(); - } else { - op_func_node.phi_kernel_ = op_with_kernel->PhiKernel(); - } - auto kernel_type = *(op_with_kernel->kernel_type()); - if (kernel_type.place_ != dev_ctx->GetPlace()) { - dev_ctx = pool.Get(kernel_type.place_); - } - op_func_node.dev_ctx_ = dev_ctx; - if (IsSupportedHetePlace(kernel_type.place_)) { - op_func_node.type_ = OpFuncType::kQueueAsync; - } else if (platform::is_cpu_place(kernel_type.place_)) { - op_func_node.type_ = OpFuncType::kQueueSync; - } else { - PADDLE_THROW(platform::errors::Fatal("Unsupported current place %s", - kernel_type.place_)); - } - VLOG(3) << op_with_kernel->Type() - << " : finally selected kernel_key: " << kernel_type; - - // step 3. data transform - VariableValueMap& ins_map_temp = runtime_context.inputs; - VariableValueMap& outs_map_temp = runtime_context.outputs; - ApplyDataTransform(kernel_type, - place, - &ins_map_temp, - &outs_map_temp, - var_scope, - &op_func_node, - vec_func_list, - use_local_scope); - - // step 4. infershape, see OperatorWithKernel::RunImpl in operator.cc for - // why. - if (!(op->HasAttr(kAllKernelsMustComputeRuntimeShape) && - op->Attr(kAllKernelsMustComputeRuntimeShape))) { - InterpretercoreInferShapeContext infer_shape_ctx(*op, runtime_context); - // TODO(Aurelius84): In case of control flow ops, they are NOT - // inheritted from OperatorWithKernel. - op_with_kernel->Info().infer_shape_(&infer_shape_ctx); - } + VLOG(4) << "if run phi kernel? : " << run_phi_kernel; + if (!run_phi_kernel) { + op_with_kernel->ChooseKernel(exec_ctx); + op_func_node.kernel_func_ = *op_with_kernel->kernel_func(); + } else { + op_func_node.phi_kernel_ = op_with_kernel->PhiKernel(); + } + auto kernel_type = *(op_with_kernel->kernel_type()); + if (kernel_type.place_ != dev_ctx->GetPlace()) { + dev_ctx = pool.Get(kernel_type.place_); + } + op_func_node.dev_ctx_ = dev_ctx; + if (IsSupportedHetePlace(kernel_type.place_)) { + op_func_node.type_ = OpFuncType::kQueueAsync; + } else if (platform::is_cpu_place(kernel_type.place_)) { + op_func_node.type_ = OpFuncType::kQueueSync; + } else { + PADDLE_THROW(platform::errors::Fatal("Unsupported current place %s", + kernel_type.place_)); + } + VLOG(3) << op_with_kernel->Type() + << " : finally selected kernel_key: " << kernel_type; + + // step 3. data transform + VariableValueMap& ins_map_temp = runtime_context.inputs; + VariableValueMap& outs_map_temp = runtime_context.outputs; + ApplyDataTransform(kernel_type, + place, + &ins_map_temp, + &outs_map_temp, + var_scope, + &op_func_node, + vec_func_list, + use_local_scope); + VLOG(4) << "apply data transform done. "; + // step 4. infershape, see OperatorWithKernel::RunImpl in operator.cc + // for why. + if (!(op->HasAttr(kAllKernelsMustComputeRuntimeShape) && + op->Attr(kAllKernelsMustComputeRuntimeShape))) { + InterpretercoreInferShapeContext infer_shape_ctx(*op, + runtime_context); + // TODO(Aurelius84): In case of control flow ops, they are NOT + // inheritted from OperatorWithKernel. + op_with_kernel->Info().infer_shape_(&infer_shape_ctx); + } - // step 5. run kernel - if (run_phi_kernel) { - phi::KernelContext phi_kernel_context; - op_with_kernel->BuildPhiKernelContext( - runtime_context, dev_ctx, &phi_kernel_context); - (*op_func_node.phi_kernel_)(&phi_kernel_context); - } else { - // the place of exec_ctx maybe has changed. - op_func_node.kernel_func_(ExecutionContext( - *op_with_kernel, *runtime_scope, *dev_ctx, runtime_context)); - } + // step 5. run kernel + if (run_phi_kernel) { + VLOG(1) << "start run phi kernel. "; + phi::KernelContext phi_kernel_context; + op_with_kernel->BuildPhiKernelContext( + runtime_context, dev_ctx, &phi_kernel_context); + (*op_func_node.phi_kernel_)(&phi_kernel_context); + VLOG(1) << "end run phi kernel. "; + } else { + VLOG(4) << "start run kernel. "; + // the place of exec_ctx maybe has changed. + op_func_node.kernel_func_(ExecutionContext( + *op_with_kernel, *runtime_scope, *dev_ctx, runtime_context)); + VLOG(4) << "end run kernel. "; + } - // post-process grad_op.outputs if need cast complex grad into real - // grad. - // NOTE(Aurelius84): insert a transfer_dtype_op inplacely to cast it. - if (framework::IsComplexType(kernel_type.data_type_)) { - interpreter::HandleComplexGradToRealGrad(op_func_node, - place, - outputs_names, - &runtime_context.outputs, - var_scope, - vec_func_list, - local_scope); - } - if (!op_func_node.inplace_back_map.empty()) { - auto& m = op_func_node.inplace_back_map; - // NOTE(zhiqiu): same logic as TransferInplaceVarsBack() in - // operator.cc - for (auto& p : m) { - auto* transformed_tensor = - GetMutableLoDTensorOrSelectedRowsValueFromVar( - local_scope->FindVar(var_scope->GetNameById(p.first))); - auto* original_tensor = GetMutableLoDTensorOrSelectedRowsValueFromVar( - local_scope->FindVar(var_scope->GetNameById(p.second))); - original_tensor->ShareDataWith(*transformed_tensor); - VLOG(4) << "Transfer inplace variable back form " - << var_scope->GetNameById(p.first) << " to " - << var_scope->GetNameById(p.second); + // post-process grad_op.outputs if need cast complex grad into real + // grad. + // NOTE(Aurelius84): insert a transfer_dtype_op inplacely to cast it. + if (framework::IsComplexType(kernel_type.data_type_)) { + interpreter::HandleComplexGradToRealGrad(op_func_node, + place, + outputs_names, + &runtime_context.outputs, + var_scope, + vec_func_list, + local_scope); + } + if (!op_func_node.inplace_back_map.empty()) { + auto& m = op_func_node.inplace_back_map; + // NOTE(zhiqiu): same logic as TransferInplaceVarsBack() in + // operator.cc + for (auto& p : m) { + auto* transformed_tensor = + GetMutableLoDTensorOrSelectedRowsValueFromVar( + local_scope->FindVar(var_scope->GetNameById(p.first))); + auto* original_tensor = + GetMutableLoDTensorOrSelectedRowsValueFromVar( + local_scope->FindVar(var_scope->GetNameById(p.second))); + original_tensor->ShareDataWith(*transformed_tensor); + VLOG(4) << "Transfer inplace variable back form " + << var_scope->GetNameById(p.first) << " to " + << var_scope->GetNameById(p.second); + } } - } - // for debug nan/inf - if (FLAGS_check_nan_inf) { - VLOG(4) << "Check nan/inf"; - framework::details::CheckOpHasNanOrInf(*op, *runtime_scope, place); + // for debug nan/inf + if (FLAGS_check_nan_inf) { + VLOG(4) << "Check nan/inf"; + framework::details::CheckOpHasNanOrInf(*op, *runtime_scope, place); + } } + } catch (platform::EnforceNotMet& ex) { + framework::InsertCallStackInfo(op->Type(), op->Attrs(), &ex); + throw std::move(ex); + } catch (platform::EOFException&) { + std::rethrow_exception(std::current_exception()); + } catch (std::exception& ex) { + LOG(WARNING) << op->Type() << " raises an exception " + << platform::demangle(typeid(ex).name()) << ", " + << ex.what(); + std::rethrow_exception(std::current_exception()); + } catch (...) { + LOG(WARNING) << op->Type() << " raises an unknown exception"; + std::rethrow_exception(std::current_exception()); } VLOG(4) << "End run " << place << " " @@ -662,20 +693,6 @@ void build_op_func_list(const platform::Place& place, if (var->IsType()) { garbages->emplace_back( var->GetMutable()->MoveMemoryHolder()); - } else if (var->IsType()) { - garbages->emplace_back(var->GetMutable() - ->mutable_value() - ->MoveMemoryHolder()); - } else if (var->IsType()) { - auto* lod_tensor_arr = var->GetMutable(); - for (auto& t : *lod_tensor_arr) { - garbages->emplace_back(t.MoveMemoryHolder()); - } - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Type %s of variable %s is not supported eager deletion.", - framework::ToTypeName(var->Type()), - var_name)); } } delete garbages; // free mem diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.h b/paddle/fluid/framework/new_executor/interpretercore_util.h index dfbc493d9dc2165781c5e786c250d9e567553ace..1860b19b1ca4208e7a28c16cf1c0ad98ef0230ed 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.h +++ b/paddle/fluid/framework/new_executor/interpretercore_util.h @@ -80,7 +80,8 @@ void build_op_func_list(const platform::Place& place, const std::set& skip_gc_vars, std::vector* vec_func_list, VariableScope* scope, - bool use_local_scope = true); + bool use_local_scope = true, + bool used_for_jit = false); void add_fetch(const std::vector& fetch_names, framework::BlockDesc* block); diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.cc b/paddle/fluid/framework/new_executor/new_executor_defs.cc index f2e39def06968a3a8c904bd1a9644fcd36e87047..6492538c6084d36968c02f1ddf97f57e187a8773 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.cc +++ b/paddle/fluid/framework/new_executor/new_executor_defs.cc @@ -577,6 +577,8 @@ Scope* VariableScope::GetMutableScope() const { return scope_; } Scope* VariableScope::GetMutableLocalScope() const { return local_scope_; } +void VariableScope::SetScope(Scope* scope) { scope_ = scope; } + void VariableScope::SetLocalScope(Scope* local_scope) { VLOG(4) << "Set local scope: " << local_scope; local_scope_ = local_scope; @@ -626,7 +628,11 @@ void VariableScope::AddVar(const std::string& name, auto id = VarSize(); name2id_[name] = id; vec_meta_info_.emplace_back(0, var_desc); - var_list_.push_back(local_scope_->FindVar(name)); + if (local_scope_ != nullptr) { + var_list_.push_back(local_scope_->FindVar(name)); + } else { + var_list_.push_back(scope_->FindVar(name)); + } PADDLE_ENFORCE_EQ( var_list_.size(), name2id_.size(), @@ -783,6 +789,8 @@ void Instruction::AddInplace(Variable* in, Variable* out) { vec_inplace_in_to_out_.emplace_back(in, out); } +void Instruction::ClearInplace() { vec_inplace_in_to_out_.clear(); } + const std::vector& Instruction::InputEvents() const { return intput_events_; } diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.h b/paddle/fluid/framework/new_executor/new_executor_defs.h index 711bef3743d24ab83487833f722d1de551a7050f..82eb237e73d18f259c600b1753a6120cc5d5741a 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.h +++ b/paddle/fluid/framework/new_executor/new_executor_defs.h @@ -176,6 +176,8 @@ class VariableScope { Scope* GetMutableLocalScope() const; + void SetScope(Scope* scope); + void SetLocalScope(Scope* local_scope); ~VariableScope(); @@ -212,6 +214,17 @@ class VariableScope { return vec_meta_info_; } + const std::vector>& DataTransferAddedVars() + const { + return data_transfer_added_vars_; + } + + std::vector>& MutableDataTransferAddedVars() { + return data_transfer_added_vars_; + } + + std::vector& MutableVarList() { return var_list_; } + void SetVarSikpInplace(const std::string& name, bool skip); bool GetVarSikpInplace(int id) const; @@ -228,6 +241,9 @@ class VariableScope { // TODO(zhiqiu): find a better way to support local scope. Scope* local_scope_{nullptr}; // mutable RWLock vars_lock_; + + // var_name -> var_type + std::vector> data_transfer_added_vars_; }; class NextInstruction { @@ -340,6 +356,8 @@ class Instruction { void AddInplace(Variable* in, Variable* out); + void ClearInplace(); + const std::vector& InputEvents() const; const std::vector& OutputEvents() const; diff --git a/paddle/fluid/framework/scope.cc b/paddle/fluid/framework/scope.cc index de1f1812235869526d3ef456d1d36cf07abb715a..87312cbfde2b9539ee731b13d5684fccdb1d1949 100644 --- a/paddle/fluid/framework/scope.cc +++ b/paddle/fluid/framework/scope.cc @@ -115,7 +115,10 @@ const Scope* Scope::FindScope(const std::string& name) const { void Scope::DropKids() { { SCOPE_KIDS_WRITER_LOCK - for (Scope* s : kids_) delete s; + for (Scope* s : kids_) { + delete s; + s = nullptr; + } kids_.clear(); } } diff --git a/paddle/fluid/operators/run_program_op.cc b/paddle/fluid/operators/run_program_op.cc index 94da2b2b35ba0d925a4a6afb9303a9a86ef6f72d..0d384eef8a02c7ff43a7ba277adfba52d566bd50 100644 --- a/paddle/fluid/operators/run_program_op.cc +++ b/paddle/fluid/operators/run_program_op.cc @@ -119,6 +119,17 @@ class RunProgramOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("cuda_graph_pool_id", "(int64_t, default 0) The CUDA Graph memory pool ID.") .SetDefault(0); + AddAttr("use_interpretorcore", + "(bool, default false) Set to true for use interpretercore.") + .SetDefault(false); + AddAttr("forward_global_block", + "(BlockDesc *)" + "The global block of executed forward program desc.") + .SetDefault(nullptr); + AddAttr("backward_global_block", + "(BlockDesc *)" + "The global block of executed backward program desc.") + .SetDefault(nullptr); AddComment(R"DOC( RunProgram operator. diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index 73f7e9a098c14779124e7c329109daf50a9a80a6..845222b6ea0714fe602a21de62cc33c81ce2a163 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -67,6 +67,7 @@ void BindGraph(py::module *m) { "The graph is a Directed Acyclic Single Static Assignment Graph, see " "`paddle::ir::Graph` for details.") .def(py::init()) + .def(py::init()) .def("clone", &Graph::Clone) .def("has", &Graph::Has) .def("get_bool", &Graph::Get) diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 425bdfea72a9745a186ba49e1560a39fc3dc438c..ae4ee11bdc6b76e907854885b5baa50a1c01b5ef 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -2147,8 +2147,14 @@ All parameter, weight, gradient are variables in Paddle. m.def("set_cudnn_switch", platform::SetAllowTF32Cudnn); m.def("get_cudnn_switch", platform::AllowTF32Cudnn); #endif // PADDLE_WITH_CUDA - m.def("clear_executor_cache", - []() { framework::ExecutorInfoCache::Instance().Finalize(); }); + m.def("clear_executor_cache", []() { + pybind11::gil_scoped_release release; + framework::ExecutorInfoCache::Instance().Finalize(); + framework::InterpreterCoreInfoCache::Instance().Finalize(); + }); + + m.def("parse_safe_eager_deletion_skip_vars", + paddle::framework::details::ParseSafeEagerDeletionSkipVarsSet); #ifdef PADDLE_WITH_IPU py::class_ 0: + return add_build_strategy_for(whole_program, 0, end_op_index) + else: + return whole_program + + @LazyInitialized + def _forward_program_desc(self): + return self._create_forward_train_program().desc + + # backward + @switch_to_static_graph + def _create_backward_train_program(self): + whole_program = _build_program_by_desc(self._train_program_desc) + start_op_index = self._infer_program_desc.block(0).op_size() + 2 * len( + self._output_descs) + end_op_index = whole_program.desc.block(0).op_size() + if (start_op_index < end_op_index): + return add_build_strategy_for(whole_program, start_op_index, + end_op_index) + else: + return paddle.static.Program() + + @LazyInitialized + def _backward_program_desc(self): + return self._create_backward_train_program().desc + @property def infer_program(self): return self._infer_program_desc @@ -341,6 +374,14 @@ class _ProgramHolder(object): def train_program(self): return self._train_program_desc + @property + def forward_program(self): + return self._forward_program_desc + + @property + def backward_program(self): + return self._backward_program_desc + @property def input_descs(self): return self._input_descs @@ -460,7 +501,7 @@ class _ProgramHolder(object): self._output_descs[i] = var.desc @switch_to_static_graph - def _append_backward_desc(self, infer_program_desc): + def _get_train_forward_program(self, infer_program_desc): program_desc_copy = core.ProgramDesc(infer_program_desc) # 1. set all `is_test` attributes to False @@ -488,6 +529,11 @@ class _ProgramHolder(object): persistable=False, stop_gradient=True) op.desc.set_output("ReserveSpace", [reserve_space.name]) + return program + + @switch_to_static_graph + def _append_backward_desc(self, infer_program_desc): + program = self._get_train_forward_program(infer_program_desc) targets = [] for out in self._output_descs: @@ -861,14 +907,29 @@ def _run_dygraph(instance, input, program_holder): # 2. run program by op trace_program = program_holder.infer_program if instance._is_test else program_holder.train_program + forward_program = program_holder._infer_program_desc if instance._is_test else program_holder.forward_program end_op_index = program_holder.infer_program.block(0).op_size() - attrs = ('global_block', trace_program.block(0), 'start_op_index', 0, - 'end_op_index', end_op_index, 'is_test', instance._is_test, - 'program_id', _hash_with_id(trace_program, instance)) + + attrs = [ + 'global_block', + trace_program.block(0), 'start_op_index', 0, 'end_op_index', + end_op_index, 'is_test', instance._is_test, 'program_id', + _hash_with_id(trace_program, instance) + ] + + use_interpretorcore = _is_enable_standalone_executor( + ) and _is_dy2st_enable_standalone_executor() + attrs.extend(('use_interpretorcore', use_interpretorcore)) + if use_interpretorcore: + attrs.extend( + ('forward_global_block', forward_program.block(0), + 'backward_global_block', program_holder.backward_program.block(0))) + _legacy_C_ops.run_program(_valid_vars(input_vars), _valid_vars(persistable_vars), _valid_vars(output_vars), tmp_scope_vec, _valid_vars(double_grad_vars), None, *attrs) + # NOTE: [ why need set param's gradient type here ] # if user set sparse gradient mode, the param's gradient # will be SelectedRows, not LoDTensor. But tracer will just diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 601c66f5b3007a8558c82c56c8ba1f7c601ec38f..5f96d7f06c5d9a31534035738240a7e7abb5f510 100755 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -399,6 +399,12 @@ def _is_enable_standalone_executor(): ] +def _is_dy2st_enable_standalone_executor(): + return framework._dy2st_enable_standalone_executor_ in [ + 1, '1', True, 'True', 'true' + ] + + def _prepare_fleet_executor(): from ..distributed.fleet.proto import fleet_executor_desc_pb2 trainer_endpoints_str = os.getenv("PADDLE_TRAINER_ENDPOINTS", "") diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 3a98f1c5655f5f839874f6377ebabff332e1e50d..6d16dfcb10f19b8af808ff6b2be86ce1c28529b3 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -86,6 +86,8 @@ _current_cuda_graph_mode = None _global_flags_ = core.globals() _enable_standalone_executor_ = (os.environ.get('FLAGS_USE_STANDALONE_EXECUTOR', None)) +_dy2st_enable_standalone_executor_ = (os.environ.get( + 'FLAGS_DY2ST_USE_STANDALONE_EXECUTOR', 1)) # Some explanation of our execution system 2022.03 # For now we have 3 kinds of execution system, since we refactored dygraph mode to @@ -5040,6 +5042,8 @@ class Program(object): all_new_vars = [] block_num = new_desc.num_blocks() for idx in range(block_num): + if (idx > (len(self.blocks) - 1)): + self._create_block() new_block_desc = new_desc.block(idx) all_new_vars.append([]) block_new_vars = all_new_vars[-1] diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/CMakeLists.txt b/python/paddle/fluid/tests/unittests/dygraph_to_static/CMakeLists.txt index 1687b277ab5b5206486b7347e89279447eab2068..e1611d524ab8ea9eba88f46c6dcba04e800d401b 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/CMakeLists.txt @@ -57,7 +57,7 @@ set_tests_properties(test_se_resnet PROPERTIES TIMEOUT 900) set_tests_properties(test_yolov3 PROPERTIES TIMEOUT 900 LABELS "RUN_TYPE=EXCLUSIVE") set_tests_properties(test_mobile_net PROPERTIES TIMEOUT 120) -set_tests_properties(test_seq2seq PROPERTIES TIMEOUT 120) +set_tests_properties(test_seq2seq PROPERTIES TIMEOUT 150) set_tests_properties(test_cycle_gan PROPERTIES TIMEOUT 150) set_tests_properties(test_bert PROPERTIES TIMEOUT 120) set_tests_properties(test_basic_api_transformation PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/test_eager_run_program.py b/python/paddle/fluid/tests/unittests/test_eager_run_program.py index ba4c9a9452c9b11f25b47212f45eb1be08896158..e67e7c12c7df2fc22977333d1be41d8905789e3b 100644 --- a/python/paddle/fluid/tests/unittests/test_eager_run_program.py +++ b/python/paddle/fluid/tests/unittests/test_eager_run_program.py @@ -18,6 +18,8 @@ from paddle import _C_ops, _legacy_C_ops from paddle.fluid.framework import _test_eager_guard, Variable, _in_legacy_dygraph from paddle.fluid import core from paddle.fluid.layers.utils import _hash_with_id +from paddle.fluid.dygraph.base import switch_to_static_graph +from paddle.fluid.executor import _is_enable_standalone_executor, _is_dy2st_enable_standalone_executor import paddle.compat as cpt import unittest @@ -67,6 +69,18 @@ def _create_out(var): return var_base +@switch_to_static_graph +def _add_build_strategy_for(input_program, start_op_index, end_op_index): + compiled_program = paddle.static.CompiledProgram( + core.Graph(input_program.desc, start_op_index, end_op_index), + build_strategy=paddle.static.BuildStrategy()) + compiled_program._compile(core.Scope(), + paddle.framework._current_expected_place()) + ir_graph = paddle.fluid.framework.IrGraph(compiled_program._graph) + builded_program = ir_graph.to_program() + return builded_program + + class TestRunProgram(unittest.TestCase): def test_eager(self): @@ -81,6 +95,13 @@ class TestRunProgram(unittest.TestCase): main_program = paddle.static.default_main_program() program = _append_backward_desc(main_program, [out]) + forward_program = _add_build_strategy_for( + program, 0, + main_program.desc.block(0).op_size()) + backward_program = _add_build_strategy_for( + program, + main_program.desc.block(0).op_size() + 2, + program.desc.block(0).op_size()) paddle.disable_static('cpu') # step 2: call run_program in eager mode @@ -98,9 +119,21 @@ class TestRunProgram(unittest.TestCase): out_t = _create_out(out) scope = core.Scope() - attrs = ('global_block', program.desc.block(0), 'start_op_index', 0, - 'end_op_index', main_program.desc.block(0).op_size(), - 'is_test', False, 'program_id', _hash_with_id(program)) + attrs = [ + 'global_block', + program.desc.block(0), 'start_op_index', 0, 'end_op_index', + main_program.desc.block(0).op_size(), 'is_test', False, + 'program_id', + _hash_with_id(program) + ] + + use_interpretorcore = _is_enable_standalone_executor( + ) and _is_dy2st_enable_standalone_executor() + attrs.extend(('use_interpretorcore', use_interpretorcore)) + if use_interpretorcore: + attrs.extend( + ('forward_global_block', forward_program.desc.block(0), + 'backward_global_block', backward_program.desc.block(0))) _legacy_C_ops.run_program([x_t, y_t], [fake_var], [out_t], [scope], [fake_var], None, *attrs) diff --git a/python/paddle/fluid/tests/unittests/test_run_program_op.py b/python/paddle/fluid/tests/unittests/test_run_program_op.py index 50fd3c01769f4e95cbfab419689c627686169ee0..2ab54a0da72fe8fd0d45d61247a87cda5b62d913 100644 --- a/python/paddle/fluid/tests/unittests/test_run_program_op.py +++ b/python/paddle/fluid/tests/unittests/test_run_program_op.py @@ -26,6 +26,8 @@ from paddle import compat as cpt from paddle.fluid import core, framework, executor from paddle.fluid.layers.utils import _hash_with_id from paddle.fluid.framework import _in_eager_mode_ +from paddle.fluid.executor import _is_enable_standalone_executor, _is_dy2st_enable_standalone_executor +from paddle.fluid.dygraph.base import switch_to_static_graph paddle.enable_static() @@ -41,6 +43,30 @@ def program_scope_guard(): yield +@switch_to_static_graph +def _add_build_strategy_for(input_program, start_op_index, end_op_index): + compiled_program = paddle.static.CompiledProgram( + core.Graph(input_program.desc, start_op_index, end_op_index), + build_strategy=paddle.static.BuildStrategy()) + compiled_program._compile(core.Scope(), + paddle.framework._current_expected_place()) + ir_graph = paddle.fluid.framework.IrGraph(compiled_program._graph) + builded_program = ir_graph.to_program() + return builded_program + + +@switch_to_static_graph +def _build_program_by_desc(program_desc): + prog = framework.Program() + prog.desc = program_desc + prog.blocks = [ + framework.Block(prog, i) + for i in six.moves.range(prog.desc.num_blocks()) + ] + prog._sync_with_cpp() + return prog + + # NOTE: Because RunProgramOp has a special output of type std::vector, # the OpTest cannot be used in RunProgramOp. The variable type cannot be specified # when creating output variables in OpTest, default type is LoDTensor @@ -97,10 +123,22 @@ class RunProgramOpTest(unittest.TestCase): fwd_op_num = self.build_model() return fluid.default_main_program().desc, fwd_op_num + def get_forward_backward_program_desc(self, whole_program_desc, + forward_op_num, output_num): + program = _build_program_by_desc(whole_program_desc) + forward_program = _add_build_strategy_for(program, 0, forward_op_num) + backward_program = _add_build_strategy_for( + program, forward_op_num + 2 * output_num, + program.desc.block(0).op_size()) + return forward_program.desc, backward_program.desc + def prepare_attrs(self): - return ('global_block', self.program_desc.block(0), 'start_op_index', 0, - 'end_op_index', self.fwd_op_num, 'program_id', - _hash_with_id(self.program_desc, self)) + return [ + 'global_block', + self.program_desc.block(0), 'start_op_index', 0, 'end_op_index', + self.fwd_op_num, 'program_id', + _hash_with_id(self.program_desc, self) + ] def get_param_grad_names(self): grad_names = [] @@ -200,9 +238,21 @@ class RunProgramOpTest(unittest.TestCase): inputs = self.prepare_dygraph_input(place) outputs = self.prepare_dygraph_output() + forward_program_desc, backward_program_desc = self.get_forward_backward_program_desc( + self.program_desc, self.fwd_op_num, len(outputs['Out'])) + + use_interpretorcore = _is_enable_standalone_executor( + ) and _is_dy2st_enable_standalone_executor() + self.attrs.extend(('use_interpretorcore', use_interpretorcore)) + if use_interpretorcore: + self.attrs.extend( + ('forward_global_block', forward_program_desc.block(0), + 'backward_global_block', backward_program_desc.block(0))) + _legacy_C_ops.run_program(inputs['X'], inputs['Params'], outputs['Out'], outputs['OutScope'], outputs['DOut'], None, *self.attrs) + return outputs['Out'] def calc_dygraph_grad(self, place): @@ -214,6 +264,17 @@ class RunProgramOpTest(unittest.TestCase): inputs, input_param_list = self.prepare_dygraph_input(place, True) outputs = self.prepare_dygraph_output() + forward_program_desc, backward_program_desc = self.get_forward_backward_program_desc( + self.program_desc, self.fwd_op_num, len(outputs['Out'])) + + use_interpretorcore = _is_enable_standalone_executor( + ) and _is_dy2st_enable_standalone_executor() + self.attrs.extend(('use_interpretorcore', use_interpretorcore)) + if use_interpretorcore: + self.attrs.extend( + ('forward_global_block', forward_program_desc.block(0), + 'backward_global_block', backward_program_desc.block(0))) + _legacy_C_ops.run_program(inputs['X'], inputs['Params'], outputs['Out'], outputs['OutScope'], outputs['DOut'], None, *self.attrs)