未验证 提交 0b6623d7 编写于 作者: R ronnywang 提交者: GitHub

[NPU] support gradient_accumulator (#35044)

上级 d53e567a
......@@ -36,6 +36,10 @@ if(WITH_GLOO)
endif()
endif()
if(NOT WITH_ASCEND_CL)
cc_library(gradient_accumulator SRCS gradient_accumulator.cc DEPS blas operator lod_tensor selected_rows selected_rows_functor var_type_traits layer math_function)
else()
cc_library(gradient_accumulator SRCS gradient_accumulator.cc DEPS blas operator lod_tensor selected_rows selected_rows_functor var_type_traits layer math_function npu_op_runner)
endif()
add_subdirectory(tests)
......@@ -31,6 +31,9 @@
#ifdef PADDLE_WITH_XPU
#include "xpu/refactor/math.h"
#endif
#ifdef PADDLE_WITH_ASCEND_CL
#include "paddle/fluid/operators/npu_op_runner.h"
#endif
namespace paddle {
namespace imperative {
......@@ -199,6 +202,30 @@ void TensorAdd(const framework::Variable& src, framework::Variable* dst) {
return; \
}
#ifdef PADDLE_WITH_ASCEND_CL
if (platform::is_npu_place(place)) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::DeviceContext* ctx = pool.Get(place);
auto dev_ctx = dynamic_cast<platform::NPUDeviceContext*>(ctx);
if (data_type == framework::DataTypeTrait<float>::DataType()) {
dst_tensor->mutable_data<float>(place);
} else if (data_type == framework::DataTypeTrait<double>::DataType()) {
dst_tensor->mutable_data<double>(place);
} else if (data_type ==
framework::DataTypeTrait<platform::float16>::DataType()) {
dst_tensor->mutable_data<platform::float16>(place);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Gradient accumulation of data type (%s) on place (%s) is not "
"supported in imperative mode",
framework::DataTypeToString(data_type), place));
}
const auto& runner = operators::NpuOpRunner(
"Add", {*dst_tensor, src_tensor}, {*dst_tensor}, {});
runner.Run(dev_ctx->stream());
return;
}
#endif
PADDLE_TENSOR_ADD(float);
#ifndef PADDLE_WITH_XPU
// NOTE(phlrain): xpu only support float
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册