From f13c3a9cd722ddd27f6c2da7669e82ce6a3585cd Mon Sep 17 00:00:00 2001 From: liuyuhui Date: Wed, 16 Dec 2020 19:07:41 +0800 Subject: [PATCH] [Kunlun] PR1:Support one Kunlun card training in parallel executor (#29337) --- .../details/broadcast_op_handle_test.h | 8 +- .../framework/details/execution_strategy.h | 7 +- .../fast_threaded_ssa_graph_executor.cc | 2 +- .../details/fused_broadcast_op_handle_test.cc | 7 +- .../details/gather_op_handle_test.cc | 5 +- .../fluid/framework/details/op_handle_base.cc | 7 +- .../fluid/framework/details/op_handle_base.h | 3 +- .../details/reduce_op_handle_test.cc | 8 +- .../details/scale_loss_grad_op_handle.cc | 16 ++- .../details/threaded_ssa_graph_executor.cc | 2 +- ...est_reference_count_pass_last_lived_ops.cc | 3 +- paddle/fluid/framework/parallel_executor.cc | 100 +++++++++++------- paddle/fluid/platform/device_context.cc | 40 ++++--- paddle/fluid/pybind/pybind.cc | 23 ++-- python/paddle/fluid/compiler.py | 36 +++++-- python/paddle/fluid/framework.py | 47 ++++++++ .../test_ir_memory_optimize_ifelse_op.py | 2 +- .../tests/unittests/xpu/test_xpu_place.py | 47 ++++++++ python/paddle/static/__init__.py | 3 +- tools/wlist.json | 3 +- 20 files changed, 282 insertions(+), 87 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_xpu_place.py diff --git a/paddle/fluid/framework/details/broadcast_op_handle_test.h b/paddle/fluid/framework/details/broadcast_op_handle_test.h index 4fdc420e1e0..8272af9c7d2 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle_test.h +++ b/paddle/fluid/framework/details/broadcast_op_handle_test.h @@ -33,6 +33,8 @@ struct VarHandle; namespace f = paddle::framework; namespace p = paddle::platform; +using UseDevice = paddle::framework::details::ExecutionStrategy::UseDevice; + // test data amount const f::DDim kDims = {20, 20}; @@ -273,7 +275,8 @@ struct TestBroadcastOpHandle { f::LoD lod{{0, 10, 20}}; auto send_vector = InitLoDTensor("input", input_scope_idx, lod); - op_handle_->Run(false); + UseDevice use_device = UseDevice::kCPU; + op_handle_->Run(use_device); WaitAll(); for (size_t j = 0; j < place_list_.size(); ++j) { @@ -287,7 +290,8 @@ struct TestBroadcastOpHandle { int height = static_cast(kDims[0] * 2); auto send_vector = InitSelectedRows("input", input_scope_idx, rows, height); - op_handle_->Run(false); + UseDevice use_device = UseDevice::kCPU; + op_handle_->Run(use_device); WaitAll(); for (size_t j = 0; j < place_list_.size(); ++j) { diff --git a/paddle/fluid/framework/details/execution_strategy.h b/paddle/fluid/framework/details/execution_strategy.h index a6936577c57..9d2341f134b 100644 --- a/paddle/fluid/framework/details/execution_strategy.h +++ b/paddle/fluid/framework/details/execution_strategy.h @@ -21,10 +21,15 @@ namespace details { struct ExecutionStrategy { enum ExecutorType { kDefault = 0, kExperimental = 1 }; + enum UseDevice { + kCPU = 0, + kCUDA = 1, + kXPU = 2, + }; // num_threads indicates the size of thread pool. size_t num_threads_{0}; - bool use_cuda_{true}; + UseDevice use_device_{kCUDA}; // Note that allow_op_delay is invalid now. bool allow_op_delay_{false}; // num_iteration_per_drop_scope indicates how many diff --git a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc index 18f2332b6ef..e13059e36d3 100644 --- a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc @@ -330,7 +330,7 @@ bool FastThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) { try { VLOG(10) << op << " " << op->Name() << " : " << op->DebugString(); if (LIKELY(!strategy_.dry_run_)) { - op->Run(strategy_.use_cuda_); + op->Run(strategy_.use_device_); } VLOG(10) << op << " " << op->Name() << " Done "; return true; diff --git a/paddle/fluid/framework/details/fused_broadcast_op_handle_test.cc b/paddle/fluid/framework/details/fused_broadcast_op_handle_test.cc index 8b1fb4c7996..600651dc162 100644 --- a/paddle/fluid/framework/details/fused_broadcast_op_handle_test.cc +++ b/paddle/fluid/framework/details/fused_broadcast_op_handle_test.cc @@ -32,6 +32,7 @@ namespace framework { namespace details { struct VarHandle; +using UseDevice = paddle::framework::details::ExecutionStrategy::UseDevice; struct TestFusedBroadcastOpHandle : TestBroadcastOpHandle { std::vector out_varnames_; @@ -108,7 +109,8 @@ struct TestFusedBroadcastOpHandle : TestBroadcastOpHandle { InitLoDTensor(varname, input_scope_idxes[i], lod, val_scalar)); } - op_handle_->Run(false); + UseDevice use_device = UseDevice::kCPU; + op_handle_->Run(use_device); WaitAll(); for (size_t i = 0; i < input_scope_idxes.size(); ++i) { @@ -131,7 +133,8 @@ struct TestFusedBroadcastOpHandle : TestBroadcastOpHandle { rows, height, val_scalar)); } - op_handle_->Run(false); + UseDevice use_device = UseDevice::kCPU; + op_handle_->Run(use_device); WaitAll(); for (size_t i = 0; i < input_scope_idxes.size(); ++i) { diff --git a/paddle/fluid/framework/details/gather_op_handle_test.cc b/paddle/fluid/framework/details/gather_op_handle_test.cc index 60c1d0d39a5..34d61c901db 100644 --- a/paddle/fluid/framework/details/gather_op_handle_test.cc +++ b/paddle/fluid/framework/details/gather_op_handle_test.cc @@ -27,6 +27,8 @@ struct DummyVarHandle; namespace f = paddle::framework; namespace p = paddle::platform; +using UseDevice = paddle::framework::details::ExecutionStrategy::UseDevice; + // test data amount const f::DDim kDims = {20, 20}; @@ -171,7 +173,8 @@ struct TestGatherOpHandle { out_selected_rows->mutable_value()->ShareDataWith( in_selected_rows->value()); - op_handle_->Run(false); + UseDevice use_device = UseDevice::kCPU; + op_handle_->Run(use_device); WaitAll(); diff --git a/paddle/fluid/framework/details/op_handle_base.cc b/paddle/fluid/framework/details/op_handle_base.cc index 22b7bd17fe4..859cd769caa 100644 --- a/paddle/fluid/framework/details/op_handle_base.cc +++ b/paddle/fluid/framework/details/op_handle_base.cc @@ -85,13 +85,14 @@ void OpHandleBase::InitCUDA() { #endif } -void OpHandleBase::Run(bool use_cuda) { +void OpHandleBase::Run(ExecutionStrategy::UseDevice use_device) { #ifdef PADDLE_WITH_CUDA - if (events_.empty() && use_cuda && dev_ctxes_.size() > 0) { + if (events_.empty() && use_device == ExecutionStrategy::UseDevice::kCUDA && + dev_ctxes_.size() > 0) { InitCUDA(); } #else - PADDLE_ENFORCE_EQ(use_cuda, false, + PADDLE_ENFORCE_NE(use_device, ExecutionStrategy::UseDevice::kCUDA, platform::errors::InvalidArgument( "Argument use_cuda should be false when Paddle is not " "compiled with CUDA.")); diff --git a/paddle/fluid/framework/details/op_handle_base.h b/paddle/fluid/framework/details/op_handle_base.h index 37e18adf9da..68c75c2d7ac 100644 --- a/paddle/fluid/framework/details/op_handle_base.h +++ b/paddle/fluid/framework/details/op_handle_base.h @@ -19,6 +19,7 @@ #include #include +#include "paddle/fluid/framework/details/execution_strategy.h" #include "paddle/fluid/framework/details/var_handle.h" #include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/platform/device_context.h" @@ -71,7 +72,7 @@ class OpHandleBase { virtual std::string Name() const = 0; - void Run(bool use_cuda); + void Run(ExecutionStrategy::UseDevice use_device); virtual void RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx); diff --git a/paddle/fluid/framework/details/reduce_op_handle_test.cc b/paddle/fluid/framework/details/reduce_op_handle_test.cc index ba03c3a267a..ae30474cfa0 100644 --- a/paddle/fluid/framework/details/reduce_op_handle_test.cc +++ b/paddle/fluid/framework/details/reduce_op_handle_test.cc @@ -25,6 +25,8 @@ namespace details { namespace f = paddle::framework; namespace p = paddle::platform; +using UseDevice = paddle::framework::details::ExecutionStrategy::UseDevice; + // test data amount const f::DDim kDims = {20, 20}; @@ -196,7 +198,8 @@ struct TestReduceOpHandle { out_selected_rows->mutable_value()->ShareDataWith( in_selected_rows->value()); - op_handle_->Run(false); + UseDevice use_device = UseDevice::kCPU; + op_handle_->Run(use_device); WaitAll(); @@ -260,7 +263,8 @@ struct TestReduceOpHandle { out_lodtensor->ShareDataWith(in_lodtensor); - op_handle_->Run(false); + UseDevice use_device = UseDevice::kCPU; + op_handle_->Run(use_device); WaitAll(); diff --git a/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc b/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc index 287667d5ee9..aa32a248e7f 100644 --- a/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc +++ b/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc @@ -58,6 +58,17 @@ struct ScaleLossGradFunctor { auto *out_data = out_->mutable_data(place_); if (platform::is_cpu_place(place_)) { *out_data = static_cast(coeff_); + } else if (platform::is_xpu_place(place_)) { +#if defined(PADDLE_WITH_XPU) + OutT cast_coeff = static_cast(coeff_); + memory::Copy(BOOST_GET_CONST(platform::XPUPlace, place_), out_data, + platform::CPUPlace(), &cast_coeff, SizeOfType(out_dtype_)); + VLOG(10) << place_ << "RUN Scale loss grad op"; +#else + PADDLE_THROW(platform::errors::PermissionDenied( + "Paddle can't use XPU device since it's not compiled with XPU," + "Please recompile or reinstall Paddle with XPU support.")); +#endif } else { #ifdef PADDLE_WITH_CUDA OutT cast_coeff = static_cast(coeff_); @@ -66,7 +77,10 @@ struct ScaleLossGradFunctor { platform::CPUPlace(), &cast_coeff, SizeOfType(out_dtype_), stream); VLOG(10) << place_ << "RUN Scale loss grad op"; - +#else + PADDLE_THROW(platform::errors::PermissionDenied( + "Paddle can't use CUDA device since it's not compiled with CUDA," + "Please recompile or reinstall Paddle with GPU support.")); #endif } } diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 2ed52b3bd94..08328e25fa9 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -348,7 +348,7 @@ bool ThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) { try { VLOG(10) << op << " " << op->Name() << " : " << op->DebugString(); if (LIKELY(!strategy_.dry_run_)) { - op->Run(strategy_.use_cuda_); + op->Run(strategy_.use_device_); } VLOG(10) << op << " " << op->Name() << " Done "; return true; diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/test_reference_count_pass_last_lived_ops.cc b/paddle/fluid/framework/ir/memory_optimize_pass/test_reference_count_pass_last_lived_ops.cc index 94274808524..4fb7f00d1bf 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/test_reference_count_pass_last_lived_ops.cc +++ b/paddle/fluid/framework/ir/memory_optimize_pass/test_reference_count_pass_last_lived_ops.cc @@ -88,7 +88,8 @@ class ReferenceCountPassTestHelper { FLAGS_eager_delete_tensor_gb = -1; details::ExecutionStrategy exec_strategy; - exec_strategy.use_cuda_ = use_cuda; + exec_strategy.use_device_ = + use_cuda ? (ExecutionStrategy::kCUDA) : (ExecutionStrategy::kCPU); executor_.reset(new ParallelExecutor(CreatePlaces(1, use_cuda), {}, "", &scope_, {}, exec_strategy, diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 579733c2a3a..3a621e64bff 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -63,6 +63,8 @@ static bool gProfileStarted = false; std::once_flag p2p_init_flag; #endif +using UseDevice = paddle::framework::details::ExecutionStrategy::UseDevice; + class ParallelExecutorPrivate { public: ParallelExecutorPrivate(const std::vector &places, @@ -93,6 +95,8 @@ class ParallelExecutorPrivate { } } + bool IsUseCUDA(UseDevice use_device); + void SetHasFeed(size_t dev_idx, bool has_feed = true); bool AllowPartialFeed() const; @@ -286,7 +290,7 @@ class ParallelExecutorPrivate { platform::NCCLCommunicator *nccl_ctxs_{nullptr}; #endif bool own_local_scope_; - bool use_cuda_; + UseDevice use_device_; bool use_all_reduce_; size_t nranks_; @@ -296,6 +300,10 @@ class ParallelExecutorPrivate { details::ParallelSSAGraphExecutor *inference_executor_{nullptr}; }; +bool ParallelExecutorPrivate::IsUseCUDA(UseDevice use_device) { + return use_device == UseDevice::kCUDA; +} + void ParallelExecutorPrivate::SetHasFeed(size_t dev_idx, bool has_feed) { if (inference_executor_) { inference_executor_->SetHasFeed(dev_idx, has_feed); @@ -340,7 +348,7 @@ ir::Graph *ParallelExecutorPrivate::ApplyMemoryOptimizePass(ir::Graph *graph) { auto addto_pass = ir::PassRegistry::Instance().Get("inplace_addto_op_pass"); addto_pass->SetNotOwned(ir::kMemOptVarInfoMapList, &mem_opt_var_infos_); addto_pass->SetNotOwned(ir::kLastLiveOpsOfVars, &last_live_ops_of_vars); - addto_pass->SetNotOwned(ir::kUseCuda, &use_cuda_); + addto_pass->Set(ir::kUseCuda, new bool(use_device_ == UseDevice::kCUDA)); VLOG(10) << "Start to apply inplace_addto_op_pass"; graph = addto_pass->Apply(graph); VLOG(10) << "inplace_addto_op_pass Applied"; @@ -351,7 +359,7 @@ ir::Graph *ParallelExecutorPrivate::ApplyMemoryOptimizePass(ir::Graph *graph) { ir::PassRegistry::Instance().Get("buffer_shared_inplace_pass"); inplace_pass->SetNotOwned(ir::kMemOptVarInfoMapList, &mem_opt_var_infos_); inplace_pass->SetNotOwned(ir::kLastLiveOpsOfVars, &last_live_ops_of_vars); - inplace_pass->SetNotOwned(ir::kUseCuda, &use_cuda_); + inplace_pass->Set(ir::kUseCuda, new bool(use_device_ == UseDevice::kCUDA)); VLOG(10) << "Start to apply buffer_shared_inplace_pass"; graph = inplace_pass->Apply(graph); VLOG(10) << "buffer_shared_inplace_pass Applied"; @@ -366,7 +374,8 @@ ir::Graph *ParallelExecutorPrivate::ApplyMemoryOptimizePass(ir::Graph *graph) { &mem_opt_var_infos_); cross_op_memory_reuse_pass->SetNotOwned(ir::kLastLiveOpsOfVars, &last_live_ops_of_vars); - cross_op_memory_reuse_pass->SetNotOwned(ir::kUseCuda, &use_cuda_); + cross_op_memory_reuse_pass->Set(ir::kUseCuda, + new bool(use_device_ == UseDevice::kCUDA)); VLOG(10) << "Start to apply buffer_shared_cross_op_memory_reuse_pass"; graph = cross_op_memory_reuse_pass->Apply(graph); VLOG(10) << "buffer_shared_cross_op_memory_reuse_pass Applied"; @@ -386,8 +395,8 @@ ir::Graph *ParallelExecutorPrivate::ApplyMemoryOptimizePass(ir::Graph *graph) { continue; } std::unique_ptr gc; -#ifdef PADDLE_WITH_CUDA if (platform::is_gpu_place(place)) { +#ifdef PADDLE_WITH_CUDA if (IsFastEagerDeletionModeEnabled()) { gc.reset(new UnsafeFastGPUGarbageCollector( BOOST_GET_CONST(platform::CUDAPlace, place), max_memory_size)); @@ -396,20 +405,29 @@ ir::Graph *ParallelExecutorPrivate::ApplyMemoryOptimizePass(ir::Graph *graph) { BOOST_GET_CONST(platform::CUDAPlace, place), max_memory_size)); } VLOG(10) << "Created " << i << "-th GarbageCollector at " << place; - } else { +#else + PADDLE_THROW(platform::errors::PermissionDenied( + "Paddle can't use CUDA device since it's not compiled with CUDA," + "Please recompile or reinstall Paddle with GPU support.")); #endif - if (platform::is_cpu_place(place)) { - gc.reset(new CPUGarbageCollector( - BOOST_GET_CONST(platform::CPUPlace, place), max_memory_size)); - VLOG(10) << "Created GarbageCollector at " << place; - } else { - PADDLE_THROW(platform::errors::PreconditionNotMet( - "Unsupported place for garbage collection")); - } -#ifdef PADDLE_WITH_CUDA - } + } else if (platform::is_xpu_place(place)) { +#if defined(PADDLE_WITH_XPU) + gc.reset(new XPUGarbageCollector( + BOOST_GET_CONST(platform::XPUPlace, place), max_memory_size)); + VLOG(10) << "Created " << i << "-th GarbageCollector at " << place; +#else + PADDLE_THROW(platform::errors::PermissionDenied( + "Paddle can't use XPU device since it's not compiled with XPU," + "Please recompile or reinstall Paddle with XPU support.")); #endif - + } else if (platform::is_cpu_place(place)) { + gc.reset(new CPUGarbageCollector( + BOOST_GET_CONST(platform::CPUPlace, place), max_memory_size)); + VLOG(10) << "Created GarbageCollector at " << place; + } else { + PADDLE_THROW(platform::errors::PreconditionNotMet( + "Unsupported place for garbage collection")); + } gcs_.emplace(place, std::move(gc)); } @@ -510,13 +528,10 @@ ParallelExecutor::ParallelExecutor(const std::vector &places, const BuildStrategy &build_strategy, ir::Graph *graph) : member_(new ParallelExecutorPrivate(places, scope)) { - PADDLE_ENFORCE(places.size() > 0 && !is_xpu_place(places[0]), - platform::errors::Unavailable( - "XPU is not supported in ParallelExecutor")); InitP2P(places); ir::InitReaderQueueDeviceCount(graph, *(member_->global_scope_), member_->places_.size()); - member_->use_cuda_ = exec_strategy.use_cuda_; + member_->use_device_ = exec_strategy.use_device_; member_->build_strategy_ = build_strategy; member_->use_all_reduce_ = member_->build_strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce; @@ -529,7 +544,7 @@ ParallelExecutor::ParallelExecutor(const std::vector &places, member_->use_all_reduce_ = true; } #if defined(PADDLE_WITH_CUDA) && defined(_WIN32) - if (member_->use_cuda_) { + if (member_->IsUseCUDA(member_->use_device_)) { PADDLE_ENFORCE_EQ( places.size(), 1, platform::errors::Unavailable("Windows can support Single GPU only.")); @@ -537,7 +552,7 @@ ParallelExecutor::ParallelExecutor(const std::vector &places, #endif #if defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_NCCL) - if (member_->use_cuda_) { + if (member_->IsUseCUDA(member_->use_device_)) { PADDLE_ENFORCE_EQ( places.size(), 1, platform::errors::PermissionDenied( @@ -548,10 +563,19 @@ ParallelExecutor::ParallelExecutor(const std::vector &places, } #endif + std::string device_name; + if (member_->use_device_ == UseDevice::kCPU) { + device_name = "CPU"; + } else if (member_->use_device_ == UseDevice::kCUDA) { + device_name = "CUDA"; + } else { + device_name = "XPU"; + } + VLOG(1) << string::Sprintf( "The Program will be executed on %s using ParallelExecutor, %lu " "cards are used, so %lu programs are executed in parallel.", - (member_->use_cuda_ ? "CUDA" : "CPU"), places.size(), places.size()); + device_name, places.size(), places.size()); // Step 1. Bcast the bcast_vars to devs. // Create local scopes @@ -575,7 +599,7 @@ ParallelExecutor::ParallelExecutor(const std::vector &places, std::vector graphs; if (member_->build_strategy_.async_mode_) { - PADDLE_ENFORCE_EQ(member_->use_cuda_, false, + PADDLE_ENFORCE_EQ(member_->IsUseCUDA(member_->use_device_), false, platform::errors::Unavailable( "gpu mode does not support async_mode_ now!")); graphs.push_back(graph); @@ -598,7 +622,7 @@ ParallelExecutor::ParallelExecutor(const std::vector &places, << "you can force it off by env FLAGS_enable_parallel_graph=0"; } - if (member_->use_cuda_ && member_->nranks_ > 1) { + if (member_->IsUseCUDA(member_->use_device_) && member_->nranks_ > 1) { #if defined(PADDLE_WITH_NCCL) member_->InitOrGetNCCLCommunicator(scope, &member_->build_strategy_); @@ -647,36 +671,39 @@ ParallelExecutor::ParallelExecutor(const std::vector &places, VLOG(3) << "use local async mode"; graph = member_->build_strategy_.Apply( graph, {member_->places_[0]}, loss_var_name, - {member_->local_scopes_[0]}, 1, member_->use_cuda_, - member_->nccl_ctxs_); + {member_->local_scopes_[0]}, 1, + member_->IsUseCUDA(member_->use_device_), member_->nccl_ctxs_); for (size_t i = 1; i < member_->places_.size(); ++i) { graphs[i] = member_->build_strategy_.Apply( graphs[i], {member_->places_[i]}, loss_var_name, - {member_->local_scopes_[i]}, 1, member_->use_cuda_, - member_->nccl_ctxs_); + {member_->local_scopes_[i]}, 1, + member_->IsUseCUDA(member_->use_device_), member_->nccl_ctxs_); async_graphs[i] = graphs[i]; } } else { graph = member_->build_strategy_.Apply( graph, member_->places_, loss_var_name, member_->local_scopes_, - member_->nranks_, member_->use_cuda_, member_->nccl_ctxs_); + member_->nranks_, member_->IsUseCUDA(member_->use_device_), + member_->nccl_ctxs_); } #else if (member_->build_strategy_.async_mode_) { VLOG(3) << "use local async mode"; graph = member_->build_strategy_.Apply( graph, {member_->places_[0]}, loss_var_name, - {member_->local_scopes_[0]}, 1, member_->use_cuda_); + {member_->local_scopes_[0]}, 1, + member_->IsUseCUDA(member_->use_device_)); for (size_t i = 1; i < member_->places_.size(); ++i) { graphs[i] = member_->build_strategy_.Apply( graphs[i], {member_->places_[i]}, loss_var_name, - {member_->local_scopes_[i]}, 1, member_->use_cuda_); + {member_->local_scopes_[i]}, 1, + member_->IsUseCUDA(member_->use_device_)); async_graphs[i] = graphs[i]; } } else { graph = member_->build_strategy_.Apply( graph, member_->places_, loss_var_name, member_->local_scopes_, - member_->nranks_, member_->use_cuda_); + member_->nranks_, member_->IsUseCUDA(member_->use_device_)); } #endif @@ -874,7 +901,8 @@ void ParallelExecutor::BCastParamsToDevices( // FIXME(zcd): LR_DECAY_COUNTER should not be shared. This is a hot fix. if (member_->build_strategy_.async_mode_) { share_memory(); - } else if (member_->use_all_reduce_ || member_->use_cuda_ || + } else if (member_->use_all_reduce_ || + member_->IsUseCUDA(member_->use_device_) || var == "@LR_DECAY_COUNTER@") { copy_memory(); } else { @@ -1105,7 +1133,7 @@ bool ParallelExecutor::EnableParallelGraphExecution( } } - if (!member_->use_all_reduce_ || !member_->use_cuda_) { + if (!member_->use_all_reduce_ || !member_->IsUseCUDA(member_->use_device_)) { if (build_strategy.enable_sequential_execution_ || exec_strategy.type_ == ExecutionStrategy::ExecutorType::kExperimental) { enable_parallel_graph = false; diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index beb1db93f48..61a60383b93 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -29,23 +29,39 @@ namespace memory { AllocationPtr Alloc(const platform::DeviceContext& dev_ctx, size_t size) { auto place = dev_ctx.GetPlace(); -#ifdef PADDLE_WITH_CUDA - if (size == 0 || !platform::is_gpu_place(place)) { + if (size == 0) { return Alloc(place, size); } - auto* default_dev_ctx = static_cast( - platform::DeviceContextPool::Instance().Get(place)); - auto& desired_dev_ctx = - static_cast(dev_ctx); - if (default_dev_ctx->stream() == desired_dev_ctx.stream()) { + + if (platform::is_gpu_place(place)) { +#ifdef PADDLE_WITH_CUDA + auto* default_dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(place)); + auto& desired_dev_ctx = + static_cast(dev_ctx); + if (default_dev_ctx->stream() == desired_dev_ctx.stream()) { + return Alloc(place, size); + } else { + return allocation::CUDADeviceContextAllocatorPool::Instance().Alloc( + desired_dev_ctx, size); + } +#else + PADDLE_THROW(platform::errors::PermissionDenied( + "Paddle can't use CUDA device since it's not compiled with CUDA," + "Please recompile or reinstall Paddle with GPU support.")); +#endif + } else if (platform::is_xpu_place(place)) { +#ifdef PADDLE_WITH_XPU + // TODO(liuyuhui): Consider xpu stream later return Alloc(place, size); - } else { - return allocation::CUDADeviceContextAllocatorPool::Instance().Alloc( - desired_dev_ctx, size); - } #else - return Alloc(place, size); + PADDLE_THROW(platform::errors::PermissionDenied( + "Paddle can't use XPU device since it's not compiled with XPU," + "Please recompile or reinstall Paddle with XPU support.")); #endif + } else { + return Alloc(place, size); + } } } // namespace memory diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 44b5614b9a1..5cefb26a4a3 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -1492,7 +1492,9 @@ All parameter, weight, gradient are variables in Paddle. #endif .def("__repr__", string::to_string) .def("__str__", string::to_string); - +#ifdef PADDLE_WITH_XPU + m.def("get_xpu_device_count", platform::GetXPUDeviceCount); +#endif py::class_(m, "CPUPlace", R"DOC( CPUPlace is a descriptor of a device. It represents a CPU device on which a tensor will be allocated and a model will run. @@ -2077,6 +2079,11 @@ All parameter, weight, gradient are variables in Paddle. exec_strategy=exec_strategy) )DOC"); + py::enum_(exec_strategy, "UseDevice") + .value("CPU", ExecutionStrategy::UseDevice::kCPU) + .value("CUDA", ExecutionStrategy::UseDevice::kCUDA) + .value("XPU", ExecutionStrategy::UseDevice::kXPU); + exec_strategy.def(py::init()) .def_property( "num_threads", @@ -2107,14 +2114,12 @@ All parameter, weight, gradient are variables in Paddle. exec_strategy.num_threads = 4 )DOC") .def_property( - "use_cuda", - [](const ExecutionStrategy &self) { return self.use_cuda_; }, - [](ExecutionStrategy &self, bool use_cuda) { - self.use_cuda_ = use_cuda; - }) // FIXME(chengduo): Doesn't add doc for 'use_cuda', use_cuda may - // make user confuse, because ParallelExecutor has a parameter named - // 'use_cuda' too, in current implementation, ParallelExecutor's - // 'use_cuda' will rewrite ExecutionStrategy's 'use_cuda'. + "_use_device", + [](const ExecutionStrategy &self) { return self.use_device_; }, + [](ExecutionStrategy &self, ExecutionStrategy::UseDevice use_device) { + self.use_device_ = use_device; + }) // NOTE(liuyuhui): Doesn't add doc for 'use_device', because + // use_device isn‘t exposed to users. .def_property( "allow_op_delay", [](const ExecutionStrategy &self) { return self.allow_op_delay_; }, diff --git a/python/paddle/fluid/compiler.py b/python/paddle/fluid/compiler.py index 0b980c7ebab..c47ad7b1087 100644 --- a/python/paddle/fluid/compiler.py +++ b/python/paddle/fluid/compiler.py @@ -18,7 +18,7 @@ import six import sys from .. import compat as cpt from . import framework -from .framework import cuda_places, cpu_places +from .framework import cuda_places, cpu_places, xpu_places from . import core @@ -316,7 +316,7 @@ class CompiledProgram(object): "Subclass of CompiledProgram should implement _with_distributed method." ) - def _compile_data_parallel(self, places, use_cuda=False, scope=None): + def _compile_data_parallel(self, places, use_device, scope=None): if self._share_vars_from: if scope: sys.stderr.write("share_vars_from is set, scope is ignored.\n") @@ -342,16 +342,23 @@ class CompiledProgram(object): if self._exec_strategy is None: self._exec_strategy = ExecutionStrategy() - self._exec_strategy.use_cuda = use_cuda + self._exec_strategy._use_device = use_device if self._exec_strategy.num_threads == 0: - if self._exec_strategy.use_cuda: + if self._exec_strategy._use_device == ExecutionStrategy.UseDevice.CUDA: # Experiments on se-resnext shows that too many threads hurt # performance. Worth tunning for other models in the future. self._exec_strategy.num_threads = len(places) * 4 + elif self._exec_strategy._use_device == ExecutionStrategy.UseDevice.XPU: + # Currently only single thread is supported in Kunlun XPU. + self._exec_strategy.num_threads = 1 else: self._exec_strategy.num_threads = len(places) * 2 + if self._exec_strategy._use_device == ExecutionStrategy.UseDevice.XPU: + assert self._exec_strategy.num_threads == 1, \ + "Currently only single thread is supported in Kunlun XPU." + if self._build_strategy.num_trainers > 1: assert self._is_data_parallel, \ "If you use multi-trainer to train the model, you should use "\ @@ -377,7 +384,7 @@ class CompiledProgram(object): self._build_strategy.enable_sequential_execution = True if self._program is not None and self._program._enable_dgc: - assert use_cuda, "DGC only used under CUDA environment." + assert self._exec_strategy._use_device == ExecutionStrategy.UseDevice.CUDA, "DGC only used under CUDA environment." assert self._build_strategy.num_trainers * len( places) > 1, "DGC is not avaliable for single card training." assert self._build_strategy.reduce_strategy == BuildStrategy.ReduceStrategy.AllReduce, "DGC \ @@ -447,11 +454,14 @@ class CompiledProgram(object): raise NotImplementedError( "If optimizer is used in control flow, " "training on multi-places is not supported now.") - + if isinstance(self._place, core.CUDAPlace): + use_device = ExecutionStrategy.UseDevice.CUDA + elif isinstance(self._place, core.XPUPlace): + use_device = ExecutionStrategy.UseDevice.XPU + else: + use_device = ExecutionStrategy.UseDevice.CPU self._executor = self._compile_data_parallel( - use_cuda=isinstance(self._place, core.CUDAPlace), - scope=self._scope, - places=self._places) + use_device=use_device, scope=self._scope, places=self._places) return self def _get_places(self, place, place_list): @@ -461,7 +471,11 @@ class CompiledProgram(object): assert p._type() == place._type(), \ "Place type not match. You may set wrong type of places." else: - place_list = cuda_places() if isinstance( - place, core.CUDAPlace) else cpu_places() + if isinstance(place, core.CUDAPlace): + place_list = cuda_places() + elif isinstance(place, core.XPUPlace): + place_list = xpu_places() + else: + place_list = cpu_places() assert place_list, "No places for execution." return place_list diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 6f1a5e61777..a0e650e4da3 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -47,6 +47,7 @@ __all__ = [ 'name_scope', 'cuda_places', 'cpu_places', + 'xpu_places', 'cuda_pinned_places', 'in_dygraph_mode', 'is_compiled_with_cuda', @@ -354,6 +355,15 @@ def _cuda_ids(): return device_ids +def _xpu_ids(): + xpus_env = os.getenv("FLAGS_selected_xpus") + if xpus_env: + device_ids = [int(s) for s in xpus_env.split(",")] + else: + device_ids = six.moves.range(core.get_xpu_device_count()) + return device_ids + + def is_compiled_with_xpu(): """ Whether this whl package can be used to run the model on XPU. @@ -430,6 +440,43 @@ def cuda_places(device_ids=None): return [core.CUDAPlace(dev_id) for dev_id in device_ids] +def xpu_places(device_ids=None): + """ + **Note**: + For multi-card tasks, please use `FLAGS_selected_xpus` environment variable to set the visible XPU device. + This function creates a list of :code:`paddle.XPUPlace` objects. + If :code:`device_ids` is None, environment variable of + :code:`FLAGS_selected_xpus` would be checked first. For example, if + :code:`FLAGS_selected_xpus=0,1,2`, the returned list would + be [paddle.XPUPlace(0), paddle.XPUPlace(1), paddle.XPUPlace(2)]. + If :code:`FLAGS_selected_xpus` is not set, all visible + xpu places would be returned. + If :code:`device_ids` is not None, it should be the device + ids of XPUs. For example, if :code:`device_ids=[0,1,2]`, + the returned list would be + [paddle.XPUPlace(0), paddle.XPUPlace(1), paddle.XPUPlace(2)]. + + Parameters: + device_ids (list or tuple of int, optional): list of XPU device ids. + Returns: + list of paddle.XPUPlace: Created XPU place list. + Examples: + .. code-block:: python + import paddle + import paddle.static as static + + paddle.enable_static() + xpu_places = static.xpu_places() + """ + assert core.is_compiled_with_xpu(), \ + "Not compiled with XPU" + if device_ids is None: + device_ids = _xpu_ids() + elif not isinstance(device_ids, (list, tuple)): + device_ids = [device_ids] + return [core.XPUPlace(dev_id) for dev_id in device_ids] + + def cpu_places(device_count=None): """ This function creates a list of :code:`paddle.CPUPlace` objects, and returns the created list. diff --git a/python/paddle/fluid/tests/unittests/test_ir_memory_optimize_ifelse_op.py b/python/paddle/fluid/tests/unittests/test_ir_memory_optimize_ifelse_op.py index 0ace288d9d4..a4e234a5134 100644 --- a/python/paddle/fluid/tests/unittests/test_ir_memory_optimize_ifelse_op.py +++ b/python/paddle/fluid/tests/unittests/test_ir_memory_optimize_ifelse_op.py @@ -75,7 +75,7 @@ class TestIrMemoryOptimizeIfElseOp(unittest.TestCase): exe = Executor(place) exec_strategy = fluid.ExecutionStrategy() - exec_strategy.use_cuda = use_cuda + exec_strategy._use_device = fluid.ExecutionStrategy.UseDevice.CUDA if use_cuda else fluid.ExecutionStrategy.UseDevice.CPU build_strategy = fluid.BuildStrategy() build_strategy.memory_optimize = use_mem_opt diff --git a/python/paddle/fluid/tests/unittests/xpu/test_xpu_place.py b/python/paddle/fluid/tests/unittests/xpu/test_xpu_place.py new file mode 100644 index 00000000000..57d456d0193 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_xpu_place.py @@ -0,0 +1,47 @@ +# copyright (c) 2020 paddlepaddle authors. all rights reserved. +# +# licensed under the apache license, version 2.0 (the "license"); +# you may not use this file except in compliance with the license. +# you may obtain a copy of the license at +# +# http://www.apache.org/licenses/license-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the license is distributed on an "as is" basis, +# without warranties or conditions of any kind, either express or implied. +# see the license for the specific language governing permissions and +# limitations under the license. + +from __future__ import print_function + +import unittest +import os +import paddle +import numpy as np +import paddle.fluid as fluid +from paddle.fluid import core +import paddle.static as static + + +class Test_XPU_Places(unittest.TestCase): + def assert_places_equal(self, places0, places1): + self.assertEqual(len(places0), len(places1)) + for place0, place1 in zip(places0, places1): + self.assertEqual(type(place0), type(place1)) + self.assertEqual(place0.get_device_id(), place1.get_device_id()) + + def test_check_preset_envs(self): + if core.is_compiled_with_xpu(): + os.environ["FLAGS_selected_xpus"] = "0" + place_list = static.xpu_places() + self.assert_places_equal([fluid.XPUPlace(0)], place_list) + + def test_check_no_preset_envs(self): + if core.is_compiled_with_xpu(): + place_list = static.xpu_places(0) + self.assert_places_equal([fluid.XPUPlace(0)], place_list) + + +if __name__ == '__main__': + paddle.enable_static() + unittest.main() diff --git a/python/paddle/static/__init__.py b/python/paddle/static/__init__.py index 9c911e722db..e37a6162af3 100644 --- a/python/paddle/static/__init__.py +++ b/python/paddle/static/__init__.py @@ -20,7 +20,7 @@ __all__ = [ 'default_main_program', 'default_startup_program', 'Program', 'data', 'InputSpec', 'save', 'load', 'save_inference_model', 'load_inference_model', 'load_program_state', 'set_program_state', 'cpu_places', 'cuda_places', - 'Variable' + 'xpu_places', 'Variable' ] from . import nn @@ -45,6 +45,7 @@ from ..fluid.framework import name_scope #DEFINE_ALIAS from ..fluid.framework import program_guard #DEFINE_ALIAS from ..fluid.framework import cpu_places #DEFINE_ALIAS from ..fluid.framework import cuda_places #DEFINE_ALIAS +from ..fluid.framework import xpu_places #DEFINE_ALIAS from ..fluid.framework import Variable #DEFINE_ALIAS from ..fluid.layers.control_flow import Print #DEFINE_ALIAS from ..fluid.layers.nn import py_func #DEFINE_ALIAS diff --git a/tools/wlist.json b/tools/wlist.json index a51ac905e66..f907d609898 100644 --- a/tools/wlist.json +++ b/tools/wlist.json @@ -413,7 +413,8 @@ "CRFDecoding.forward", "SequenceTagging.forward", "XPUPlace", - "is_compiled_with_xpu" + "is_compiled_with_xpu", + "xpu_places" ], "gpu_not_white":[ "deformable_conv", -- GitLab