diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 32ceb7d7903a3ce70c60ecf80725762a48030168..467f485269b23e272014db2eaacaa4c2ccccac6f 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -469,11 +469,22 @@ void Executor::RunPartialPreparedContext(ExecutorPrepareContext* ctx, #endif } else if (platform::is_npu_place(place_)) { #ifdef PADDLE_WITH_ASCEND_CL - // TODO(ascendrc): Support garbage collector on NPUPlace - VLOG(4) << "Skip NPU gc because it is not implemented now."; + if (IsFastEagerDeletionModeEnabled()) { + VLOG(4) << "Use unsafe fast gc for NPU."; + gc.reset(new NPUUnsafeFastGarbageCollector( + BOOST_GET_CONST(platform::NPUPlace, place_), max_memory_size)); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Please set FLAGS_fast_eager_deletion_mode=true to use " + "GarbageCollector on NPU.")); + // TODO(zhiqiu): fix bugs and enable NPUDefaultStreamGarbageCollector. + VLOG(4) << "Use default stream gc for NPU."; + gc.reset(new NPUDefaultStreamGarbageCollector( + BOOST_GET_CONST(platform::NPUPlace, place_), max_memory_size)); + } #else - PADDLE_THROW(platform::errors::Unimplemented( - "No NPU gc found in CPU/GPU/XPU paddle")); + PADDLE_THROW( + platform::errors::Unimplemented("No NPU gc found in CPU/NPU paddle")); #endif } } diff --git a/paddle/fluid/framework/garbage_collector.cc b/paddle/fluid/framework/garbage_collector.cc index a48589a82dd1664fd54f0d73cea311e7b95283b0..f44d3e3624bbb3e19e76d643062720675526b513 100644 --- a/paddle/fluid/framework/garbage_collector.cc +++ b/paddle/fluid/framework/garbage_collector.cc @@ -119,6 +119,32 @@ void CUDAPinnedGarbageCollector::ClearCallback( } #endif +#ifdef PADDLE_WITH_ASCEND_CL +NPUDefaultStreamGarbageCollector::NPUDefaultStreamGarbageCollector( + const platform::NPUPlace &place, size_t max_memory_size) + : GarbageCollector(place, max_memory_size) {} + +void NPUDefaultStreamGarbageCollector::Wait() const { + static_cast(this->dev_ctx_) + ->WaitStreamCallback(); +} + +void NPUDefaultStreamGarbageCollector::ClearCallback( + const std::function &callback) { + static_cast(this->dev_ctx_) + ->AddStreamCallback(callback); +} +NPUUnsafeFastGarbageCollector::NPUUnsafeFastGarbageCollector( + const platform::NPUPlace &place, size_t max_memory_size) + : GarbageCollector(place, max_memory_size) {} + +void NPUUnsafeFastGarbageCollector::ClearCallback( + const std::function &callback) { + callback(); +} + +#endif + int64_t GetEagerDeletionThreshold() { return FLAGS_eager_delete_tensor_gb < 0 ? -1 diff --git a/paddle/fluid/framework/garbage_collector.h b/paddle/fluid/framework/garbage_collector.h index eec8327c728a14ae202e17c6b6b3ae00bbb13e5e..690fa240ca692f39e77bbae41e3a7cf43655af62 100644 --- a/paddle/fluid/framework/garbage_collector.h +++ b/paddle/fluid/framework/garbage_collector.h @@ -131,6 +131,28 @@ class CUDAPinnedGarbageCollector : public GarbageCollector { }; #endif +#ifdef PADDLE_WITH_ASCEND_CL +class NPUDefaultStreamGarbageCollector : public GarbageCollector { + public: + NPUDefaultStreamGarbageCollector(const platform::NPUPlace &place, + size_t max_memory_size); + + void Wait() const override; + + protected: + void ClearCallback(const std::function &callback) override; +}; + +class NPUUnsafeFastGarbageCollector : public GarbageCollector { + public: + NPUUnsafeFastGarbageCollector(const platform::NPUPlace &place, + size_t max_memory_size); + + protected: + void ClearCallback(const std::function &callback) override; +}; +#endif + template void GarbageCollector::Add(Container &&objs) { Add(std::forward(objs), []() {}); diff --git a/paddle/fluid/operators/gather_op_npu.cc b/paddle/fluid/operators/gather_op_npu.cc index 8a487234ad94acd294193e26019e087dc3a7854c..1ee8889995f4d6045f237aa51e00faff7f67b2a3 100644 --- a/paddle/fluid/operators/gather_op_npu.cc +++ b/paddle/fluid/operators/gather_op_npu.cc @@ -50,6 +50,7 @@ class GatherGradOpNPUKernel : public framework::OpKernel { auto *x = ctx.Input("X"); auto *dout = ctx.Input(framework::GradVarName("Out")); auto *dx = ctx.Output(framework::GradVarName("X")); + dx->mutable_data(ctx.GetPlace()); // step1: Unsqueeze index framework::Tensor tmp_tensor(index->type()); @@ -66,7 +67,7 @@ class GatherGradOpNPUKernel : public framework::OpKernel { .stream(); // step2: ZerosLike x in device - Tensor zeroslike_xout(x->type()); + Tensor zeroslike_xout(dx->type()); zeroslike_xout.Resize(x->dims()); auto p = zeroslike_xout.mutable_data(ctx.GetPlace()); @@ -74,7 +75,6 @@ class GatherGradOpNPUKernel : public framework::OpKernel { zeroslike_xout.numel() * sizeof(T), stream); // step3: scatter(x_grad) - dx->mutable_data(ctx.GetPlace()); auto runner_scatter = NpuOpRunner( "TensorScatterUpdate", {zeroslike_xout, *index, *dout}, {*dx}, {}); runner_scatter.Run(stream); diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 187dd627e4a7c5205b69bb26e54a161244d5e4e6..bf34e57b773fa48b7df353c3035fe2597c93bc28 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -178,6 +178,13 @@ class NPUDeviceContext : public DeviceContext { /*! \brief Return npu stream in the device context. */ aclrtStream stream() const; + template + void AddStreamCallback(Callback&& callback) const { + return stream_->AddCallback(callback); + } + + void WaitStreamCallback() const { return stream_->WaitCallback(); } + private: NPUPlace place_; aclrtContext context_; diff --git a/paddle/fluid/platform/stream_callback_manager.cc b/paddle/fluid/platform/stream_callback_manager.cc index 76128e9a8f471c04fb4a770691387c844f857e48..45f49e1f896de095a0357874471baffbdd948244 100644 --- a/paddle/fluid/platform/stream_callback_manager.cc +++ b/paddle/fluid/platform/stream_callback_manager.cc @@ -29,7 +29,7 @@ static void CUDART_CB StreamCallbackFunc(cudaStream_t stream, #endif #if PADDLE_WITH_ASCEND_CL -static void StreamCallbackFunc(void *user_data) + static void StreamCallbackFunc(void *user_data) #endif { std::unique_ptr> func( @@ -42,7 +42,8 @@ StreamCallbackManager::StreamCallbackManager(const Stream stream) : stream_(stream), thread_pool_(1) {} template -void StreamCallbackManager::AddCallback(std::function callback) const { +void StreamCallbackManager::AddCallback( + std::function callback) const { auto *callback_func = new std::function(std::move(callback)); auto *func = new std::function([this, callback_func] { std::lock_guard lock(mtx_); @@ -62,6 +63,8 @@ void StreamCallbackManager::AddCallback(std::function callback) #endif #if PADDLE_WITH_ASCEND_CL + VLOG(3) << "aclrtLaunchCallback at stream: " << stream_; + // TODO(zhiqiu): failed to call aclrtLaunchCallback PADDLE_ENFORCE_NPU_SUCCESS(aclrtLaunchCallback(StreamCallbackFunc, func, ACL_CALLBACK_BLOCK, stream_)); #endif