未验证 提交 ac89174e 编写于 作者: L Leo Chen 提交者: GitHub

[NPU] support GarbageCollector for npu (#31874)

* support GarbageCollector for npu

* fix typo

* fix gather_grad

* disable NPUDefaultStreamGarbageCollector on NPU
上级 3c66b872
......@@ -469,11 +469,22 @@ void Executor::RunPartialPreparedContext(ExecutorPrepareContext* ctx,
#endif
} else if (platform::is_npu_place(place_)) {
#ifdef PADDLE_WITH_ASCEND_CL
// TODO(ascendrc): Support garbage collector on NPUPlace
VLOG(4) << "Skip NPU gc because it is not implemented now.";
#else
if (IsFastEagerDeletionModeEnabled()) {
VLOG(4) << "Use unsafe fast gc for NPU.";
gc.reset(new NPUUnsafeFastGarbageCollector(
BOOST_GET_CONST(platform::NPUPlace, place_), max_memory_size));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"No NPU gc found in CPU/GPU/XPU paddle"));
"Please set FLAGS_fast_eager_deletion_mode=true to use "
"GarbageCollector on NPU."));
// TODO(zhiqiu): fix bugs and enable NPUDefaultStreamGarbageCollector.
VLOG(4) << "Use default stream gc for NPU.";
gc.reset(new NPUDefaultStreamGarbageCollector(
BOOST_GET_CONST(platform::NPUPlace, place_), max_memory_size));
}
#else
PADDLE_THROW(
platform::errors::Unimplemented("No NPU gc found in CPU/NPU paddle"));
#endif
}
}
......
......@@ -119,6 +119,32 @@ void CUDAPinnedGarbageCollector::ClearCallback(
}
#endif
#ifdef PADDLE_WITH_ASCEND_CL
NPUDefaultStreamGarbageCollector::NPUDefaultStreamGarbageCollector(
const platform::NPUPlace &place, size_t max_memory_size)
: GarbageCollector(place, max_memory_size) {}
void NPUDefaultStreamGarbageCollector::Wait() const {
static_cast<platform::NPUDeviceContext *>(this->dev_ctx_)
->WaitStreamCallback();
}
void NPUDefaultStreamGarbageCollector::ClearCallback(
const std::function<void()> &callback) {
static_cast<platform::NPUDeviceContext *>(this->dev_ctx_)
->AddStreamCallback(callback);
}
NPUUnsafeFastGarbageCollector::NPUUnsafeFastGarbageCollector(
const platform::NPUPlace &place, size_t max_memory_size)
: GarbageCollector(place, max_memory_size) {}
void NPUUnsafeFastGarbageCollector::ClearCallback(
const std::function<void()> &callback) {
callback();
}
#endif
int64_t GetEagerDeletionThreshold() {
return FLAGS_eager_delete_tensor_gb < 0
? -1
......
......@@ -131,6 +131,28 @@ class CUDAPinnedGarbageCollector : public GarbageCollector {
};
#endif
#ifdef PADDLE_WITH_ASCEND_CL
class NPUDefaultStreamGarbageCollector : public GarbageCollector {
public:
NPUDefaultStreamGarbageCollector(const platform::NPUPlace &place,
size_t max_memory_size);
void Wait() const override;
protected:
void ClearCallback(const std::function<void()> &callback) override;
};
class NPUUnsafeFastGarbageCollector : public GarbageCollector {
public:
NPUUnsafeFastGarbageCollector(const platform::NPUPlace &place,
size_t max_memory_size);
protected:
void ClearCallback(const std::function<void()> &callback) override;
};
#endif
template <typename Container>
void GarbageCollector::Add(Container &&objs) {
Add(std::forward<Container>(objs), []() {});
......
......@@ -50,6 +50,7 @@ class GatherGradOpNPUKernel : public framework::OpKernel<T> {
auto *x = ctx.Input<Tensor>("X");
auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
dx->mutable_data<T>(ctx.GetPlace());
// step1: Unsqueeze index
framework::Tensor tmp_tensor(index->type());
......@@ -66,7 +67,7 @@ class GatherGradOpNPUKernel : public framework::OpKernel<T> {
.stream();
// step2: ZerosLike x in device
Tensor zeroslike_xout(x->type());
Tensor zeroslike_xout(dx->type());
zeroslike_xout.Resize(x->dims());
auto p = zeroslike_xout.mutable_data<T>(ctx.GetPlace());
......@@ -74,7 +75,6 @@ class GatherGradOpNPUKernel : public framework::OpKernel<T> {
zeroslike_xout.numel() * sizeof(T), stream);
// step3: scatter(x_grad)
dx->mutable_data<T>(ctx.GetPlace());
auto runner_scatter = NpuOpRunner(
"TensorScatterUpdate", {zeroslike_xout, *index, *dout}, {*dx}, {});
runner_scatter.Run(stream);
......
......@@ -178,6 +178,13 @@ class NPUDeviceContext : public DeviceContext {
/*! \brief Return npu stream in the device context. */
aclrtStream stream() const;
template <typename Callback>
void AddStreamCallback(Callback&& callback) const {
return stream_->AddCallback(callback);
}
void WaitStreamCallback() const { return stream_->WaitCallback(); }
private:
NPUPlace place_;
aclrtContext context_;
......
......@@ -29,7 +29,7 @@ static void CUDART_CB StreamCallbackFunc(cudaStream_t stream,
#endif
#if PADDLE_WITH_ASCEND_CL
static void StreamCallbackFunc(void *user_data)
static void StreamCallbackFunc(void *user_data)
#endif
{
std::unique_ptr<std::function<void()>> func(
......@@ -42,7 +42,8 @@ StreamCallbackManager<Stream>::StreamCallbackManager(const Stream stream)
: stream_(stream), thread_pool_(1) {}
template <typename Stream>
void StreamCallbackManager<Stream>::AddCallback(std::function<void()> callback) const {
void StreamCallbackManager<Stream>::AddCallback(
std::function<void()> callback) const {
auto *callback_func = new std::function<void()>(std::move(callback));
auto *func = new std::function<void()>([this, callback_func] {
std::lock_guard<std::mutex> lock(mtx_);
......@@ -62,6 +63,8 @@ void StreamCallbackManager<Stream>::AddCallback(std::function<void()> callback)
#endif
#if PADDLE_WITH_ASCEND_CL
VLOG(3) << "aclrtLaunchCallback at stream: " << stream_;
// TODO(zhiqiu): failed to call aclrtLaunchCallback
PADDLE_ENFORCE_NPU_SUCCESS(aclrtLaunchCallback(StreamCallbackFunc, func,
ACL_CALLBACK_BLOCK, stream_));
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册