未验证 提交 0e3b49d4 编写于 作者: L Leo Chen 提交者: GitHub

Reuse addKernel to replace TensorAdd (#45161)

* use addKernel

* fix compile

* remove elementwiseAddto

* add return

* fix custom place
上级 d0cd0a11
......@@ -42,6 +42,7 @@
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/backends/device_manager.h"
#endif
#include "paddle/phi/kernels/elementwise_add_kernel.h"
namespace paddle {
namespace imperative {
......@@ -81,137 +82,6 @@ static void MoveOrCopyVar(framework::Variable* dst,
}
}
template <typename T>
class TensorAddFunctor
: public std::unary_function<const platform::Place&, void> {
public:
TensorAddFunctor(int64_t numel, const T* x, T* y)
: numel_(numel), x_(x), y_(y) {}
void operator()(const platform::CPUPlace& place) const {
phi::CPUContext* ctx = dynamic_cast<phi::CPUContext*>(
platform::DeviceContextPool::Instance().Get(place));
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(*ctx);
blas.AXPY(numel_, 1., x_, y_);
}
#ifdef PADDLE_WITH_XPU
void operator()(const platform::XPUPlace& place) const {
using XPUType = typename XPUTypeTrait<T>::Type;
platform::XPUDeviceContext* ctx = dynamic_cast<platform::XPUDeviceContext*>(
platform::DeviceContextPool::Instance().Get(place));
int r = xpu::add<XPUType>(ctx->x_context(),
reinterpret_cast<const XPUType*>(x_),
reinterpret_cast<const XPUType*>(y_),
reinterpret_cast<XPUType*>(y_),
static_cast<int>(numel_));
PADDLE_ENFORCE_EQ(
r,
XPU_SUCCESS,
platform::errors::External(
"XPU add kernel return wrong value[%d %s]", r, XPUAPIErrorMsg[r]));
}
#else
void operator()(const platform::XPUPlace& place) const {
PADDLE_THROW(platform::errors::PermissionDenied(
"Gradient accumulation on place (%s) "
"is not supported in imperative mode",
place));
}
#endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
void operator()(const platform::CUDAPlace& place) const {
phi::GPUContext* ctx = dynamic_cast<phi::GPUContext*>(
platform::DeviceContextPool::Instance().Get(place));
auto blas = phi::funcs::GetBlas<phi::GPUContext, T>(*ctx);
blas.AXPY(numel_, 1., x_, y_);
}
#else
void operator()(const platform::CUDAPlace& place) const {
PADDLE_THROW(platform::errors::PermissionDenied(
"Gradient accumulation on place (%s) "
"is not supported in imperative mode",
place));
}
#endif
#ifdef PADDLE_WITH_MLU
void operator()(const platform::MLUPlace& place) const {
// TODO(fwg): SUPPORT it
PADDLE_THROW(platform::errors::PermissionDenied(
"Gradient accumulation on place (%s) "
"is not supported in imperative mode",
place));
}
#else
void operator()(const platform::MLUPlace& place) const {
PADDLE_THROW(platform::errors::PermissionDenied(
"Gradient accumulation on place (%s) "
"is not supported in imperative mode",
place));
}
#endif
#ifdef PADDLE_WITH_ASCEND_CL
void operator()(const platform::NPUPlace& place) const {
// TODO(zhiqiu): SUPPORT it
PADDLE_THROW(platform::errors::PermissionDenied(
"Gradient accumulation on place (%s) "
"is not supported in imperative mode",
place));
}
#else
void operator()(const platform::NPUPlace& place) const {
PADDLE_THROW(platform::errors::PermissionDenied(
"Gradient accumulation on place (%s) "
"is not supported in imperative mode",
place));
}
#endif
void operator()(const platform::NPUPinnedPlace& place) const {
PADDLE_THROW(platform::errors::PermissionDenied(
"Gradient accumulation on place (%s) "
"is not supported in imperative mode",
place));
}
// there is NO blas in CUDAPinnedPlace
void operator()(const platform::CUDAPinnedPlace& place) const {
PADDLE_THROW(platform::errors::PermissionDenied(
"Gradient accumulation on place (%s) "
"is not supported in imperative mode",
place));
}
// there is NO support in IPUPlace
void operator()(const platform::IPUPlace& place) const {
PADDLE_THROW(platform::errors::PermissionDenied(
"Gradient accumulation on place (%s) "
"is not supported in imperative mode",
place));
}
void operator()(const platform::CustomPlace& place) const {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
platform::CustomDeviceContext* ctx =
dynamic_cast<platform::CustomDeviceContext*>(
platform::DeviceContextPool::Instance().Get(place));
phi::stream::Stream stream(place, ctx->stream());
auto device = phi::DeviceManager::GetDeviceWithPlace(place);
device->BlasAXPBY<T>(stream, static_cast<size_t>(numel_), 1., x_, 1., y_);
#else
PADDLE_THROW(platform::errors::PermissionDenied(
"Gradient accumulation on place (%s) "
"is not supported in imperative mode",
place));
#endif
}
private:
int64_t numel_;
const T* x_;
mutable T* y_;
};
#ifdef PADDLE_WITH_XPU
template <typename T>
void XPUTensorAddFunctor(const platform::Place& place,
......@@ -232,17 +102,6 @@ void XPUTensorAddFunctor(const platform::Place& place,
}
#endif
template <typename DeviceContext, typename T>
void TensorAddImpl(const framework::Tensor& src,
framework::Tensor* dst,
const platform::Place& place) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
paddle::platform::DeviceContext* ctx = pool.Get(place);
auto dev_ctx = dynamic_cast<DeviceContext*>(ctx);
phi::funcs::ElementwiseAddTo<DeviceContext, T> func;
func(dev_ctx, src, dst);
}
template <typename TType>
TType* GetInnerMutableTensor(framework::Variable* dst) {
auto* dst_tensor = dst->GetMutable<TType>();
......@@ -327,14 +186,71 @@ void TensorAdd(const VarType& src, VarType* dst) {
if (dst_tensor->place() != place) {
paddle::framework::TensorCopySync(*dst_tensor, place, dst_tensor);
}
#define PADDLE_TENSOR_ADD(cpp_type) \
if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) { \
TensorAddFunctor<cpp_type> func( \
numel, \
src_tensor.data<cpp_type>(), \
dst_tensor->mutable_data<cpp_type>(place)); \
platform::VisitPlace(place, func); \
return; \
#define PADDLE_TENSOR_ADD(T, CONTEXT) \
if (data_type == framework::DataTypeTrait<T>::DataType()) { \
auto cpu_ctx = static_cast<CONTEXT*>( \
platform::DeviceContextPool::Instance().Get(place)); \
phi::AddKernel<T, CONTEXT>(*cpu_ctx, src_tensor, *dst_tensor, dst_tensor); \
return; \
}
if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PADDLE_TENSOR_ADD(float, phi::GPUContext);
PADDLE_TENSOR_ADD(double, phi::GPUContext);
PADDLE_TENSOR_ADD(phi::dtype::float16, phi::GPUContext);
PADDLE_TENSOR_ADD(phi::dtype::bfloat16, phi::GPUContext);
PADDLE_TENSOR_ADD(platform::complex<float>, phi::GPUContext);
PADDLE_TENSOR_ADD(platform::complex<double>, phi::GPUContext);
#endif
}
#define TENSOR_ADD_EIGEN(T) \
auto cpu_ctx = static_cast<phi::CPUContext*>( \
platform::DeviceContextPool::Instance().Get(place)); \
auto in = paddle::framework::EigenVector<T>::Flatten(src_tensor); \
auto out = paddle::framework::EigenVector<T>::Flatten(*dst_tensor); \
auto& p = *(cpu_ctx->eigen_device()); \
out.device(p) = out + in; \
return;
if (platform::is_cpu_place(place)) {
PADDLE_TENSOR_ADD(float, phi::CPUContext);
PADDLE_TENSOR_ADD(double, phi::CPUContext);
PADDLE_TENSOR_ADD(platform::complex<float>, phi::CPUContext);
PADDLE_TENSOR_ADD(platform::complex<double>, phi::CPUContext);
if (data_type == framework::proto::VarType::BF16) {
TENSOR_ADD_EIGEN(phi::dtype::bfloat16);
}
if (data_type == framework::proto::VarType::FP16) {
TENSOR_ADD_EIGEN(phi::dtype::float16);
}
}
#define PADDLE_TENSOR_ADD_CUSTOM(T) \
if (data_type == framework::DataTypeTrait<T>::DataType()) { \
platform::CustomDeviceContext* ctx = \
static_cast<platform::CustomDeviceContext*>( \
platform::DeviceContextPool::Instance().Get(place)); \
phi::stream::Stream stream(place, ctx->stream()); \
auto device = phi::DeviceManager::GetDeviceWithPlace(place); \
device->BlasAXPBY<T>(stream, \
static_cast<size_t>(numel), \
1., \
src_tensor.data<T>(), \
1., \
dst_tensor->mutable_data<T>(place)); \
return; \
}
if (platform::is_custom_place(place)) {
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
PADDLE_TENSOR_ADD_CUSTOM(float);
PADDLE_TENSOR_ADD_CUSTOM(double);
PADDLE_TENSOR_ADD_CUSTOM(platform::complex<float>);
PADDLE_TENSOR_ADD_CUSTOM(platform::complex<double>);
#endif
}
#ifdef PADDLE_WITH_ASCEND_CL
......@@ -416,53 +332,6 @@ void TensorAdd(const VarType& src, VarType* dst) {
}
#endif
PADDLE_TENSOR_ADD(float);
#ifndef PADDLE_WITH_XPU
// NOTE(phlrain): xpu only support float
PADDLE_TENSOR_ADD(double);
// NOTE(chenweihang): only support complex grad tensor accumulated,
// support selected rows if needed in the future
PADDLE_TENSOR_ADD(platform::complex<float>);
PADDLE_TENSOR_ADD(platform::complex<double>);
#endif
#undef PADDLE_TENSOR_ADD
if (data_type == framework::proto::VarType::FP16) {
if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
return TensorAddImpl<phi::GPUContext, platform::float16>(
src_tensor, dst_tensor, place);
#else
PADDLE_THROW(platform::errors::Unimplemented(
"Gradient accumulation of data type (%s) on place (%s) is not "
"supported in imperative mode",
framework::DataTypeToString(data_type),
place));
#endif
} else if (platform::is_cpu_place(place)) {
return TensorAddImpl<phi::CPUContext, platform::float16>(
src_tensor, dst_tensor, place);
}
}
if (data_type == framework::proto::VarType::BF16) {
if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
return TensorAddImpl<phi::GPUContext, platform::bfloat16>(
src_tensor, dst_tensor, place);
#else
PADDLE_THROW(platform::errors::Unimplemented(
"Gradient accumulation of data type (%s) on place (%s) is not "
"supported in imperative mode",
framework::DataTypeToString(data_type),
place));
#endif
} else if (platform::is_cpu_place(place)) {
return TensorAddImpl<phi::CPUContext, platform::bfloat16>(
src_tensor, dst_tensor, place);
}
}
PADDLE_THROW(platform::errors::Unimplemented(
"Gradient accumulation of data type (%s) on place (%s) is not "
"supported in imperative mode",
......
......@@ -257,21 +257,6 @@ template struct ColwiseSum<phi::CPUContext, int64_t>;
template struct RowwiseMean<phi::CPUContext, float>;
template struct RowwiseMean<phi::CPUContext, double>;
template <typename T>
struct ElementwiseAddTo<phi::CPUContext, T> {
void operator()(phi::CPUContext* ctx,
const paddle::framework::Tensor& src,
paddle::framework::Tensor* dst) {
auto in = paddle::framework::EigenVector<T>::Flatten(src);
auto out = paddle::framework::EigenVector<T>::Flatten(*dst);
auto& place = *(ctx->eigen_device());
out.device(place) = out + in;
}
};
template struct ElementwiseAddTo<phi::CPUContext, phi::dtype::float16>;
template struct ElementwiseAddTo<phi::CPUContext, phi::dtype::bfloat16>;
template <typename T>
struct RowwiseAdd<phi::CPUContext, T> {
void operator()(const phi::CPUContext& context,
......
......@@ -21,7 +21,6 @@ limitations under the License. */
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/math_function_impl.h"
......@@ -371,20 +370,5 @@ void RowwiseSum<phi::GPUContext, double>::operator()(
template struct RowwiseMean<phi::GPUContext, float>;
template struct RowwiseMean<phi::GPUContext, double>;
template <typename T>
struct ElementwiseAddTo<phi::GPUContext, T> {
void operator()(phi::GPUContext* ctx,
const paddle::framework::Tensor& src,
paddle::framework::Tensor* dst) {
auto in = paddle::framework::EigenVector<T>::Flatten(src);
auto out = paddle::framework::EigenVector<T>::Flatten(*dst);
auto& place = *(ctx->eigen_device());
out.device(place) = out + in;
}
};
template struct ElementwiseAddTo<phi::GPUContext, phi::dtype::float16>;
template struct ElementwiseAddTo<phi::GPUContext, phi::dtype::bfloat16>;
} // namespace funcs
} // namespace phi
......@@ -18,7 +18,6 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
......@@ -71,14 +70,6 @@ struct RowwiseAdd {
paddle::framework::Tensor* output);
};
template <typename DeviceContext, typename T>
struct ElementwiseAddTo {
// dst = dst + src
void operator()(DeviceContext* ctx,
const paddle::framework::Tensor& src,
paddle::framework::Tensor* dst);
};
template <typename DeviceContext, typename T>
struct ColwiseSum {
void operator()(const DeviceContext& context,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册