diff --git a/paddle/fluid/imperative/gradient_accumulator.cc b/paddle/fluid/imperative/gradient_accumulator.cc index f6883fe6c6a923429235874ab788e7bd2f224c1f..b656da34fb6a6909a9f0a6e1f6e0f68eb4736955 100644 --- a/paddle/fluid/imperative/gradient_accumulator.cc +++ b/paddle/fluid/imperative/gradient_accumulator.cc @@ -39,6 +39,9 @@ #ifdef PADDLE_WITH_MLU #include "paddle/fluid/operators/mlu/mlu_baseop.h" #endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE +#include "paddle/phi/backends/device_manager.h" +#endif namespace paddle { namespace imperative { @@ -189,10 +192,19 @@ class TensorAddFunctor place)); } void operator()(const platform::CustomPlace& place) const { +#ifdef PADDLE_WITH_CUSTOM_DEVICE + platform::CustomDeviceContext* ctx = + dynamic_cast( + platform::DeviceContextPool::Instance().Get(place)); + phi::stream::Stream stream(place, ctx->stream()); + auto device = phi::DeviceManager::GetDeviceWithPlace(place); + device->BlasAXPBY(stream, static_cast(numel_), 1., x_, 1., y_); +#else PADDLE_THROW(platform::errors::PermissionDenied( "Gradient accumulation on place (%s) " "is not supported in imperative mode", place)); +#endif } private: @@ -351,15 +363,7 @@ void TensorAdd(const VarType& src, VarType* dst) { return; } #endif -#ifdef PADDLE_WITH_CUSTOM_DEVICE - if (platform::is_custom_place(place)) { - PADDLE_THROW(platform::errors::Unimplemented( - "Gradient accumulation of data type (%s) on place (%s) is not " - "supported in imperative mode", - framework::DataTypeToString(data_type), - place)); - } -#endif + #ifdef PADDLE_WITH_XPU if (platform::is_xpu_place(place)) { if (data_type == framework::DataTypeTrait::DataType()) { diff --git a/paddle/phi/backends/CMakeLists.txt b/paddle/phi/backends/CMakeLists.txt index de4d82b46133caa24b5268bbea29a01b8d0915aa..6a55c34266ff7ab75d968a546b2b512121d07543 100644 --- a/paddle/phi/backends/CMakeLists.txt +++ b/paddle/phi/backends/CMakeLists.txt @@ -51,7 +51,7 @@ if(WITH_CUSTOM_DEVICE) cc_test( custom_device_test SRCS custom/custom_device_test.cc - DEPS phi_backends phi_device_context) + DEPS phi_backends phi_device_context gradient_accumulator) cc_test( capi_test SRCS custom/capi_test.cc diff --git a/paddle/phi/backends/custom/custom_device.cc b/paddle/phi/backends/custom/custom_device.cc index 1a92868dd07db1c7c9b67581587b1cc7323027d7..2567857bca1cad0596dcf693979c511bb55f4460 100644 --- a/paddle/phi/backends/custom/custom_device.cc +++ b/paddle/phi/backends/custom/custom_device.cc @@ -14,6 +14,8 @@ #include "paddle/fluid/platform/device/custom/enforce_custom.h" #include "paddle/fluid/platform/device_context.h" +#include "paddle/phi/common/data_type.h" + #include "paddle/phi/backends/callback_manager.h" #include "paddle/phi/backends/device_base.h" #include "paddle/phi/backends/device_guard.h" @@ -608,6 +610,27 @@ class CustomDevice : public DeviceInterface { #undef return_result } + C_DataType ToCDatatType(paddle::experimental::DataType data_type) { +#define return_result(in, ret) \ + case in: \ + return C_DataType::ret + switch (data_type) { + return_result(paddle::experimental::DataType::FLOAT64, FLOAT64); + return_result(paddle::experimental::DataType::FLOAT32, FLOAT32); + return_result(paddle::experimental::DataType::FLOAT16, FLOAT16); + return_result(paddle::experimental::DataType::INT64, INT64); + return_result(paddle::experimental::DataType::INT32, INT32); + return_result(paddle::experimental::DataType::INT16, INT16); + return_result(paddle::experimental::DataType::INT8, INT8); + default: { + PADDLE_THROW(phi::errors::Unavailable( + "DataType is not supported on %s.", Type())); + return C_DataType::UNDEFINED; + } + } +#undef return_result + } + void CCLGetUniqueId(ccl::CCLRootId* unique_id) override { CHECK_PTR(pimpl_->xccl_get_unique_id_size); CHECK_PTR(pimpl_->xccl_get_unique_id); @@ -771,6 +794,27 @@ class CustomDevice : public DeviceInterface { reinterpret_cast(stream.raw_stream()))); } + void BlasAXPBY(size_t dev_id, + const stream::Stream& stream, + paddle::experimental::DataType dtype, + size_t numel, + float alpha, + void* x, + float beta, + void* y) override { + CHECK_PTR(pimpl_->blas_axpby); + const auto device = &devices_pool[dev_id]; + PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS( + pimpl_->blas_axpby(device, + reinterpret_cast(stream.raw_stream()), + ToCDatatType(dtype), + numel, + alpha, + x, + beta, + y)); + } + private: inline int PlaceToIdNoCheck(const Place& place) { int dev_id = place.GetDeviceId(); @@ -877,6 +921,8 @@ bool ValidCustomCustomRuntimeParams(const CustomRuntimeParams* params) { CHECK_INTERFACE(xccl_group_end, false); CHECK_INTERFACE(xccl_send, false); CHECK_INTERFACE(xccl_recv, false); + + CHECK_INTERFACE(blas_axpby, false); return true; #undef CHECK_INTERFACE } diff --git a/paddle/phi/backends/custom/custom_device_test.cc b/paddle/phi/backends/custom/custom_device_test.cc index 930750e864883aa3d04db5102593449f199f7f61..425d7bde6173ccab94da9685af4f63b66e78176c 100644 --- a/paddle/phi/backends/custom/custom_device_test.cc +++ b/paddle/phi/backends/custom/custom_device_test.cc @@ -18,6 +18,8 @@ #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/framework/variable.h" +#include "paddle/fluid/imperative/gradient_accumulator.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/phi/backends/custom/fake_cpu_device.h" #include "paddle/phi/backends/device_manager.h" @@ -237,6 +239,51 @@ void TestCustomCCL(const paddle::platform::Place& place) { stream); } +void TestBlasAPI(const paddle::platform::Place& place) { + std::cout << "TestBlasAPI on " << place << std::endl; + if (paddle::platform::is_custom_place(place) == false) { + return; + } + auto device = phi::DeviceManager::GetDeviceWithPlace(place); + phi::stream::Stream stream(place, nullptr); + device->BlasAXPBY(stream, 0, 1., nullptr, 1., nullptr); + + paddle::framework::Variable var1; + paddle::framework::Variable var2; + std::vector src_data(10, 1.0); + std::vector dst_data(10, 0.0); + std::vector result; + paddle::platform::CPUPlace src_place; + for (unsigned int i = 0; i < 10; i++) { + result.emplace_back(src_data[i] + dst_data[i]); + } + + std::vector dims = {2, 5}; + auto* src = var1.GetMutable(); + auto* dst = var2.GetMutable(); + src->Resize(phi::make_ddim(dims)); + dst->Resize(phi::make_ddim(dims)); + auto* src_mutable = src->mutable_data(place); + auto* dst_mutable = dst->mutable_data(place); + + paddle::memory::Copy(place, + src_mutable, + src_place, + src_data.data(), + sizeof(float) * src_data.size()); + + paddle::memory::Copy(place, + dst_mutable, + src_place, + dst_data.data(), + sizeof(float) * dst_data.size()); + + paddle::imperative::TensorAdd(var1, &var2); + paddle::framework::LoDTensor rlt; + paddle::platform::CPUPlace rlt_place; + paddle::framework::TensorCopySync(*dst, rlt_place, &rlt); +} + TEST(CustomDevice, Tensor) { InitDevice(); auto dev_types = phi::DeviceManager::GetAllDeviceTypes(); @@ -251,6 +298,7 @@ TEST(CustomDevice, Tensor) { TestTensorShareDataWith(place); TestTensorUtils(place); TestCustomCCL(place); + TestBlasAPI(place); } } diff --git a/paddle/phi/backends/custom/fake_cpu_device.h b/paddle/phi/backends/custom/fake_cpu_device.h index 41c7acc4469cdedbe35aae81ca82170a5d9aee70..a4eaa834a60f9843a14f56c2d58173a93f2e0a54 100644 --- a/paddle/phi/backends/custom/fake_cpu_device.h +++ b/paddle/phi/backends/custom/fake_cpu_device.h @@ -210,6 +210,17 @@ C_Status XcclRecv(void *recv_buf, return C_SUCCESS; } +C_Status BlasAXPBY(const C_Device device, + C_Stream stream, + C_DataType dtype, + size_t numel, + float alpha, + void *x, + float beta, + void *y) { + return C_SUCCESS; +} + #define DEVICE_TYPE "FakeCPU" #define SUB_DEVICE_TYPE "V100" @@ -278,4 +289,6 @@ void InitFakeCPUDevice(CustomRuntimeParams *params) { params->interface->xccl_reduce_scatter = XcclReduceScatter; params->interface->xccl_send = XcclSend; params->interface->xccl_recv = XcclRecv; + + params->interface->blas_axpby = BlasAXPBY; } diff --git a/paddle/phi/backends/device_base.cc b/paddle/phi/backends/device_base.cc index 4b82f4a340ebb13d4c40113f879d6c3a57517848..41871f69c7790458cad21e92fe0dc6209f7ebb61 100644 --- a/paddle/phi/backends/device_base.cc +++ b/paddle/phi/backends/device_base.cc @@ -355,6 +355,18 @@ void DeviceInterface::CCLRecv(void* recvbuf, INTERFACE_UNIMPLEMENT; } +// blas +void DeviceInterface::BlasAXPBY(size_t dev_id, + const stream::Stream& stream, + paddle::experimental::DataType dtype, + size_t numel, + float alpha, + void* x, + float beta, + void* y) { + INTERFACE_UNIMPLEMENT; +} + #undef INTERFACE_UNIMPLEMENT } // namespace phi diff --git a/paddle/phi/backends/device_base.h b/paddle/phi/backends/device_base.h index 84249261d196215d73f302a67ee80e33e15607f8..b823a4a983207c4ea902ad1580eeb48eeeb068af 100644 --- a/paddle/phi/backends/device_base.h +++ b/paddle/phi/backends/device_base.h @@ -225,6 +225,16 @@ class DeviceInterface { // Driver / Runtime const ccl::CCLComm& ccl_comm, const stream::Stream& stream); + // blas + virtual void BlasAXPBY(size_t dev_id, + const stream::Stream& stream, + paddle::experimental::DataType dtype, + size_t numel, + float alpha, + void* x, + float beta, + void* y); + private: const std::string type_; const uint8_t priority_; diff --git a/paddle/phi/backends/device_ext.h b/paddle/phi/backends/device_ext.h index a4dc9176e1b1ec686d3338d1f425cd96dafe8228..5bb5def9c2b19866910817df7c5f465a72b02970 100644 --- a/paddle/phi/backends/device_ext.h +++ b/paddle/phi/backends/device_ext.h @@ -635,7 +635,19 @@ struct C_DeviceInterface { // other api // /////////////// - void* reserved_other_api[8]; + /** + * @brief y = alpha * x + beta * y + * + */ + C_Status (*blas_axpby)(const C_Device device, + C_Stream stream, + C_DataType dtype, + size_t numel, + float alpha, + void* x, + float beta, + void* y); + void* reserved_other_api[7]; }; struct CustomRuntimeVersion { diff --git a/paddle/phi/backends/device_manager.cc b/paddle/phi/backends/device_manager.cc index 405a87f7496a844194701b45e63775d57f8986f2..dbdbce13d4f40122027a0737479eb4fbf3630b54 100644 --- a/paddle/phi/backends/device_manager.cc +++ b/paddle/phi/backends/device_manager.cc @@ -14,6 +14,7 @@ #ifdef PADDLE_WITH_CUSTOM_DEVICE #include "paddle/phi/backends/device_manager.h" +#include "paddle/phi/common/complex.h" #if !defined(_WIN32) #include @@ -135,6 +136,80 @@ void Device::MemorySet(void* ptr, uint8_t value, size_t size) { impl_->MemorySet(dev_id_, ptr, value, size); } +template +void Device::BlasAXPBY(const stream::Stream& stream, + size_t numel, + float alpha, + const T* x, + float beta, + T* y) { + impl_->BlasAXPBY(dev_id_, + stream, + paddle::experimental::CppTypeToDataType::Type(), + numel, + alpha, + reinterpret_cast(const_cast(x)), + beta, + reinterpret_cast(y)); +} + +template void Device::BlasAXPBY(const stream::Stream& stream, + size_t numel, + float alpha, + const paddle::float16* x, + float beta, + paddle::float16* y); +template void Device::BlasAXPBY(const stream::Stream& stream, + size_t numel, + float alpha, + const float* x, + float beta, + float* y); +template void Device::BlasAXPBY(const stream::Stream& stream, + size_t numel, + float alpha, + const double* x, + float beta, + double* y); +template void Device::BlasAXPBY(const stream::Stream& stream, + size_t numel, + float alpha, + const int8_t* x, + float beta, + int8_t* y); +template void Device::BlasAXPBY(const stream::Stream& stream, + size_t numel, + float alpha, + const int16_t* x, + float beta, + int16_t* y); +template void Device::BlasAXPBY(const stream::Stream& stream, + size_t numel, + float alpha, + const int32_t* x, + float beta, + int32_t* y); +template void Device::BlasAXPBY(const stream::Stream& stream, + size_t numel, + float alpha, + const int64_t* x, + float beta, + int64_t* y); +template void Device::BlasAXPBY>( + const stream::Stream& stream, + size_t numel, + float alpha, + const phi::dtype::complex* x, + float beta, + phi::dtype::complex* y); +template void Device::BlasAXPBY>( + const stream::Stream& stream, + size_t numel, + float alpha, + const phi::dtype::complex* x, + float beta, + phi::dtype::complex* y); + std::string Device::Type() { return impl_->Type(); } static phi::RWLock _global_device_manager_rw_lock; diff --git a/paddle/phi/backends/device_manager.h b/paddle/phi/backends/device_manager.h index 4ad7643c33d3c498c93ada6a608e20ae37ea7ebd..6d621b6a43223919c74815a799b4d21177d66736 100644 --- a/paddle/phi/backends/device_manager.h +++ b/paddle/phi/backends/device_manager.h @@ -17,14 +17,16 @@ #include +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/utils/rw_lock.h" + #include "paddle/phi/backends/c_comm_lib.h" #include "paddle/phi/backends/device_base.h" #include "paddle/phi/backends/device_ext.h" #include "paddle/phi/backends/dynload/port.h" #include "paddle/phi/backends/event.h" #include "paddle/phi/backends/stream.h" -#include "paddle/phi/common/place.h" -#include "paddle/phi/core/utils/rw_lock.h" namespace phi { class Device final { @@ -106,6 +108,16 @@ class Device final { void MemorySet(void* ptr, uint8_t value, size_t size); + // Blas + // ! y = alpha * x + beta * y + template + void BlasAXPBY(const stream::Stream& stream, + size_t numel, + float alpha, + const T* x, + float beta, + T* y); + std::string Type(); private: