未验证 提交 0d51fcf1 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] add blas_axpby api for gradient_accumulator (#44584)

上级 356ff436
...@@ -39,6 +39,9 @@ ...@@ -39,6 +39,9 @@
#ifdef PADDLE_WITH_MLU #ifdef PADDLE_WITH_MLU
#include "paddle/fluid/operators/mlu/mlu_baseop.h" #include "paddle/fluid/operators/mlu/mlu_baseop.h"
#endif #endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/backends/device_manager.h"
#endif
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
...@@ -189,10 +192,19 @@ class TensorAddFunctor ...@@ -189,10 +192,19 @@ class TensorAddFunctor
place)); place));
} }
void operator()(const platform::CustomPlace& place) const { void operator()(const platform::CustomPlace& place) const {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
platform::CustomDeviceContext* ctx =
dynamic_cast<platform::CustomDeviceContext*>(
platform::DeviceContextPool::Instance().Get(place));
phi::stream::Stream stream(place, ctx->stream());
auto device = phi::DeviceManager::GetDeviceWithPlace(place);
device->BlasAXPBY<T>(stream, static_cast<size_t>(numel_), 1., x_, 1., y_);
#else
PADDLE_THROW(platform::errors::PermissionDenied( PADDLE_THROW(platform::errors::PermissionDenied(
"Gradient accumulation on place (%s) " "Gradient accumulation on place (%s) "
"is not supported in imperative mode", "is not supported in imperative mode",
place)); place));
#endif
} }
private: private:
...@@ -351,15 +363,7 @@ void TensorAdd(const VarType& src, VarType* dst) { ...@@ -351,15 +363,7 @@ void TensorAdd(const VarType& src, VarType* dst) {
return; return;
} }
#endif #endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
if (platform::is_custom_place(place)) {
PADDLE_THROW(platform::errors::Unimplemented(
"Gradient accumulation of data type (%s) on place (%s) is not "
"supported in imperative mode",
framework::DataTypeToString(data_type),
place));
}
#endif
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
if (platform::is_xpu_place(place)) { if (platform::is_xpu_place(place)) {
if (data_type == framework::DataTypeTrait<float>::DataType()) { if (data_type == framework::DataTypeTrait<float>::DataType()) {
......
...@@ -51,7 +51,7 @@ if(WITH_CUSTOM_DEVICE) ...@@ -51,7 +51,7 @@ if(WITH_CUSTOM_DEVICE)
cc_test( cc_test(
custom_device_test custom_device_test
SRCS custom/custom_device_test.cc SRCS custom/custom_device_test.cc
DEPS phi_backends phi_device_context) DEPS phi_backends phi_device_context gradient_accumulator)
cc_test( cc_test(
capi_test capi_test
SRCS custom/capi_test.cc SRCS custom/capi_test.cc
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#include "paddle/fluid/platform/device/custom/enforce_custom.h" #include "paddle/fluid/platform/device/custom/enforce_custom.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/backends/callback_manager.h" #include "paddle/phi/backends/callback_manager.h"
#include "paddle/phi/backends/device_base.h" #include "paddle/phi/backends/device_base.h"
#include "paddle/phi/backends/device_guard.h" #include "paddle/phi/backends/device_guard.h"
...@@ -608,6 +610,27 @@ class CustomDevice : public DeviceInterface { ...@@ -608,6 +610,27 @@ class CustomDevice : public DeviceInterface {
#undef return_result #undef return_result
} }
C_DataType ToCDatatType(paddle::experimental::DataType data_type) {
#define return_result(in, ret) \
case in: \
return C_DataType::ret
switch (data_type) {
return_result(paddle::experimental::DataType::FLOAT64, FLOAT64);
return_result(paddle::experimental::DataType::FLOAT32, FLOAT32);
return_result(paddle::experimental::DataType::FLOAT16, FLOAT16);
return_result(paddle::experimental::DataType::INT64, INT64);
return_result(paddle::experimental::DataType::INT32, INT32);
return_result(paddle::experimental::DataType::INT16, INT16);
return_result(paddle::experimental::DataType::INT8, INT8);
default: {
PADDLE_THROW(phi::errors::Unavailable(
"DataType is not supported on %s.", Type()));
return C_DataType::UNDEFINED;
}
}
#undef return_result
}
void CCLGetUniqueId(ccl::CCLRootId* unique_id) override { void CCLGetUniqueId(ccl::CCLRootId* unique_id) override {
CHECK_PTR(pimpl_->xccl_get_unique_id_size); CHECK_PTR(pimpl_->xccl_get_unique_id_size);
CHECK_PTR(pimpl_->xccl_get_unique_id); CHECK_PTR(pimpl_->xccl_get_unique_id);
...@@ -771,6 +794,27 @@ class CustomDevice : public DeviceInterface { ...@@ -771,6 +794,27 @@ class CustomDevice : public DeviceInterface {
reinterpret_cast<C_Stream>(stream.raw_stream()))); reinterpret_cast<C_Stream>(stream.raw_stream())));
} }
void BlasAXPBY(size_t dev_id,
const stream::Stream& stream,
paddle::experimental::DataType dtype,
size_t numel,
float alpha,
void* x,
float beta,
void* y) override {
CHECK_PTR(pimpl_->blas_axpby);
const auto device = &devices_pool[dev_id];
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(
pimpl_->blas_axpby(device,
reinterpret_cast<C_Stream>(stream.raw_stream()),
ToCDatatType(dtype),
numel,
alpha,
x,
beta,
y));
}
private: private:
inline int PlaceToIdNoCheck(const Place& place) { inline int PlaceToIdNoCheck(const Place& place) {
int dev_id = place.GetDeviceId(); int dev_id = place.GetDeviceId();
...@@ -877,6 +921,8 @@ bool ValidCustomCustomRuntimeParams(const CustomRuntimeParams* params) { ...@@ -877,6 +921,8 @@ bool ValidCustomCustomRuntimeParams(const CustomRuntimeParams* params) {
CHECK_INTERFACE(xccl_group_end, false); CHECK_INTERFACE(xccl_group_end, false);
CHECK_INTERFACE(xccl_send, false); CHECK_INTERFACE(xccl_send, false);
CHECK_INTERFACE(xccl_recv, false); CHECK_INTERFACE(xccl_recv, false);
CHECK_INTERFACE(blas_axpby, false);
return true; return true;
#undef CHECK_INTERFACE #undef CHECK_INTERFACE
} }
......
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/imperative/gradient_accumulator.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/backends/custom/fake_cpu_device.h" #include "paddle/phi/backends/custom/fake_cpu_device.h"
#include "paddle/phi/backends/device_manager.h" #include "paddle/phi/backends/device_manager.h"
...@@ -237,6 +239,51 @@ void TestCustomCCL(const paddle::platform::Place& place) { ...@@ -237,6 +239,51 @@ void TestCustomCCL(const paddle::platform::Place& place) {
stream); stream);
} }
void TestBlasAPI(const paddle::platform::Place& place) {
std::cout << "TestBlasAPI on " << place << std::endl;
if (paddle::platform::is_custom_place(place) == false) {
return;
}
auto device = phi::DeviceManager::GetDeviceWithPlace(place);
phi::stream::Stream stream(place, nullptr);
device->BlasAXPBY<float>(stream, 0, 1., nullptr, 1., nullptr);
paddle::framework::Variable var1;
paddle::framework::Variable var2;
std::vector<float> src_data(10, 1.0);
std::vector<float> dst_data(10, 0.0);
std::vector<float> result;
paddle::platform::CPUPlace src_place;
for (unsigned int i = 0; i < 10; i++) {
result.emplace_back(src_data[i] + dst_data[i]);
}
std::vector<int64_t> dims = {2, 5};
auto* src = var1.GetMutable<paddle::framework::LoDTensor>();
auto* dst = var2.GetMutable<paddle::framework::LoDTensor>();
src->Resize(phi::make_ddim(dims));
dst->Resize(phi::make_ddim(dims));
auto* src_mutable = src->mutable_data<float>(place);
auto* dst_mutable = dst->mutable_data<float>(place);
paddle::memory::Copy(place,
src_mutable,
src_place,
src_data.data(),
sizeof(float) * src_data.size());
paddle::memory::Copy(place,
dst_mutable,
src_place,
dst_data.data(),
sizeof(float) * dst_data.size());
paddle::imperative::TensorAdd<paddle::framework::Variable>(var1, &var2);
paddle::framework::LoDTensor rlt;
paddle::platform::CPUPlace rlt_place;
paddle::framework::TensorCopySync(*dst, rlt_place, &rlt);
}
TEST(CustomDevice, Tensor) { TEST(CustomDevice, Tensor) {
InitDevice(); InitDevice();
auto dev_types = phi::DeviceManager::GetAllDeviceTypes(); auto dev_types = phi::DeviceManager::GetAllDeviceTypes();
...@@ -251,6 +298,7 @@ TEST(CustomDevice, Tensor) { ...@@ -251,6 +298,7 @@ TEST(CustomDevice, Tensor) {
TestTensorShareDataWith(place); TestTensorShareDataWith(place);
TestTensorUtils(place); TestTensorUtils(place);
TestCustomCCL(place); TestCustomCCL(place);
TestBlasAPI(place);
} }
} }
......
...@@ -210,6 +210,17 @@ C_Status XcclRecv(void *recv_buf, ...@@ -210,6 +210,17 @@ C_Status XcclRecv(void *recv_buf,
return C_SUCCESS; return C_SUCCESS;
} }
C_Status BlasAXPBY(const C_Device device,
C_Stream stream,
C_DataType dtype,
size_t numel,
float alpha,
void *x,
float beta,
void *y) {
return C_SUCCESS;
}
#define DEVICE_TYPE "FakeCPU" #define DEVICE_TYPE "FakeCPU"
#define SUB_DEVICE_TYPE "V100" #define SUB_DEVICE_TYPE "V100"
...@@ -278,4 +289,6 @@ void InitFakeCPUDevice(CustomRuntimeParams *params) { ...@@ -278,4 +289,6 @@ void InitFakeCPUDevice(CustomRuntimeParams *params) {
params->interface->xccl_reduce_scatter = XcclReduceScatter; params->interface->xccl_reduce_scatter = XcclReduceScatter;
params->interface->xccl_send = XcclSend; params->interface->xccl_send = XcclSend;
params->interface->xccl_recv = XcclRecv; params->interface->xccl_recv = XcclRecv;
params->interface->blas_axpby = BlasAXPBY;
} }
...@@ -355,6 +355,18 @@ void DeviceInterface::CCLRecv(void* recvbuf, ...@@ -355,6 +355,18 @@ void DeviceInterface::CCLRecv(void* recvbuf,
INTERFACE_UNIMPLEMENT; INTERFACE_UNIMPLEMENT;
} }
// blas
void DeviceInterface::BlasAXPBY(size_t dev_id,
const stream::Stream& stream,
paddle::experimental::DataType dtype,
size_t numel,
float alpha,
void* x,
float beta,
void* y) {
INTERFACE_UNIMPLEMENT;
}
#undef INTERFACE_UNIMPLEMENT #undef INTERFACE_UNIMPLEMENT
} // namespace phi } // namespace phi
...@@ -225,6 +225,16 @@ class DeviceInterface { // Driver / Runtime ...@@ -225,6 +225,16 @@ class DeviceInterface { // Driver / Runtime
const ccl::CCLComm& ccl_comm, const ccl::CCLComm& ccl_comm,
const stream::Stream& stream); const stream::Stream& stream);
// blas
virtual void BlasAXPBY(size_t dev_id,
const stream::Stream& stream,
paddle::experimental::DataType dtype,
size_t numel,
float alpha,
void* x,
float beta,
void* y);
private: private:
const std::string type_; const std::string type_;
const uint8_t priority_; const uint8_t priority_;
......
...@@ -635,7 +635,19 @@ struct C_DeviceInterface { ...@@ -635,7 +635,19 @@ struct C_DeviceInterface {
// other api // // other api //
/////////////// ///////////////
void* reserved_other_api[8]; /**
* @brief y = alpha * x + beta * y
*
*/
C_Status (*blas_axpby)(const C_Device device,
C_Stream stream,
C_DataType dtype,
size_t numel,
float alpha,
void* x,
float beta,
void* y);
void* reserved_other_api[7];
}; };
struct CustomRuntimeVersion { struct CustomRuntimeVersion {
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#ifdef PADDLE_WITH_CUSTOM_DEVICE #ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/backends/device_manager.h" #include "paddle/phi/backends/device_manager.h"
#include "paddle/phi/common/complex.h"
#if !defined(_WIN32) #if !defined(_WIN32)
#include <dirent.h> #include <dirent.h>
...@@ -135,6 +136,80 @@ void Device::MemorySet(void* ptr, uint8_t value, size_t size) { ...@@ -135,6 +136,80 @@ void Device::MemorySet(void* ptr, uint8_t value, size_t size) {
impl_->MemorySet(dev_id_, ptr, value, size); impl_->MemorySet(dev_id_, ptr, value, size);
} }
template <typename T>
void Device::BlasAXPBY(const stream::Stream& stream,
size_t numel,
float alpha,
const T* x,
float beta,
T* y) {
impl_->BlasAXPBY(dev_id_,
stream,
paddle::experimental::CppTypeToDataType<T>::Type(),
numel,
alpha,
reinterpret_cast<void*>(const_cast<T*>(x)),
beta,
reinterpret_cast<void*>(y));
}
template void Device::BlasAXPBY<paddle::float16>(const stream::Stream& stream,
size_t numel,
float alpha,
const paddle::float16* x,
float beta,
paddle::float16* y);
template void Device::BlasAXPBY<float>(const stream::Stream& stream,
size_t numel,
float alpha,
const float* x,
float beta,
float* y);
template void Device::BlasAXPBY<double>(const stream::Stream& stream,
size_t numel,
float alpha,
const double* x,
float beta,
double* y);
template void Device::BlasAXPBY<int8_t>(const stream::Stream& stream,
size_t numel,
float alpha,
const int8_t* x,
float beta,
int8_t* y);
template void Device::BlasAXPBY<int16_t>(const stream::Stream& stream,
size_t numel,
float alpha,
const int16_t* x,
float beta,
int16_t* y);
template void Device::BlasAXPBY<int32_t>(const stream::Stream& stream,
size_t numel,
float alpha,
const int32_t* x,
float beta,
int32_t* y);
template void Device::BlasAXPBY<int64_t>(const stream::Stream& stream,
size_t numel,
float alpha,
const int64_t* x,
float beta,
int64_t* y);
template void Device::BlasAXPBY<phi::dtype::complex<float>>(
const stream::Stream& stream,
size_t numel,
float alpha,
const phi::dtype::complex<float>* x,
float beta,
phi::dtype::complex<float>* y);
template void Device::BlasAXPBY<phi::dtype::complex<double>>(
const stream::Stream& stream,
size_t numel,
float alpha,
const phi::dtype::complex<double>* x,
float beta,
phi::dtype::complex<double>* y);
std::string Device::Type() { return impl_->Type(); } std::string Device::Type() { return impl_->Type(); }
static phi::RWLock _global_device_manager_rw_lock; static phi::RWLock _global_device_manager_rw_lock;
......
...@@ -17,14 +17,16 @@ ...@@ -17,14 +17,16 @@
#include <unordered_map> #include <unordered_map>
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/utils/rw_lock.h"
#include "paddle/phi/backends/c_comm_lib.h" #include "paddle/phi/backends/c_comm_lib.h"
#include "paddle/phi/backends/device_base.h" #include "paddle/phi/backends/device_base.h"
#include "paddle/phi/backends/device_ext.h" #include "paddle/phi/backends/device_ext.h"
#include "paddle/phi/backends/dynload/port.h" #include "paddle/phi/backends/dynload/port.h"
#include "paddle/phi/backends/event.h" #include "paddle/phi/backends/event.h"
#include "paddle/phi/backends/stream.h" #include "paddle/phi/backends/stream.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/utils/rw_lock.h"
namespace phi { namespace phi {
class Device final { class Device final {
...@@ -106,6 +108,16 @@ class Device final { ...@@ -106,6 +108,16 @@ class Device final {
void MemorySet(void* ptr, uint8_t value, size_t size); void MemorySet(void* ptr, uint8_t value, size_t size);
// Blas
// ! y = alpha * x + beta * y
template <typename T>
void BlasAXPBY(const stream::Stream& stream,
size_t numel,
float alpha,
const T* x,
float beta,
T* y);
std::string Type(); std::string Type();
private: private:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册