未验证 提交 f3270fc8 编写于 作者: W wanghuancoder 提交者: GitHub

[Eager] Support pinned (#41035)

* support pinned, test=develop

* support async_write, test=develop

* refine, test=develop

* refine, test=develop

* refine, test=develop

* refine,test=develop

* refine, test=develop

* refine, test=develop

* refine, test=develop

* refine, test=develop
上级 53a62ea4
...@@ -38,6 +38,10 @@ class TensorRTEngineTest : public ::testing::Test { ...@@ -38,6 +38,10 @@ class TensorRTEngineTest : public ::testing::Test {
paddle::memory::allocation::AllocatorFacade::Instance() paddle::memory::allocation::AllocatorFacade::Instance()
.GetZeroAllocator(platform::CUDAPlace(0)) .GetZeroAllocator(platform::CUDAPlace(0))
.get()); .get());
ctx_->SetPinnedAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CUDAPinnedPlace())
.get());
ctx_->PartialInitWithAllocator(); ctx_->PartialInitWithAllocator();
engine_ = new TensorRTEngine(10, 1 << 10); engine_ = new TensorRTEngine(10, 1 << 10);
......
...@@ -120,6 +120,10 @@ TEST(Malloc, CUDADeviceContextMultiStream) { ...@@ -120,6 +120,10 @@ TEST(Malloc, CUDADeviceContextMultiStream) {
paddle::memory::allocation::AllocatorFacade::Instance() paddle::memory::allocation::AllocatorFacade::Instance()
.GetZeroAllocator(place) .GetZeroAllocator(place)
.get()); .get());
ctx->SetPinnedAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CUDAPinnedPlace())
.get());
ctx->PartialInitWithAllocator(); ctx->PartialInitWithAllocator();
dev_ctx.emplace_back(std::move(ctx)); dev_ctx.emplace_back(std::move(ctx));
MultiStreamCompute(&data[i], &second_data[i], *dev_ctx[i]); MultiStreamCompute(&data[i], &second_data[i], *dev_ctx[i]);
...@@ -172,6 +176,10 @@ TEST(Malloc, CUDADeviceContextMultiThreadMultiStream) { ...@@ -172,6 +176,10 @@ TEST(Malloc, CUDADeviceContextMultiThreadMultiStream) {
paddle::memory::allocation::AllocatorFacade::Instance() paddle::memory::allocation::AllocatorFacade::Instance()
.GetZeroAllocator(place) .GetZeroAllocator(place)
.get()); .get());
ctx->SetPinnedAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CUDAPinnedPlace())
.get());
ctx->PartialInitWithAllocator(); ctx->PartialInitWithAllocator();
dev_ctx.emplace_back(std::move(ctx)); dev_ctx.emplace_back(std::move(ctx));
threads.push_back(std::thread(MultiStreamCompute, &data[i], &second_data[i], threads.push_back(std::thread(MultiStreamCompute, &data[i], &second_data[i],
......
...@@ -292,6 +292,10 @@ class TestFeedForward { ...@@ -292,6 +292,10 @@ class TestFeedForward {
paddle::memory::allocation::AllocatorFacade::Instance() paddle::memory::allocation::AllocatorFacade::Instance()
.GetZeroAllocator(place_) .GetZeroAllocator(place_)
.get()); .get());
ctx_->SetPinnedAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CUDAPinnedPlace())
.get());
ctx_->PartialInitWithAllocator(); ctx_->PartialInitWithAllocator();
size_src_ = bsz_seq_ * dim_embed_; // src: [bs, seq_len, em_dim] size_src_ = bsz_seq_ * dim_embed_; // src: [bs, seq_len, em_dim]
......
...@@ -199,6 +199,10 @@ NCCLComm* NCCLCommContext::AssignNCCLComm(ncclComm_t comm, int nranks, int rank, ...@@ -199,6 +199,10 @@ NCCLComm* NCCLCommContext::AssignNCCLComm(ncclComm_t comm, int nranks, int rank,
paddle::memory::allocation::AllocatorFacade::Instance() paddle::memory::allocation::AllocatorFacade::Instance()
.GetZeroAllocator(CUDAPlace(dev_id)) .GetZeroAllocator(CUDAPlace(dev_id))
.get()); .get());
dev_ctx->SetPinnedAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CUDAPinnedPlace())
.get());
dev_ctx->PartialInitWithAllocator(); dev_ctx->PartialInitWithAllocator();
std::shared_ptr<platform::CudaEventObject> compute_event( std::shared_ptr<platform::CudaEventObject> compute_event(
......
...@@ -113,6 +113,10 @@ struct NCCLContext { ...@@ -113,6 +113,10 @@ struct NCCLContext {
paddle::memory::allocation::AllocatorFacade::Instance() paddle::memory::allocation::AllocatorFacade::Instance()
.GetZeroAllocator(CUDAPlace(dev_id)) .GetZeroAllocator(CUDAPlace(dev_id))
.get()); .get());
ctx_->SetPinnedAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CUDAPinnedPlace())
.get());
ctx_->PartialInitWithAllocator(); ctx_->PartialInitWithAllocator();
} }
......
...@@ -162,6 +162,11 @@ inline void EmplaceDeviceContext( ...@@ -162,6 +162,11 @@ inline void EmplaceDeviceContext(
dev_ctx->SetAllocator(memory::allocation::AllocatorFacade::Instance() dev_ctx->SetAllocator(memory::allocation::AllocatorFacade::Instance()
.GetAllocator(p) .GetAllocator(p)
.get()); .get());
dev_ctx->SetPinnedAllocator(
memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CUDAPinnedPlace())
.get());
cuda_ctx->PartialInitWithAllocator(); cuda_ctx->PartialInitWithAllocator();
dev_ctx->SetGenerator( dev_ctx->SetGenerator(
framework::GetDefaultCUDAGenerator(p.GetDeviceId()).get()); framework::GetDefaultCUDAGenerator(p.GetDeviceId()).get());
......
...@@ -39,6 +39,10 @@ TEST(Device, Init) { ...@@ -39,6 +39,10 @@ TEST(Device, Init) {
paddle::memory::allocation::AllocatorFacade::Instance() paddle::memory::allocation::AllocatorFacade::Instance()
.GetZeroAllocator(CUDAPlace(i)) .GetZeroAllocator(CUDAPlace(i))
.get()); .get());
device_context->SetPinnedAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CUDAPinnedPlace())
.get());
device_context->PartialInitWithAllocator(); device_context->PartialInitWithAllocator();
Eigen::GpuDevice* gpu_device = device_context->eigen_device(); Eigen::GpuDevice* gpu_device = device_context->eigen_device();
...@@ -66,6 +70,10 @@ TEST(Device, CUDADeviceContext) { ...@@ -66,6 +70,10 @@ TEST(Device, CUDADeviceContext) {
paddle::memory::allocation::AllocatorFacade::Instance() paddle::memory::allocation::AllocatorFacade::Instance()
.GetZeroAllocator(CUDAPlace(i)) .GetZeroAllocator(CUDAPlace(i))
.get()); .get());
device_context->SetPinnedAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CUDAPinnedPlace())
.get());
device_context->PartialInitWithAllocator(); device_context->PartialInitWithAllocator();
Eigen::GpuDevice* gpu_device = device_context->eigen_device(); Eigen::GpuDevice* gpu_device = device_context->eigen_device();
ASSERT_NE(nullptr, gpu_device); ASSERT_NE(nullptr, gpu_device);
......
...@@ -28,8 +28,10 @@ limitations under the License. */ ...@@ -28,8 +28,10 @@ limitations under the License. */
#include "paddle/fluid/framework/op_meta_info_helper.h" #include "paddle/fluid/framework/op_meta_info_helper.h"
#include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/dynload/dynamic_loader.h" #include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/stream/cuda_stream.h"
#include "paddle/fluid/pybind/eager.h" #include "paddle/fluid/pybind/eager.h"
#include "paddle/fluid/pybind/eager_utils.h" #include "paddle/fluid/pybind/eager_utils.h"
#include "paddle/fluid/pybind/exception.h" #include "paddle/fluid/pybind/exception.h"
...@@ -536,7 +538,239 @@ static PyObject* eager_api_sparse_csr_tensor(PyObject* self, PyObject* args, ...@@ -536,7 +538,239 @@ static PyObject* eager_api_sparse_csr_tensor(PyObject* self, PyObject* args,
return ToPyObject(tensor); return ToPyObject(tensor);
EAGER_CATCH_AND_THROW_RETURN_NULL EAGER_CATCH_AND_THROW_RETURN_NULL
} }
#if defined(PADDLE_WITH_CUDA)
static PyObject* eager_api_async_read(PyObject* self, PyObject* args,
PyObject* kwargs) {
EAGER_TRY
auto& src = GetTensorFromArgs("async_read", "src", args, 0, false);
auto& dst = GetTensorFromArgs("async_read", "dst", args, 1, false);
auto& index = GetTensorFromArgs("async_read", "index", args, 2, false);
auto& buffer = GetTensorFromArgs("async_read", "buffer", args, 3, false);
auto& offset = GetTensorFromArgs("async_read", "offset", args, 4, false);
auto& count = GetTensorFromArgs("async_read", "count", args, 5, false);
PADDLE_ENFORCE_EQ(
src.is_gpu_pinned(), true,
platform::errors::InvalidArgument("Required `src` device should be "
"CUDAPinnedPlace, but received %d.",
src.inner_place()));
PADDLE_ENFORCE_EQ(
dst.is_gpu(), true,
platform::errors::InvalidArgument(
"Required `dst` device should be CUDAPlace, but received %d.",
dst.inner_place()));
PADDLE_ENFORCE_EQ(
index.is_cpu(), true,
platform::errors::InvalidArgument(
"Required `index` device should be CPUPlace, but received %d.",
index.inner_place()));
PADDLE_ENFORCE_EQ(buffer.is_gpu_pinned(), true,
platform::errors::InvalidArgument(
"Required `buffer` device should be CUDAPinnedPlace, "
"but received %d.",
buffer.inner_place()));
PADDLE_ENFORCE_EQ(
offset.is_cpu(), true,
platform::errors::InvalidArgument(
"Required `offset` device should be CPUPlace, but received %d.",
offset.inner_place()));
PADDLE_ENFORCE_EQ(
count.is_cpu(), true,
platform::errors::InvalidArgument(
"Required `count` device should be CPUPlace, but received %d.",
count.inner_place()));
auto& src_tensor = src;
auto* dst_tensor = &dst;
auto& index_tensor = index;
auto* buffer_tensor = &buffer;
auto& offset_tensor = offset;
auto& count_tensor = count;
auto* dst_data = dst_tensor->mutable_data<float>(dst.place());
const auto& deviceId = paddle::platform::GetCurrentDeviceId();
PADDLE_ENFORCE_EQ(src_tensor.dims().size(), dst_tensor->dims().size(),
platform::errors::InvalidArgument(
"`src` and `dst` should have same tensor shape, "
"except for the first dimension."));
PADDLE_ENFORCE_EQ(src_tensor.dims().size(), buffer_tensor->dims().size(),
platform::errors::InvalidArgument(
"`src` and `buffer` should have same tensor shape, "
"except for the first dimension."));
for (int i = 1; i < src_tensor.dims().size(); i++) {
PADDLE_ENFORCE_EQ(src_tensor.dims()[i], dst_tensor->dims()[i],
platform::errors::InvalidArgument(
"`src` and `dst` should have the same tensor shape, "
"except for the first dimension."));
PADDLE_ENFORCE_EQ(
src_tensor.dims()[i], buffer_tensor->dims()[i],
platform::errors::InvalidArgument(
"`src` and `buffer` should have the same tensor shape, "
"except for the first dimension."));
}
PADDLE_ENFORCE_EQ(index_tensor.dims().size(), 1,
platform::errors::InvalidArgument(
"`index` tensor should be one-dimensional."));
auto stream =
paddle::platform::stream::get_current_stream(deviceId)->raw_stream();
int64_t numel = 0; // total copy length
int64_t copy_flag = offset_tensor.dims()[0];
int64_t size = src_tensor.numel() / src_tensor.dims()[0];
if (copy_flag != 0) {
PADDLE_ENFORCE_EQ(offset_tensor.dims().size(), 1,
platform::errors::InvalidArgument(
"`offset` tensor should be one-dimensional."));
PADDLE_ENFORCE_EQ(count_tensor.dims().size(), 1,
platform::errors::InvalidArgument(
"`count` tensor should be one-dimensional."));
PADDLE_ENFORCE_EQ(offset_tensor.numel(), count_tensor.numel(),
platform::errors::InvalidArgument(
"`offset` and `count` tensor size dismatch."));
auto* offset_data = offset_tensor.data<int64_t>();
auto* count_data = count_tensor.data<int64_t>();
for (int64_t i = 0; i < count_tensor.numel(); i++) {
numel += count_data[i];
}
PADDLE_ENFORCE_LE(
numel + index_tensor.numel(), buffer_tensor->dims()[0],
platform::errors::InvalidArgument("Buffer tensor size is too small."));
PADDLE_ENFORCE_LE(
numel + index_tensor.numel(), dst_tensor->dims()[0],
platform::errors::InvalidArgument("Target tensor size is too small."));
int64_t src_offset, dst_offset = 0, c;
auto* src_data = src_tensor.data<float>();
for (int64_t i = 0; i < offset_tensor.numel(); i++) {
src_offset = offset_data[i], c = count_data[i];
PADDLE_ENFORCE_LE(
src_offset + c, src_tensor.dims()[0],
platform::errors::InvalidArgument("Invalid offset or count index."));
PADDLE_ENFORCE_LE(
dst_offset + c, dst_tensor->dims()[0],
platform::errors::InvalidArgument("Invalid offset or count index."));
cudaMemcpyAsync(dst_data + (dst_offset * size),
src_data + (src_offset * size), c * size * sizeof(float),
cudaMemcpyHostToDevice, stream);
dst_offset += c;
}
} else {
PADDLE_ENFORCE_LE(
index_tensor.numel(), buffer_tensor->dims()[0],
platform::errors::InvalidArgument("Buffer tensor size is too small."));
}
// Select the index data to the buffer
auto index_select = [](const paddle::experimental::Tensor& src_tensor,
const paddle::experimental::Tensor& index_tensor,
paddle::experimental::Tensor* buffer_tensor) {
auto* src_data = src_tensor.data<float>();
auto* index_data = index_tensor.data<int64_t>();
auto* buffer_data = buffer_tensor->data<float>();
const int& slice_size = src_tensor.numel() / src_tensor.dims()[0];
const int& copy_bytes = slice_size * sizeof(float);
int64_t c = 0;
for (int64_t i = 0; i < index_tensor.numel(); i++) {
std::memcpy(buffer_data + c * slice_size,
src_data + index_data[i] * slice_size, copy_bytes);
c += 1;
}
};
index_select(src_tensor, index_tensor, buffer_tensor);
// Copy the data to device memory
cudaMemcpyAsync(dst_data + (numel * size), buffer_tensor->data<float>(),
index_tensor.numel() * size * sizeof(float),
cudaMemcpyHostToDevice, stream);
Py_INCREF(Py_None);
return Py_None;
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* eager_api_async_write(PyObject* self, PyObject* args,
PyObject* kwargs) {
EAGER_TRY
auto& src = GetTensorFromArgs("async_write", "src", args, 0, false);
auto& dst = GetTensorFromArgs("async_write", "dst", args, 1, false);
auto& offset = GetTensorFromArgs("async_write", "offset", args, 2, false);
auto& count = GetTensorFromArgs("async_write", "count", args, 3, false);
PADDLE_ENFORCE_EQ(
src.is_gpu(), true,
platform::errors::InvalidArgument(
"Required `src` device should be CUDAPlace, but received %d. ",
src.inner_place()));
PADDLE_ENFORCE_EQ(dst.is_gpu_pinned(), true,
platform::errors::InvalidArgument(
"Required `dst` device should be CUDAPinnedPlace, "
"but received %d. ",
dst.inner_place()));
PADDLE_ENFORCE_EQ(
offset.is_cpu(), true,
platform::errors::InvalidArgument("Required `offset` device should "
"be CPUPlace, but received %d. ",
offset.inner_place()));
PADDLE_ENFORCE_EQ(
count.is_cpu(), true,
platform::errors::InvalidArgument(
"Required `count` device should be CPUPlace, but received %d. ",
count.inner_place()));
// TODO(daisiming): In future, add index as arguments following
// async_read.
auto& src_tensor = src;
auto* dst_tensor = &dst;
auto& offset_tensor = offset;
auto& count_tensor = count;
const auto& deviceId = paddle::platform::GetCurrentDeviceId();
PADDLE_ENFORCE_EQ(offset_tensor.dims().size(), 1,
platform::errors::InvalidArgument(
"`offset` tensor should be one-dimensional."));
PADDLE_ENFORCE_EQ(count_tensor.dims().size(), 1,
platform::errors::InvalidArgument(
"`count` tensor should be one-dimensional."));
PADDLE_ENFORCE_EQ(offset_tensor.numel(), count_tensor.numel(),
platform::errors::InvalidArgument(
"`offset` and `count` tensor size dismatch."));
PADDLE_ENFORCE_EQ(src_tensor.dims().size(), dst_tensor->dims().size(),
platform::errors::InvalidArgument(
"`src` and `dst` should have the same tensor shape, "
"except for the first dimension."));
for (int i = 1; i < src_tensor.dims().size(); i++) {
PADDLE_ENFORCE_EQ(src_tensor.dims()[i], dst_tensor->dims()[i],
platform::errors::InvalidArgument(
"`src` and `dst` should have the same tensor shape, "
"except for the first dimension."));
}
auto stream =
paddle::platform::stream::get_current_stream(deviceId)->raw_stream();
int64_t size = src_tensor.numel() / src_tensor.dims()[0];
auto* src_data = src_tensor.data<float>();
auto* dst_data = dst_tensor->data<float>();
const int64_t* offset_data = offset_tensor.data<int64_t>();
const int64_t* count_data = count_tensor.data<int64_t>();
int64_t src_offset = 0, dst_offset, c;
for (int64_t i = 0; i < offset_tensor.numel(); i++) {
dst_offset = offset_data[i], c = count_data[i];
PADDLE_ENFORCE_LE(
src_offset + c, src_tensor.dims()[0],
platform::errors::InvalidArgument("Invalid offset or count index"));
PADDLE_ENFORCE_LE(
dst_offset + c, dst_tensor->dims()[0],
platform::errors::InvalidArgument("Invalid offset or count index"));
cudaMemcpyAsync(dst_data + (dst_offset * size),
src_data + (src_offset * size), c * size * sizeof(float),
cudaMemcpyDeviceToHost, stream);
src_offset += c;
}
Py_INCREF(Py_None);
return Py_None;
EAGER_CATCH_AND_THROW_RETURN_NULL
}
#endif
PyMethodDef variable_functions[] = { PyMethodDef variable_functions[] = {
// TODO(jiabin): Remove scale when we have final state tests // TODO(jiabin): Remove scale when we have final state tests
{"scale", (PyCFunction)(void (*)(void))eager_api_scale, {"scale", (PyCFunction)(void (*)(void))eager_api_scale,
...@@ -560,6 +794,12 @@ PyMethodDef variable_functions[] = { ...@@ -560,6 +794,12 @@ PyMethodDef variable_functions[] = {
{"sparse_csr_tensor", {"sparse_csr_tensor",
(PyCFunction)(void (*)(void))eager_api_sparse_csr_tensor, (PyCFunction)(void (*)(void))eager_api_sparse_csr_tensor,
METH_VARARGS | METH_KEYWORDS, NULL}, METH_VARARGS | METH_KEYWORDS, NULL},
#if defined(PADDLE_WITH_CUDA)
{"async_read", (PyCFunction)(void (*)(void))eager_api_async_read,
METH_VARARGS | METH_KEYWORDS, NULL},
{"async_write", (PyCFunction)(void (*)(void))eager_api_async_write,
METH_VARARGS | METH_KEYWORDS, NULL},
#endif
/**sparse functions**/ /**sparse functions**/
{NULL, NULL, 0, NULL}}; {NULL, NULL, 0, NULL}};
......
...@@ -2007,6 +2007,10 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -2007,6 +2007,10 @@ All parameter, weight, gradient are variables in Paddle.
paddle::memory::allocation::AllocatorFacade::Instance() paddle::memory::allocation::AllocatorFacade::Instance()
.GetZeroAllocator(place) .GetZeroAllocator(place)
.get()); .get());
context->SetPinnedAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CUDAPinnedPlace())
.get());
context->PartialInitWithAllocator(); context->PartialInitWithAllocator();
return context; return context;
#endif #endif
......
...@@ -49,6 +49,14 @@ struct DeviceContext::Impl { ...@@ -49,6 +49,14 @@ struct DeviceContext::Impl {
zero_allocator_ = allocator; zero_allocator_ = allocator;
} }
void SetPinnedAllocator(const Allocator* allocator) {
PADDLE_ENFORCE_NOT_NULL(
allocator,
phi::errors::InvalidArgument(
"Required allocator shall not be nullptr, but received nullptr."));
pinned_allocator_ = allocator;
}
const Allocator& GetAllocator() const { const Allocator& GetAllocator() const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
device_allocator_, device_allocator_,
...@@ -68,15 +76,24 @@ struct DeviceContext::Impl { ...@@ -68,15 +76,24 @@ struct DeviceContext::Impl {
const Allocator& GetZeroAllocator() const { const Allocator& GetZeroAllocator() const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
zero_allocator_, zero_allocator_,
phi::errors::InvalidArgument("Required host_allocator_ shall not be " phi::errors::InvalidArgument("Required zero_allocator_ shall not be "
"nullptr, but received nullptr.")); "nullptr, but received nullptr."));
return *zero_allocator_; return *zero_allocator_;
} }
const Allocator& GetPinnedAllocator() const {
PADDLE_ENFORCE_NOT_NULL(
pinned_allocator_,
phi::errors::InvalidArgument("Required pinned_allocator_ shall not be "
"nullptr, but received nullptr."));
return *pinned_allocator_;
}
void* Alloc(TensorBase* tensor, void* Alloc(TensorBase* tensor,
const Place& place, const Place& place,
DataType dtype = DataType::UNDEFINED, DataType dtype = DataType::UNDEFINED,
size_t requested_size = 0) const { size_t requested_size = 0,
bool pinned = false) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
tensor, tensor,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
...@@ -90,8 +107,9 @@ struct DeviceContext::Impl { ...@@ -90,8 +107,9 @@ struct DeviceContext::Impl {
if (tensor->initialized() && tensor->place() != place) { if (tensor->initialized() && tensor->place() != place) {
ClearHolder(tensor); ClearHolder(tensor);
} }
auto* allocator = auto* allocator = tensor->numel() == 0
tensor->numel() == 0 ? zero_allocator_ : device_allocator_; ? zero_allocator_
: (pinned ? pinned_allocator_ : device_allocator_);
return tensor->AllocateFrom( return tensor->AllocateFrom(
const_cast<Allocator*>(allocator), dtype, requested_size); const_cast<Allocator*>(allocator), dtype, requested_size);
} }
...@@ -99,9 +117,10 @@ struct DeviceContext::Impl { ...@@ -99,9 +117,10 @@ struct DeviceContext::Impl {
template <typename T> template <typename T>
T* Alloc(TensorBase* tensor, T* Alloc(TensorBase* tensor,
const Place& place, const Place& place,
size_t requested_size = 0) const { size_t requested_size = 0,
bool pinned = false) const {
DataType dtype = paddle::experimental::CppTypeToDataType<T>::Type(); DataType dtype = paddle::experimental::CppTypeToDataType<T>::Type();
return static_cast<T*>(Alloc(tensor, place, dtype, requested_size)); return static_cast<T*>(Alloc(tensor, place, dtype, requested_size, pinned));
} }
void* HostAlloc(TensorBase* tensor, void* HostAlloc(TensorBase* tensor,
...@@ -179,6 +198,7 @@ struct DeviceContext::Impl { ...@@ -179,6 +198,7 @@ struct DeviceContext::Impl {
const Allocator* device_allocator_{nullptr}; const Allocator* device_allocator_{nullptr};
const Allocator* host_allocator_{nullptr}; const Allocator* host_allocator_{nullptr};
const Allocator* zero_allocator_{nullptr}; const Allocator* zero_allocator_{nullptr};
const Allocator* pinned_allocator_{nullptr};
Generator* device_generator_{nullptr}; Generator* device_generator_{nullptr};
Generator* host_generator_{nullptr}; Generator* host_generator_{nullptr};
}; };
...@@ -189,6 +209,7 @@ DeviceContext::DeviceContext(const DeviceContext& other) { ...@@ -189,6 +209,7 @@ DeviceContext::DeviceContext(const DeviceContext& other) {
impl_->SetHostAllocator(&other.GetHostAllocator()); impl_->SetHostAllocator(&other.GetHostAllocator());
impl_->SetAllocator(&other.GetAllocator()); impl_->SetAllocator(&other.GetAllocator());
impl_->SetZeroAllocator(&other.GetZeroAllocator()); impl_->SetZeroAllocator(&other.GetZeroAllocator());
impl_->SetPinnedAllocator(&other.GetPinnedAllocator());
impl_->SetHostGenerator(other.GetHostGenerator()); impl_->SetHostGenerator(other.GetHostGenerator());
impl_->SetGenerator(other.GetGenerator()); impl_->SetGenerator(other.GetGenerator());
} }
...@@ -225,15 +246,25 @@ const Allocator& DeviceContext::GetZeroAllocator() const { ...@@ -225,15 +246,25 @@ const Allocator& DeviceContext::GetZeroAllocator() const {
return impl_->GetZeroAllocator(); return impl_->GetZeroAllocator();
} }
void DeviceContext::SetPinnedAllocator(const Allocator* allocator) {
impl_->SetPinnedAllocator(allocator);
}
const Allocator& DeviceContext::GetPinnedAllocator() const {
return impl_->GetPinnedAllocator();
}
void* DeviceContext::Alloc(TensorBase* tensor, void* DeviceContext::Alloc(TensorBase* tensor,
DataType dtype, DataType dtype,
size_t requested_size) const { size_t requested_size,
return impl_->Alloc(tensor, GetPlace(), dtype, requested_size); bool pinned) const {
return impl_->Alloc(tensor, GetPlace(), dtype, requested_size, pinned);
} }
template <typename T> template <typename T>
T* DeviceContext::Alloc(TensorBase* tensor, size_t requested_size) const { T* DeviceContext::Alloc(TensorBase* tensor,
return impl_->Alloc<T>(tensor, GetPlace(), requested_size); size_t requested_size,
bool pinned) const {
return impl_->Alloc<T>(tensor, GetPlace(), requested_size, pinned);
} }
void* DeviceContext::HostAlloc(TensorBase* tensor, void* DeviceContext::HostAlloc(TensorBase* tensor,
...@@ -248,8 +279,8 @@ T* DeviceContext::HostAlloc(TensorBase* tensor, size_t requested_size) const { ...@@ -248,8 +279,8 @@ T* DeviceContext::HostAlloc(TensorBase* tensor, size_t requested_size) const {
} }
#define DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(dtype) \ #define DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(dtype) \
template dtype* DeviceContext::Alloc(TensorBase* tensor, \ template dtype* DeviceContext::Alloc( \
size_t requested_size) const; \ TensorBase* tensor, size_t requested_size, bool pinned) const; \
template dtype* DeviceContext::HostAlloc(TensorBase* tensor, \ template dtype* DeviceContext::HostAlloc(TensorBase* tensor, \
size_t requested_size) const; size_t requested_size) const;
......
...@@ -80,6 +80,13 @@ class DeviceContext { ...@@ -80,6 +80,13 @@ class DeviceContext {
*/ */
void SetZeroAllocator(const Allocator*); void SetZeroAllocator(const Allocator*);
/**
* @brief Set the zero-size Allocator object.
*
* @param allocator
*/
void SetPinnedAllocator(const Allocator*);
/** /**
* @brief Get the const Allocator object. * @brief Get the const Allocator object.
* *
...@@ -96,13 +103,20 @@ class DeviceContext { ...@@ -96,13 +103,20 @@ class DeviceContext {
const Allocator& GetZeroAllocator() const; const Allocator& GetZeroAllocator() const;
const Allocator& GetPinnedAllocator() const;
/** /**
* @brief Allocate device memory for tensor. * @brief Allocate device memory for tensor.
*/ */
void* Alloc(TensorBase*, DataType dtype, size_t requested_size = 0) const; void* Alloc(TensorBase*,
DataType dtype,
size_t requested_size = 0,
bool pinned = false) const;
template <typename T> template <typename T>
T* Alloc(TensorBase* tensor, size_t requested_size = 0) const; T* Alloc(TensorBase* tensor,
size_t requested_size = 0,
bool pinned = false) const;
/** /**
* @brief Allocate host memory for tensor. * @brief Allocate host memory for tensor.
......
...@@ -48,7 +48,8 @@ void Copy(const Context& dev_ctx, ...@@ -48,7 +48,8 @@ void Copy(const Context& dev_ctx,
// dev_ctx can not alloc pinned memory now // dev_ctx can not alloc pinned memory now
dst_ptr = dst->mutable_data(dst_place, src.dtype()); dst_ptr = dst->mutable_data(dst_place, src.dtype());
} else { } else {
dst_ptr = dev_ctx.Alloc(dst, src.dtype()); dst_ptr = dev_ctx.Alloc(
dst, src.dtype(), 0, paddle::platform::is_cuda_pinned_place(dst_place));
} }
if (src_ptr == dst_ptr && src_place == dst_place) { if (src_ptr == dst_ptr && src_place == dst_place) {
...@@ -151,6 +152,30 @@ void Copy(const Context& dev_ctx, ...@@ -151,6 +152,30 @@ void Copy(const Context& dev_ctx,
"Context place dose not match the source and destination place.")); "Context place dose not match the source and destination place."));
} }
} }
} else if (paddle::platform::is_gpu_place(src_place) && // NOLINT
paddle::platform::is_cuda_pinned_place(dst_place)) {
auto src_gpu_place = src_place;
auto dst_cuda_pinned_place = dst_place;
auto ctx_place = dev_ctx.GetPlace();
PADDLE_ENFORCE_EQ(
paddle::platform::is_gpu_place(ctx_place),
true,
phi::errors::PreconditionNotMet(
"Context place error, excepted GPUPlace, but actually %s.",
ctx_place));
auto ctx_gpu_place = ctx_place;
PADDLE_ENFORCE_EQ(src_gpu_place,
ctx_gpu_place,
phi::errors::Unavailable(
"Source place and context place do not match, source "
"place is %s, context place is %s.",
src_gpu_place,
ctx_gpu_place));
auto stream =
blocking ? nullptr
: reinterpret_cast<const phi::GPUContext&>(dev_ctx).stream();
paddle::memory::Copy(
dst_cuda_pinned_place, dst_ptr, src_gpu_place, src_ptr, size, stream);
} else { } else {
PADDLE_THROW(phi::errors::InvalidArgument( PADDLE_THROW(phi::errors::InvalidArgument(
"Place type error. Please check the place of src and dst Tensor.")); "Place type error. Please check the place of src and dst Tensor."));
......
...@@ -160,6 +160,10 @@ void TestConv3dBase(const std::vector<int>& indices, ...@@ -160,6 +160,10 @@ void TestConv3dBase(const std::vector<int>& indices,
paddle::memory::allocation::AllocatorFacade::Instance() paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(phi::CPUPlace()) .GetAllocator(phi::CPUPlace())
.get()); .get());
dev_ctx_gpu.SetPinnedAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CUDAPinnedPlace())
.get());
dev_ctx_gpu.PartialInitWithAllocator(); dev_ctx_gpu.PartialInitWithAllocator();
DenseTensor d_indices_tensor = phi::Empty( DenseTensor d_indices_tensor = phi::Empty(
......
...@@ -134,6 +134,10 @@ void TestMaxPoolBase(const std::vector<int>& indices, ...@@ -134,6 +134,10 @@ void TestMaxPoolBase(const std::vector<int>& indices,
paddle::memory::allocation::AllocatorFacade::Instance() paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(phi::CPUPlace()) .GetAllocator(phi::CPUPlace())
.get()); .get());
dev_ctx_gpu.SetPinnedAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CUDAPinnedPlace())
.get());
dev_ctx_gpu.PartialInitWithAllocator(); dev_ctx_gpu.PartialInitWithAllocator();
DenseTensor d_indices_tensor = phi::Empty( DenseTensor d_indices_tensor = phi::Empty(
......
...@@ -117,6 +117,10 @@ void TestDenseToSparseCoo(const DenseTensor& dense_x, ...@@ -117,6 +117,10 @@ void TestDenseToSparseCoo(const DenseTensor& dense_x,
paddle::memory::allocation::AllocatorFacade::Instance() paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(phi::CPUPlace()) .GetAllocator(phi::CPUPlace())
.get()); .get());
dev_ctx_gpu.SetPinnedAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CUDAPinnedPlace())
.get());
dev_ctx_gpu.PartialInitWithAllocator(); dev_ctx_gpu.PartialInitWithAllocator();
const auto cuda_alloc = const auto cuda_alloc =
...@@ -328,6 +332,10 @@ void TestSparseCsrToCoo(const DDim& dense_dims, ...@@ -328,6 +332,10 @@ void TestSparseCsrToCoo(const DDim& dense_dims,
paddle::memory::allocation::AllocatorFacade::Instance() paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(phi::CPUPlace()) .GetAllocator(phi::CPUPlace())
.get()); .get());
dev_ctx_gpu.SetPinnedAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CUDAPinnedPlace())
.get());
dev_ctx_gpu.PartialInitWithAllocator(); dev_ctx_gpu.PartialInitWithAllocator();
const auto cuda_alloc = const auto cuda_alloc =
...@@ -511,6 +519,10 @@ void TestCooToCsr(const DDim& dense_dims, ...@@ -511,6 +519,10 @@ void TestCooToCsr(const DDim& dense_dims,
paddle::memory::allocation::AllocatorFacade::Instance() paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(phi::CPUPlace()) .GetAllocator(phi::CPUPlace())
.get()); .get());
dev_ctx_gpu.SetPinnedAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CUDAPinnedPlace())
.get());
dev_ctx_gpu.PartialInitWithAllocator(); dev_ctx_gpu.PartialInitWithAllocator();
phi::DenseTensor d_indices(cuda_alloc.get(), indices_meta); phi::DenseTensor d_indices(cuda_alloc.get(), indices_meta);
phi::DenseTensor d_values(cuda_alloc.get(), values_meta); phi::DenseTensor d_values(cuda_alloc.get(), values_meta);
...@@ -611,6 +623,10 @@ void TestDenseToSparseCsr(const DenseTensor& dense_x, ...@@ -611,6 +623,10 @@ void TestDenseToSparseCsr(const DenseTensor& dense_x,
paddle::memory::allocation::AllocatorFacade::Instance() paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(phi::CPUPlace()) .GetAllocator(phi::CPUPlace())
.get()); .get());
dev_ctx_gpu.SetPinnedAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CUDAPinnedPlace())
.get());
dev_ctx_gpu.PartialInitWithAllocator(); dev_ctx_gpu.PartialInitWithAllocator();
phi::Copy(dev_ctx_gpu, dense_x, phi::GPUPlace(), true, &d_dense_x); phi::Copy(dev_ctx_gpu, dense_x, phi::GPUPlace(), true, &d_dense_x);
auto sparse_out = sparse::DenseToSparseCsr<T>(dev_ctx_gpu, d_dense_x); auto sparse_out = sparse::DenseToSparseCsr<T>(dev_ctx_gpu, d_dense_x);
...@@ -741,6 +757,10 @@ void TestSparseCooToDense(const DDim& dense_dims, ...@@ -741,6 +757,10 @@ void TestSparseCooToDense(const DDim& dense_dims,
paddle::memory::allocation::AllocatorFacade::Instance() paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(phi::CPUPlace()) .GetAllocator(phi::CPUPlace())
.get()); .get());
dev_ctx_gpu.SetPinnedAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CUDAPinnedPlace())
.get());
dev_ctx_gpu.PartialInitWithAllocator(); dev_ctx_gpu.PartialInitWithAllocator();
DenseTensor d_dense_indices(cuda_alloc.get(), dense_indices.meta()); DenseTensor d_dense_indices(cuda_alloc.get(), dense_indices.meta());
DenseTensor d_dense_elements(cuda_alloc.get(), dense_elements.meta()); DenseTensor d_dense_elements(cuda_alloc.get(), dense_elements.meta());
...@@ -886,6 +906,10 @@ void TestSparseCsrToDense(const DDim& dense_dims, ...@@ -886,6 +906,10 @@ void TestSparseCsrToDense(const DDim& dense_dims,
paddle::memory::allocation::AllocatorFacade::Instance() paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(phi::CPUPlace()) .GetAllocator(phi::CPUPlace())
.get()); .get());
dev_ctx_gpu.SetPinnedAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CUDAPinnedPlace())
.get());
dev_ctx_gpu.PartialInitWithAllocator(); dev_ctx_gpu.PartialInitWithAllocator();
phi::DenseTensor d_crows(cuda_alloc.get(), crows_meta); phi::DenseTensor d_crows(cuda_alloc.get(), crows_meta);
phi::DenseTensor d_cols(cuda_alloc.get(), cols_meta); phi::DenseTensor d_cols(cuda_alloc.get(), cols_meta);
......
...@@ -836,6 +836,16 @@ def monkey_patch_varbase(): ...@@ -836,6 +836,16 @@ def monkey_patch_varbase():
res.persistable = self.persistable res.persistable = self.persistable
return res return res
@framework.dygraph_only
def pin_memory(self):
if self.place.is_cuda_pinned_place():
return self
else:
res = self._copy_to(core.CUDAPinnedPlace(), True)
res.stop_gradient = self.stop_gradient
res.persistable = self.persistable
return res
if framework._in_eager_mode_ and not hasattr(core, "eager"): if framework._in_eager_mode_ and not hasattr(core, "eager"):
return return
...@@ -861,6 +871,7 @@ def monkey_patch_varbase(): ...@@ -861,6 +871,7 @@ def monkey_patch_varbase():
setattr(core.eager.Tensor, "value", value) setattr(core.eager.Tensor, "value", value)
setattr(core.eager.Tensor, "cpu", cpu) setattr(core.eager.Tensor, "cpu", cpu)
setattr(core.eager.Tensor, "cuda", cuda) setattr(core.eager.Tensor, "cuda", cuda)
setattr(core.eager.Tensor, "pin_memory", pin_memory)
setattr(core.eager.Tensor, "_slice", _slice) setattr(core.eager.Tensor, "_slice", _slice)
setattr(core.eager.Tensor, "_numel", _numel) setattr(core.eager.Tensor, "_numel", _numel)
else: else:
......
...@@ -18,10 +18,11 @@ import numpy as np ...@@ -18,10 +18,11 @@ import numpy as np
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
from paddle.device import cuda from paddle.device import cuda
from paddle.fluid.framework import _test_eager_guard, _in_legacy_dygraph
class TestAsyncRead(unittest.TestCase): class TestAsyncRead(unittest.TestCase):
def setUp(self): def func_setUp(self):
self.empty = paddle.to_tensor( self.empty = paddle.to_tensor(
np.array( np.array(
[], dtype="int64"), place=paddle.CPUPlace()) [], dtype="int64"), place=paddle.CPUPlace())
...@@ -35,16 +36,20 @@ class TestAsyncRead(unittest.TestCase): ...@@ -35,16 +36,20 @@ class TestAsyncRead(unittest.TestCase):
shape=[50, 50, 50], dtype="float32").pin_memory() shape=[50, 50, 50], dtype="float32").pin_memory()
self.stream = cuda.Stream() self.stream = cuda.Stream()
def test_async_read_empty_offset_and_count(self): def func_test_async_read_empty_offset_and_count(self):
with cuda.stream_guard(self.stream): with cuda.stream_guard(self.stream):
core.async_read(self.src, self.dst, self.index, self.buffer, if _in_legacy_dygraph():
self.empty, self.empty) core.async_read(self.src, self.dst, self.index, self.buffer,
self.empty, self.empty)
else:
core.eager.async_read(self.src, self.dst, self.index,
self.buffer, self.empty, self.empty)
array1 = paddle.gather(self.src, self.index) array1 = paddle.gather(self.src, self.index)
array2 = self.dst[:len(self.index)] array2 = self.dst[:len(self.index)]
self.assertTrue(np.allclose(array1.numpy(), array2.numpy())) self.assertTrue(np.allclose(array1.numpy(), array2.numpy()))
def test_async_read_success(self): def func_test_async_read_success(self):
offset = paddle.to_tensor( offset = paddle.to_tensor(
np.array( np.array(
[10, 20], dtype="int64"), place=paddle.CPUPlace()) [10, 20], dtype="int64"), place=paddle.CPUPlace())
...@@ -52,9 +57,12 @@ class TestAsyncRead(unittest.TestCase): ...@@ -52,9 +57,12 @@ class TestAsyncRead(unittest.TestCase):
np.array( np.array(
[5, 10], dtype="int64"), place=paddle.CPUPlace()) [5, 10], dtype="int64"), place=paddle.CPUPlace())
with cuda.stream_guard(self.stream): with cuda.stream_guard(self.stream):
core.async_read(self.src, self.dst, self.index, self.buffer, offset, if _in_legacy_dygraph():
count) core.async_read(self.src, self.dst, self.index, self.buffer,
offset, count)
else:
core.eager.async_read(self.src, self.dst, self.index,
self.buffer, offset, count)
# index data # index data
index_array1 = paddle.gather(self.src, self.index) index_array1 = paddle.gather(self.src, self.index)
count_numel = paddle.sum(count).numpy()[0] count_numel = paddle.sum(count).numpy()[0]
...@@ -69,26 +77,43 @@ class TestAsyncRead(unittest.TestCase): ...@@ -69,26 +77,43 @@ class TestAsyncRead(unittest.TestCase):
self.assertTrue( self.assertTrue(
np.allclose(offset_array1.numpy(), offset_array2.numpy())) np.allclose(offset_array1.numpy(), offset_array2.numpy()))
def test_async_read_only_1dim(self): def func_test_async_read_only_1dim(self):
src = paddle.rand([40], dtype="float32").pin_memory() src = paddle.rand([40], dtype="float32").pin_memory()
dst = paddle.empty([40], dtype="float32") dst = paddle.empty([40], dtype="float32")
buffer_ = paddle.empty([20]).pin_memory() buffer_ = paddle.empty([20]).pin_memory()
with cuda.stream_guard(self.stream): with cuda.stream_guard(self.stream):
core.async_read(src, dst, self.index, buffer_, self.empty, if _in_legacy_dygraph():
self.empty) core.async_read(src, dst, self.index, buffer_, self.empty,
self.empty)
else:
core.eager.async_read(src, dst, self.index, buffer_, self.empty,
self.empty)
array1 = paddle.gather(src, self.index) array1 = paddle.gather(src, self.index)
array2 = dst[:len(self.index)] array2 = dst[:len(self.index)]
self.assertTrue(np.allclose(array1.numpy(), array2.numpy())) self.assertTrue(np.allclose(array1.numpy(), array2.numpy()))
def test_main(self):
with _test_eager_guard():
self.func_setUp()
self.func_test_async_read_empty_offset_and_count()
self.func_test_async_read_success()
self.func_test_async_read_only_1dim()
self.func_setUp()
self.func_test_async_read_empty_offset_and_count()
self.func_setUp()
self.func_test_async_read_success()
self.func_setUp()
self.func_test_async_read_only_1dim()
class TestAsyncWrite(unittest.TestCase): class TestAsyncWrite(unittest.TestCase):
def setUp(self): def func_setUp(self):
self.src = paddle.rand(shape=[100, 50, 50, 5], dtype="float32") self.src = paddle.rand(shape=[100, 50, 50, 5], dtype="float32")
self.dst = paddle.empty( self.dst = paddle.empty(
shape=[200, 50, 50, 5], dtype="float32").pin_memory() shape=[200, 50, 50, 5], dtype="float32").pin_memory()
self.stream = cuda.Stream() self.stream = cuda.Stream()
def test_async_write_success(self): def func_test_async_write_success(self):
offset = paddle.to_tensor( offset = paddle.to_tensor(
np.array( np.array(
[0, 60], dtype="int64"), place=paddle.CPUPlace()) [0, 60], dtype="int64"), place=paddle.CPUPlace())
...@@ -96,13 +121,23 @@ class TestAsyncWrite(unittest.TestCase): ...@@ -96,13 +121,23 @@ class TestAsyncWrite(unittest.TestCase):
np.array( np.array(
[40, 60], dtype="int64"), place=paddle.CPUPlace()) [40, 60], dtype="int64"), place=paddle.CPUPlace())
with cuda.stream_guard(self.stream): with cuda.stream_guard(self.stream):
core.async_write(self.src, self.dst, offset, count) if _in_legacy_dygraph():
core.async_write(self.src, self.dst, offset, count)
else:
core.eager.async_write(self.src, self.dst, offset, count)
offset_a = paddle.gather(self.dst, paddle.to_tensor(np.arange(0, 40))) offset_a = paddle.gather(self.dst, paddle.to_tensor(np.arange(0, 40)))
offset_b = paddle.gather(self.dst, paddle.to_tensor(np.arange(60, 120))) offset_b = paddle.gather(self.dst, paddle.to_tensor(np.arange(60, 120)))
offset_array = paddle.concat([offset_a, offset_b], axis=0) offset_array = paddle.concat([offset_a, offset_b], axis=0)
self.assertTrue(np.allclose(self.src.numpy(), offset_array.numpy())) self.assertTrue(np.allclose(self.src.numpy(), offset_array.numpy()))
def test_async_write_success(self):
with _test_eager_guard():
self.func_setUp()
self.func_test_async_write_success()
self.func_setUp()
self.func_test_async_write_success()
if __name__ == "__main__": if __name__ == "__main__":
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册