未验证 提交 297fff1a 编写于 作者: H hong 提交者: GitHub

support dygraph in xpu place (#30051)

* support dygraph in xpu place; test=develop

* fix cpu/gpu compile error; test=develop

* fix compile error; test=develop

* fix xpu compile error; testd=develop
上级 eea7090c
......@@ -30,6 +30,9 @@
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/profiler.h"
#ifdef PADDLE_WITH_XPU
#include "xpu/refactor/math.h"
#endif
namespace paddle {
namespace imperative {
......@@ -81,12 +84,20 @@ class TensorAddFunctor : public boost::static_visitor<> {
blas.AXPY(numel_, 1., x_, y_);
}
#ifdef PADDLE_WITH_XPU
void operator()(const platform::XPUPlace& place) {
platform::XPUDeviceContext* ctx = dynamic_cast<platform::XPUDeviceContext*>(
platform::DeviceContextPool::Instance().Get(place));
xpu::add<T>(ctx->x_context(), x_, y_, y_, static_cast<int>(numel_));
}
#else
void operator()(const platform::XPUPlace& place) {
PADDLE_THROW(platform::errors::PermissionDenied(
"Gradient accumulation on place (%s) "
"is not supported in imperative mode",
place));
}
#endif
#ifdef PADDLE_WITH_CUDA
void operator()(const platform::CUDAPlace& place) {
......@@ -162,11 +173,14 @@ void TensorAdd(const framework::Variable& src, framework::Variable* dst) {
}
PADDLE_TENSOR_ADD(float);
#ifndef PADDLE_WITH_XPU
// NOTE(phlrain): xpu only support float
PADDLE_TENSOR_ADD(double);
// NOTE(chenweihang): only support complex grad tensor accumulated,
// support selected rows if needed in the future
PADDLE_TENSOR_ADD(platform::complex64);
PADDLE_TENSOR_ADD(platform::complex128);
#endif
#undef PADDLE_TENSOR_ADD
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册