未验证 提交 2850391d 编写于 作者: A Aurelius84 提交者: GitHub

Upgrade Executor into ParallelExcutor to apply Graph Optimization in @to_static (#32283)

* Refine Constructor logic of ParallelExecutor

* Replace executor into ParallelExecutor in run_program_op
上级 7d9bf244
...@@ -340,7 +340,7 @@ cc_library(parallel_executor SRCS parallel_executor.cc DEPS ...@@ -340,7 +340,7 @@ cc_library(parallel_executor SRCS parallel_executor.cc DEPS
graph build_strategy bind_threaded_ssa_graph_executor collective_helper graph build_strategy bind_threaded_ssa_graph_executor collective_helper
fast_threaded_ssa_graph_executor variable_helper) fast_threaded_ssa_graph_executor variable_helper)
cc_library(executor_cache SRCS executor_cache.cc DEPS executor) cc_library(executor_cache SRCS executor_cache.cc DEPS parallel_executor)
if(WITH_PSCORE) if(WITH_PSCORE)
get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS) get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS)
cc_test(dist_multi_trainer_test SRCS dist_multi_trainer_test.cc DEPS cc_test(dist_multi_trainer_test SRCS dist_multi_trainer_test.cc DEPS
......
...@@ -88,16 +88,12 @@ cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_te ...@@ -88,16 +88,12 @@ cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_te
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto
multi_devices_helper multi_devices_helper
sequential_execution_pass
modify_op_lock_and_record_event_pass
all_reduce_deps_pass
reference_count_pass reference_count_pass
eager_deletion_pass eager_deletion_pass
buffer_shared_inplace_op_pass buffer_shared_inplace_op_pass
buffer_shared_cross_op_memory_reuse_pass buffer_shared_cross_op_memory_reuse_pass
inplace_addto_op_pass inplace_addto_op_pass
set_reader_device_info_utils set_reader_device_info_utils)
add_reader_dependency_pass)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS}) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS})
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
...@@ -132,6 +128,10 @@ set(IR_PASS_DEPS graph_viz_pass multi_devices_graph_pass ...@@ -132,6 +128,10 @@ set(IR_PASS_DEPS graph_viz_pass multi_devices_graph_pass
multi_batch_merge_pass multi_batch_merge_pass
fuse_relu_depthwise_conv_pass fuse_relu_depthwise_conv_pass
lock_free_optimize_pass lock_free_optimize_pass
sequential_execution_pass
all_reduce_deps_pass
add_reader_dependency_pass
modify_op_lock_and_record_event_pass
coalesce_grad_tensor_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass coalesce_grad_tensor_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass
fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass
sync_batch_norm_pass runtime_context_cache_pass) sync_batch_norm_pass runtime_context_cache_pass)
......
...@@ -109,7 +109,8 @@ void EagerDeletionOpHandle::CallOnce() { ...@@ -109,7 +109,8 @@ void EagerDeletionOpHandle::CallOnce() {
std::string EagerDeletionOpHandle::Name() const { return "eager_deletion"; } std::string EagerDeletionOpHandle::Name() const { return "eager_deletion"; }
void EagerDeletionOpHandle::RunImpl() { void EagerDeletionOpHandle::RunImpl() {
if (vars_.size() != var_infos_.size()) { if (vars_.size() != var_infos_.size() || is_variant_scope_) {
vars_.clear();
CallOnce(); CallOnce();
} }
...@@ -119,6 +120,7 @@ void EagerDeletionOpHandle::RunImpl() { ...@@ -119,6 +120,7 @@ void EagerDeletionOpHandle::RunImpl() {
auto *var_info = var_infos_[i]; auto *var_info = var_infos_[i];
if (var_info->IsSkippedAllMemoryOptimization() || if (var_info->IsSkippedAllMemoryOptimization() ||
!var_info->DecreaseRefCnt()) { !var_info->DecreaseRefCnt()) {
VLOG(4) << "skip memory optimization with var: " << var_info->Name();
continue; continue;
} }
......
...@@ -76,7 +76,6 @@ FetchResultType FastThreadedSSAGraphExecutor::Run( ...@@ -76,7 +76,6 @@ FetchResultType FastThreadedSSAGraphExecutor::Run(
std::vector<OpHandleBase *> fetch_ops; std::vector<OpHandleBase *> fetch_ops;
std::vector<OpHandleBase *> ready_fetch_ops; std::vector<OpHandleBase *> ready_fetch_ops;
exception_.Clear(); exception_.Clear();
InsertFetchOps(fetch_tensors, &fetches, &fetched_vars, op_deps.get(), InsertFetchOps(fetch_tensors, &fetches, &fetched_vars, op_deps.get(),
&fetch_ops, &ready_fetch_ops, return_merged); &fetch_ops, &ready_fetch_ops, return_merged);
event.reset(nullptr); event.reset(nullptr);
...@@ -95,6 +94,8 @@ FetchResultType FastThreadedSSAGraphExecutor::Run( ...@@ -95,6 +94,8 @@ FetchResultType FastThreadedSSAGraphExecutor::Run(
traced_ops_.clear(); traced_ops_.clear();
remaining_ = 0; remaining_ = 0;
auto complete_q = std::make_shared<BlockingQueue<size_t>>(); auto complete_q = std::make_shared<BlockingQueue<size_t>>();
VLOG(3) << "number of bootstrap_ops_: " << bootstrap_ops_.size();
VLOG(3) << "number of ready_fetch_ops: " << ready_fetch_ops.size();
for (auto op : bootstrap_ops_) { for (auto op : bootstrap_ops_) {
RunOpAsync(op_deps.get(), op, complete_q); RunOpAsync(op_deps.get(), op, complete_q);
} }
...@@ -247,11 +248,10 @@ void FastThreadedSSAGraphExecutor::RunOpAsync( ...@@ -247,11 +248,10 @@ void FastThreadedSSAGraphExecutor::RunOpAsync(
RunOpAsync(op_deps, post_op, complete_q); RunOpAsync(op_deps, post_op, complete_q);
} }
} }
VLOG(3) << "start to run op: " << op_to_run->Name();
if (!RunOp(op_to_run, complete_q, &complete)) { if (!RunOp(op_to_run, complete_q, &complete)) {
return; return;
} }
auto &outputs = op_to_run->Outputs(); auto &outputs = op_to_run->Outputs();
op_to_run = nullptr; op_to_run = nullptr;
for (auto &output : outputs) { for (auto &output : outputs) {
......
...@@ -136,6 +136,10 @@ class OpHandleBase { ...@@ -136,6 +136,10 @@ class OpHandleBase {
void SetLocalExecScopes( void SetLocalExecScopes(
const std::unordered_map<Scope *, Scope *> &scope_map); const std::unordered_map<Scope *, Scope *> &scope_map);
void SetIsVariantScope(bool is_variant_scope) {
is_variant_scope_ = is_variant_scope;
}
protected: protected:
virtual std::vector<Scope *> GetLocalScopes() = 0; virtual std::vector<Scope *> GetLocalScopes() = 0;
...@@ -156,6 +160,12 @@ class OpHandleBase { ...@@ -156,6 +160,12 @@ class OpHandleBase {
std::vector<Scope *> local_exec_scopes_; std::vector<Scope *> local_exec_scopes_;
bool skip_running_ = false; bool skip_running_ = false;
// NOTE(Aurelius84): Indicate whether scope held in OpHandle is chanageable.
// Ophandle's scope noramlly keep same in most cases, except running
// run_program_op from @to_static.
// The scope may be chanaged while each training iteration.
// See https://github.com/PaddlePaddle/Paddle/pull/32283
bool is_variant_scope_ = false;
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
std::unordered_map<int, gpuEvent_t> events_; std::unordered_map<int, gpuEvent_t> events_;
......
...@@ -41,7 +41,8 @@ static inline const Tensor &GetTensorFromVar(const Variable *var) { ...@@ -41,7 +41,8 @@ static inline const Tensor &GetTensorFromVar(const Variable *var) {
return var->Get<LoDTensor>(); return var->Get<LoDTensor>();
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"Variable must be type of LoDTensor.")); "Variable must be type of LoDTensor, but received %s.",
framework::ToTypeName(var->Type())));
} }
} }
...@@ -50,19 +51,22 @@ static inline Tensor *GetMutableTensorFromVar(Variable *var) { ...@@ -50,19 +51,22 @@ static inline Tensor *GetMutableTensorFromVar(Variable *var) {
return var->GetMutable<LoDTensor>(); return var->GetMutable<LoDTensor>();
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"Variable must be type of LoDTensor.")); "Variable must be type of LoDTensor, but received %s.",
framework::ToTypeName(var->Type())));
} }
} }
ShareTensorBufferFunctor::ShareTensorBufferFunctor( ShareTensorBufferFunctor::ShareTensorBufferFunctor(
Scope *scope, size_t scope_idx, const std::string &op_type, Scope *scope, size_t scope_idx, const std::string &op_type,
const std::vector<const ir::MemOptVarInfo *> &in_var_infos, const std::vector<const ir::MemOptVarInfo *> &in_var_infos,
const std::vector<std::string> &out_var_names, bool share_dims) const std::vector<std::string> &out_var_names, const bool &is_variant_scope,
bool share_dims)
: scope_(scope), : scope_(scope),
scope_idx_(scope_idx), scope_idx_(scope_idx),
op_type_(op_type), op_type_(op_type),
in_var_infos_(in_var_infos), in_var_infos_(in_var_infos),
out_var_names_(out_var_names), out_var_names_(out_var_names),
is_variant_scope_(is_variant_scope),
share_dims_(share_dims) { share_dims_(share_dims) {
PADDLE_ENFORCE_EQ(in_var_infos_.size(), out_var_names_.size(), PADDLE_ENFORCE_EQ(in_var_infos_.size(), out_var_names_.size(),
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
...@@ -126,12 +130,13 @@ void ShareTensorBufferFunctor::CallOnce() { ...@@ -126,12 +130,13 @@ void ShareTensorBufferFunctor::CallOnce() {
} }
void ShareTensorBufferFunctor::operator()(Scope *exec_scope) { void ShareTensorBufferFunctor::operator()(Scope *exec_scope) {
if (!exec_scope_) { if (!exec_scope_ || is_variant_scope_) {
PADDLE_ENFORCE_NOT_NULL(exec_scope, PADDLE_ENFORCE_NOT_NULL(exec_scope,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The given execution scope should not be NULL " "The given execution scope should not be NULL "
"if the cached scope is NULL.")); "if the cached scope is NULL."));
exec_scope_ = exec_scope; exec_scope_ = exec_scope;
in_out_vars_.clear();
CallOnce(); CallOnce();
} else { } else {
PADDLE_ENFORCE_EQ(exec_scope_, exec_scope, PADDLE_ENFORCE_EQ(exec_scope_, exec_scope,
......
...@@ -51,7 +51,8 @@ class ShareTensorBufferFunctor { ...@@ -51,7 +51,8 @@ class ShareTensorBufferFunctor {
ShareTensorBufferFunctor( ShareTensorBufferFunctor(
Scope *scope, size_t scope_idx, const std::string &op_type, Scope *scope, size_t scope_idx, const std::string &op_type,
const std::vector<const ir::MemOptVarInfo *> &in_var_infos, const std::vector<const ir::MemOptVarInfo *> &in_var_infos,
const std::vector<std::string> &out_var_names, bool share_dims = false); const std::vector<std::string> &out_var_names,
const bool &is_variant_scope, bool share_dims = false);
void AddReuseVarPair(const ir::MemOptVarInfo *in_var_info, void AddReuseVarPair(const ir::MemOptVarInfo *in_var_info,
const std::string &out_var_name); const std::string &out_var_name);
...@@ -80,6 +81,9 @@ class ShareTensorBufferFunctor { ...@@ -80,6 +81,9 @@ class ShareTensorBufferFunctor {
std::vector<std::pair<const Variable *, Variable *>> in_out_vars_; std::vector<std::pair<const Variable *, Variable *>> in_out_vars_;
// NOTE(Aurelius84): Use const reference to always keep consistant with
// share_tensor_buffer_op_handle.
const bool &is_variant_scope_;
// NOTE(zhiqiu): In the case of inplace addto, if the operator of // NOTE(zhiqiu): In the case of inplace addto, if the operator of
// the in_out_vars is skipped during running, we should set the dims of output // the in_out_vars is skipped during running, we should set the dims of output
// as the same as input. // as the same as input.
......
...@@ -67,7 +67,7 @@ ShareTensorBufferOpHandle::ShareTensorBufferOpHandle( ...@@ -67,7 +67,7 @@ ShareTensorBufferOpHandle::ShareTensorBufferOpHandle(
const std::vector<std::string> &out_var_names, bool share_dims) const std::vector<std::string> &out_var_names, bool share_dims)
: OpHandleBase(node), : OpHandleBase(node),
functor_(scope, scope_idx, op_type, in_var_infos, out_var_names, functor_(scope, scope_idx, op_type, in_var_infos, out_var_names,
share_dims) {} is_variant_scope_, share_dims) {}
std::unordered_map<std::string, std::string> std::unordered_map<std::string, std::string>
ShareTensorBufferOpHandle::ReusedVars() const { ShareTensorBufferOpHandle::ReusedVars() const {
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class BlockDesc;
class ProgramDesc; class ProgramDesc;
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -26,45 +25,89 @@ namespace framework { ...@@ -26,45 +25,89 @@ namespace framework {
namespace details { namespace details {
static void AppendSkipDeletionVars(const std::vector<std::string> &append_vars, static ExecutionStrategy GetExecutionStrategy(
std::vector<std::string> *all_vars) { const ExecutorInfoCache::CacheKey &cache_key) {
framework::ExecutionStrategy execution_strategy;
switch (cache_key.device_type_) {
case platform::DeviceType::CPU: {
execution_strategy.num_threads_ = 2;
break;
}
case platform::DeviceType::CUDA: {
// NOTE: According experiments, one thread is faster in
// most model training.
execution_strategy.num_threads_ = 1;
break;
}
case platform::DeviceType::XPU: {
execution_strategy.num_threads_ = 1;
break;
}
default:
PADDLE_THROW(platform::errors::Unavailable("Unsupported Device type %d.",
cache_key.device_type_));
}
execution_strategy.use_device_ = cache_key.device_type_;
return execution_strategy;
}
void AppendSkipDeletionVars(const std::vector<std::string> &append_vars,
std::vector<std::string> *all_vars) {
for (auto &var : append_vars) { for (auto &var : append_vars) {
all_vars->emplace_back(var); all_vars->emplace_back(var);
} }
} }
static void AppendSafeEagerDeletionSkipVars( /*
const framework::ProgramDesc &program, * NOTE(Aurelius84): In ParallelExecutor, memory optimized pass will be applied.
std::vector<std::string> *skip_vars) { * To avoid eagerly deleting last alive variables which are necessary in
const framework::BlockDesc &block = program.Block(0); * backward program, we firstly parse these variable names as
const std::vector<framework::OpDesc *> &all_ops = block.AllOps(); * skip_eager_vars. While executing pe.run skip_eager_vars are used to
* skip memory optimization.
std::unordered_set<std::string> grad_op_output; *
std::unordered_set<std::string> grad_op_input; * Variables satisfying the following rules are considered as skip_eager_var:
for (const framework::OpDesc *op : all_ops) { *
int op_role = BOOST_GET_CONST( * 1. it is an output var in run_program_op
int, op->GetAttr(framework::OpProtoAndCheckerMaker::OpRoleAttrName())); * 2. it is an input var used in backward_op
if ((op_role & static_cast<int>(framework::OpRole::kBackward)) == 0) { */
continue; void ParseSafeEagerDeletionSkipVars(
} const ProgramDesc &program, int64_t forward_op_nums,
const std::vector<std::string> &output_var_names,
std::vector<std::string> *skip_eager_delete_vars) {
auto all_ops = program.Block(0).AllOps();
// NOTE: skip `shape` and `fill_constant` op created by
// fluid.backward.gradients, one forward output will generate one `shape`
// and `fill_constant`.
size_t backward_op_start_index =
forward_op_nums + (output_var_names.size() * 2);
// step 2: parse the necessary variable of backward op
std::unordered_set<std::string> op_outputs;
std::unordered_set<std::string> op_inputs;
for (auto i = backward_op_start_index; i < all_ops.size(); ++i) {
framework::OpDesc *op = all_ops[i];
for (const std::string &in_arg_name : op->InputArgumentNames()) { for (const std::string &in_arg_name : op->InputArgumentNames()) {
grad_op_input.emplace(in_arg_name); op_inputs.emplace(in_arg_name);
} }
for (const std::string &out_arg_name : op->OutputArgumentNames()) { for (const std::string &out_arg_name : op->OutputArgumentNames()) {
grad_op_output.emplace(out_arg_name); op_outputs.emplace(out_arg_name);
} }
} }
// For the grad op input variables, if it is not output of grad_op, it may // For the grad op input variables, if it is not output of grad_op, it may
// be output of forward op and we should set the variables as skip_var to // be output of forward op and we should set the variables as skip_var to
// prevent it being deleted when grad op is called multiple times. // prevent it being deleted when grad op is called multiple times.
for (const std::string &var_name : grad_op_input) { for (const std::string &var_name : op_inputs) {
if (grad_op_output.find(var_name) == grad_op_output.end()) { if (op_outputs.find(var_name) == op_outputs.end()) {
skip_vars->emplace_back(var_name); VLOG(2) << "skip eager var: " << var_name;
skip_eager_delete_vars->emplace_back(var_name);
} }
} }
VLOG(3) << "Found skip_eager_delete_vars: " << skip_eager_delete_vars->size();
} }
} // namespace details } // namespace details
// C++11 removes the need for manual locking. Concurrent execution shall wait if // C++11 removes the need for manual locking. Concurrent execution shall wait if
...@@ -75,38 +118,58 @@ ExecutorInfoCache &ExecutorInfoCache::Instance() { ...@@ -75,38 +118,58 @@ ExecutorInfoCache &ExecutorInfoCache::Instance() {
return g_exe_cache_info_map; return g_exe_cache_info_map;
} }
std::shared_ptr<framework::ExecutorPrepareContext> GetExecutorInfoFromCache( void ExecutorInfoCache::Finalize() {
const framework::Executor &exe, const framework::ExecutionContext &ctx, // NOTE(Aurelius84): DO NOT perform finalize in destructor
const std::vector<std::vector<std::string>> &ctx_output_names, // to avoid problems caused by destructor order of static
bool is_grad) { // object.
auto *program = ctx.Attr<BlockDesc *>("global_block")->Program(); info_map_.clear();
}
CacheInfo GetExecutorInfoFromCache(const ExecutorInfoCache::CacheKey &cache_key,
framework::Scope *scope) {
auto &cached_exe_info = framework::ExecutorInfoCache::Instance(); auto &cached_exe_info = framework::ExecutorInfoCache::Instance();
auto cache_key = framework::ExecutorInfoCache::KeyInfo(program, is_grad);
if (!cached_exe_info.Has(cache_key)) { if (!cached_exe_info.Has(cache_key)) {
VLOG(1) << "create exe_info for program: " << program VLOG(1) << "create exe_info for " << cache_key.DebugString();
<< " is_grad: " << is_grad;
// skip delete vars // TODO(Aurelius84): Consider to use LRU algorithm to replace this.
std::vector<std::string> skip_vars; if (cached_exe_info.Size() > 4u /* max_cached_size*/) {
for (auto &output_names : ctx_output_names) { VLOG(2) << "The cached info size has exceeded max_cached_size: 4, clear "
details::AppendSkipDeletionVars(output_names, &skip_vars); "all cache!";
} cached_exe_info.Finalize();
if (is_grad) {
details::AppendSafeEagerDeletionSkipVars(*program, &skip_vars);
} }
VLOG(2) << "Prepare to skip " << skip_vars.size() framework::BuildStrategy build_strategy;
<< " var(s): " << string::join_strings(skip_vars, ' '); auto execution_strategy = details::GetExecutionStrategy(cache_key);
std::shared_ptr<framework::ExecutorPrepareContext> exe_ctx =
std::move(exe.Prepare(*program, /*block_id=*/0, skip_vars)); auto graph = std::make_shared<framework::ir::Graph>(
*cache_key.program_desc_, cache_key.start_op_index_,
cache_key.end_op_index_);
auto parallel_executor = std::make_shared<framework::ParallelExecutor>(
cache_key.place_, scope, execution_strategy, build_strategy,
graph.get());
parallel_executor->PrepareVariables(scope);
framework::ExecutorInfoCache::ValueType cache_val = {parallel_executor,
graph};
cached_exe_info.Insert(cache_key, cache_val);
cached_exe_info.Insert(cache_key, exe_ctx); bool is_new_created = true;
return exe_ctx; return std::make_pair(parallel_executor, is_new_created);
} else { } else {
VLOG(1) << "get exe_info from cache by program: " << program VLOG(1) << "get exe_info from cache by: " << cache_key.DebugString();
<< " is_grad: " << is_grad; bool is_new_created = false;
return cached_exe_info.Get(cache_key); auto cache_val = cached_exe_info.GetMutable(cache_key);
auto parallel_executor = cache_val.first;
// update op_handle scope_map in pe->executor_->Graph
std::unordered_map<Scope *, Scope *> scope_map = {
{parallel_executor->GetLocalScopes().front(), scope}};
parallel_executor->ResetOpHandleScopeMapOfGraphs(scope_map);
// need to recreate tmp variables in new scope
parallel_executor->PrepareVariables(scope);
return std::make_pair(parallel_executor, is_new_created);
} }
} }
......
...@@ -16,38 +16,78 @@ ...@@ -16,38 +16,78 @@
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <sstream>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/parallel_executor.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir {
class Graph;
}
class ExecutionContext; namespace details {
class Executor; void AppendSkipDeletionVars(const std::vector<std::string>& append_vars,
class ProgramDesc; std::vector<std::string>* all_vars);
struct ExecutorPrepareContext;
void ParseSafeEagerDeletionSkipVars(
const ProgramDesc& program, int64_t forward_op_nums,
const std::vector<std::string>& output_var_names,
std::vector<std::string>* skip_eager_delete_vars);
} // namespace details
class ExecutorInfoCache { class ExecutorInfoCache {
public: public:
/* struct CacheKey {
* The ExecutorPrepareContext is different while running forward program and CacheKey(const ProgramDesc* program_desc, const platform::Place& place,
* backward program. We add bool value into cached key to distinguish this. int64_t start_op_index, int64_t end_op_index, bool is_grad)
*/ : program_desc_(program_desc),
using KeyInfo = std::pair<const framework::ProgramDesc*, /*is_grad*/ bool>; place_(place),
start_op_index_(start_op_index),
end_op_index_(end_op_index),
is_grad_(is_grad) {
device_type_ = platform::Place2DeviceType(place);
PADDLE_ENFORCE_NOT_NULL(program_desc_,
"program_desc should not be null.");
}
std::string DebugString() const {
std::stringstream ss;
ss << "\n CacheKey(program_desc: " << program_desc_;
ss << ", start_op_index: " << start_op_index_;
ss << ", end_op_index: " << end_op_index_;
ss << ", is_grad: " << is_grad_;
ss << ", device_type: " << device_type_ << ")";
return ss.str();
}
const ProgramDesc* program_desc_;
platform::Place place_;
int64_t start_op_index_;
int64_t end_op_index_;
bool is_grad_;
platform::DeviceType device_type_;
};
using KeyType = size_t; using KeyType = size_t;
using ValueType =
std::pair<std::shared_ptr<ParallelExecutor>, std::shared_ptr<ir::Graph>>;
struct HashPair { struct KeyHasher {
size_t operator()(const KeyInfo& key) const noexcept { size_t operator()(const CacheKey& key) const noexcept {
size_t seed = 10; size_t seed = 10;
auto* prog_desc = key.first; auto* prog_desc = key.program_desc_;
/* /*
* Note(Aurelius84): DO NOT use only ProgramDesc* to calculate hash value * Note(Aurelius84): DO NOT use only ProgramDesc* to calculate hash value
* because a new program will hold same pointer address after an older * because a new program will hold same pointer address after an older
...@@ -59,8 +99,12 @@ class ExecutorInfoCache { ...@@ -59,8 +99,12 @@ class ExecutorInfoCache {
hash_combine(&seed, &prog_desc->Block(i)); hash_combine(&seed, &prog_desc->Block(i));
hash_combine(&seed, prog_desc->Block(i).OpSize()); hash_combine(&seed, prog_desc->Block(i).OpSize());
} }
hash_combine(&seed, key.second); hash_combine(&seed, static_cast<int>(key.device_type_));
VLOG(1) << "hash value is : " << seed << " of pointer " << prog_desc; hash_combine(&seed, key.start_op_index_);
hash_combine(&seed, key.end_op_index_);
hash_combine(&seed, key.is_grad_);
VLOG(3) << "hash value is : " << seed
<< " of key: " << key.DebugString();
return seed; return seed;
} }
...@@ -73,54 +117,50 @@ class ExecutorInfoCache { ...@@ -73,54 +117,50 @@ class ExecutorInfoCache {
static ExecutorInfoCache& Instance(); static ExecutorInfoCache& Instance();
std::shared_ptr<framework::ExecutorPrepareContext> Get( ValueType GetMutable(const CacheKey& key) {
const KeyInfo& key) const { auto key_val = key_hash_func_(key);
KeyType key_value = key_hash_func_(key);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
Has(key_value), true, Has(key_val), true,
platform::errors::NotFound( platform::errors::NotFound("%s doesn't exist in ExecutorInfoCache",
"(programDesc: %s, is_grad: %s) doesn't exist in ExecutorInfoCache", key.DebugString()));
key.first, key.second)); return info_map_[key_val];
return info_map_.at(key_value);
} }
bool Has(const KeyInfo& key) const { bool Has(const CacheKey& key) const {
KeyType key_value = key_hash_func_(key); auto key_val = key_hash_func_(key);
return Has(key_value); return Has(key_val);
} }
bool Has(const KeyType& key) const { bool Has(const KeyType& key) const {
return info_map_.find(key) != info_map_.end(); return info_map_.find(key) != info_map_.end();
} }
void Insert(const KeyInfo& key, void Insert(const CacheKey& key, ValueType value) {
std::shared_ptr<framework::ExecutorPrepareContext> exe_ctx) { auto key_val = key_hash_func_(key);
KeyType key_value = key_hash_func_(key); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_NE( Has(key_val), false,
Has(key_value), true, platform::errors::NotFound("%s has existed in ExecutorInfoCache",
platform::errors::NotFound( key.DebugString()));
"(programDesc: %s, is_grad: %s) has existed in ExecutorInfoCache", info_map_.insert({key_val, value});
key.first, key.second));
info_map_.insert({key_value, exe_ctx});
} }
private: size_t Size() const { return info_map_.size(); }
ExecutorInfoCache() = default;
HashPair key_hash_func_; void Finalize();
// Note: we shall avoid using raw pointer as key but use hash code, private:
// beacause pointer doesn't hold resource indeed. ExecutorInfoCache() = default;
std::unordered_map<KeyType,
std::shared_ptr<framework::ExecutorPrepareContext>>
info_map_;
DISABLE_COPY_AND_ASSIGN(ExecutorInfoCache); DISABLE_COPY_AND_ASSIGN(ExecutorInfoCache);
KeyHasher key_hash_func_;
std::unordered_map<KeyType, ValueType> info_map_;
}; };
std::shared_ptr<framework::ExecutorPrepareContext> GetExecutorInfoFromCache( using CacheInfo =
const framework::Executor& exe, const framework::ExecutionContext& ctx, std::pair<std::shared_ptr<ParallelExecutor>, bool /*is_new_created*/>;
const std::vector<std::vector<std::string>>& ctx_output_names,
bool is_grad); CacheInfo GetExecutorInfoFromCache(const ExecutorInfoCache::CacheKey& cache_key,
framework::Scope* scope);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -21,14 +21,30 @@ namespace paddle { ...@@ -21,14 +21,30 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
Graph::Graph(const ProgramDesc &program) : program_(program) { Graph::Graph(const ProgramDesc &program)
auto var_nodes = InitFromProgram(program_); : Graph(program, 0, program.Block(0).AllOps().size()) {}
Graph::Graph(const ProgramDesc &program, int64_t start_op_index,
int64_t end_op_index)
: program_(program) {
auto var_nodes = InitFromProgram(program_, start_op_index, end_op_index);
ResolveHazard(var_nodes); ResolveHazard(var_nodes);
} }
std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram( std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram(
const ProgramDesc &program) { const ProgramDesc &program, int64_t start_op_index, int64_t end_op_index) {
VLOG(3) << "block in program:" << program_.Size(); VLOG(3) << "block in program:" << program_.Size();
PADDLE_ENFORCE_GE(start_op_index, 0,
platform::errors::InvalidArgument(
"Required start_op_index >= 0, but received "
"start_op_index = %d",
start_op_index));
PADDLE_ENFORCE_GE(end_op_index, start_op_index,
platform::errors::InvalidArgument(
"Required end_op_index >= start_op_index, but received "
"end_op_index: %d < start_op_index: %d",
end_op_index, start_op_index));
std::unordered_map<std::string, VarDesc *> all_vars; std::unordered_map<std::string, VarDesc *> all_vars;
// var nodes for each var name, will have multiple versions in SSA // var nodes for each var name, will have multiple versions in SSA
std::map<std::string, std::vector<ir::Node *>> var_nodes; std::map<std::string, std::vector<ir::Node *>> var_nodes;
...@@ -37,8 +53,16 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram( ...@@ -37,8 +53,16 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram(
} }
auto not_visited_vars = all_vars; auto not_visited_vars = all_vars;
auto all_ops = program.Block(0).AllOps();
for (auto *op : program.Block(0).AllOps()) { PADDLE_ENFORCE_LE(
end_op_index, all_ops.size(),
platform::errors::InvalidArgument(
"Required end_op_index <= %d, but received end_op_index = %d",
all_ops.size(), end_op_index));
for (auto i = start_op_index; i < end_op_index; ++i) {
auto *op = all_ops[i];
VLOG(3) << "create OpNode by " << op->Type();
ir::Node *node = CreateOpNode(op); ir::Node *node = CreateOpNode(op);
// For input args, reuse the same var name if it was created before. // For input args, reuse the same var name if it was created before.
// Otherwise, create a new one. // Otherwise, create a new one.
...@@ -88,18 +112,28 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram( ...@@ -88,18 +112,28 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram(
} }
} }
for (auto &pair : not_visited_vars) { if (end_op_index < static_cast<int64_t>(all_ops.size()) ||
const auto &var_name = pair.first; start_op_index > 0) {
auto *var_desc = pair.second; is_partial_ = true;
if (var_name != kEmptyVarName) { }
VLOG(10) << "Create isolated var node " << var_name; if (!is_partial_) {
var_nodes[var_name].push_back(CreateVarNode(var_desc)); for (auto &pair : not_visited_vars) {
const auto &var_name = pair.first;
auto *var_desc = pair.second;
if (var_name != kEmptyVarName) {
VLOG(10) << "Create isolated var node " << var_name;
var_nodes[var_name].push_back(CreateVarNode(var_desc));
}
} }
} }
Set<const std::vector<OpDesc *>>( Set<const std::vector<OpDesc *>>(
details::kStaleProgramOpDescs, details::kStaleProgramOpDescs,
new std::vector<OpDesc *>(program.Block(0).AllOps())); new std::vector<OpDesc *>(all_ops.begin() + start_op_index,
all_ops.begin() + end_op_index));
VLOG(3)
<< "kStaleProgramOpDescs.size: "
<< Get<const std::vector<OpDesc *>>(details::kStaleProgramOpDescs).size();
return var_nodes; return var_nodes;
} }
......
...@@ -79,6 +79,9 @@ namespace ir { ...@@ -79,6 +79,9 @@ namespace ir {
class Graph { class Graph {
public: public:
explicit Graph(const ProgramDesc &program); explicit Graph(const ProgramDesc &program);
// Construct a Graph with ops[start_op_index, end_op_index)
explicit Graph(const ProgramDesc &program, int64_t start_op_index,
int64_t end_op_index);
virtual ~Graph() { virtual ~Graph() {
for (auto &attr : attrs_) { for (auto &attr : attrs_) {
...@@ -88,6 +91,8 @@ class Graph { ...@@ -88,6 +91,8 @@ class Graph {
attr_dels_.clear(); attr_dels_.clear();
} }
bool IsConstructedByPartialProgram() const { return is_partial_; }
bool Has(const std::string &attr_name) const { bool Has(const std::string &attr_name) const {
return attrs_.count(attr_name) > 0; return attrs_.count(attr_name) > 0;
} }
...@@ -253,7 +258,7 @@ class Graph { ...@@ -253,7 +258,7 @@ class Graph {
private: private:
std::map<std::string, std::vector<ir::Node *>> InitFromProgram( std::map<std::string, std::vector<ir::Node *>> InitFromProgram(
const ProgramDesc &program); const ProgramDesc &program, int64_t start_op_index, int64_t end_op_index);
// NOTE: program_ shouldn't be exposed to user. // NOTE: program_ shouldn't be exposed to user.
const ProgramDesc program_; const ProgramDesc program_;
...@@ -262,6 +267,11 @@ class Graph { ...@@ -262,6 +267,11 @@ class Graph {
std::map<ir::Node *, std::unique_ptr<ir::Node>> nodes_; std::map<ir::Node *, std::unique_ptr<ir::Node>> nodes_;
std::unordered_set<ir::Node *> node_set_; std::unordered_set<ir::Node *> node_set_;
size_t num_node_created_{0}; // help to generate a unique node id. size_t num_node_created_{0}; // help to generate a unique node id.
// NOTE(Aurelius84): Whether is constructed with partial ProgramDesc.
// In case of @to_static, whole trainning program is splited into two
// parts: forward graph and backward graph, which can be executed
// independently.
bool is_partial_{false};
}; };
bool IsControlDepVar(const ir::Node &var); bool IsControlDepVar(const ir::Node &var);
......
...@@ -20,15 +20,15 @@ ...@@ -20,15 +20,15 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
using OpVariant = operators::OpVariant;
class ConditionalOpEagerDeletionPass : public Pass { class ConditionalOpEagerDeletionPass : public Pass {
protected: protected:
void ApplyImpl(Graph *graph) const override { void ApplyImpl(Graph *graph) const override {
auto all_ops = ir::FilterByNodeWrapper<details::OpHandleBase>(*graph); auto all_ops = ir::FilterByNodeWrapper<details::OpHandleBase>(*graph);
// Find all conditional_op and conditional_grad_op // Find all conditional_op and conditional_grad_op
std::unordered_map<size_t, std::pair<std::vector<OperatorBase *>, std::unordered_map<
std::vector<OperatorBase *>>> size_t, std::pair<std::vector<OpVariant>, std::vector<OpVariant>>>
target_ops; target_ops;
for (auto *op : all_ops) { for (auto *op : all_ops) {
auto compute_op = dynamic_cast<details::ComputationOpHandle *>(op); auto compute_op = dynamic_cast<details::ComputationOpHandle *>(op);
...@@ -43,6 +43,30 @@ class ConditionalOpEagerDeletionPass : public Pass { ...@@ -43,6 +43,30 @@ class ConditionalOpEagerDeletionPass : public Pass {
} }
} }
// NOTE(Aurelius84): In case of @to_static, after we finish executing
// forward graph, some necessaray variable in step_scope of controlflow_op
// should be kept for backward graph.
if (graph->IsConstructedByPartialProgram()) {
PADDLE_ENFORCE_LE(target_ops.size(), 1,
platform::errors::InvalidArgument(
"Unsupported multi devices if graph is constructed "
"with partial program."));
size_t scope_idx = 0;
auto &ifelse_ops = target_ops[scope_idx].first;
auto &ifelse_grad_ops = target_ops[scope_idx].second;
auto all_ops = graph->OriginProgram().Block(0).AllOps();
if (ifelse_ops.empty()) {
operators::AppendOpVariantByOpName(
all_ops, std::string("conditional_block"), &ifelse_ops);
} else if (ifelse_grad_ops.empty()) {
operators::AppendOpVariantByOpName(
all_ops, std::string("conditional_block_grad"), &ifelse_grad_ops);
} else {
PADDLE_THROW("One of ifelse_ops or ifelse_grad_ops should be empty.");
}
}
for (auto &ops_pair : target_ops) { for (auto &ops_pair : target_ops) {
auto &ifelse_ops = ops_pair.second.first; auto &ifelse_ops = ops_pair.second.first;
auto &ifelse_grad_ops = ops_pair.second.second; auto &ifelse_grad_ops = ops_pair.second.second;
......
...@@ -15,20 +15,24 @@ ...@@ -15,20 +15,24 @@
#include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/operators/controlflow/op_variant.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h" #include "paddle/fluid/operators/controlflow/while_op_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
using OpVariant = operators::OpVariant;
class WhileOpEagerDeletionPass : public ir::Pass { class WhileOpEagerDeletionPass : public ir::Pass {
protected: protected:
void ApplyImpl(ir::Graph *graph) const override { void ApplyImpl(ir::Graph *graph) const override {
auto all_ops = ir::FilterByNodeWrapper<details::OpHandleBase>(*graph); auto all_ops = ir::FilterByNodeWrapper<details::OpHandleBase>(*graph);
// Find all while_op and while_grad_op // Find all while_op and while_grad_op. In case of @to_static, graph
std::unordered_map<size_t, std::pair<std::vector<OperatorBase *>, // may be constructed only by forward or backward program, so we use
std::vector<OperatorBase *>>> // OpVariant here instead of OperatorBase.
std::unordered_map<
size_t, std::pair<std::vector<OpVariant>, std::vector<OpVariant>>>
target_ops; target_ops;
for (auto *op : all_ops) { for (auto *op : all_ops) {
auto compute_op = dynamic_cast<details::ComputationOpHandle *>(op); auto compute_op = dynamic_cast<details::ComputationOpHandle *>(op);
...@@ -42,6 +46,27 @@ class WhileOpEagerDeletionPass : public ir::Pass { ...@@ -42,6 +46,27 @@ class WhileOpEagerDeletionPass : public ir::Pass {
compute_op->GetOp()); compute_op->GetOp());
} }
} }
if (graph->IsConstructedByPartialProgram()) {
PADDLE_ENFORCE_LE(
target_ops.size(), 1,
platform::errors::InvalidArgument(
"Unsupported multi device if graph is constructed by "
"partial program."));
size_t scope_idx = 0;
auto &while_ops = target_ops[scope_idx].first;
auto &while_grad_ops = target_ops[scope_idx].second;
auto all_ops = graph->OriginProgram().Block(0).AllOps();
if (while_ops.empty()) {
operators::AppendOpVariantByOpName(all_ops, std::string("while"),
&while_ops);
} else if (while_grad_ops.empty()) {
operators::AppendOpVariantByOpName(all_ops, std::string("while_grad"),
&while_grad_ops);
} else {
PADDLE_THROW("One of while_ops or while_grad_ops should be empty.");
}
}
for (auto &ops_pair : target_ops) { for (auto &ops_pair : target_ops) {
auto &while_ops = ops_pair.second.first; auto &while_ops = ops_pair.second.first;
......
...@@ -33,6 +33,7 @@ limitations under the License. */ ...@@ -33,6 +33,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h" #include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h" #include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/set_reader_device_info_utils.h" #include "paddle/fluid/framework/ir/multi_devices_graph_pass/set_reader_device_info_utils.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/event.h" #include "paddle/fluid/platform/event.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -684,6 +685,51 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places, ...@@ -684,6 +685,51 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
SetReaderOpDeviceInfoOfGraphs(final_graphs); SetReaderOpDeviceInfoOfGraphs(final_graphs);
} }
ParallelExecutor::ParallelExecutor(const platform::Place &place, Scope *scope,
const ExecutionStrategy &exec_strategy,
const BuildStrategy &build_strategy,
ir::Graph *graph)
: member_(new ParallelExecutorPrivate({place}, scope)) {
// Initialize necessary info of member_ with strategy.
InitExecutorPrivateMemberInfo(exec_strategy, build_strategy,
/*device_count=*/1, *graph);
CreateLocalScopes(scope, /*local_scope=*/{scope}, /*create_new=*/false);
// Apply BuildStrategy to compile graph.
std::vector<ir::Graph *> graphs = {graph};
std::vector<ir::Graph *> async_graphs =
CompileGraphWithBuildStrategy(graph, &graphs, /*loss_var_name=*/"");
graph = member_->ApplyMemoryOptimizePass(graph);
// Create vars in each scope. Passes may also create new vars.
// skip control vars and empty vars
CreateVariableInfos(&var_infos_, graph);
// Create local execution scopes
std::unordered_map<Scope *, Scope *> scope_map =
CreateLocalExecScopes(member_->local_scopes_, /*create_new=*/false);
std::vector<ir::Graph *> final_graphs =
CreateSSAGraphExecutor(exec_strategy, &async_graphs, graph);
// Set scope_map of op from each graph
ResetOpHandleScopeMapOfGraphs(final_graphs, scope_map);
}
void ParallelExecutor::PrepareVariables(Scope *scope) {
for (auto &info : var_infos_) {
auto var = scope->FindVar(info.name_);
if (var != nullptr) {
VLOG(2) << info.name_
<< " has been initialized beforehand in global scope, skipped.";
continue;
}
framework::InitializeVariable(scope->Var(info.name_), info.type_);
}
}
void ParallelExecutor::BCastParamsToDevices( void ParallelExecutor::BCastParamsToDevices(
const std::vector<std::string> &vars, int trainer_id) const { const std::vector<std::string> &vars, int trainer_id) const {
VLOG(3) << "BCastParamsToDevices"; VLOG(3) << "BCastParamsToDevices";
...@@ -845,6 +891,36 @@ FetchResultType ParallelExecutor::Run( ...@@ -845,6 +891,36 @@ FetchResultType ParallelExecutor::Run(
return fetch_data; return fetch_data;
} }
void ParallelExecutor::RunWithoutFetch(
const std::vector<std::string> &skip_eager_vars) {
VLOG(3) << "enter ParallelExecutor RunWithoutFetch";
#ifdef WITH_GPERFTOOLS
if (gProfileStarted) {
ProfilerFlush();
}
#endif
platform::RecordBlock b(0);
ResetHasFeedGuard reset_has_feed_guard(member_);
ir::SkipMemOptVarsGuard guard(&(member_->mem_opt_var_infos_), skip_eager_vars,
member_->HasGarbageCollectors());
VLOG(3) << "ParallelExecutor begin to run member_->executor_->Run";
member_->executor_->Run(/*fetch_tensors*/ {}, /*return_merged*/ false);
}
void ParallelExecutor::SkipMemoryReuse(
size_t scope_idx, const std::vector<std::string> &skip_vars) {
for (auto &var_name : skip_vars) {
bool is_persistable = member_->IsPersistable(var_name);
if (!is_persistable) {
VLOG(3) << "SkipMemoryReuse for var: " << var_name;
member_->SetSkipMemoryReuse(scope_idx, var_name);
}
}
}
void ParallelExecutor::FeedTensorsIntoLocalScopes( void ParallelExecutor::FeedTensorsIntoLocalScopes(
const std::vector<std::unordered_map<std::string, LoDTensor>> &tensors) { const std::vector<std::unordered_map<std::string, LoDTensor>> &tensors) {
if (!member_->AllowPartialFeed()) { if (!member_->AllowPartialFeed()) {
...@@ -1449,10 +1525,18 @@ void ParallelExecutor::ResetOpHandleScopeMapOfGraphs( ...@@ -1449,10 +1525,18 @@ void ParallelExecutor::ResetOpHandleScopeMapOfGraphs(
auto ops = ir::FilterByNodeWrapper<details::OpHandleBase>(*g); auto ops = ir::FilterByNodeWrapper<details::OpHandleBase>(*g);
for (auto *op : ops) { for (auto *op : ops) {
op->SetLocalExecScopes(scope_map); op->SetLocalExecScopes(scope_map);
op->SetIsVariantScope(true);
} }
} }
} }
void ParallelExecutor::ResetOpHandleScopeMapOfGraphs(
const std::unordered_map<Scope *, Scope *> &scope_map) {
auto inner_graph = const_cast<ir::Graph *>(&Graph());
std::vector<ir::Graph *> graphs = {inner_graph};
ResetOpHandleScopeMapOfGraphs(graphs, scope_map);
}
void ParallelExecutor::SetReaderOpDeviceInfoOfGraphs( void ParallelExecutor::SetReaderOpDeviceInfoOfGraphs(
const std::vector<ir::Graph *> &final_graphs) { const std::vector<ir::Graph *> &final_graphs) {
if (final_graphs.size() == 1) { if (final_graphs.size() == 1) {
......
...@@ -60,6 +60,12 @@ class ParallelExecutor { ...@@ -60,6 +60,12 @@ class ParallelExecutor {
const BuildStrategy &build_strategy, const BuildStrategy &build_strategy,
ir::Graph *graph); ir::Graph *graph);
// NOTE(Aurelius84): Construct a PE running on single device for @to_static
explicit ParallelExecutor(const platform::Place &place, Scope *scope,
const ExecutionStrategy &exec_strategy,
const BuildStrategy &build_strategy,
ir::Graph *graph);
~ParallelExecutor(); ~ParallelExecutor();
size_t DeviceCount() const; size_t DeviceCount() const;
...@@ -84,7 +90,16 @@ class ParallelExecutor { ...@@ -84,7 +90,16 @@ class ParallelExecutor {
FetchResultType Run(const std::vector<std::string> &fetch_tensors, FetchResultType Run(const std::vector<std::string> &fetch_tensors,
bool return_merged = true); bool return_merged = true);
void RunWithoutFetch(const std::vector<std::string> &skip_eager_vars);
void ResetOpHandleScopeMapOfGraphs(
const std::unordered_map<Scope *, Scope *> &scope_map);
const ir::Graph &Graph() const; const ir::Graph &Graph() const;
void PrepareVariables(Scope *scope);
void SkipMemoryReuse(size_t scope_idx,
const std::vector<std::string> &skip_vars);
private: private:
// broadcast the parameters from the 0th device. // broadcast the parameters from the 0th device.
...@@ -131,6 +146,7 @@ class ParallelExecutor { ...@@ -131,6 +146,7 @@ class ParallelExecutor {
ParallelExecutorPrivate *member_; ParallelExecutorPrivate *member_;
std::vector<std::unique_ptr<ir::Graph>> async_graphs_; std::vector<std::unique_ptr<ir::Graph>> async_graphs_;
std::vector<VariableInfo> var_infos_;
}; };
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -16,8 +16,6 @@ ...@@ -16,8 +16,6 @@
#include <string> #include <string>
#include "paddle/fluid/operators/controlflow/op_variant.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class ProgramDesc; class ProgramDesc;
...@@ -189,18 +187,10 @@ void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp( ...@@ -189,18 +187,10 @@ void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp( void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
const framework::ProgramDesc &program, const framework::ProgramDesc &program,
const std::vector<framework::OperatorBase *> &ifelse_ops, const std::vector<OpVariant> &ifelse_ops,
const std::vector<framework::OperatorBase *> &ifelse_grad_ops) { const std::vector<OpVariant> &ifelse_grad_ops) {
std::vector<OpVariant> fwd_ops, bwd_ops; std::vector<OpVariant> fwd_ops = ifelse_ops;
fwd_ops.reserve(ifelse_ops.size()); std::vector<OpVariant> bwd_ops = ifelse_grad_ops;
for (auto *op : ifelse_ops) {
fwd_ops.emplace_back(op);
}
bwd_ops.reserve(ifelse_grad_ops.size());
for (auto *op : ifelse_grad_ops) {
bwd_ops.emplace_back(op);
}
PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOpImpl( PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOpImpl(
program, &fwd_ops, &bwd_ops); program, &fwd_ops, &bwd_ops);
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/controlflow/conditional_block_op.h" #include "paddle/fluid/operators/controlflow/conditional_block_op.h"
#include "paddle/fluid/operators/controlflow/op_variant.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
namespace paddle { namespace paddle {
...@@ -40,8 +41,8 @@ void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp( ...@@ -40,8 +41,8 @@ void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp( void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
const framework::ProgramDesc &program, const framework::ProgramDesc &program,
const std::vector<framework::OperatorBase *> &ifelse_ops, const std::vector<OpVariant> &ifelse_ops,
const std::vector<framework::OperatorBase *> &ifelse_grad_ops); const std::vector<OpVariant> &ifelse_grad_ops);
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -68,5 +68,20 @@ const void *OpVariant::RawPointer() const { ...@@ -68,5 +68,20 @@ const void *OpVariant::RawPointer() const {
return boost::apply_visitor(RawPointerVisitor(), op_); return boost::apply_visitor(RawPointerVisitor(), op_);
} }
void AppendOpVariantByOpName(const std::vector<framework::OpDesc *> &op_descs,
const std::string &candidate_op_name,
std::vector<OpVariant> *result_ops) {
PADDLE_ENFORCE_NOT_NULL(
result_ops,
platform::errors::Unavailable("result_ops should not be a null_ptr."));
for (auto *op_desc : op_descs) {
PADDLE_ENFORCE_NOT_NULL(op_desc, platform::errors::Unavailable(
"op_desc should not be a null_ptr."));
if (op_desc->Type() == candidate_op_name) {
result_ops->emplace_back(op_desc);
}
}
}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -73,5 +73,9 @@ class OpVariant { ...@@ -73,5 +73,9 @@ class OpVariant {
op_; op_;
}; };
void AppendOpVariantByOpName(const std::vector<framework::OpDesc *> &op_descs,
const std::string &candidate_op_name,
std::vector<OpVariant> *result_ops);
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
#include "paddle/fluid/operators/controlflow/while_op_helper.h" #include "paddle/fluid/operators/controlflow/while_op_helper.h"
#include <string> #include <string>
#include "paddle/fluid/operators/controlflow/op_variant.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
namespace paddle { namespace paddle {
...@@ -199,18 +198,10 @@ void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( ...@@ -199,18 +198,10 @@ void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
const framework::ProgramDesc &program, const framework::ProgramDesc &program,
const std::vector<framework::OperatorBase *> &while_ops, const std::vector<OpVariant> &while_ops,
const std::vector<framework::OperatorBase *> &while_grad_ops) { const std::vector<OpVariant> &while_grad_ops) {
std::vector<OpVariant> fwd_ops, bwd_ops; std::vector<OpVariant> fwd_ops = while_ops;
fwd_ops.reserve(while_ops.size()); std::vector<OpVariant> bwd_ops = while_grad_ops;
for (auto *op : while_ops) {
fwd_ops.emplace_back(op);
}
bwd_ops.reserve(while_grad_ops.size());
for (auto *op : while_grad_ops) {
bwd_ops.emplace_back(op);
}
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(program, &fwd_ops, PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(program, &fwd_ops,
&bwd_ops); &bwd_ops);
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/controlflow/op_variant.h"
#include "paddle/fluid/platform/variant.h" #include "paddle/fluid/platform/variant.h"
namespace paddle { namespace paddle {
...@@ -46,8 +47,8 @@ void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( ...@@ -46,8 +47,8 @@ void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
const framework::ProgramDesc &program, const framework::ProgramDesc &program,
const std::vector<framework::OperatorBase *> &while_ops, const std::vector<OpVariant> &while_ops,
const std::vector<framework::OperatorBase *> &while_grad_ops); const std::vector<OpVariant> &while_grad_ops);
bool GetCondData(const framework::LoDTensor &cond); bool GetCondData(const framework::LoDTensor &cond);
......
...@@ -23,7 +23,6 @@ limitations under the License. */ ...@@ -23,7 +23,6 @@ limitations under the License. */
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/executor_cache.h" #include "paddle/fluid/framework/executor_cache.h"
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -43,6 +42,7 @@ namespace operators { ...@@ -43,6 +42,7 @@ namespace operators {
using StepScopeVar = std::vector<framework::Scope *>; using StepScopeVar = std::vector<framework::Scope *>;
using BlockDesc = framework::BlockDesc; using BlockDesc = framework::BlockDesc;
using ProgramDesc = framework::ProgramDesc;
using Variable = framework::Variable; using Variable = framework::Variable;
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
...@@ -198,9 +198,6 @@ class RunProgramOpKernel : public framework::OpKernel<T> { ...@@ -198,9 +198,6 @@ class RunProgramOpKernel : public framework::OpKernel<T> {
"The OutScope of RunProgramGradOp should only hold one scope.")); "The OutScope of RunProgramGradOp should only hold one scope."));
// Step 2. prepare executor and init persistable variables // Step 2. prepare executor and init persistable variables
framework::Executor exe(ctx.GetPlace());
auto exe_ctx = framework::GetExecutorInfoFromCache(
exe, ctx, {output_var_names, dout_var_names}, /*is_grad=*/false);
// NOTE(Aurelius84): While training some models, forward can be called many // NOTE(Aurelius84): While training some models, forward can be called many
// times and then apply backpropagation all at once, such as Reinforcement // times and then apply backpropagation all at once, such as Reinforcement
...@@ -216,12 +213,27 @@ class RunProgramOpKernel : public framework::OpKernel<T> { ...@@ -216,12 +213,27 @@ class RunProgramOpKernel : public framework::OpKernel<T> {
details::ShareVarsIntoScope(input_vars, input_var_names, &scope); details::ShareVarsIntoScope(input_vars, input_var_names, &scope);
details::ShareVarsIntoScope(param_vars, param_names, &scope); details::ShareVarsIntoScope(param_vars, param_names, &scope);
// Step 3. run ops if (end_op_index > start_op_index) {
exe.RunPartialPreparedContext(exe_ctx.get(), &scope, start_op_index, auto *program = ctx.Attr<BlockDesc *>("global_block")->Program();
end_op_index, /*create_local_scope=*/false, auto cache_key = framework::ExecutorInfoCache::CacheKey(
/*create_vars=*/true, program, ctx.GetPlace(), start_op_index, end_op_index,
/*keep_kids=*/!is_test); /*is_grad=*/false);
auto cache_info = framework::GetExecutorInfoFromCache(cache_key, &scope);
auto &parallel_executor = cache_info.first;
if (cache_info.second /*is_new_created*/) {
parallel_executor->SkipMemoryReuse(/*scope_idx=*/0, input_var_names);
}
// Step 3. run ops
// all out_vars are skip_eager_var
std::vector<std::string> skip_eager_delete_vars(output_var_names);
skip_eager_delete_vars.insert(skip_eager_delete_vars.end(),
dout_var_names.begin(),
dout_var_names.end());
framework::details::ParseSafeEagerDeletionSkipVars(
*program, end_op_index, output_var_names, &skip_eager_delete_vars);
parallel_executor->RunWithoutFetch(skip_eager_delete_vars);
}
// Step 4. Get Output // Step 4. Get Output
details::ShareVarsFromScope(output_vars, output_var_names, &scope); details::ShareVarsFromScope(output_vars, output_var_names, &scope);
details::ShareVarsFromScope(dout_vars, dout_var_names, &scope); details::ShareVarsFromScope(dout_vars, dout_var_names, &scope);
...@@ -290,21 +302,31 @@ class RunProgramGradOpKernel : public framework::OpKernel<T> { ...@@ -290,21 +302,31 @@ class RunProgramGradOpKernel : public framework::OpKernel<T> {
auto &scope = *(global_inner_scope->kids().front()); auto &scope = *(global_inner_scope->kids().front());
// Step 2. prepare executor and scope if (end_op_index > start_op_index) {
framework::Executor exe(ctx.GetPlace()); // Step 2. prepare executor and scope
auto exe_ctx = framework::GetExecutorInfoFromCache( auto *program = ctx.Attr<BlockDesc *>("global_block")->Program();
exe, ctx, {input_grad_var_names, param_grad_names}, auto cache_key = framework::ExecutorInfoCache::CacheKey(
/*is_grad=*/true); program, ctx.GetPlace(), start_op_index, end_op_index,
/*is_grad*/ true);
details::ShareVarsIntoScope(output_grad_vars, output_grad_var_names, auto cache_info = framework::GetExecutorInfoFromCache(cache_key, &scope);
&scope); auto &parallel_executor = cache_info.first;
// Debug info: scope info when run end
VLOG(3) << framework::GenScopeTreeDebugInfo(out_scope_vec->front()); parallel_executor->SkipMemoryReuse(/*scope_idx=*/0,
output_grad_var_names);
// Step 3. run ops
exe.RunPartialPreparedContext(exe_ctx.get(), &scope, start_op_index, details::ShareVarsIntoScope(output_grad_vars, output_grad_var_names,
end_op_index, /*create_local_scope=*/false, &scope);
/*create_vars=*/true, /*keep_kids=*/false); // Debug info: scope info when run end
VLOG(3) << framework::GenScopeTreeDebugInfo(out_scope_vec->front());
std::vector<std::string> skip_eager_delete_vars(input_grad_var_names);
framework::details::AppendSkipDeletionVars(param_grad_names,
&skip_eager_delete_vars);
// Step 3. run ops
parallel_executor->RunWithoutFetch(
/*skip_eager_delete_vars=*/skip_eager_delete_vars);
}
// Step 4. get outputs // Step 4. get outputs
details::ShareVarsFromScope(input_grad_vars, input_grad_var_names, &scope); details::ShareVarsFromScope(input_grad_vars, input_grad_var_names, &scope);
......
...@@ -75,6 +75,19 @@ void SetAllowTF32Cudnn(bool active) { allow_tf32_cudnn = active; } ...@@ -75,6 +75,19 @@ void SetAllowTF32Cudnn(bool active) { allow_tf32_cudnn = active; }
bool AllowTF32Cudnn() { return allow_tf32_cudnn; } bool AllowTF32Cudnn() { return allow_tf32_cudnn; }
#endif // PADDLE_WITH_CUDA #endif // PADDLE_WITH_CUDA
DeviceType Place2DeviceType(const platform::Place& place) {
if (platform::is_cpu_place(place)) {
return platform::DeviceType::CPU;
} else if (platform::is_gpu_place(place)) {
return platform::DeviceType::CUDA;
} else if (platform::is_xpu_place(place)) {
return platform::DeviceType::XPU;
} else {
PADDLE_THROW(platform::errors::Unavailable(
"Unsupported place %s to convert into platform::DeviceType.", place));
}
}
DeviceContextPool* DeviceContextPool::pool = nullptr; DeviceContextPool* DeviceContextPool::pool = nullptr;
platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) { platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
......
...@@ -99,6 +99,8 @@ enum DeviceType { ...@@ -99,6 +99,8 @@ enum DeviceType {
NPU = 3, NPU = 3,
}; };
DeviceType Place2DeviceType(const platform::Place& place);
constexpr DeviceType kCPU = DeviceType::CPU; constexpr DeviceType kCPU = DeviceType::CPU;
constexpr DeviceType kCUDA = DeviceType::CUDA; constexpr DeviceType kCUDA = DeviceType::CUDA;
constexpr DeviceType kXPU = DeviceType::XPU; constexpr DeviceType kXPU = DeviceType::XPU;
......
...@@ -31,6 +31,7 @@ limitations under the License. */ ...@@ -31,6 +31,7 @@ limitations under the License. */
#include "paddle/fluid/framework/custom_operator.h" #include "paddle/fluid/framework/custom_operator.h"
#include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/executor_cache.h"
#include "paddle/fluid/framework/executor_gc_helper.h" #include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/feed_fetch_type.h"
...@@ -2216,6 +2217,8 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -2216,6 +2217,8 @@ All parameter, weight, gradient are variables in Paddle.
m.def("set_cudnn_switch", platform::SetAllowTF32Cudnn); m.def("set_cudnn_switch", platform::SetAllowTF32Cudnn);
m.def("get_cudnn_switch", platform::AllowTF32Cudnn); m.def("get_cudnn_switch", platform::AllowTF32Cudnn);
#endif // PADDLE_WITH_CUDA #endif // PADDLE_WITH_CUDA
m.def("clear_executor_cache",
[]() { framework::ExecutorInfoCache::Instance().Finalize(); });
using VarQuantScale = using VarQuantScale =
std::unordered_map<std::string, std::pair<bool, LoDTensor>>; std::unordered_map<std::string, std::pair<bool, LoDTensor>>;
......
...@@ -262,3 +262,5 @@ monkey_patch_varbase() ...@@ -262,3 +262,5 @@ monkey_patch_varbase()
# do some clean up manually. # do some clean up manually.
if core.is_compiled_with_npu(): if core.is_compiled_with_npu():
atexit.register(core.npu_finalize) atexit.register(core.npu_finalize)
# NOTE(Aurelius84): clean up ExecutorCacheInfo in advance manually.
atexit.register(core.clear_executor_cache)
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
list(REMOVE_ITEM TEST_OPS test_lac)
# NOTE(Aurelius84): In case of Windows CI, if open ON_INFER, RWLOCK of Scope will
# be removed and will cause some random failed in multi-thread.
if(NOT ON_INFER)
py_test_modules(test_lac MODULES test_lac)
set_tests_properties(test_lac PROPERTIES TIMEOUT 120)
endif()
foreach(TEST_OP ${TEST_OPS}) foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP}) py_test_modules(${TEST_OP} MODULES ${TEST_OP})
endforeach(TEST_OP) endforeach(TEST_OP)
...@@ -10,7 +18,6 @@ set_tests_properties(test_yolov3 PROPERTIES TIMEOUT 900 LABELS "RUN_TYPE=EXCLUSI ...@@ -10,7 +18,6 @@ set_tests_properties(test_yolov3 PROPERTIES TIMEOUT 900 LABELS "RUN_TYPE=EXCLUSI
set_tests_properties(test_mobile_net PROPERTIES TIMEOUT 120) set_tests_properties(test_mobile_net PROPERTIES TIMEOUT 120)
set_tests_properties(test_seq2seq PROPERTIES TIMEOUT 120) set_tests_properties(test_seq2seq PROPERTIES TIMEOUT 120)
set_tests_properties(test_cycle_gan PROPERTIES TIMEOUT 120) set_tests_properties(test_cycle_gan PROPERTIES TIMEOUT 120)
set_tests_properties(test_lac PROPERTIES TIMEOUT 120)
set_tests_properties(test_bert PROPERTIES TIMEOUT 120) set_tests_properties(test_bert PROPERTIES TIMEOUT 120)
set_tests_properties(test_basic_api_transformation PROPERTIES TIMEOUT 120) set_tests_properties(test_basic_api_transformation PROPERTIES TIMEOUT 120)
set_tests_properties(test_reinforcement_learning PROPERTIES TIMEOUT 120) set_tests_properties(test_reinforcement_learning PROPERTIES TIMEOUT 120)
......
...@@ -31,6 +31,12 @@ from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX ...@@ -31,6 +31,12 @@ from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
SEED = 2020 SEED = 2020
program_translator = ProgramTranslator() program_translator = ProgramTranslator()
# Add InputSpec to make unittest run faster.
input_specs = [
paddle.static.InputSpec([None, None], 'int64'),
paddle.static.InputSpec([None, None], 'int64'),
paddle.static.InputSpec([None], 'int64')
]
class DynamicGRU(fluid.dygraph.Layer): class DynamicGRU(fluid.dygraph.Layer):
...@@ -354,7 +360,7 @@ class LexNet(fluid.dygraph.Layer): ...@@ -354,7 +360,7 @@ class LexNet(fluid.dygraph.Layer):
# share weight # share weight
self.crf_decoding.weight = self.linear_chain_crf.weight self.crf_decoding.weight = self.linear_chain_crf.weight
@declarative @declarative(input_spec=input_specs)
def forward(self, word, target, length=None): def forward(self, word, target, length=None):
""" """
Configure the network Configure the network
...@@ -494,7 +500,7 @@ def do_train(args, to_static): ...@@ -494,7 +500,7 @@ def do_train(args, to_static):
fluid.dygraph.jit.save( fluid.dygraph.jit.save(
layer=model, layer=model,
path=args.model_save_prefix, path=args.model_save_prefix,
input_spec=[words, length], input_spec=[input_specs[0], input_specs[-1]],
output_spec=[crf_decode]) output_spec=[crf_decode])
else: else:
fluid.dygraph.save_dygraph(model.state_dict(), args.dy_param_path) fluid.dygraph.save_dygraph(model.state_dict(), args.dy_param_path)
......
...@@ -358,7 +358,8 @@ class TestResnet(unittest.TestCase): ...@@ -358,7 +358,8 @@ class TestResnet(unittest.TestCase):
def test_in_static_mode_mkldnn(self): def test_in_static_mode_mkldnn(self):
fluid.set_flags({'FLAGS_use_mkldnn': True}) fluid.set_flags({'FLAGS_use_mkldnn': True})
try: try:
train(to_static=True) if paddle.fluid.core.is_compiled_with_mkldnn():
train(to_static=True)
finally: finally:
fluid.set_flags({'FLAGS_use_mkldnn': False}) fluid.set_flags({'FLAGS_use_mkldnn': False})
......
...@@ -358,7 +358,8 @@ class TestResnet(unittest.TestCase): ...@@ -358,7 +358,8 @@ class TestResnet(unittest.TestCase):
def test_in_static_mode_mkldnn(self): def test_in_static_mode_mkldnn(self):
paddle.fluid.set_flags({'FLAGS_use_mkldnn': True}) paddle.fluid.set_flags({'FLAGS_use_mkldnn': True})
try: try:
train(to_static=True) if paddle.fluid.core.is_compiled_with_mkldnn():
train(to_static=True)
finally: finally:
paddle.fluid.set_flags({'FLAGS_use_mkldnn': False}) paddle.fluid.set_flags({'FLAGS_use_mkldnn': False})
......
...@@ -79,7 +79,6 @@ disable_wingpu_test="^test_model$|\ ...@@ -79,7 +79,6 @@ disable_wingpu_test="^test_model$|\
^test_fuse_bn_add_act_pass$|\ ^test_fuse_bn_add_act_pass$|\
^disable_wingpu_test$" ^disable_wingpu_test$"
# /*============================================================================*/ # /*============================================================================*/
# /*==================Fixed Disabled Windows CPU OPENBLAS unittests==============================*/ # /*==================Fixed Disabled Windows CPU OPENBLAS unittests==============================*/
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册