diff --git a/paddle/fluid/imperative/gradient_accumulator.cc b/paddle/fluid/imperative/gradient_accumulator.cc index e6e156fa61c1435f375841b7d718e16afe4f470a..199359a960326378f9641c6a5de4c6d3ccfd5303 100644 --- a/paddle/fluid/imperative/gradient_accumulator.cc +++ b/paddle/fluid/imperative/gradient_accumulator.cc @@ -42,6 +42,7 @@ #ifdef PADDLE_WITH_CUSTOM_DEVICE #include "paddle/phi/backends/device_manager.h" #endif +#include "paddle/phi/kernels/elementwise_add_kernel.h" namespace paddle { namespace imperative { @@ -81,137 +82,6 @@ static void MoveOrCopyVar(framework::Variable* dst, } } -template -class TensorAddFunctor - : public std::unary_function { - public: - TensorAddFunctor(int64_t numel, const T* x, T* y) - : numel_(numel), x_(x), y_(y) {} - - void operator()(const platform::CPUPlace& place) const { - phi::CPUContext* ctx = dynamic_cast( - platform::DeviceContextPool::Instance().Get(place)); - auto blas = phi::funcs::GetBlas(*ctx); - blas.AXPY(numel_, 1., x_, y_); - } - -#ifdef PADDLE_WITH_XPU - void operator()(const platform::XPUPlace& place) const { - using XPUType = typename XPUTypeTrait::Type; - platform::XPUDeviceContext* ctx = dynamic_cast( - platform::DeviceContextPool::Instance().Get(place)); - int r = xpu::add(ctx->x_context(), - reinterpret_cast(x_), - reinterpret_cast(y_), - reinterpret_cast(y_), - static_cast(numel_)); - PADDLE_ENFORCE_EQ( - r, - XPU_SUCCESS, - platform::errors::External( - "XPU add kernel return wrong value[%d %s]", r, XPUAPIErrorMsg[r])); - } -#else - void operator()(const platform::XPUPlace& place) const { - PADDLE_THROW(platform::errors::PermissionDenied( - "Gradient accumulation on place (%s) " - "is not supported in imperative mode", - place)); - } -#endif - -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - void operator()(const platform::CUDAPlace& place) const { - phi::GPUContext* ctx = dynamic_cast( - platform::DeviceContextPool::Instance().Get(place)); - auto blas = phi::funcs::GetBlas(*ctx); - blas.AXPY(numel_, 1., x_, y_); - } -#else - void operator()(const platform::CUDAPlace& place) const { - PADDLE_THROW(platform::errors::PermissionDenied( - "Gradient accumulation on place (%s) " - "is not supported in imperative mode", - place)); - } -#endif - -#ifdef PADDLE_WITH_MLU - void operator()(const platform::MLUPlace& place) const { - // TODO(fwg): SUPPORT it - PADDLE_THROW(platform::errors::PermissionDenied( - "Gradient accumulation on place (%s) " - "is not supported in imperative mode", - place)); - } -#else - void operator()(const platform::MLUPlace& place) const { - PADDLE_THROW(platform::errors::PermissionDenied( - "Gradient accumulation on place (%s) " - "is not supported in imperative mode", - place)); - } -#endif - -#ifdef PADDLE_WITH_ASCEND_CL - void operator()(const platform::NPUPlace& place) const { - // TODO(zhiqiu): SUPPORT it - PADDLE_THROW(platform::errors::PermissionDenied( - "Gradient accumulation on place (%s) " - "is not supported in imperative mode", - place)); - } -#else - void operator()(const platform::NPUPlace& place) const { - PADDLE_THROW(platform::errors::PermissionDenied( - "Gradient accumulation on place (%s) " - "is not supported in imperative mode", - place)); - } -#endif - - void operator()(const platform::NPUPinnedPlace& place) const { - PADDLE_THROW(platform::errors::PermissionDenied( - "Gradient accumulation on place (%s) " - "is not supported in imperative mode", - place)); - } - // there is NO blas in CUDAPinnedPlace - void operator()(const platform::CUDAPinnedPlace& place) const { - PADDLE_THROW(platform::errors::PermissionDenied( - "Gradient accumulation on place (%s) " - "is not supported in imperative mode", - place)); - } - // there is NO support in IPUPlace - void operator()(const platform::IPUPlace& place) const { - PADDLE_THROW(platform::errors::PermissionDenied( - "Gradient accumulation on place (%s) " - "is not supported in imperative mode", - place)); - } - void operator()(const platform::CustomPlace& place) const { -#ifdef PADDLE_WITH_CUSTOM_DEVICE - platform::CustomDeviceContext* ctx = - dynamic_cast( - platform::DeviceContextPool::Instance().Get(place)); - phi::stream::Stream stream(place, ctx->stream()); - auto device = phi::DeviceManager::GetDeviceWithPlace(place); - device->BlasAXPBY(stream, static_cast(numel_), 1., x_, 1., y_); -#else - PADDLE_THROW(platform::errors::PermissionDenied( - "Gradient accumulation on place (%s) " - "is not supported in imperative mode", - place)); -#endif - } - - private: - int64_t numel_; - const T* x_; - mutable T* y_; -}; - #ifdef PADDLE_WITH_XPU template void XPUTensorAddFunctor(const platform::Place& place, @@ -232,17 +102,6 @@ void XPUTensorAddFunctor(const platform::Place& place, } #endif -template -void TensorAddImpl(const framework::Tensor& src, - framework::Tensor* dst, - const platform::Place& place) { - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - paddle::platform::DeviceContext* ctx = pool.Get(place); - auto dev_ctx = dynamic_cast(ctx); - phi::funcs::ElementwiseAddTo func; - func(dev_ctx, src, dst); -} - template TType* GetInnerMutableTensor(framework::Variable* dst) { auto* dst_tensor = dst->GetMutable(); @@ -327,14 +186,71 @@ void TensorAdd(const VarType& src, VarType* dst) { if (dst_tensor->place() != place) { paddle::framework::TensorCopySync(*dst_tensor, place, dst_tensor); } -#define PADDLE_TENSOR_ADD(cpp_type) \ - if (data_type == framework::DataTypeTrait::DataType()) { \ - TensorAddFunctor func( \ - numel, \ - src_tensor.data(), \ - dst_tensor->mutable_data(place)); \ - platform::VisitPlace(place, func); \ - return; \ + +#define PADDLE_TENSOR_ADD(T, CONTEXT) \ + if (data_type == framework::DataTypeTrait::DataType()) { \ + auto cpu_ctx = static_cast( \ + platform::DeviceContextPool::Instance().Get(place)); \ + phi::AddKernel(*cpu_ctx, src_tensor, *dst_tensor, dst_tensor); \ + return; \ + } + + if (platform::is_gpu_place(place)) { +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + PADDLE_TENSOR_ADD(float, phi::GPUContext); + PADDLE_TENSOR_ADD(double, phi::GPUContext); + PADDLE_TENSOR_ADD(phi::dtype::float16, phi::GPUContext); + PADDLE_TENSOR_ADD(phi::dtype::bfloat16, phi::GPUContext); + PADDLE_TENSOR_ADD(platform::complex, phi::GPUContext); + PADDLE_TENSOR_ADD(platform::complex, phi::GPUContext); +#endif + } + +#define TENSOR_ADD_EIGEN(T) \ + auto cpu_ctx = static_cast( \ + platform::DeviceContextPool::Instance().Get(place)); \ + auto in = paddle::framework::EigenVector::Flatten(src_tensor); \ + auto out = paddle::framework::EigenVector::Flatten(*dst_tensor); \ + auto& p = *(cpu_ctx->eigen_device()); \ + out.device(p) = out + in; \ + return; + + if (platform::is_cpu_place(place)) { + PADDLE_TENSOR_ADD(float, phi::CPUContext); + PADDLE_TENSOR_ADD(double, phi::CPUContext); + PADDLE_TENSOR_ADD(platform::complex, phi::CPUContext); + PADDLE_TENSOR_ADD(platform::complex, phi::CPUContext); + if (data_type == framework::proto::VarType::BF16) { + TENSOR_ADD_EIGEN(phi::dtype::bfloat16); + } + if (data_type == framework::proto::VarType::FP16) { + TENSOR_ADD_EIGEN(phi::dtype::float16); + } + } + +#define PADDLE_TENSOR_ADD_CUSTOM(T) \ + if (data_type == framework::DataTypeTrait::DataType()) { \ + platform::CustomDeviceContext* ctx = \ + static_cast( \ + platform::DeviceContextPool::Instance().Get(place)); \ + phi::stream::Stream stream(place, ctx->stream()); \ + auto device = phi::DeviceManager::GetDeviceWithPlace(place); \ + device->BlasAXPBY(stream, \ + static_cast(numel), \ + 1., \ + src_tensor.data(), \ + 1., \ + dst_tensor->mutable_data(place)); \ + return; \ + } + + if (platform::is_custom_place(place)) { +#if defined(PADDLE_WITH_CUSTOM_DEVICE) + PADDLE_TENSOR_ADD_CUSTOM(float); + PADDLE_TENSOR_ADD_CUSTOM(double); + PADDLE_TENSOR_ADD_CUSTOM(platform::complex); + PADDLE_TENSOR_ADD_CUSTOM(platform::complex); +#endif } #ifdef PADDLE_WITH_ASCEND_CL @@ -416,53 +332,6 @@ void TensorAdd(const VarType& src, VarType* dst) { } #endif - PADDLE_TENSOR_ADD(float); - -#ifndef PADDLE_WITH_XPU - // NOTE(phlrain): xpu only support float - PADDLE_TENSOR_ADD(double); - // NOTE(chenweihang): only support complex grad tensor accumulated, - // support selected rows if needed in the future - PADDLE_TENSOR_ADD(platform::complex); - PADDLE_TENSOR_ADD(platform::complex); -#endif - -#undef PADDLE_TENSOR_ADD - - if (data_type == framework::proto::VarType::FP16) { - if (platform::is_gpu_place(place)) { -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - return TensorAddImpl( - src_tensor, dst_tensor, place); -#else - PADDLE_THROW(platform::errors::Unimplemented( - "Gradient accumulation of data type (%s) on place (%s) is not " - "supported in imperative mode", - framework::DataTypeToString(data_type), - place)); -#endif - } else if (platform::is_cpu_place(place)) { - return TensorAddImpl( - src_tensor, dst_tensor, place); - } - } - if (data_type == framework::proto::VarType::BF16) { - if (platform::is_gpu_place(place)) { -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - return TensorAddImpl( - src_tensor, dst_tensor, place); -#else - PADDLE_THROW(platform::errors::Unimplemented( - "Gradient accumulation of data type (%s) on place (%s) is not " - "supported in imperative mode", - framework::DataTypeToString(data_type), - place)); -#endif - } else if (platform::is_cpu_place(place)) { - return TensorAddImpl( - src_tensor, dst_tensor, place); - } - } PADDLE_THROW(platform::errors::Unimplemented( "Gradient accumulation of data type (%s) on place (%s) is not " "supported in imperative mode", diff --git a/paddle/phi/kernels/funcs/math_function.cc b/paddle/phi/kernels/funcs/math_function.cc index 15a708f02f4974c0d8a5e3e485276c789ff37f41..84936d1e20c0ea695e42ac804ed69d0d45d190a8 100644 --- a/paddle/phi/kernels/funcs/math_function.cc +++ b/paddle/phi/kernels/funcs/math_function.cc @@ -257,21 +257,6 @@ template struct ColwiseSum; template struct RowwiseMean; template struct RowwiseMean; -template -struct ElementwiseAddTo { - void operator()(phi::CPUContext* ctx, - const paddle::framework::Tensor& src, - paddle::framework::Tensor* dst) { - auto in = paddle::framework::EigenVector::Flatten(src); - auto out = paddle::framework::EigenVector::Flatten(*dst); - auto& place = *(ctx->eigen_device()); - out.device(place) = out + in; - } -}; - -template struct ElementwiseAddTo; -template struct ElementwiseAddTo; - template struct RowwiseAdd { void operator()(const phi::CPUContext& context, diff --git a/paddle/phi/kernels/funcs/math_function.cu b/paddle/phi/kernels/funcs/math_function.cu index 9f0c20ccf14dc0e9cef169e75b352c0d1eabda39..c829adbc41373513e451af7d42c3e2055b22d539 100644 --- a/paddle/phi/kernels/funcs/math_function.cu +++ b/paddle/phi/kernels/funcs/math_function.cu @@ -21,7 +21,6 @@ limitations under the License. */ #include "paddle/fluid/platform/float16.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/kernels/funcs/blas/blas.h" -#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function_impl.h" @@ -371,20 +370,5 @@ void RowwiseSum::operator()( template struct RowwiseMean; template struct RowwiseMean; -template -struct ElementwiseAddTo { - void operator()(phi::GPUContext* ctx, - const paddle::framework::Tensor& src, - paddle::framework::Tensor* dst) { - auto in = paddle::framework::EigenVector::Flatten(src); - auto out = paddle::framework::EigenVector::Flatten(*dst); - auto& place = *(ctx->eigen_device()); - out.device(place) = out + in; - } -}; - -template struct ElementwiseAddTo; -template struct ElementwiseAddTo; - } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/funcs/math_function.h b/paddle/phi/kernels/funcs/math_function.h index b735587d3d53df03dfb82f7e3657a6b0c2cabd9b..b93096565603803e5ad8fe192c34093631614947 100644 --- a/paddle/phi/kernels/funcs/math_function.h +++ b/paddle/phi/kernels/funcs/math_function.h @@ -18,7 +18,6 @@ limitations under the License. */ #include #include "paddle/fluid/framework/convert_utils.h" -#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor_util.h" @@ -71,14 +70,6 @@ struct RowwiseAdd { paddle::framework::Tensor* output); }; -template -struct ElementwiseAddTo { - // dst = dst + src - void operator()(DeviceContext* ctx, - const paddle::framework::Tensor& src, - paddle::framework::Tensor* dst); -}; - template struct ColwiseSum { void operator()(const DeviceContext& context,