From 285f33e5607fa8583d8b74b1f20082c46026f9af Mon Sep 17 00:00:00 2001 From: hong <43953930+phlrain@users.noreply.github.com> Date: Wed, 6 Jan 2021 16:40:16 +0800 Subject: [PATCH] support dygraph in xpu place (#30051) (#30112) * support dygraph in xpu place; test=develop * fix cpu/gpu compile error; test=develop * fix compile error; test=develop * fix xpu compile error; testd=develop --- paddle/fluid/imperative/gradient_accumulator.cc | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/paddle/fluid/imperative/gradient_accumulator.cc b/paddle/fluid/imperative/gradient_accumulator.cc index bc38e3b59b..ff8494a388 100644 --- a/paddle/fluid/imperative/gradient_accumulator.cc +++ b/paddle/fluid/imperative/gradient_accumulator.cc @@ -30,6 +30,9 @@ #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/profiler.h" +#ifdef PADDLE_WITH_XPU +#include "xpu/refactor/math.h" +#endif namespace paddle { namespace imperative { @@ -81,12 +84,20 @@ class TensorAddFunctor : public boost::static_visitor<> { blas.AXPY(numel_, 1., x_, y_); } +#ifdef PADDLE_WITH_XPU + void operator()(const platform::XPUPlace& place) { + platform::XPUDeviceContext* ctx = dynamic_cast( + platform::DeviceContextPool::Instance().Get(place)); + xpu::add(ctx->x_context(), x_, y_, y_, static_cast(numel_)); + } +#else void operator()(const platform::XPUPlace& place) { PADDLE_THROW(platform::errors::PermissionDenied( "Gradient accumulation on place (%s) " "is not supported in imperative mode", place)); } +#endif #ifdef PADDLE_WITH_CUDA void operator()(const platform::CUDAPlace& place) { @@ -162,11 +173,14 @@ void TensorAdd(const framework::Variable& src, framework::Variable* dst) { } PADDLE_TENSOR_ADD(float); +#ifndef PADDLE_WITH_XPU + // NOTE(phlrain): xpu only support float PADDLE_TENSOR_ADD(double); // NOTE(chenweihang): only support complex grad tensor accumulated, // support selected rows if needed in the future PADDLE_TENSOR_ADD(platform::complex64); PADDLE_TENSOR_ADD(platform::complex128); +#endif #undef PADDLE_TENSOR_ADD -- GitLab