未验证 提交 0d51fcf1 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] add blas_axpby api for gradient_accumulator (#44584)

上级 356ff436
......@@ -39,6 +39,9 @@
#ifdef PADDLE_WITH_MLU
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/backends/device_manager.h"
#endif
namespace paddle {
namespace imperative {
......@@ -189,10 +192,19 @@ class TensorAddFunctor
place));
}
void operator()(const platform::CustomPlace& place) const {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
platform::CustomDeviceContext* ctx =
dynamic_cast<platform::CustomDeviceContext*>(
platform::DeviceContextPool::Instance().Get(place));
phi::stream::Stream stream(place, ctx->stream());
auto device = phi::DeviceManager::GetDeviceWithPlace(place);
device->BlasAXPBY<T>(stream, static_cast<size_t>(numel_), 1., x_, 1., y_);
#else
PADDLE_THROW(platform::errors::PermissionDenied(
"Gradient accumulation on place (%s) "
"is not supported in imperative mode",
place));
#endif
}
private:
......@@ -351,15 +363,7 @@ void TensorAdd(const VarType& src, VarType* dst) {
return;
}
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
if (platform::is_custom_place(place)) {
PADDLE_THROW(platform::errors::Unimplemented(
"Gradient accumulation of data type (%s) on place (%s) is not "
"supported in imperative mode",
framework::DataTypeToString(data_type),
place));
}
#endif
#ifdef PADDLE_WITH_XPU
if (platform::is_xpu_place(place)) {
if (data_type == framework::DataTypeTrait<float>::DataType()) {
......
......@@ -51,7 +51,7 @@ if(WITH_CUSTOM_DEVICE)
cc_test(
custom_device_test
SRCS custom/custom_device_test.cc
DEPS phi_backends phi_device_context)
DEPS phi_backends phi_device_context gradient_accumulator)
cc_test(
capi_test
SRCS custom/capi_test.cc
......
......@@ -14,6 +14,8 @@
#include "paddle/fluid/platform/device/custom/enforce_custom.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/backends/callback_manager.h"
#include "paddle/phi/backends/device_base.h"
#include "paddle/phi/backends/device_guard.h"
......@@ -608,6 +610,27 @@ class CustomDevice : public DeviceInterface {
#undef return_result
}
C_DataType ToCDatatType(paddle::experimental::DataType data_type) {
#define return_result(in, ret) \
case in: \
return C_DataType::ret
switch (data_type) {
return_result(paddle::experimental::DataType::FLOAT64, FLOAT64);
return_result(paddle::experimental::DataType::FLOAT32, FLOAT32);
return_result(paddle::experimental::DataType::FLOAT16, FLOAT16);
return_result(paddle::experimental::DataType::INT64, INT64);
return_result(paddle::experimental::DataType::INT32, INT32);
return_result(paddle::experimental::DataType::INT16, INT16);
return_result(paddle::experimental::DataType::INT8, INT8);
default: {
PADDLE_THROW(phi::errors::Unavailable(
"DataType is not supported on %s.", Type()));
return C_DataType::UNDEFINED;
}
}
#undef return_result
}
void CCLGetUniqueId(ccl::CCLRootId* unique_id) override {
CHECK_PTR(pimpl_->xccl_get_unique_id_size);
CHECK_PTR(pimpl_->xccl_get_unique_id);
......@@ -771,6 +794,27 @@ class CustomDevice : public DeviceInterface {
reinterpret_cast<C_Stream>(stream.raw_stream())));
}
void BlasAXPBY(size_t dev_id,
const stream::Stream& stream,
paddle::experimental::DataType dtype,
size_t numel,
float alpha,
void* x,
float beta,
void* y) override {
CHECK_PTR(pimpl_->blas_axpby);
const auto device = &devices_pool[dev_id];
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(
pimpl_->blas_axpby(device,
reinterpret_cast<C_Stream>(stream.raw_stream()),
ToCDatatType(dtype),
numel,
alpha,
x,
beta,
y));
}
private:
inline int PlaceToIdNoCheck(const Place& place) {
int dev_id = place.GetDeviceId();
......@@ -877,6 +921,8 @@ bool ValidCustomCustomRuntimeParams(const CustomRuntimeParams* params) {
CHECK_INTERFACE(xccl_group_end, false);
CHECK_INTERFACE(xccl_send, false);
CHECK_INTERFACE(xccl_recv, false);
CHECK_INTERFACE(blas_axpby, false);
return true;
#undef CHECK_INTERFACE
}
......
......@@ -18,6 +18,8 @@
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/imperative/gradient_accumulator.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/backends/custom/fake_cpu_device.h"
#include "paddle/phi/backends/device_manager.h"
......@@ -237,6 +239,51 @@ void TestCustomCCL(const paddle::platform::Place& place) {
stream);
}
void TestBlasAPI(const paddle::platform::Place& place) {
std::cout << "TestBlasAPI on " << place << std::endl;
if (paddle::platform::is_custom_place(place) == false) {
return;
}
auto device = phi::DeviceManager::GetDeviceWithPlace(place);
phi::stream::Stream stream(place, nullptr);
device->BlasAXPBY<float>(stream, 0, 1., nullptr, 1., nullptr);
paddle::framework::Variable var1;
paddle::framework::Variable var2;
std::vector<float> src_data(10, 1.0);
std::vector<float> dst_data(10, 0.0);
std::vector<float> result;
paddle::platform::CPUPlace src_place;
for (unsigned int i = 0; i < 10; i++) {
result.emplace_back(src_data[i] + dst_data[i]);
}
std::vector<int64_t> dims = {2, 5};
auto* src = var1.GetMutable<paddle::framework::LoDTensor>();
auto* dst = var2.GetMutable<paddle::framework::LoDTensor>();
src->Resize(phi::make_ddim(dims));
dst->Resize(phi::make_ddim(dims));
auto* src_mutable = src->mutable_data<float>(place);
auto* dst_mutable = dst->mutable_data<float>(place);
paddle::memory::Copy(place,
src_mutable,
src_place,
src_data.data(),
sizeof(float) * src_data.size());
paddle::memory::Copy(place,
dst_mutable,
src_place,
dst_data.data(),
sizeof(float) * dst_data.size());
paddle::imperative::TensorAdd<paddle::framework::Variable>(var1, &var2);
paddle::framework::LoDTensor rlt;
paddle::platform::CPUPlace rlt_place;
paddle::framework::TensorCopySync(*dst, rlt_place, &rlt);
}
TEST(CustomDevice, Tensor) {
InitDevice();
auto dev_types = phi::DeviceManager::GetAllDeviceTypes();
......@@ -251,6 +298,7 @@ TEST(CustomDevice, Tensor) {
TestTensorShareDataWith(place);
TestTensorUtils(place);
TestCustomCCL(place);
TestBlasAPI(place);
}
}
......
......@@ -210,6 +210,17 @@ C_Status XcclRecv(void *recv_buf,
return C_SUCCESS;
}
C_Status BlasAXPBY(const C_Device device,
C_Stream stream,
C_DataType dtype,
size_t numel,
float alpha,
void *x,
float beta,
void *y) {
return C_SUCCESS;
}
#define DEVICE_TYPE "FakeCPU"
#define SUB_DEVICE_TYPE "V100"
......@@ -278,4 +289,6 @@ void InitFakeCPUDevice(CustomRuntimeParams *params) {
params->interface->xccl_reduce_scatter = XcclReduceScatter;
params->interface->xccl_send = XcclSend;
params->interface->xccl_recv = XcclRecv;
params->interface->blas_axpby = BlasAXPBY;
}
......@@ -355,6 +355,18 @@ void DeviceInterface::CCLRecv(void* recvbuf,
INTERFACE_UNIMPLEMENT;
}
// blas
void DeviceInterface::BlasAXPBY(size_t dev_id,
const stream::Stream& stream,
paddle::experimental::DataType dtype,
size_t numel,
float alpha,
void* x,
float beta,
void* y) {
INTERFACE_UNIMPLEMENT;
}
#undef INTERFACE_UNIMPLEMENT
} // namespace phi
......@@ -225,6 +225,16 @@ class DeviceInterface { // Driver / Runtime
const ccl::CCLComm& ccl_comm,
const stream::Stream& stream);
// blas
virtual void BlasAXPBY(size_t dev_id,
const stream::Stream& stream,
paddle::experimental::DataType dtype,
size_t numel,
float alpha,
void* x,
float beta,
void* y);
private:
const std::string type_;
const uint8_t priority_;
......
......@@ -635,7 +635,19 @@ struct C_DeviceInterface {
// other api //
///////////////
void* reserved_other_api[8];
/**
* @brief y = alpha * x + beta * y
*
*/
C_Status (*blas_axpby)(const C_Device device,
C_Stream stream,
C_DataType dtype,
size_t numel,
float alpha,
void* x,
float beta,
void* y);
void* reserved_other_api[7];
};
struct CustomRuntimeVersion {
......
......@@ -14,6 +14,7 @@
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/backends/device_manager.h"
#include "paddle/phi/common/complex.h"
#if !defined(_WIN32)
#include <dirent.h>
......@@ -135,6 +136,80 @@ void Device::MemorySet(void* ptr, uint8_t value, size_t size) {
impl_->MemorySet(dev_id_, ptr, value, size);
}
template <typename T>
void Device::BlasAXPBY(const stream::Stream& stream,
size_t numel,
float alpha,
const T* x,
float beta,
T* y) {
impl_->BlasAXPBY(dev_id_,
stream,
paddle::experimental::CppTypeToDataType<T>::Type(),
numel,
alpha,
reinterpret_cast<void*>(const_cast<T*>(x)),
beta,
reinterpret_cast<void*>(y));
}
template void Device::BlasAXPBY<paddle::float16>(const stream::Stream& stream,
size_t numel,
float alpha,
const paddle::float16* x,
float beta,
paddle::float16* y);
template void Device::BlasAXPBY<float>(const stream::Stream& stream,
size_t numel,
float alpha,
const float* x,
float beta,
float* y);
template void Device::BlasAXPBY<double>(const stream::Stream& stream,
size_t numel,
float alpha,
const double* x,
float beta,
double* y);
template void Device::BlasAXPBY<int8_t>(const stream::Stream& stream,
size_t numel,
float alpha,
const int8_t* x,
float beta,
int8_t* y);
template void Device::BlasAXPBY<int16_t>(const stream::Stream& stream,
size_t numel,
float alpha,
const int16_t* x,
float beta,
int16_t* y);
template void Device::BlasAXPBY<int32_t>(const stream::Stream& stream,
size_t numel,
float alpha,
const int32_t* x,
float beta,
int32_t* y);
template void Device::BlasAXPBY<int64_t>(const stream::Stream& stream,
size_t numel,
float alpha,
const int64_t* x,
float beta,
int64_t* y);
template void Device::BlasAXPBY<phi::dtype::complex<float>>(
const stream::Stream& stream,
size_t numel,
float alpha,
const phi::dtype::complex<float>* x,
float beta,
phi::dtype::complex<float>* y);
template void Device::BlasAXPBY<phi::dtype::complex<double>>(
const stream::Stream& stream,
size_t numel,
float alpha,
const phi::dtype::complex<double>* x,
float beta,
phi::dtype::complex<double>* y);
std::string Device::Type() { return impl_->Type(); }
static phi::RWLock _global_device_manager_rw_lock;
......
......@@ -17,14 +17,16 @@
#include <unordered_map>
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/utils/rw_lock.h"
#include "paddle/phi/backends/c_comm_lib.h"
#include "paddle/phi/backends/device_base.h"
#include "paddle/phi/backends/device_ext.h"
#include "paddle/phi/backends/dynload/port.h"
#include "paddle/phi/backends/event.h"
#include "paddle/phi/backends/stream.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/utils/rw_lock.h"
namespace phi {
class Device final {
......@@ -106,6 +108,16 @@ class Device final {
void MemorySet(void* ptr, uint8_t value, size_t size);
// Blas
// ! y = alpha * x + beta * y
template <typename T>
void BlasAXPBY(const stream::Stream& stream,
size_t numel,
float alpha,
const T* x,
float beta,
T* y);
std::string Type();
private:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册