未验证 提交 285f33e5 编写于 作者: H hong 提交者: GitHub

support dygraph in xpu place (#30051) (#30112)

* support dygraph in xpu place; test=develop

* fix cpu/gpu compile error; test=develop

* fix compile error; test=develop

* fix xpu compile error; testd=develop
上级 19bec2fe
...@@ -30,6 +30,9 @@ ...@@ -30,6 +30,9 @@
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
#ifdef PADDLE_WITH_XPU
#include "xpu/refactor/math.h"
#endif
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
...@@ -81,12 +84,20 @@ class TensorAddFunctor : public boost::static_visitor<> { ...@@ -81,12 +84,20 @@ class TensorAddFunctor : public boost::static_visitor<> {
blas.AXPY(numel_, 1., x_, y_); blas.AXPY(numel_, 1., x_, y_);
} }
#ifdef PADDLE_WITH_XPU
void operator()(const platform::XPUPlace& place) {
platform::XPUDeviceContext* ctx = dynamic_cast<platform::XPUDeviceContext*>(
platform::DeviceContextPool::Instance().Get(place));
xpu::add<T>(ctx->x_context(), x_, y_, y_, static_cast<int>(numel_));
}
#else
void operator()(const platform::XPUPlace& place) { void operator()(const platform::XPUPlace& place) {
PADDLE_THROW(platform::errors::PermissionDenied( PADDLE_THROW(platform::errors::PermissionDenied(
"Gradient accumulation on place (%s) " "Gradient accumulation on place (%s) "
"is not supported in imperative mode", "is not supported in imperative mode",
place)); place));
} }
#endif
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
void operator()(const platform::CUDAPlace& place) { void operator()(const platform::CUDAPlace& place) {
...@@ -162,11 +173,14 @@ void TensorAdd(const framework::Variable& src, framework::Variable* dst) { ...@@ -162,11 +173,14 @@ void TensorAdd(const framework::Variable& src, framework::Variable* dst) {
} }
PADDLE_TENSOR_ADD(float); PADDLE_TENSOR_ADD(float);
#ifndef PADDLE_WITH_XPU
// NOTE(phlrain): xpu only support float
PADDLE_TENSOR_ADD(double); PADDLE_TENSOR_ADD(double);
// NOTE(chenweihang): only support complex grad tensor accumulated, // NOTE(chenweihang): only support complex grad tensor accumulated,
// support selected rows if needed in the future // support selected rows if needed in the future
PADDLE_TENSOR_ADD(platform::complex64); PADDLE_TENSOR_ADD(platform::complex64);
PADDLE_TENSOR_ADD(platform::complex128); PADDLE_TENSOR_ADD(platform::complex128);
#endif
#undef PADDLE_TENSOR_ADD #undef PADDLE_TENSOR_ADD
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册