未验证 提交 aba1295b 编写于 作者: Z zhangbo9674 提交者: GitHub

[new_exe] Dy2Static support new_executor (#44450)

* add interpretercore

* refine backward program id

* add code

* refine program

* refine code

* create forward/backward_program by prog2graph2prog method

* test, do not care

* refine code

* refine code

* refine code

* test, do not care

* add interpretorcore

* add scope

* refine scope create method

* add jit for new_exe

* solve conflict

* delete unused code

* polish code

* polish code

* refine scope in inplace

* refine for datatransfer

* refine _rebuild_from_desc

* refine control eager deletion attr

* refine used_for_jit

* refine jit for infer

* op size0 use ori program

* polish code

* refine jit

* refine run_program_op ut

* refine inplace

* refine control

* refine graph helper

* refine control

* refine inplace

* refine buffer_share_inplace_pass

* polish code

* polish code

* refine usage for compilerProgram

* refine control

* test

* test core cache

* refine code

* refine io.py

* increase test_seq2seq timeout

* refine convert program

* refine interpretercore_cache release

* delete buildinplace

* refine partial_program && io

* refine code for io

* test

* test

* test
上级 02621079
......@@ -38,6 +38,22 @@ static void clear_no_grad_edges(
}
}
static void clear_no_grad_edges_with_partial_block(
const std::vector<paddle::experimental::Tensor>& params,
const paddle::framework::BlockDesc* forward_block_desc,
const paddle::framework::BlockDesc* backward_block_desc,
egr::GradNodeBase* grad_node,
size_t slot_id) {
for (size_t i = 0; i < params.size(); ++i) {
auto p_grad_name = paddle::framework::GradVarName(params[i].name());
if (!forward_block_desc->HasVar(p_grad_name) &&
!backward_block_desc->HasVar(p_grad_name)) {
VLOG(1) << "clear edge of " << p_grad_name;
grad_node->MutableOutputMeta()[slot_id][i].GetMutableEdge().Clear();
}
}
}
inline void run_program_dygraph_function(
const std::vector<paddle::experimental::Tensor>& x,
const std::vector<paddle::experimental::Tensor>& params,
......@@ -85,9 +101,26 @@ inline void run_program_dygraph_function(
// Set Grad out rank as same as fwd input and set stop gradient to bwd
grad_node->SetGradOutMeta(x, /*slot id*/ 0);
grad_node->SetGradOutMeta(params, /*slot id*/ 1);
auto* global_block = PADDLE_GET_CONST(paddle::framework::BlockDesc*,
attrs.at("global_block"));
clear_no_grad_edges(params, global_block, grad_node.get(), /*slot id*/ 1);
bool use_interpretorcore =
PADDLE_GET_CONST(bool, attrs.at("use_interpretorcore"));
VLOG(2) << "clear_no_grad_edges.";
if (use_interpretorcore) {
auto* forward_global_block = PADDLE_GET_CONST(
paddle::framework::BlockDesc*, attrs.at("forward_global_block"));
auto* backward_global_block = PADDLE_GET_CONST(
paddle::framework::BlockDesc*, attrs.at("backward_global_block"));
clear_no_grad_edges_with_partial_block(params,
forward_global_block,
backward_global_block,
grad_node.get(),
/*slot id*/ 1);
} else {
auto* global_block = PADDLE_GET_CONST(paddle::framework::BlockDesc*,
attrs.at("global_block"));
clear_no_grad_edges(params, global_block, grad_node.get(), /*slot id*/ 1);
}
grad_node->SetGradInMeta(deref_out, 0);
......
......@@ -1004,7 +1004,7 @@ cc_library(
cc_library(
executor_cache
SRCS executor_cache.cc
DEPS parallel_executor)
DEPS parallel_executor standalone_executor)
if(WITH_PSCORE)
get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS)
if(WITH_HETERPS)
......
......@@ -137,6 +137,58 @@ void ParseSafeEagerDeletionSkipVars(
VLOG(3) << "Found skip_eager_delete_vars: " << skip_eager_delete_vars->size();
}
void AppendSkipDeletionVars(const std::vector<std::string> &append_vars,
std::set<std::string> *all_vars) {
for (auto &var : append_vars) {
all_vars->insert(var);
}
}
std::set<std::string> ParseSafeEagerDeletionSkipVarsSet(
const ProgramDesc &backward_program) {
std::set<std::string> skip_eager_delete_vars;
auto backward_ops = backward_program.Block(0).AllOps();
auto &op_info_map = OpInfoMap::Instance();
std::unordered_set<std::string> op_outputs;
std::unordered_set<std::string> op_inputs;
std::unordered_set<std::string> no_need_buffer_ins;
for (size_t i = 0; i < backward_ops.size(); ++i) {
framework::OpDesc *op = backward_ops[i];
if (op->Type() == "share_buffer") {
VLOG(1) << "skip share_buffer op";
continue;
}
// NOTE: skip NoNeedBufferVars of grad_op and GC its memory in advance.
auto &op_info = op_info_map.Get(op->Type());
auto &inferer = op_info.NoNeedBufferVarsInferer();
no_need_buffer_ins.clear();
if (inferer != nullptr) {
no_need_buffer_ins =
inferer(op->Inputs(), op->Outputs(), op->GetAttrMap());
}
for (auto &in_names : op->Inputs()) {
if (no_need_buffer_ins.count(in_names.first) == 0) {
for (auto &in_name : in_names.second) {
op_inputs.emplace(in_name);
}
} else {
VLOG(2) << op->Type() << " has no_need_buffer_in: " << in_names.first
<< " , skip it.";
}
}
for (const std::string &out_arg_name : op->OutputArgumentNames()) {
op_outputs.emplace(out_arg_name);
}
}
for (const std::string &var_name : op_inputs) {
if (op_outputs.find(var_name) == op_outputs.end()) {
VLOG(1) << "skip eager var: " << var_name;
skip_eager_delete_vars.insert(var_name);
}
}
VLOG(1) << "Found skip_eager_delete_vars: " << skip_eager_delete_vars.size();
return skip_eager_delete_vars;
}
} // namespace details
// C++11 removes the need for manual locking. Concurrent execution shall wait if
......@@ -225,5 +277,33 @@ CacheInfo GetExecutorInfoFromCache(const ProgramDesc &program_desc,
}
}
InterpreterCoreInfoCache &InterpreterCoreInfoCache::Instance() {
static InterpreterCoreInfoCache g_info_cache;
return g_info_cache;
}
std::shared_ptr<InterpreterCore> CreateInterpreterCoreInfoToCache(
const ProgramDesc &program_desc,
const platform::Place &place,
bool is_grad,
int64_t program_id,
framework::Scope *scope) {
auto &interpretercore_info_cache =
framework::InterpreterCoreInfoCache::Instance();
if (interpretercore_info_cache.Size() > 4u /* max_cached_size*/) {
interpretercore_info_cache.Finalize();
}
auto core = std::make_shared<InterpreterCore>(
place,
program_desc.Block(0),
/*skip_gc_vars=*/std::set<std::string>(),
scope,
/*used_for_jit=*/true);
auto &cached_value =
interpretercore_info_cache.GetMutable(program_id, is_grad);
cached_value.core_ = core;
return core;
}
} // namespace framework
} // namespace paddle
......@@ -23,6 +23,7 @@
#include <utility>
#include <vector>
#include "paddle/fluid/framework/new_executor/interpretercore.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/parallel_executor.h"
#include "paddle/fluid/framework/program_desc.h"
......@@ -45,6 +46,12 @@ void ParseSafeEagerDeletionSkipVars(
const std::vector<std::string>& output_var_names,
std::vector<std::string>* skip_eager_delete_vars);
void AppendSkipDeletionVars(const std::vector<std::string>& append_vars,
std::set<std::string>* all_vars);
std::set<std::string> ParseSafeEagerDeletionSkipVarsSet(
const ProgramDesc& backward_program);
} // namespace details
class ExecutorInfo {
......@@ -147,5 +154,73 @@ PEAndGraphPair CreateFixOrderExecutorInfo(const ProgramDesc& program_desc,
int64_t end_op_index,
framework::Scope* scope);
class InterpreterCoreInfo {
public:
struct CacheValue {
std::shared_ptr<InterpreterCore> core_{nullptr};
std::set<std::string> skip_eager_delete_vars_;
};
bool IsAvailable(bool is_grad) {
const auto& core = is_grad ? backward_info_.core_ : forward_info_.core_;
return core != nullptr;
}
CacheValue& GetMutable(bool is_grad) {
return is_grad ? backward_info_ : forward_info_;
}
private:
CacheValue forward_info_;
CacheValue backward_info_;
};
class InterpreterCoreInfoCache {
public:
static InterpreterCoreInfoCache& Instance();
bool Has(int64_t program_id, bool is_grad) {
return info_map_.find(program_id) != info_map_.end() &&
info_map_[program_id].IsAvailable(is_grad);
}
InterpreterCoreInfo::CacheValue& GetMutable(int64_t program_id,
bool is_grad) {
return info_map_[program_id].GetMutable(is_grad);
}
void UpdateSkipEagerDeleteVars(int64_t program_id,
bool is_grad,
const std::set<std::string>& skip_vars) {
auto& cached_value = GetMutable(program_id, is_grad);
cached_value.skip_eager_delete_vars_ = std::move(skip_vars);
}
std::set<std::string>& GetSkipEagerDeleteVars(int64_t program_id,
bool is_grad) {
auto& cached_value = GetMutable(program_id, is_grad);
return cached_value.skip_eager_delete_vars_;
}
size_t Size() const { return info_map_.size(); }
void Finalize() {
// NOTE(Aurelius84): DO NOT perform finalize in destructor
// to avoid problems caused by destructor order of static
// object.
info_map_.clear();
}
private:
std::unordered_map<int64_t, InterpreterCoreInfo> info_map_;
};
std::shared_ptr<InterpreterCore> CreateInterpreterCoreInfoToCache(
const ProgramDesc& program_desc,
const platform::Place& place,
bool is_grad,
int64_t program_id,
framework::Scope* scope);
} // namespace framework
} // namespace paddle
......@@ -202,23 +202,34 @@ static std::vector<std::unique_ptr<OperatorBase>> CreateOpsFromBlock(
}
std::vector<std::vector<std::vector<std::string>>> GetEagerDeletionCleanVars(
const ProgramDesc &origin_program,
const std::vector<std::string> &skip_vars) {
const ProgramDesc &program, const std::vector<std::string> &skip_vars) {
return GetEagerDeletionCleanVarsForPartial(program, skip_vars, false);
}
std::vector<std::vector<std::vector<std::string>>>
GetEagerDeletionCleanVarsForPartial(const ProgramDesc &origin_program,
const std::vector<std::string> &skip_vars,
const bool &for_partial_block) {
ProgramDesc program{origin_program};
size_t block_num = program.Size();
PADDLE_ENFORCE_GE(block_num,
1,
platform::errors::PermissionDenied(
"Program should have at least one block"));
// prepare safe GCs on sub block ops
auto global_block_ops = CreateOpsFromBlock(program.Block(0));
operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
program, 0, global_block_ops);
operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
program, 0, global_block_ops);
operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
program, 0, global_block_ops);
// Note(zhangbo): For dygraph2static inplace policy, origin_program is a
// partial program(only include forward or backward), and control flow op's
// attr skip_eager_deletion_vars has been updated at graph->program before
// calling this function.
if (!for_partial_block) {
// prepare safe GCs on sub block ops
auto global_block_ops = CreateOpsFromBlock(program.Block(0));
operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
program, 0, global_block_ops);
operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
program, 0, global_block_ops);
operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
program, 0, global_block_ops);
}
// find the skip vars on each block
std::vector<std::vector<std::string>> skip_vars_on_each_block(block_num);
......
......@@ -71,5 +71,11 @@ void DeleteUnusedTensors(
std::vector<std::vector<std::vector<std::string>>> GetEagerDeletionCleanVars(
const ProgramDesc &program, const std::vector<std::string> &skip_vars = {});
std::vector<std::vector<std::vector<std::string>>>
GetEagerDeletionCleanVarsForPartial(
const ProgramDesc &program,
const std::vector<std::string> &skip_vars = {},
const bool &for_partial_block = false);
} // namespace framework
} // namespace paddle
......@@ -497,8 +497,38 @@ static OpDesc *ReplaceScaleLossGradOp(const Node &node, OpDesc *desc) {
return desc;
}
void UpdateControlOpSkipEagerDeletionVars(const Node &node,
const Graph &graph,
const size_t graph_idx,
const std::string &control_type) {
// Node(zhangbo): SkipEagerDeletionVars pass policy for control flow class op:
// 1) if op is in main_block: SkipEagerDeletionVars information will be
// writted into Graph OpNode which wrapped by OpHandleBase; 2) if op is in
// sub_block: SkipEagerDeletionVars information will be writted into graph's
// OriginProgram OpDesc. Please refer to
// FindAllConditionalBlockAndConditionalBlockGradOp in
// "paddle/fluid/operators/controlflow/conditional_block_op_helper.cc"
if (graph_idx != 0) {
auto origin_program = graph.OriginProgram();
auto &block = origin_program.Block(graph_idx);
for (size_t j = 0; j < block.OpSize(); ++j) {
auto *op = block.Op(j);
if (op->Type() == control_type &&
op->HasAttr("skip_eager_deletion_vars")) {
if (op->InputArgumentNames() == node.Op()->InputArgumentNames() &&
op->OutputArgumentNames() == node.Op()->OutputArgumentNames()) {
node.Op()->SetAttr("skip_eager_deletion_vars",
op->GetAttr("skip_eager_deletion_vars"));
}
}
}
}
}
static void GetGraphOpDesc(const std::vector<Node *> &nodes,
std::vector<OpDesc> *ops) {
std::vector<OpDesc> *ops,
const Graph &graph,
const size_t graph_idx) {
auto is_fused_opt = [](Node *n) -> bool {
auto op_type = n->Op()->Type();
auto is_opt =
......@@ -524,7 +554,6 @@ static void GetGraphOpDesc(const std::vector<Node *> &nodes,
ReplaceScaleLossGradOp(*n, &desc);
} else if (n->Op()) {
VLOG(4) << "convert op node to desc " << n->Op()->Type();
VLOG(4) << n->ToString();
if (is_fused_opt(n)) {
OpDesc depend_desc(n->Op()->Block());
......@@ -543,7 +572,15 @@ static void GetGraphOpDesc(const std::vector<Node *> &nodes,
ops->emplace_back(depend_desc);
VLOG(4) << "add depend op";
}
if (n->Name() == "while" || n->Name() == "while_grad" ||
n->Name() == "conditional_block" ||
n->Name() == "conditional_block_grad" || n->Name() == "recurrent" ||
n->Name() == "recurrent_grad") {
VLOG(1) << "Update control op attr: skip_eager_deletion_vars";
UpdateControlOpSkipEagerDeletionVars(*n, graph, graph_idx, n->Name());
}
ops->emplace_back(*n->Op());
VLOG(4) << n->ToString();
}
// delete no OpDesc op
}
......@@ -563,7 +600,8 @@ static void GetGraphVarDesc(const Graph &graph,
static void GraphToBlock(const Graph &graph,
proto::BlockDesc *block,
const SortKind *sort_kind) {
const SortKind *sort_kind,
const size_t graph_idx) {
// Remove the unneeded variables after memory optimization.
std::unordered_set<std::string> vars2remove;
if (graph.Has(kGraphToProgramVarsToRemove)) {
......@@ -607,7 +645,7 @@ static void GraphToBlock(const Graph &graph,
}
std::vector<OpDesc> ops;
GetGraphOpDesc(nodes, &ops);
GetGraphOpDesc(nodes, &ops, graph, graph_idx);
for (auto &op : ops) {
RemoveControlDepInputAndOuput(&op);
......@@ -633,7 +671,10 @@ void GraphToProgram(const Graph &graph,
block->set_idx(kRootBlockIndex);
if (FLAGS_convert_all_blocks) {
GraphToBlock(*graph.GetSubGraph(kRootBlockIndex), block, sort_kind);
GraphToBlock(*graph.GetSubGraph(kRootBlockIndex),
block,
sort_kind,
graph.GetSubGraph(kRootBlockIndex)->GetBlockId());
VLOG(3) << "Graph to program need convert " << graph.SubGraphsSize()
<< " sub graph";
......@@ -644,10 +685,13 @@ void GraphToProgram(const Graph &graph,
block = program_pb.add_blocks();
block->set_idx(idx);
block->set_parent_idx(kRootBlockIndex);
GraphToBlock(*graph.GetSubGraph(idx), block, sort_kind);
GraphToBlock(*graph.GetSubGraph(idx),
block,
sort_kind,
graph.GetSubGraph(idx)->GetBlockId());
}
} else {
GraphToBlock(graph, block, sort_kind);
GraphToBlock(graph, block, sort_kind, graph.GetBlockId());
}
program->CopyFrom(program_pb);
......
......@@ -167,14 +167,15 @@ static std::string GetFirstVarName(const OpDesc &op,
static std::vector<std::vector<std::pair<std::string, std::string>>>
GetInplaceVars(const BlockDesc &block,
bool use_cuda,
const std::vector<std::string> &skip_vars) {
const std::vector<std::string> &skip_vars,
const bool &for_partial_block) {
PADDLE_ENFORCE_EQ(
block.ID(),
0,
platform::errors::Unimplemented("Inplace can only perform in block 0."));
// only take block 0 gc_vars
const auto op_gc_vars =
GetEagerDeletionCleanVars(*block.Program(), skip_vars)[0];
const auto op_gc_vars = GetEagerDeletionCleanVarsForPartial(
*block.Program(), skip_vars, for_partial_block)[0];
const auto all_ops = block.AllOps();
PADDLE_ENFORCE_EQ(op_gc_vars.size(),
all_ops.size(),
......@@ -267,9 +268,14 @@ void BufferSharedInplaceOpPass::ApplyImpl(ProgramDesc *main_program,
ProgramDesc *startup_program) const {
bool use_cuda = Get<bool>(kUseCuda);
auto skip_vars = Get<std::vector<std::string>>("mem_opt_skip_vars");
bool for_partial_block = false;
if (Has("for_partial_block")) {
for_partial_block = Get<bool>("for_partial_block");
}
auto *block = main_program->MutableBlock(0);
auto inplace_vars = GetInplaceVars(*block, use_cuda, skip_vars);
auto inplace_vars =
GetInplaceVars(*block, use_cuda, skip_vars, for_partial_block);
PADDLE_ENFORCE_EQ(inplace_vars.size(),
block->OpSize(),
platform::errors::PermissionDenied(
......
......@@ -75,6 +75,22 @@ class ConditionalOpEagerDeletionPass : public Pass {
operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
graph->OriginProgram(), ifelse_ops, ifelse_grad_ops);
}
for (auto op_hander : all_ops) {
auto *compute_op =
dynamic_cast<details::ComputationOpHandle *>(op_hander);
if (compute_op == nullptr) continue;
if (compute_op->Name() == "conditional_block" ||
compute_op->Name() == "conditional_block_grad") {
ir::Node *op_node = op_hander->Node();
auto *op_base = compute_op->GetOp();
if (op_base->Attrs().count("skip_eager_deletion_vars")) {
op_node->Op()->SetAttr(
"skip_eager_deletion_vars",
op_base->Attrs().at("skip_eager_deletion_vars"));
}
}
}
}
};
......
......@@ -43,6 +43,21 @@ void RecurrentOpEagerDeletionPass::ApplyImpl(Graph *graph) const {
PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
graph->OriginProgram(), &op_pair);
}
auto all_ops = ir::FilterByNodeWrapper<details::OpHandleBase>(*graph);
for (auto op_hander : all_ops) {
auto *compute_op = dynamic_cast<details::ComputationOpHandle *>(op_hander);
if (compute_op == nullptr) continue;
if (compute_op->Name() == "recurrent" ||
compute_op->Name() == "recurrent_grad") {
ir::Node *op_node = op_hander->Node();
auto *op_base = compute_op->GetOp();
if (op_base->Attrs().count("skip_eager_deletion_vars")) {
op_node->Op()->SetAttr("skip_eager_deletion_vars",
op_base->Attrs().at("skip_eager_deletion_vars"));
}
}
}
}
// Returns a std::unordered_map mapping from the device id to recurrent op and
......
......@@ -87,6 +87,21 @@ class WhileOpEagerDeletionPass : public ir::Pass {
operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
graph->OriginProgram(), while_ops, while_grad_ops);
}
for (auto op_hander : all_ops) {
auto *compute_op =
dynamic_cast<details::ComputationOpHandle *>(op_hander);
if (compute_op == nullptr) continue;
if (compute_op->Name() == "while" || compute_op->Name() == "while_grad") {
ir::Node *op_node = op_hander->Node();
auto *op_base = compute_op->GetOp();
if (op_base->Attrs().count("skip_eager_deletion_vars")) {
op_node->Op()->SetAttr(
"skip_eager_deletion_vars",
op_base->Attrs().at("skip_eager_deletion_vars"));
}
}
}
}
};
......
......@@ -245,6 +245,8 @@ std::shared_ptr<OperatorBase> TransferLayout(const std::string& var_name,
VLOG(3) << "Create Variable " << *new_var_name
<< " locally, which pointer is " << ptr << "Variable Type "
<< var_type;
var_scope->MutableDataTransferAddedVars().push_back(
std::make_pair(*new_var_name, var_type));
var_scope->AddVar(*new_var_name, nullptr);
// 2. Construct VariableNameMap
......@@ -288,10 +290,11 @@ std::shared_ptr<OperatorBase> TransferDtype(const std::string& var_name,
auto* ptr = local_scope->Var(*new_var_name);
auto var_type = local_scope->FindVar(var_name)->Type();
InitializeVariable(ptr, static_cast<proto::VarType::Type>(var_type));
VLOG(3) << "Create Variable " << *new_var_name
<< " locally, which pointer is " << ptr << "Variable Type "
<< var_type;
var_scope->MutableDataTransferAddedVars().push_back(
std::make_pair(*new_var_name, var_type));
var_scope->AddVar(*new_var_name, nullptr);
// 2. Construct VariableNameMap
......@@ -328,7 +331,7 @@ std::shared_ptr<OperatorBase> TransferDevice(const std::string& var_name,
*new_var_name = var_name + "_device_" + src_place.DebugString() + "_" +
dst_place.DebugString();
if (local_scope->FindVar(*new_var_name) &&
if (var_scope->HasVar(*new_var_name) &&
IsTensorOfVarInitialized(local_scope->FindVar(*new_var_name))) {
// already has same var
VLOG(4) << "Use cached variable: " << *new_var_name;
......@@ -341,6 +344,8 @@ std::shared_ptr<OperatorBase> TransferDevice(const std::string& var_name,
VLOG(3) << "Create Variable " << *new_var_name
<< " locally, which pointer is " << ptr << "Variable Type "
<< var_type;
var_scope->MutableDataTransferAddedVars().push_back(
std::make_pair(*new_var_name, var_type));
var_scope->AddVar(*new_var_name, nullptr);
// 2. Construct VariableNameMap
......
......@@ -53,12 +53,14 @@ static constexpr size_t kDeviceNumThreads = 1;
InterpreterCore::InterpreterCore(const platform::Place& place,
const BlockDesc& block,
const std::set<std::string>& skip_gc_vars,
framework::Scope* scope)
framework::Scope* scope,
bool used_for_jit)
: place_(place),
block_(block),
skip_gc_vars_(skip_gc_vars),
var_scope_(scope),
stream_analyzer_(place) {
stream_analyzer_(place),
used_for_jit_(used_for_jit) {
VLOG(4) << "InterpreterCore(): " << this << " on " << place_;
is_build_ = false;
......@@ -67,6 +69,10 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
completion_notifier_ = main_thread_blocker_.RegisterEvent(kTaskCompletion);
create_local_scope_ = FLAGS_new_executor_use_local_scope;
if (used_for_jit_) {
create_local_scope_ = false;
}
VLOG(4) << "create_local_scope_ is " << create_local_scope_;
if (create_local_scope_) {
......@@ -85,7 +91,6 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
InterpreterCore::~InterpreterCore() {
// cancle gc's thread
gc_.reset(nullptr);
async_work_queue_.reset();
VLOG(4) << "~InterpreterCore(): " << this << " on " << place_;
......@@ -184,7 +189,8 @@ paddle::framework::FetchList InterpreterCore::Run(
platform::AttachPointerHashToMKLDNNKey(this, place_);
#endif
if (!is_build_) {
paddle::framework::interpreter::build_variable_scope(block_, &var_scope_);
paddle::framework::interpreter::build_variable_scope(
block_, &var_scope_, create_local_scope_);
std::vector<paddle::framework::OpFuncNode> op_func_nodes;
paddle::framework::interpreter::build_op_func_list(place_,
......@@ -192,12 +198,12 @@ paddle::framework::FetchList InterpreterCore::Run(
skip_gc_vars_,
&op_func_nodes,
&var_scope_,
create_local_scope_);
create_local_scope_,
used_for_jit_);
is_build_ = true;
SetFeedVarsInplaceSkip(feed_names);
// convert vec func_list to graph
Convert(&op_func_nodes);
} else {
// For the program that only run once, it is no need to
// create work_queue, so the async_work_queue_ is created
......@@ -219,7 +225,9 @@ paddle::framework::FetchList InterpreterCore::Run(
ClearLoDTensorArrayInLocalScope();
}
// return Fetch Tensors
auto* fetch_var = local_scope_->FindVar(interpreter::kFetchVarName);
Scope* inner_scope =
create_local_scope_ ? local_scope_ : var_scope_.GetMutableScope();
auto* fetch_var = inner_scope->FindVar(interpreter::kFetchVarName);
if (fetch_var) {
return std::move(*fetch_var->GetMutable<framework::FetchList>());
} else {
......@@ -231,6 +239,31 @@ void InterpreterCore::SetCopyProgram(std::shared_ptr<ProgramDesc> prog) {
copy_program_ = prog;
}
void InterpreterCore::SetSkipGcVars(const std::set<std::string>& skip_gc_vars) {
PADDLE_ENFORCE_EQ(
skip_gc_vars_.empty(),
true,
platform::errors::PreconditionNotMet(
"Skip_gc_vars_ can only be initialized once, now skip_gc_vars_ is "
"not empty, do not call SetSkipGcVars method repeatedly."));
skip_gc_vars_ = skip_gc_vars;
}
const VariableScope* InterpreterCore::GetVariableScope() const {
return &var_scope_;
}
void InterpreterCore::reset_scope(Scope* new_scope) {
var_scope_.SetScope(new_scope);
auto& var_list = var_scope_.MutableVarList();
for (size_t i = 0; i < var_list.size(); i++) {
var_list[i] = new_scope->FindVar(var_scope_.GetNameById(i));
}
for (size_t i = 0; i < vec_instruction_.size(); ++i) {
BuildAndCacheInstructionCtx(&vec_instruction_[i]);
}
}
void InterpreterCore::ShareWorkQueueFrom(std::shared_ptr<InterpreterCore> src) {
async_work_queue_ = src->GetWorkQueue();
VLOG(8) << "Share AsyncWorkQueue from InterpreterCore(" << &src
......@@ -262,14 +295,15 @@ std::shared_ptr<interpreter::AsyncWorkQueue> InterpreterCore::GetWorkQueue() {
}
void InterpreterCore::BuildAndCacheInstructionCtx(Instruction* instr_node) {
Scope* inner_scope =
create_local_scope_ ? local_scope_ : var_scope_.GetMutableScope();
VariableValueMap ins_map;
for (auto& var_name_item : instr_node->Inputs()) {
std::vector<Variable*> input_vars;
input_vars.reserve(var_name_item.second.size());
for (auto& id : var_name_item.second) {
input_vars.emplace_back(
local_scope_->FindVar(var_scope_.GetNameById(id)));
input_vars.emplace_back(inner_scope->FindVar(var_scope_.GetNameById(id)));
}
ins_map.emplace(var_name_item.first, std::move(input_vars));
}
......@@ -280,7 +314,7 @@ void InterpreterCore::BuildAndCacheInstructionCtx(Instruction* instr_node) {
out_vars.reserve(var_name_item.second.size());
for (auto& id : var_name_item.second) {
out_vars.emplace_back(local_scope_->FindVar(var_scope_.GetNameById(id)));
out_vars.emplace_back(inner_scope->FindVar(var_scope_.GetNameById(id)));
}
outs_map.emplace(var_name_item.first, std::move(out_vars));
}
......@@ -319,6 +353,9 @@ void InterpreterCore::BuildInplace() {
}
}
Scope* local_scope = create_local_scope_ ? var_scope_.GetMutableLocalScope()
: var_scope_.GetMutableScope();
for (size_t i = 0; i < vec_instruction_.size(); ++i) {
auto& instr = vec_instruction_[i];
auto* op_base = instr.OpBase();
......@@ -348,8 +385,8 @@ void InterpreterCore::BuildInplace() {
var_scope_.GetNameById(iter->second[0]);
const std::string& outvar_name =
var_scope_.GetNameById(iterout->second[0]);
auto invar = local_scope_->FindVar(invar_name);
auto outvar = local_scope_->FindVar(outvar_name);
auto invar = local_scope->FindVar(invar_name);
auto outvar = local_scope->FindVar(outvar_name);
if (invar && outvar && invar->IsType<LoDTensor>() &&
outvar->IsType<LoDTensor>() &&
......@@ -410,15 +447,12 @@ void InterpreterCore::Convert(
auto* dev_ctx_ = stream_analyzer_.ParseDeviceContext(op_func_node);
vec_instruction_.emplace_back(op_idx, std::move(op_func_node), *dev_ctx_);
}
BuildOperatorDependences();
// calculate last_live_ops_
for (size_t op_idx = 0; op_idx < op_nums; ++op_idx) {
auto& instr = vec_instruction_[op_idx];
OpInOutInfo info;
std::set<size_t> gc_check_inputs;
for (auto& item : instr.Inputs()) {
for (auto id : item.second) {
if (id == kEmptyVarIndex) {
......@@ -439,10 +473,11 @@ void InterpreterCore::Convert(
}
}
}
for (auto var_id : gc_check_inputs) {
Scope* inner_scope =
create_local_scope_ ? local_scope_ : var_scope_.GetMutableScope();
paddle::framework::Variable* var =
local_scope_->FindVar(var_scope_.GetNameById(var_id));
inner_scope->FindVar(var_scope_.GetNameById(var_id));
if (var->IsType<LoDTensor>() || var->IsType<phi::SelectedRows>() ||
var->IsType<LoDTensorArray>()) {
last_live_ops_[var_id].insert(op_idx);
......@@ -453,7 +488,6 @@ void InterpreterCore::Convert(
}
}
}
for (size_t i = 0; i < vec_instruction_.size(); ++i) {
// checkout output
for (auto& item : vec_instruction_[i].Outputs()) {
......@@ -464,7 +498,6 @@ void InterpreterCore::Convert(
}
}
}
// clear the last_live_ops list for all vars in skip_gc_vars
for (const std::string& skip_gc_var : skip_gc_vars_) {
int var_id = var_scope_.GetIdByName(skip_gc_var);
......@@ -561,7 +594,6 @@ void InterpreterCore::BuildSkipShareLoDInfo() {
void InterpreterCore::RunInstruction(const Instruction& instr_node) {
auto* op = instr_node.OpBase();
auto place = instr_node.DeviceContext().GetPlace();
VLOG(4) << "Start run " << place << " " << op->DebugStringEx(local_scope_);
Scope* local_scope = create_local_scope_ ? var_scope_.GetMutableLocalScope()
: var_scope_.GetMutableScope();
......@@ -602,7 +634,6 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
*(instr_node.InnerRuntimeContext()));
}
}
if (op_with_kernel != nullptr && FLAGS_new_executor_use_inplace) {
// TODO(xiongkun03) Does operator base support inplace ?
for (auto& pair : instr_node.InplaceInfo()) {
......@@ -1009,7 +1040,6 @@ void InterpreterCore::Prepare(
"but received %d != %d",
feed_names.size(),
feed_tensors.size()));
auto FeedInput = [&] {
VLOG(4) << "Feed inputs";
for (size_t i = 0; i < feed_names.size(); ++i) {
......@@ -1035,7 +1065,8 @@ void InterpreterCore::Prepare(
skip_gc_vars_,
&op_func_nodes,
&var_scope_,
create_local_scope_);
create_local_scope_,
used_for_jit_);
is_build_ = true;
SetFeedVarsInplaceSkip(feed_names);
// convert vec func_list to graph
......
......@@ -41,7 +41,8 @@ class InterpreterCore {
InterpreterCore(const platform::Place& place,
const BlockDesc& block,
const std::set<std::string>& skip_gc_vars,
Scope* scope);
Scope* scope,
bool used_for_jit = false);
~InterpreterCore();
......@@ -59,6 +60,12 @@ class InterpreterCore {
void SetCopyProgram(std::shared_ptr<ProgramDesc> prog);
void SetSkipGcVars(const std::set<std::string>& skip_gc_vars);
const VariableScope* GetVariableScope() const;
void reset_scope(Scope* new_scope);
private:
bool BuildInplaceCheckVarIsOnlyInput(size_t var_index);
......@@ -103,9 +110,9 @@ class InterpreterCore {
bool is_build_;
const platform::Place& place_;
platform::Place place_;
const BlockDesc& block_; // not owned
const std::set<std::string> skip_gc_vars_;
std::set<std::string> skip_gc_vars_;
interpreter::DependencyBuilder dependency_builder_;
......@@ -144,6 +151,8 @@ class InterpreterCore {
std::future<std::unique_ptr<AtomicVectorSizeT>> atomic_deps_;
std::future<std::unique_ptr<AtomicVectorSizeT>> atomic_var_ref_;
bool used_for_jit_{false};
};
std::shared_ptr<InterpreterCore> CreateInterpreterCore(
......
......@@ -387,10 +387,8 @@ void deal_operator_base(const platform::Place& place,
PADDLE_THROW(
platform::errors::Fatal("Unsupported current place %s", place));
}
op_func_node->kernel_func_ = nullptr;
op_base->Run(*local_scope, place); // Run without data transformer.
std::unordered_set<int> no_data_transform_index;
for (auto& it : op_func_node->input_index) {
for (auto& id : it.second) {
......@@ -407,7 +405,8 @@ void build_op_func_list(const platform::Place& place,
const std::set<std::string>& skip_gc_vars,
std::vector<OpFuncNode>* vec_func_list,
VariableScope* var_scope,
bool use_local_scope) {
bool use_local_scope,
bool used_for_jit) {
Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope()
: var_scope->GetMutableScope();
std::vector<std::unique_ptr<OperatorBase>>
......@@ -415,19 +414,21 @@ void build_op_func_list(const platform::Place& place,
bool flag_log_is_printed = false;
// Step 1: create all ops for current block.
create_all_ops(block, &ops_unique);
// If gc is enabled and block size > 1
const ProgramDesc& main_program = *block.Program();
operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
main_program, block.ID(), ops_unique);
operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
main_program, block.ID(), ops_unique);
operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
main_program, block.ID(), ops_unique);
if (!used_for_jit) {
// If gc is enabled and block size > 1
const ProgramDesc& main_program = *block.Program();
operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
main_program, block.ID(), ops_unique);
operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
main_program, block.ID(), ops_unique);
operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
main_program, block.ID(), ops_unique);
}
#ifdef PADDLE_WITH_MKLDNN
platform::RegisterModelLayout(ops_unique, place);
#endif
// its elements will be moved to vec_func_list
std::vector<std::shared_ptr<OperatorBase>> ops;
for (auto& op_unique : ops_unique) {
......@@ -484,157 +485,187 @@ void build_op_func_list(const platform::Place& place,
}
#endif
if (dynamic_cast<framework::OperatorWithKernel*>(op) == nullptr) {
// op is not a operatorwithkernel, so direcly run OperatorBase::Run()
deal_operator_base(place, var_scope, ops[i], &op_func_node, local_scope);
} else {
auto op_with_kernel = const_cast<framework::OperatorWithKernel*>(
static_cast<const framework::OperatorWithKernel*>(op));
// construct RuntimeContext and analysis KernelType
RuntimeContext runtime_context({}, {});
runtime_context.inputs.swap(ins_map);
runtime_context.outputs.swap(outs_map);
Scope scope, *runtime_scope = &scope;
// NOTE(Ruibiao): We do not encourage directly using scope in OP kernel.
// But some OPs do have such behavior (e.g., cinn_launch OP). Here special
// treatment for them.
if (op_with_kernel->Type() == "cinn_launch") {
VLOG(6) << "OP(" << op_with_kernel->Type()
<< ") use scope in kernel, "
"so pass a real scope to "
"ExecutionContext";
runtime_scope = local_scope;
}
try {
if (dynamic_cast<framework::OperatorWithKernel*>(op) == nullptr) {
// op is not a operatorwithkernel, so direcly run OperatorBase::Run()
deal_operator_base(
place, var_scope, ops[i], &op_func_node, local_scope);
VLOG(4) << "deal_operator_base";
} else {
VLOG(4) << "OP is not null";
auto op_with_kernel = const_cast<framework::OperatorWithKernel*>(
static_cast<const framework::OperatorWithKernel*>(op));
VLOG(4) << "get op_with_kernel";
// construct RuntimeContext and analysis KernelType
RuntimeContext runtime_context({}, {});
runtime_context.inputs.swap(ins_map);
runtime_context.outputs.swap(outs_map);
VLOG(4) << "get RuntimeContext";
Scope scope, *runtime_scope = &scope;
// NOTE(Ruibiao): We do not encourage directly using scope in OP kernel.
// But some OPs do have such behavior (e.g., cinn_launch OP). Here
// special treatment for them.
if (op_with_kernel->Type() == "cinn_launch") {
VLOG(6) << "OP(" << op_with_kernel->Type()
<< ") use scope in kernel, "
"so pass a real scope to "
"ExecutionContext";
runtime_scope = local_scope;
}
auto& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
auto exec_ctx = ExecutionContext(
*op_with_kernel, *runtime_scope, *dev_ctx, runtime_context);
auto expected_kernel_key =
op_with_kernel->GetExpectedKernelType(exec_ctx);
// change device by the device_guard()
apply_device_guard(op, place, &expected_kernel_key);
VLOG(4) << "expected_kernel_key : " << expected_kernel_key;
// step 2. select op kernel
auto run_phi_kernel = false;
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(
op_with_kernel->Type())) {
auto phi_kernel_key = op_with_kernel->ChoosePhiKernel(exec_ctx);
auto phi_kernel_name = op_with_kernel->PhiKernelSignature()->name;
if (op_with_kernel->PhiKernel()->IsValid()) {
run_phi_kernel = true;
} else {
if (!op_with_kernel->SupportsKernelType(expected_kernel_key)) {
auto phi_cpu_kernel_key = FallBackToCpu(
expected_kernel_key, phi_kernel_key, *op_with_kernel);
op_with_kernel->ResetPhiKernel(
new phi::Kernel(phi::KernelFactory::Instance().SelectKernel(
phi_kernel_name, phi_cpu_kernel_key)));
if (op_with_kernel->PhiKernel()->IsValid()) {
VLOG(6) << "Static mode PrepareImpl - kernel name: "
<< phi_kernel_name
<< " | kernel key: " << phi_cpu_kernel_key
<< " | kernel: " << *(op_with_kernel->PhiKernel());
op_with_kernel->ResetKernelType(new OpKernelType(
TransPhiKernelKeyToOpKernelType(phi_cpu_kernel_key)));
run_phi_kernel = true;
auto& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
VLOG(4) << "get dev_ctx";
auto exec_ctx = ExecutionContext(
*op_with_kernel, *runtime_scope, *dev_ctx, runtime_context);
VLOG(4) << "get exec_ctx";
auto expected_kernel_key =
op_with_kernel->GetExpectedKernelType(exec_ctx);
VLOG(4) << "get expected_kernel_key";
// change device by the device_guard()
apply_device_guard(op, place, &expected_kernel_key);
VLOG(4) << "expected_kernel_key : " << expected_kernel_key;
// step 2. select op kernel
auto run_phi_kernel = false;
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(
op_with_kernel->Type())) {
auto phi_kernel_key = op_with_kernel->ChoosePhiKernel(exec_ctx);
auto phi_kernel_name = op_with_kernel->PhiKernelSignature()->name;
if (op_with_kernel->PhiKernel()->IsValid()) {
run_phi_kernel = true;
} else {
if (!op_with_kernel->SupportsKernelType(expected_kernel_key)) {
auto phi_cpu_kernel_key = FallBackToCpu(
expected_kernel_key, phi_kernel_key, *op_with_kernel);
op_with_kernel->ResetPhiKernel(
new phi::Kernel(phi::KernelFactory::Instance().SelectKernel(
phi_kernel_name, phi_cpu_kernel_key)));
if (op_with_kernel->PhiKernel()->IsValid()) {
VLOG(6) << "Static mode PrepareImpl - kernel name: "
<< phi_kernel_name
<< " | kernel key: " << phi_cpu_kernel_key
<< " | kernel: " << *(op_with_kernel->PhiKernel());
op_with_kernel->ResetKernelType(new OpKernelType(
TransPhiKernelKeyToOpKernelType(phi_cpu_kernel_key)));
run_phi_kernel = true;
}
}
}
}
}
if (!run_phi_kernel) {
op_with_kernel->ChooseKernel(exec_ctx);
op_func_node.kernel_func_ = *op_with_kernel->kernel_func();
} else {
op_func_node.phi_kernel_ = op_with_kernel->PhiKernel();
}
auto kernel_type = *(op_with_kernel->kernel_type());
if (kernel_type.place_ != dev_ctx->GetPlace()) {
dev_ctx = pool.Get(kernel_type.place_);
}
op_func_node.dev_ctx_ = dev_ctx;
if (IsSupportedHetePlace(kernel_type.place_)) {
op_func_node.type_ = OpFuncType::kQueueAsync;
} else if (platform::is_cpu_place(kernel_type.place_)) {
op_func_node.type_ = OpFuncType::kQueueSync;
} else {
PADDLE_THROW(platform::errors::Fatal("Unsupported current place %s",
kernel_type.place_));
}
VLOG(3) << op_with_kernel->Type()
<< " : finally selected kernel_key: " << kernel_type;
// step 3. data transform
VariableValueMap& ins_map_temp = runtime_context.inputs;
VariableValueMap& outs_map_temp = runtime_context.outputs;
ApplyDataTransform(kernel_type,
place,
&ins_map_temp,
&outs_map_temp,
var_scope,
&op_func_node,
vec_func_list,
use_local_scope);
// step 4. infershape, see OperatorWithKernel::RunImpl in operator.cc for
// why.
if (!(op->HasAttr(kAllKernelsMustComputeRuntimeShape) &&
op->Attr<bool>(kAllKernelsMustComputeRuntimeShape))) {
InterpretercoreInferShapeContext infer_shape_ctx(*op, runtime_context);
// TODO(Aurelius84): In case of control flow ops, they are NOT
// inheritted from OperatorWithKernel.
op_with_kernel->Info().infer_shape_(&infer_shape_ctx);
}
VLOG(4) << "if run phi kernel? : " << run_phi_kernel;
if (!run_phi_kernel) {
op_with_kernel->ChooseKernel(exec_ctx);
op_func_node.kernel_func_ = *op_with_kernel->kernel_func();
} else {
op_func_node.phi_kernel_ = op_with_kernel->PhiKernel();
}
auto kernel_type = *(op_with_kernel->kernel_type());
if (kernel_type.place_ != dev_ctx->GetPlace()) {
dev_ctx = pool.Get(kernel_type.place_);
}
op_func_node.dev_ctx_ = dev_ctx;
if (IsSupportedHetePlace(kernel_type.place_)) {
op_func_node.type_ = OpFuncType::kQueueAsync;
} else if (platform::is_cpu_place(kernel_type.place_)) {
op_func_node.type_ = OpFuncType::kQueueSync;
} else {
PADDLE_THROW(platform::errors::Fatal("Unsupported current place %s",
kernel_type.place_));
}
VLOG(3) << op_with_kernel->Type()
<< " : finally selected kernel_key: " << kernel_type;
// step 3. data transform
VariableValueMap& ins_map_temp = runtime_context.inputs;
VariableValueMap& outs_map_temp = runtime_context.outputs;
ApplyDataTransform(kernel_type,
place,
&ins_map_temp,
&outs_map_temp,
var_scope,
&op_func_node,
vec_func_list,
use_local_scope);
VLOG(4) << "apply data transform done. ";
// step 4. infershape, see OperatorWithKernel::RunImpl in operator.cc
// for why.
if (!(op->HasAttr(kAllKernelsMustComputeRuntimeShape) &&
op->Attr<bool>(kAllKernelsMustComputeRuntimeShape))) {
InterpretercoreInferShapeContext infer_shape_ctx(*op,
runtime_context);
// TODO(Aurelius84): In case of control flow ops, they are NOT
// inheritted from OperatorWithKernel.
op_with_kernel->Info().infer_shape_(&infer_shape_ctx);
}
// step 5. run kernel
if (run_phi_kernel) {
phi::KernelContext phi_kernel_context;
op_with_kernel->BuildPhiKernelContext(
runtime_context, dev_ctx, &phi_kernel_context);
(*op_func_node.phi_kernel_)(&phi_kernel_context);
} else {
// the place of exec_ctx maybe has changed.
op_func_node.kernel_func_(ExecutionContext(
*op_with_kernel, *runtime_scope, *dev_ctx, runtime_context));
}
// step 5. run kernel
if (run_phi_kernel) {
VLOG(1) << "start run phi kernel. ";
phi::KernelContext phi_kernel_context;
op_with_kernel->BuildPhiKernelContext(
runtime_context, dev_ctx, &phi_kernel_context);
(*op_func_node.phi_kernel_)(&phi_kernel_context);
VLOG(1) << "end run phi kernel. ";
} else {
VLOG(4) << "start run kernel. ";
// the place of exec_ctx maybe has changed.
op_func_node.kernel_func_(ExecutionContext(
*op_with_kernel, *runtime_scope, *dev_ctx, runtime_context));
VLOG(4) << "end run kernel. ";
}
// post-process grad_op.outputs if need cast complex grad into real
// grad.
// NOTE(Aurelius84): insert a transfer_dtype_op inplacely to cast it.
if (framework::IsComplexType(kernel_type.data_type_)) {
interpreter::HandleComplexGradToRealGrad(op_func_node,
place,
outputs_names,
&runtime_context.outputs,
var_scope,
vec_func_list,
local_scope);
}
if (!op_func_node.inplace_back_map.empty()) {
auto& m = op_func_node.inplace_back_map;
// NOTE(zhiqiu): same logic as TransferInplaceVarsBack() in
// operator.cc
for (auto& p : m) {
auto* transformed_tensor =
GetMutableLoDTensorOrSelectedRowsValueFromVar(
local_scope->FindVar(var_scope->GetNameById(p.first)));
auto* original_tensor = GetMutableLoDTensorOrSelectedRowsValueFromVar(
local_scope->FindVar(var_scope->GetNameById(p.second)));
original_tensor->ShareDataWith(*transformed_tensor);
VLOG(4) << "Transfer inplace variable back form "
<< var_scope->GetNameById(p.first) << " to "
<< var_scope->GetNameById(p.second);
// post-process grad_op.outputs if need cast complex grad into real
// grad.
// NOTE(Aurelius84): insert a transfer_dtype_op inplacely to cast it.
if (framework::IsComplexType(kernel_type.data_type_)) {
interpreter::HandleComplexGradToRealGrad(op_func_node,
place,
outputs_names,
&runtime_context.outputs,
var_scope,
vec_func_list,
local_scope);
}
if (!op_func_node.inplace_back_map.empty()) {
auto& m = op_func_node.inplace_back_map;
// NOTE(zhiqiu): same logic as TransferInplaceVarsBack() in
// operator.cc
for (auto& p : m) {
auto* transformed_tensor =
GetMutableLoDTensorOrSelectedRowsValueFromVar(
local_scope->FindVar(var_scope->GetNameById(p.first)));
auto* original_tensor =
GetMutableLoDTensorOrSelectedRowsValueFromVar(
local_scope->FindVar(var_scope->GetNameById(p.second)));
original_tensor->ShareDataWith(*transformed_tensor);
VLOG(4) << "Transfer inplace variable back form "
<< var_scope->GetNameById(p.first) << " to "
<< var_scope->GetNameById(p.second);
}
}
}
// for debug nan/inf
if (FLAGS_check_nan_inf) {
VLOG(4) << "Check nan/inf";
framework::details::CheckOpHasNanOrInf(*op, *runtime_scope, place);
// for debug nan/inf
if (FLAGS_check_nan_inf) {
VLOG(4) << "Check nan/inf";
framework::details::CheckOpHasNanOrInf(*op, *runtime_scope, place);
}
}
} catch (platform::EnforceNotMet& ex) {
framework::InsertCallStackInfo(op->Type(), op->Attrs(), &ex);
throw std::move(ex);
} catch (platform::EOFException&) {
std::rethrow_exception(std::current_exception());
} catch (std::exception& ex) {
LOG(WARNING) << op->Type() << " raises an exception "
<< platform::demangle(typeid(ex).name()) << ", "
<< ex.what();
std::rethrow_exception(std::current_exception());
} catch (...) {
LOG(WARNING) << op->Type() << " raises an unknown exception";
std::rethrow_exception(std::current_exception());
}
VLOG(4) << "End run " << place << " "
......@@ -662,20 +693,6 @@ void build_op_func_list(const platform::Place& place,
if (var->IsType<LoDTensor>()) {
garbages->emplace_back(
var->GetMutable<LoDTensor>()->MoveMemoryHolder());
} else if (var->IsType<phi::SelectedRows>()) {
garbages->emplace_back(var->GetMutable<phi::SelectedRows>()
->mutable_value()
->MoveMemoryHolder());
} else if (var->IsType<LoDTensorArray>()) {
auto* lod_tensor_arr = var->GetMutable<LoDTensorArray>();
for (auto& t : *lod_tensor_arr) {
garbages->emplace_back(t.MoveMemoryHolder());
}
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Type %s of variable %s is not supported eager deletion.",
framework::ToTypeName(var->Type()),
var_name));
}
}
delete garbages; // free mem
......
......@@ -80,7 +80,8 @@ void build_op_func_list(const platform::Place& place,
const std::set<std::string>& skip_gc_vars,
std::vector<OpFuncNode>* vec_func_list,
VariableScope* scope,
bool use_local_scope = true);
bool use_local_scope = true,
bool used_for_jit = false);
void add_fetch(const std::vector<std::string>& fetch_names,
framework::BlockDesc* block);
......
......@@ -577,6 +577,8 @@ Scope* VariableScope::GetMutableScope() const { return scope_; }
Scope* VariableScope::GetMutableLocalScope() const { return local_scope_; }
void VariableScope::SetScope(Scope* scope) { scope_ = scope; }
void VariableScope::SetLocalScope(Scope* local_scope) {
VLOG(4) << "Set local scope: " << local_scope;
local_scope_ = local_scope;
......@@ -626,7 +628,11 @@ void VariableScope::AddVar(const std::string& name,
auto id = VarSize();
name2id_[name] = id;
vec_meta_info_.emplace_back(0, var_desc);
var_list_.push_back(local_scope_->FindVar(name));
if (local_scope_ != nullptr) {
var_list_.push_back(local_scope_->FindVar(name));
} else {
var_list_.push_back(scope_->FindVar(name));
}
PADDLE_ENFORCE_EQ(
var_list_.size(),
name2id_.size(),
......@@ -783,6 +789,8 @@ void Instruction::AddInplace(Variable* in, Variable* out) {
vec_inplace_in_to_out_.emplace_back(in, out);
}
void Instruction::ClearInplace() { vec_inplace_in_to_out_.clear(); }
const std::vector<EventInter>& Instruction::InputEvents() const {
return intput_events_;
}
......
......@@ -176,6 +176,8 @@ class VariableScope {
Scope* GetMutableLocalScope() const;
void SetScope(Scope* scope);
void SetLocalScope(Scope* local_scope);
~VariableScope();
......@@ -212,6 +214,17 @@ class VariableScope {
return vec_meta_info_;
}
const std::vector<std::pair<std::string, int>>& DataTransferAddedVars()
const {
return data_transfer_added_vars_;
}
std::vector<std::pair<std::string, int>>& MutableDataTransferAddedVars() {
return data_transfer_added_vars_;
}
std::vector<Variable*>& MutableVarList() { return var_list_; }
void SetVarSikpInplace(const std::string& name, bool skip);
bool GetVarSikpInplace(int id) const;
......@@ -228,6 +241,9 @@ class VariableScope {
// TODO(zhiqiu): find a better way to support local scope.
Scope* local_scope_{nullptr};
// mutable RWLock vars_lock_;
// var_name -> var_type
std::vector<std::pair<std::string, int>> data_transfer_added_vars_;
};
class NextInstruction {
......@@ -340,6 +356,8 @@ class Instruction {
void AddInplace(Variable* in, Variable* out);
void ClearInplace();
const std::vector<EventInter>& InputEvents() const;
const std::vector<EventInter>& OutputEvents() const;
......
......@@ -115,7 +115,10 @@ const Scope* Scope::FindScope(const std::string& name) const {
void Scope::DropKids() {
{
SCOPE_KIDS_WRITER_LOCK
for (Scope* s : kids_) delete s;
for (Scope* s : kids_) {
delete s;
s = nullptr;
}
kids_.clear();
}
}
......
......@@ -119,6 +119,17 @@ class RunProgramOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int64_t>("cuda_graph_pool_id",
"(int64_t, default 0) The CUDA Graph memory pool ID.")
.SetDefault(0);
AddAttr<bool>("use_interpretorcore",
"(bool, default false) Set to true for use interpretercore.")
.SetDefault(false);
AddAttr<BlockDesc*>("forward_global_block",
"(BlockDesc *)"
"The global block of executed forward program desc.")
.SetDefault(nullptr);
AddAttr<BlockDesc*>("backward_global_block",
"(BlockDesc *)"
"The global block of executed backward program desc.")
.SetDefault(nullptr);
AddComment(R"DOC(
RunProgram operator.
......
......@@ -67,6 +67,7 @@ void BindGraph(py::module *m) {
"The graph is a Directed Acyclic Single Static Assignment Graph, see "
"`paddle::ir::Graph` for details.")
.def(py::init<const ProgramDesc &>())
.def(py::init<const ProgramDesc &, int64_t, int64_t>())
.def("clone", &Graph::Clone)
.def("has", &Graph::Has)
.def("get_bool", &Graph::Get<bool>)
......
......@@ -2147,8 +2147,14 @@ All parameter, weight, gradient are variables in Paddle.
m.def("set_cudnn_switch", platform::SetAllowTF32Cudnn);
m.def("get_cudnn_switch", platform::AllowTF32Cudnn);
#endif // PADDLE_WITH_CUDA
m.def("clear_executor_cache",
[]() { framework::ExecutorInfoCache::Instance().Finalize(); });
m.def("clear_executor_cache", []() {
pybind11::gil_scoped_release release;
framework::ExecutorInfoCache::Instance().Finalize();
framework::InterpreterCoreInfoCache::Instance().Finalize();
});
m.def("parse_safe_eager_deletion_skip_vars",
paddle::framework::details::ParseSafeEagerDeletionSkipVarsSet);
#ifdef PADDLE_WITH_IPU
py::class_<platform::ipu::IpuBackend,
......
......@@ -365,30 +365,28 @@ def replace_cuda_graph_section(ins_and_outs, section_program, section_idx,
program_id = _hash_with_id(section_program, ins_and_outs)
# insert the run_program_op into the block
origin_block._insert_op(insert_idx,
type='run_program',
inputs={'X': ins},
outputs={
'Out': outs,
'OutScope': out_scope_var,
'CUDAGraph': cuda_graph_var
},
attrs={
'global_block':
section_program.global_block(),
'start_op_index':
0,
'end_op_index':
len(section_program.global_block().ops),
'is_test':
is_test,
'program_id':
program_id,
'cuda_graph_capture_mode':
mode,
'cuda_graph_pool_id':
memory_pool_id,
})
origin_block._insert_op(
insert_idx,
type='run_program',
inputs={'X': ins},
outputs={
'Out': outs,
'OutScope': out_scope_var,
'CUDAGraph': cuda_graph_var
},
attrs={
'global_block': section_program.global_block(),
'start_op_index': 0,
'end_op_index': len(section_program.global_block().ops),
'is_test': is_test,
'program_id': program_id,
'cuda_graph_capture_mode': mode,
'cuda_graph_pool_id': memory_pool_id,
# Todo: now not support use interpretercore
'use_interpretorcore': False,
'forward_global_block': section_program.global_block(),
'backward_global_block': section_program.global_block(),
})
def cuda_graph_transform(program):
......
......@@ -18,6 +18,7 @@ import six
import paddle
from paddle.fluid import framework, backward, core, program_guard
from paddle.fluid.executor import _is_enable_standalone_executor, _is_dy2st_enable_standalone_executor
from paddle.fluid.dygraph import layers
from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static import logging_utils
......@@ -26,6 +27,7 @@ from paddle.fluid.layers.utils import flatten
from paddle.fluid.layers.utils import pack_sequence_as
from paddle.fluid.layers.utils import _hash_with_id
from paddle.fluid.compiler import BuildStrategy
from paddle.fluid.framework import _apply_pass
from paddle.fluid.contrib.mixed_precision.decorator import AutoMixedPrecisionLists
from paddle.fluid.contrib.mixed_precision.fp16_utils import rewrite_program, cast_model_to_fp16
from paddle.fluid.dygraph.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard
......@@ -175,106 +177,209 @@ class PartialProgramLayer:
def _double_grads(self):
return self._get_double_grads(self._origin_main_program)
@LazyInitialized
def _infer_program(self):
"""
Lazy initialized property of infer_program.
"""
return self._clone_for_test(self._origin_main_program)
# whole
@switch_to_static_graph
def _create_program(self, is_infer_mode=False):
if is_infer_mode:
return self._origin_main_program.clone(for_test=is_infer_mode)
else:
train_program = self._append_backward_desc(
self._origin_main_program)
# Note: Only set grad type once after initializing train program. So we put it here.
self._set_grad_type(self._params, train_program)
return train_program
@LazyInitialized
def _train_program(self):
"""
Lazy initialized property of train_program.
"""
train_program = self._append_backward_desc(self._origin_main_program)
# Note: Only set grad type once after initializing train program. So we
# put it here.
self._set_grad_type(self._params, train_program)
@switch_to_static_graph
def _create_amp_program(self, is_infer_mode=False):
amp_program = self._origin_main_program.clone(for_test=is_infer_mode)
with program_guard(amp_program):
rewrite_program(amp_program, self._amp_list)
if is_infer_mode:
return amp_program
else:
train_amp_program = self._append_backward_desc(amp_program)
self._set_grad_type(self._params, train_amp_program)
return train_amp_program
return train_program
@switch_to_static_graph
def _create_pure_fp16_program(self, is_infer_mode=False):
pure_fp16_program = self._origin_main_program.clone(
for_test=is_infer_mode)
with program_guard(pure_fp16_program):
cast_model_to_fp16(pure_fp16_program,
self._amp_list,
use_fp16_guard=False)
if is_infer_mode:
return pure_fp16_program
else:
train_pure_fp16_program = self._append_backward_desc(
pure_fp16_program)
self._set_grad_type(self._params, train_pure_fp16_program)
return train_pure_fp16_program
@LazyInitialized
@switch_to_static_graph
def _infer_amp_program(self):
"""
Lazy initialized property of infer_amp_program.
"""
infer_amp_program = self._origin_main_program.clone()
with program_guard(infer_amp_program):
rewrite_program(infer_amp_program, self._amp_list)
def _create_forward_backward_train_program(self):
whole_program = self._create_program()
forward_end_op_index = self._infer_program.desc.block(0).op_size()
return self._get_forward_backward_program_form(whole_program,
forward_end_op_index)
return infer_amp_program
@switch_to_static_graph
def _create_forward_backward_train_amp_program(self):
whole_program = self._create_amp_program()
forward_end_op_index = self._infer_amp_program.desc.block(0).op_size()
return self._get_forward_backward_program_form(whole_program,
forward_end_op_index)
@switch_to_static_graph
def _create_forward_backward_train_pure_fp16_program(self):
whole_program = self._create_pure_fp16_program()
forward_end_op_index = self._infer_pure_fp16_program.desc.block(
0).op_size()
return self._get_forward_backward_program_form(whole_program,
forward_end_op_index)
@LazyInitialized
def _train_amp_program(self):
"""
Lazy initialized property of train_amp_program.
"""
train_amp_program = self._append_backward_desc(self._infer_amp_program)
self._set_grad_type(self._params, train_amp_program)
return train_amp_program
def _train_program(self):
return self._create_program()
@LazyInitialized
@switch_to_static_graph
def _infer_pure_fp16_program(self):
"""
Lazy initialized property of _infer_pure_fp16_program.
"""
infer_pure_fp16_program = self._origin_main_program.clone()
with program_guard(infer_pure_fp16_program):
cast_model_to_fp16(infer_pure_fp16_program,
self._amp_list,
use_fp16_guard=False)
def _infer_program(self):
return self._create_program(is_infer_mode=True)
return infer_pure_fp16_program
@LazyInitialized
def _train_amp_program(self):
return self._create_amp_program()
@LazyInitialized
def _infer_amp_program(self):
return self._create_amp_program(is_infer_mode=True)
@LazyInitialized
def _train_pure_fp16_program(self):
"""
Lazy initialized property of _train_pure_fp16_program.
"""
train_pure_fp16_program = self._append_backward_desc(
self._infer_pure_fp16_program)
self._set_grad_type(self._params, train_pure_fp16_program)
return train_pure_fp16_program
return self._create_pure_fp16_program()
@LazyInitialized
def _infer_program_id(self):
return _hash_with_id(self._infer_program, self)
def _infer_pure_fp16_program(self):
return self._create_pure_fp16_program(is_infer_mode=True)
@LazyInitialized
def _infer_pure_fp16_program_id(self):
return _hash_with_id(self._infer_pure_fp16_program, self)
def _train_forward_backward_program(self):
program = self._create_forward_backward_train_program()
return program
@LazyInitialized
def _infer_amp_program_id(self):
return _hash_with_id(self._infer_amp_program, self)
def _train_amp_forward_backward_program(self):
program = self._create_forward_backward_train_amp_program()
return program
@LazyInitialized
def _train_pure_fp16_forward_backward_program(self):
program = self._create_forward_backward_train_pure_fp16_program()
return program
@property
def whole_program(self):
if self.training:
if _in_amp_guard():
return self._train_amp_program
elif _in_pure_fp16_guard():
return self._train_pure_fp16_program
else:
return self._train_program
else:
if _in_amp_guard():
return self._infer_amp_program
elif _in_pure_fp16_guard():
return self._infer_pure_fp16_program
else:
return self._infer_program
@property
def forward_program(self):
if self.training:
if _in_amp_guard():
program = self._train_amp_forward_backward_program
return program[0]
elif _in_pure_fp16_guard():
program = self._train_pure_fp16_forward_backward_program
return program[0]
else:
program = self._train_forward_backward_program
return program[0]
else:
if _in_amp_guard():
return self._infer_amp_program
elif _in_pure_fp16_guard():
return self._infer_pure_fp16_program
else:
return self._infer_program
@property
def backward_program(self):
if self.training:
if _in_amp_guard():
program = self._train_amp_forward_backward_program
return program[1]
elif _in_pure_fp16_guard():
program = self._train_pure_fp16_forward_backward_program
return program[1]
else:
program = self._train_forward_backward_program
return program[1]
else:
return paddle.static.Program()
@LazyInitialized
def _train_program_id(self):
program_id = _hash_with_id(self._train_program, self)
core._set_cached_executor_build_strategy(program_id,
self._build_strategy)
return program_id
@LazyInitialized
def _infer_program_id(self):
return _hash_with_id(self._infer_program, self)
@LazyInitialized
def _train_amp_program_id(self):
program_id = _hash_with_id(self._train_amp_program, self)
core._set_cached_executor_build_strategy(program_id,
self._build_strategy)
return program_id
@LazyInitialized
def _infer_amp_program_id(self):
return _hash_with_id(self._infer_amp_program, self)
@LazyInitialized
def _train_pure_fp16_program_id(self):
program_id = _hash_with_id(self._train_pure_fp16_program, self)
core._set_cached_executor_build_strategy(program_id,
self._build_strategy)
return program_id
@LazyInitialized
def _infer_pure_fp16_program_id(self):
return _hash_with_id(self._infer_pure_fp16_program, self)
@property
def whole_program_id(self):
if self.training:
if _in_amp_guard():
return self._train_amp_program_id
elif _in_pure_fp16_guard():
return self._train_pure_fp16_program_id
else:
return self._train_program_id
else:
if _in_amp_guard():
return self._infer_amp_program_id
elif _in_pure_fp16_guard():
return self._infer_pure_fp16_program_id
else:
return self._infer_program_id
def _verify_program(self, main_program):
"""
Verify that the program parameter is initialized, prune some unused params,
......@@ -429,6 +534,8 @@ class PartialProgramLayer:
def __call__(self, inputs):
in_vars, out_vars = self._prepare(inputs)
self._cast_fp16_if_pure_fp16(in_vars)
attrs = [
'global_block',
self.program.desc.block(0), 'start_op_index', 0, 'end_op_index',
......@@ -440,7 +547,13 @@ class PartialProgramLayer:
('cuda_graph_capture_mode', self._cuda_graph_capture_mode,
'cuda_graph_pool_id', self._cuda_graph_pool_id))
self._cast_fp16_if_pure_fp16(in_vars)
use_interpretorcore = _is_enable_standalone_executor(
) and _is_dy2st_enable_standalone_executor()
attrs.extend(('use_interpretorcore', use_interpretorcore))
if use_interpretorcore:
attrs.extend(
('forward_global_block', self.forward_program.desc.block(0),
'backward_global_block', self.backward_program.desc.block(0)))
_legacy_C_ops.run_program(self._valid_vars(in_vars),
self._valid_vars(self._params),
......@@ -459,30 +572,24 @@ class PartialProgramLayer:
== paddle.float16):
in_vars[i] = var.astype('float16')
in_vars[i].name = name
if (self.forward_program.global_block().has_var(name)
and self.forward_program.global_block().var(name).dtype
== paddle.float16):
in_vars[i] = var.astype('float16')
in_vars[i].name = name
if (self.backward_program.global_block().has_var(name)
and self.backward_program.global_block().var(name).dtype
== paddle.float16):
in_vars[i] = var.astype('float16')
in_vars[i].name = name
@property
def program(self):
if self.training:
return self.train_program
else:
return self.infer_program
return self.whole_program
@property
def program_id(self):
if self.training:
if _in_amp_guard():
return self._train_amp_program_id
elif _in_pure_fp16_guard():
return self._train_pure_fp16_program_id
else:
return self._train_program_id
else:
if _in_amp_guard():
return self._infer_amp_program_id
elif _in_pure_fp16_guard():
return self._infer_pure_fp16_program_id
else:
return self._infer_program_id
return self.whole_program_id
@property
def train_program(self):
......@@ -502,6 +609,64 @@ class PartialProgramLayer:
else:
return self._infer_program
@switch_to_static_graph
def _get_forward_backward_program_form(self, whole_program,
forward_end_op_index):
forward_builded_program = add_build_strategy_for(
whole_program, 0, forward_end_op_index, self._build_strategy)
backward_start_op_index = forward_end_op_index + 2 * len(
self._outputs.var_ids)
backward_end_op_index = whole_program.desc.block(0).op_size()
backward_builded_program = add_build_strategy_for(
whole_program, backward_start_op_index, backward_end_op_index,
self._build_strategy)
self._apply_inplace_pass(forward_builded_program,
backward_builded_program)
return [forward_builded_program, backward_builded_program]
def _apply_inplace_pass(self, forward_program, backward_program):
attr_types = {
"use_cuda": "bool",
"mem_opt_skip_vars": "list[str]",
"for_partial_block": "bool"
}
empty_startup_program = paddle.static.Program()
use_cuda = True if core.is_compiled_with_cuda() else False
# skip data var
forward_mem_opt_skip_vars = []
backward_mem_opt_skip_vars = []
for var_name, var in forward_program.global_block().vars.items():
if var.is_data:
forward_mem_opt_skip_vars.append(var_name)
for var_name, var in backward_program.global_block().vars.items():
if var.is_data:
backward_mem_opt_skip_vars.append(var_name)
for var in self._inputs:
if isinstance(var, paddle.fluid.framework.Variable):
forward_mem_opt_skip_vars.append(var.desc.name())
backward_mem_opt_skip_vars.append(var.desc.name())
for var in self._outputs:
if isinstance(var, paddle.fluid.framework.Variable):
forward_mem_opt_skip_vars.append(var.desc.name())
backward_mem_opt_skip_vars.append(var.desc.name())
for var_name in core.parse_safe_eager_deletion_skip_vars(
backward_program.desc):
forward_mem_opt_skip_vars.append(var_name)
attrs = {
"use_cuda": use_cuda,
"mem_opt_skip_vars": forward_mem_opt_skip_vars,
"for_partial_block": True
}
_apply_pass(forward_program, empty_startup_program,
"buffer_shared_inplace_pass", attrs, attr_types)
attrs = {
"use_cuda": use_cuda,
"mem_opt_skip_vars": backward_mem_opt_skip_vars,
"for_partial_block": True
}
_apply_pass(backward_program, empty_startup_program,
"buffer_shared_inplace_pass", attrs, attr_types)
def _prepare(self, inputs):
"""
Prepare inputs, outputs, attrs.
......@@ -739,3 +904,23 @@ def partial_program_from(concrete_program):
concrete_program.outputs,
concrete_program.parameters,
**concrete_program.kwargs)
@switch_to_static_graph
def add_build_strategy_for(program,
start_op_index,
end_op_index,
build_strategy=None):
if (start_op_index < end_op_index):
compiled_program = paddle.static.CompiledProgram(
core.Graph(program.desc, start_op_index, end_op_index),
build_strategy=build_strategy)
compiled_program._compile(core.Scope(),
framework._current_expected_place())
ir_graph = framework.IrGraph(compiled_program._graph)
builded_program = ir_graph.to_program()
if hasattr(compiled_program._program, 'lr_sheduler'):
builded_program.lr_sheduler = compiled_program._program.lr_sheduler
else:
builded_program = program
return builded_program
......@@ -30,6 +30,8 @@ from paddle.fluid.layers import nn
from paddle.fluid.layers.utils import _hash_with_id
from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.framework import _non_static_mode
from paddle.fluid.executor import _is_enable_standalone_executor, _is_dy2st_enable_standalone_executor
from paddle.fluid.dygraph.dygraph_to_static.partial_program import add_build_strategy_for, LazyInitialized
from paddle import _C_ops, _legacy_C_ops
__all__ = ['TranslatedLayer']
......@@ -333,6 +335,37 @@ class _ProgramHolder(object):
self._train_program_desc = self._append_backward_desc(
self._infer_program_desc)
# forward:
@switch_to_static_graph
def _create_forward_train_program(self):
whole_program = _build_program_by_desc(self._train_program_desc)
end_op_index = self._infer_program_desc.block(0).op_size()
if end_op_index > 0:
return add_build_strategy_for(whole_program, 0, end_op_index)
else:
return whole_program
@LazyInitialized
def _forward_program_desc(self):
return self._create_forward_train_program().desc
# backward
@switch_to_static_graph
def _create_backward_train_program(self):
whole_program = _build_program_by_desc(self._train_program_desc)
start_op_index = self._infer_program_desc.block(0).op_size() + 2 * len(
self._output_descs)
end_op_index = whole_program.desc.block(0).op_size()
if (start_op_index < end_op_index):
return add_build_strategy_for(whole_program, start_op_index,
end_op_index)
else:
return paddle.static.Program()
@LazyInitialized
def _backward_program_desc(self):
return self._create_backward_train_program().desc
@property
def infer_program(self):
return self._infer_program_desc
......@@ -341,6 +374,14 @@ class _ProgramHolder(object):
def train_program(self):
return self._train_program_desc
@property
def forward_program(self):
return self._forward_program_desc
@property
def backward_program(self):
return self._backward_program_desc
@property
def input_descs(self):
return self._input_descs
......@@ -460,7 +501,7 @@ class _ProgramHolder(object):
self._output_descs[i] = var.desc
@switch_to_static_graph
def _append_backward_desc(self, infer_program_desc):
def _get_train_forward_program(self, infer_program_desc):
program_desc_copy = core.ProgramDesc(infer_program_desc)
# 1. set all `is_test` attributes to False
......@@ -488,6 +529,11 @@ class _ProgramHolder(object):
persistable=False,
stop_gradient=True)
op.desc.set_output("ReserveSpace", [reserve_space.name])
return program
@switch_to_static_graph
def _append_backward_desc(self, infer_program_desc):
program = self._get_train_forward_program(infer_program_desc)
targets = []
for out in self._output_descs:
......@@ -861,14 +907,29 @@ def _run_dygraph(instance, input, program_holder):
# 2. run program by op
trace_program = program_holder.infer_program if instance._is_test else program_holder.train_program
forward_program = program_holder._infer_program_desc if instance._is_test else program_holder.forward_program
end_op_index = program_holder.infer_program.block(0).op_size()
attrs = ('global_block', trace_program.block(0), 'start_op_index', 0,
'end_op_index', end_op_index, 'is_test', instance._is_test,
'program_id', _hash_with_id(trace_program, instance))
attrs = [
'global_block',
trace_program.block(0), 'start_op_index', 0, 'end_op_index',
end_op_index, 'is_test', instance._is_test, 'program_id',
_hash_with_id(trace_program, instance)
]
use_interpretorcore = _is_enable_standalone_executor(
) and _is_dy2st_enable_standalone_executor()
attrs.extend(('use_interpretorcore', use_interpretorcore))
if use_interpretorcore:
attrs.extend(
('forward_global_block', forward_program.block(0),
'backward_global_block', program_holder.backward_program.block(0)))
_legacy_C_ops.run_program(_valid_vars(input_vars),
_valid_vars(persistable_vars),
_valid_vars(output_vars), tmp_scope_vec,
_valid_vars(double_grad_vars), None, *attrs)
# NOTE: [ why need set param's gradient type here ]
# if user set sparse gradient mode, the param's gradient
# will be SelectedRows, not LoDTensor. But tracer will just
......
......@@ -399,6 +399,12 @@ def _is_enable_standalone_executor():
]
def _is_dy2st_enable_standalone_executor():
return framework._dy2st_enable_standalone_executor_ in [
1, '1', True, 'True', 'true'
]
def _prepare_fleet_executor():
from ..distributed.fleet.proto import fleet_executor_desc_pb2
trainer_endpoints_str = os.getenv("PADDLE_TRAINER_ENDPOINTS", "")
......
......@@ -86,6 +86,8 @@ _current_cuda_graph_mode = None
_global_flags_ = core.globals()
_enable_standalone_executor_ = (os.environ.get('FLAGS_USE_STANDALONE_EXECUTOR',
None))
_dy2st_enable_standalone_executor_ = (os.environ.get(
'FLAGS_DY2ST_USE_STANDALONE_EXECUTOR', 1))
# Some explanation of our execution system 2022.03
# For now we have 3 kinds of execution system, since we refactored dygraph mode to
......@@ -5040,6 +5042,8 @@ class Program(object):
all_new_vars = []
block_num = new_desc.num_blocks()
for idx in range(block_num):
if (idx > (len(self.blocks) - 1)):
self._create_block()
new_block_desc = new_desc.block(idx)
all_new_vars.append([])
block_new_vars = all_new_vars[-1]
......
......@@ -57,7 +57,7 @@ set_tests_properties(test_se_resnet PROPERTIES TIMEOUT 900)
set_tests_properties(test_yolov3 PROPERTIES TIMEOUT 900 LABELS
"RUN_TYPE=EXCLUSIVE")
set_tests_properties(test_mobile_net PROPERTIES TIMEOUT 120)
set_tests_properties(test_seq2seq PROPERTIES TIMEOUT 120)
set_tests_properties(test_seq2seq PROPERTIES TIMEOUT 150)
set_tests_properties(test_cycle_gan PROPERTIES TIMEOUT 150)
set_tests_properties(test_bert PROPERTIES TIMEOUT 120)
set_tests_properties(test_basic_api_transformation PROPERTIES TIMEOUT 120)
......
......@@ -18,6 +18,8 @@ from paddle import _C_ops, _legacy_C_ops
from paddle.fluid.framework import _test_eager_guard, Variable, _in_legacy_dygraph
from paddle.fluid import core
from paddle.fluid.layers.utils import _hash_with_id
from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.executor import _is_enable_standalone_executor, _is_dy2st_enable_standalone_executor
import paddle.compat as cpt
import unittest
......@@ -67,6 +69,18 @@ def _create_out(var):
return var_base
@switch_to_static_graph
def _add_build_strategy_for(input_program, start_op_index, end_op_index):
compiled_program = paddle.static.CompiledProgram(
core.Graph(input_program.desc, start_op_index, end_op_index),
build_strategy=paddle.static.BuildStrategy())
compiled_program._compile(core.Scope(),
paddle.framework._current_expected_place())
ir_graph = paddle.fluid.framework.IrGraph(compiled_program._graph)
builded_program = ir_graph.to_program()
return builded_program
class TestRunProgram(unittest.TestCase):
def test_eager(self):
......@@ -81,6 +95,13 @@ class TestRunProgram(unittest.TestCase):
main_program = paddle.static.default_main_program()
program = _append_backward_desc(main_program, [out])
forward_program = _add_build_strategy_for(
program, 0,
main_program.desc.block(0).op_size())
backward_program = _add_build_strategy_for(
program,
main_program.desc.block(0).op_size() + 2,
program.desc.block(0).op_size())
paddle.disable_static('cpu')
# step 2: call run_program in eager mode
......@@ -98,9 +119,21 @@ class TestRunProgram(unittest.TestCase):
out_t = _create_out(out)
scope = core.Scope()
attrs = ('global_block', program.desc.block(0), 'start_op_index', 0,
'end_op_index', main_program.desc.block(0).op_size(),
'is_test', False, 'program_id', _hash_with_id(program))
attrs = [
'global_block',
program.desc.block(0), 'start_op_index', 0, 'end_op_index',
main_program.desc.block(0).op_size(), 'is_test', False,
'program_id',
_hash_with_id(program)
]
use_interpretorcore = _is_enable_standalone_executor(
) and _is_dy2st_enable_standalone_executor()
attrs.extend(('use_interpretorcore', use_interpretorcore))
if use_interpretorcore:
attrs.extend(
('forward_global_block', forward_program.desc.block(0),
'backward_global_block', backward_program.desc.block(0)))
_legacy_C_ops.run_program([x_t, y_t], [fake_var], [out_t], [scope],
[fake_var], None, *attrs)
......
......@@ -26,6 +26,8 @@ from paddle import compat as cpt
from paddle.fluid import core, framework, executor
from paddle.fluid.layers.utils import _hash_with_id
from paddle.fluid.framework import _in_eager_mode_
from paddle.fluid.executor import _is_enable_standalone_executor, _is_dy2st_enable_standalone_executor
from paddle.fluid.dygraph.base import switch_to_static_graph
paddle.enable_static()
......@@ -41,6 +43,30 @@ def program_scope_guard():
yield
@switch_to_static_graph
def _add_build_strategy_for(input_program, start_op_index, end_op_index):
compiled_program = paddle.static.CompiledProgram(
core.Graph(input_program.desc, start_op_index, end_op_index),
build_strategy=paddle.static.BuildStrategy())
compiled_program._compile(core.Scope(),
paddle.framework._current_expected_place())
ir_graph = paddle.fluid.framework.IrGraph(compiled_program._graph)
builded_program = ir_graph.to_program()
return builded_program
@switch_to_static_graph
def _build_program_by_desc(program_desc):
prog = framework.Program()
prog.desc = program_desc
prog.blocks = [
framework.Block(prog, i)
for i in six.moves.range(prog.desc.num_blocks())
]
prog._sync_with_cpp()
return prog
# NOTE: Because RunProgramOp has a special output of type std::vector<Scope *>,
# the OpTest cannot be used in RunProgramOp. The variable type cannot be specified
# when creating output variables in OpTest, default type is LoDTensor
......@@ -97,10 +123,22 @@ class RunProgramOpTest(unittest.TestCase):
fwd_op_num = self.build_model()
return fluid.default_main_program().desc, fwd_op_num
def get_forward_backward_program_desc(self, whole_program_desc,
forward_op_num, output_num):
program = _build_program_by_desc(whole_program_desc)
forward_program = _add_build_strategy_for(program, 0, forward_op_num)
backward_program = _add_build_strategy_for(
program, forward_op_num + 2 * output_num,
program.desc.block(0).op_size())
return forward_program.desc, backward_program.desc
def prepare_attrs(self):
return ('global_block', self.program_desc.block(0), 'start_op_index', 0,
'end_op_index', self.fwd_op_num, 'program_id',
_hash_with_id(self.program_desc, self))
return [
'global_block',
self.program_desc.block(0), 'start_op_index', 0, 'end_op_index',
self.fwd_op_num, 'program_id',
_hash_with_id(self.program_desc, self)
]
def get_param_grad_names(self):
grad_names = []
......@@ -200,9 +238,21 @@ class RunProgramOpTest(unittest.TestCase):
inputs = self.prepare_dygraph_input(place)
outputs = self.prepare_dygraph_output()
forward_program_desc, backward_program_desc = self.get_forward_backward_program_desc(
self.program_desc, self.fwd_op_num, len(outputs['Out']))
use_interpretorcore = _is_enable_standalone_executor(
) and _is_dy2st_enable_standalone_executor()
self.attrs.extend(('use_interpretorcore', use_interpretorcore))
if use_interpretorcore:
self.attrs.extend(
('forward_global_block', forward_program_desc.block(0),
'backward_global_block', backward_program_desc.block(0)))
_legacy_C_ops.run_program(inputs['X'], inputs['Params'],
outputs['Out'], outputs['OutScope'],
outputs['DOut'], None, *self.attrs)
return outputs['Out']
def calc_dygraph_grad(self, place):
......@@ -214,6 +264,17 @@ class RunProgramOpTest(unittest.TestCase):
inputs, input_param_list = self.prepare_dygraph_input(place, True)
outputs = self.prepare_dygraph_output()
forward_program_desc, backward_program_desc = self.get_forward_backward_program_desc(
self.program_desc, self.fwd_op_num, len(outputs['Out']))
use_interpretorcore = _is_enable_standalone_executor(
) and _is_dy2st_enable_standalone_executor()
self.attrs.extend(('use_interpretorcore', use_interpretorcore))
if use_interpretorcore:
self.attrs.extend(
('forward_global_block', forward_program_desc.block(0),
'backward_global_block', backward_program_desc.block(0)))
_legacy_C_ops.run_program(inputs['X'], inputs['Params'],
outputs['Out'], outputs['OutScope'],
outputs['DOut'], None, *self.attrs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册