未验证 提交 ac89174e 编写于 作者: L Leo Chen 提交者: GitHub

[NPU] support GarbageCollector for npu (#31874)

* support GarbageCollector for npu

* fix typo

* fix gather_grad

* disable NPUDefaultStreamGarbageCollector on NPU
上级 3c66b872
...@@ -469,11 +469,22 @@ void Executor::RunPartialPreparedContext(ExecutorPrepareContext* ctx, ...@@ -469,11 +469,22 @@ void Executor::RunPartialPreparedContext(ExecutorPrepareContext* ctx,
#endif #endif
} else if (platform::is_npu_place(place_)) { } else if (platform::is_npu_place(place_)) {
#ifdef PADDLE_WITH_ASCEND_CL #ifdef PADDLE_WITH_ASCEND_CL
// TODO(ascendrc): Support garbage collector on NPUPlace if (IsFastEagerDeletionModeEnabled()) {
VLOG(4) << "Skip NPU gc because it is not implemented now."; VLOG(4) << "Use unsafe fast gc for NPU.";
#else gc.reset(new NPUUnsafeFastGarbageCollector(
BOOST_GET_CONST(platform::NPUPlace, place_), max_memory_size));
} else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"No NPU gc found in CPU/GPU/XPU paddle")); "Please set FLAGS_fast_eager_deletion_mode=true to use "
"GarbageCollector on NPU."));
// TODO(zhiqiu): fix bugs and enable NPUDefaultStreamGarbageCollector.
VLOG(4) << "Use default stream gc for NPU.";
gc.reset(new NPUDefaultStreamGarbageCollector(
BOOST_GET_CONST(platform::NPUPlace, place_), max_memory_size));
}
#else
PADDLE_THROW(
platform::errors::Unimplemented("No NPU gc found in CPU/NPU paddle"));
#endif #endif
} }
} }
......
...@@ -119,6 +119,32 @@ void CUDAPinnedGarbageCollector::ClearCallback( ...@@ -119,6 +119,32 @@ void CUDAPinnedGarbageCollector::ClearCallback(
} }
#endif #endif
#ifdef PADDLE_WITH_ASCEND_CL
NPUDefaultStreamGarbageCollector::NPUDefaultStreamGarbageCollector(
const platform::NPUPlace &place, size_t max_memory_size)
: GarbageCollector(place, max_memory_size) {}
void NPUDefaultStreamGarbageCollector::Wait() const {
static_cast<platform::NPUDeviceContext *>(this->dev_ctx_)
->WaitStreamCallback();
}
void NPUDefaultStreamGarbageCollector::ClearCallback(
const std::function<void()> &callback) {
static_cast<platform::NPUDeviceContext *>(this->dev_ctx_)
->AddStreamCallback(callback);
}
NPUUnsafeFastGarbageCollector::NPUUnsafeFastGarbageCollector(
const platform::NPUPlace &place, size_t max_memory_size)
: GarbageCollector(place, max_memory_size) {}
void NPUUnsafeFastGarbageCollector::ClearCallback(
const std::function<void()> &callback) {
callback();
}
#endif
int64_t GetEagerDeletionThreshold() { int64_t GetEagerDeletionThreshold() {
return FLAGS_eager_delete_tensor_gb < 0 return FLAGS_eager_delete_tensor_gb < 0
? -1 ? -1
......
...@@ -131,6 +131,28 @@ class CUDAPinnedGarbageCollector : public GarbageCollector { ...@@ -131,6 +131,28 @@ class CUDAPinnedGarbageCollector : public GarbageCollector {
}; };
#endif #endif
#ifdef PADDLE_WITH_ASCEND_CL
class NPUDefaultStreamGarbageCollector : public GarbageCollector {
public:
NPUDefaultStreamGarbageCollector(const platform::NPUPlace &place,
size_t max_memory_size);
void Wait() const override;
protected:
void ClearCallback(const std::function<void()> &callback) override;
};
class NPUUnsafeFastGarbageCollector : public GarbageCollector {
public:
NPUUnsafeFastGarbageCollector(const platform::NPUPlace &place,
size_t max_memory_size);
protected:
void ClearCallback(const std::function<void()> &callback) override;
};
#endif
template <typename Container> template <typename Container>
void GarbageCollector::Add(Container &&objs) { void GarbageCollector::Add(Container &&objs) {
Add(std::forward<Container>(objs), []() {}); Add(std::forward<Container>(objs), []() {});
......
...@@ -50,6 +50,7 @@ class GatherGradOpNPUKernel : public framework::OpKernel<T> { ...@@ -50,6 +50,7 @@ class GatherGradOpNPUKernel : public framework::OpKernel<T> {
auto *x = ctx.Input<Tensor>("X"); auto *x = ctx.Input<Tensor>("X");
auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out")); auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
dx->mutable_data<T>(ctx.GetPlace());
// step1: Unsqueeze index // step1: Unsqueeze index
framework::Tensor tmp_tensor(index->type()); framework::Tensor tmp_tensor(index->type());
...@@ -66,7 +67,7 @@ class GatherGradOpNPUKernel : public framework::OpKernel<T> { ...@@ -66,7 +67,7 @@ class GatherGradOpNPUKernel : public framework::OpKernel<T> {
.stream(); .stream();
// step2: ZerosLike x in device // step2: ZerosLike x in device
Tensor zeroslike_xout(x->type()); Tensor zeroslike_xout(dx->type());
zeroslike_xout.Resize(x->dims()); zeroslike_xout.Resize(x->dims());
auto p = zeroslike_xout.mutable_data<T>(ctx.GetPlace()); auto p = zeroslike_xout.mutable_data<T>(ctx.GetPlace());
...@@ -74,7 +75,6 @@ class GatherGradOpNPUKernel : public framework::OpKernel<T> { ...@@ -74,7 +75,6 @@ class GatherGradOpNPUKernel : public framework::OpKernel<T> {
zeroslike_xout.numel() * sizeof(T), stream); zeroslike_xout.numel() * sizeof(T), stream);
// step3: scatter(x_grad) // step3: scatter(x_grad)
dx->mutable_data<T>(ctx.GetPlace());
auto runner_scatter = NpuOpRunner( auto runner_scatter = NpuOpRunner(
"TensorScatterUpdate", {zeroslike_xout, *index, *dout}, {*dx}, {}); "TensorScatterUpdate", {zeroslike_xout, *index, *dout}, {*dx}, {});
runner_scatter.Run(stream); runner_scatter.Run(stream);
......
...@@ -178,6 +178,13 @@ class NPUDeviceContext : public DeviceContext { ...@@ -178,6 +178,13 @@ class NPUDeviceContext : public DeviceContext {
/*! \brief Return npu stream in the device context. */ /*! \brief Return npu stream in the device context. */
aclrtStream stream() const; aclrtStream stream() const;
template <typename Callback>
void AddStreamCallback(Callback&& callback) const {
return stream_->AddCallback(callback);
}
void WaitStreamCallback() const { return stream_->WaitCallback(); }
private: private:
NPUPlace place_; NPUPlace place_;
aclrtContext context_; aclrtContext context_;
......
...@@ -29,7 +29,7 @@ static void CUDART_CB StreamCallbackFunc(cudaStream_t stream, ...@@ -29,7 +29,7 @@ static void CUDART_CB StreamCallbackFunc(cudaStream_t stream,
#endif #endif
#if PADDLE_WITH_ASCEND_CL #if PADDLE_WITH_ASCEND_CL
static void StreamCallbackFunc(void *user_data) static void StreamCallbackFunc(void *user_data)
#endif #endif
{ {
std::unique_ptr<std::function<void()>> func( std::unique_ptr<std::function<void()>> func(
...@@ -42,7 +42,8 @@ StreamCallbackManager<Stream>::StreamCallbackManager(const Stream stream) ...@@ -42,7 +42,8 @@ StreamCallbackManager<Stream>::StreamCallbackManager(const Stream stream)
: stream_(stream), thread_pool_(1) {} : stream_(stream), thread_pool_(1) {}
template <typename Stream> template <typename Stream>
void StreamCallbackManager<Stream>::AddCallback(std::function<void()> callback) const { void StreamCallbackManager<Stream>::AddCallback(
std::function<void()> callback) const {
auto *callback_func = new std::function<void()>(std::move(callback)); auto *callback_func = new std::function<void()>(std::move(callback));
auto *func = new std::function<void()>([this, callback_func] { auto *func = new std::function<void()>([this, callback_func] {
std::lock_guard<std::mutex> lock(mtx_); std::lock_guard<std::mutex> lock(mtx_);
...@@ -62,6 +63,8 @@ void StreamCallbackManager<Stream>::AddCallback(std::function<void()> callback) ...@@ -62,6 +63,8 @@ void StreamCallbackManager<Stream>::AddCallback(std::function<void()> callback)
#endif #endif
#if PADDLE_WITH_ASCEND_CL #if PADDLE_WITH_ASCEND_CL
VLOG(3) << "aclrtLaunchCallback at stream: " << stream_;
// TODO(zhiqiu): failed to call aclrtLaunchCallback
PADDLE_ENFORCE_NPU_SUCCESS(aclrtLaunchCallback(StreamCallbackFunc, func, PADDLE_ENFORCE_NPU_SUCCESS(aclrtLaunchCallback(StreamCallbackFunc, func,
ACL_CALLBACK_BLOCK, stream_)); ACL_CALLBACK_BLOCK, stream_));
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册