未验证 提交 3845afff 编写于 作者: S Siming Dai 提交者: GitHub

Add operators for async read & async write (#36333)

* fix async_read bug

* change index place to cpu

* add tensor size judge

* add async_read & async_write test

* fix bug in async_write

* fix mac py3 ci

* fix bug for cpu version paddle

* fix windows ci bug

* change input argument error type

* change const_cast to mutable_data

* add async_write out-of-bound check and consumate error hint

* fix a small bug for dst_tensor

* add docs and refine codes

* refine docs

* notest,test=windows_ci

* fix windows ci

* fix require

* fix code-block

* add core.is_compiled_with_cuda()
上级 051544b6
......@@ -2249,6 +2249,343 @@ void BindImperative(py::module *m_ptr) {
const py::args args, const py::kwargs kwargs) {
return imperative::PyLayerApply(place, cls, args, kwargs);
});
#if defined(PADDLE_WITH_CUDA)
m.def(
"async_write",
[](const imperative::VarBase &src, imperative::VarBase &dst,
const imperative::VarBase &offset, const imperative::VarBase &count) {
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(src.Place()), true,
platform::errors::InvalidArgument(
"Required `src` device should be CUDAPlace, but received %d. ",
src.Place()));
PADDLE_ENFORCE_EQ(
platform::is_cuda_pinned_place(dst.Place()), true,
platform::errors::InvalidArgument(
"Required `dst` device should be CUDAPinnedPlace, "
"but received %d. ",
dst.Place()));
PADDLE_ENFORCE_EQ(
platform::is_cpu_place(offset.Place()), true,
platform::errors::InvalidArgument("Required `offset` device should "
"be CPUPlace, but received %d. ",
offset.Place()));
PADDLE_ENFORCE_EQ(
platform::is_cpu_place(count.Place()), true,
platform::errors::InvalidArgument(
"Required `count` device should be CPUPlace, but received %d. ",
count.Place()));
// TODO(daisiming): In future, add index as arguments following
// async_read.
auto &src_tensor = src.Var().Get<framework::LoDTensor>();
auto *dst_tensor = dst.MutableVar()->GetMutable<framework::LoDTensor>();
auto &offset_tensor = offset.Var().Get<framework::LoDTensor>();
auto &count_tensor = count.Var().Get<framework::LoDTensor>();
const auto &deviceId = paddle::platform::GetCurrentDeviceId();
PADDLE_ENFORCE_EQ(offset_tensor.dims().size(), 1,
platform::errors::InvalidArgument(
"`offset` tensor should be one-dimensional."));
PADDLE_ENFORCE_EQ(count_tensor.dims().size(), 1,
platform::errors::InvalidArgument(
"`count` tensor should be one-dimensional."));
PADDLE_ENFORCE_EQ(offset_tensor.numel(), count_tensor.numel(),
platform::errors::InvalidArgument(
"`offset` and `count` tensor size dismatch."));
PADDLE_ENFORCE_EQ(
src_tensor.dims().size(), dst_tensor->dims().size(),
platform::errors::InvalidArgument(
"`src` and `dst` should have the same tensor shape, "
"except for the first dimension."));
for (int i = 1; i < src_tensor.dims().size(); i++) {
PADDLE_ENFORCE_EQ(
src_tensor.dims()[i], dst_tensor->dims()[i],
platform::errors::InvalidArgument(
"`src` and `dst` should have the same tensor shape, "
"except for the first dimension."));
}
auto stream = paddle::platform::stream::get_current_stream(deviceId)
->raw_stream();
int64_t size = src_tensor.numel() / src_tensor.dims()[0];
auto *src_data = src_tensor.data<float>();
auto *dst_data = dst_tensor->mutable_data<float>(dst.Place());
const int64_t *offset_data = offset_tensor.data<int64_t>();
const int64_t *count_data = count_tensor.data<int64_t>();
int64_t src_offset = 0, dst_offset, c;
for (int64_t i = 0; i < offset_tensor.numel(); i++) {
dst_offset = offset_data[i], c = count_data[i];
PADDLE_ENFORCE_LE(src_offset + c, src_tensor.dims()[0],
platform::errors::InvalidArgument(
"Invalid offset or count index"));
PADDLE_ENFORCE_LE(dst_offset + c, dst_tensor->dims()[0],
platform::errors::InvalidArgument(
"Invalid offset or count index"));
cudaMemcpyAsync(
dst_data + (dst_offset * size), src_data + (src_offset * size),
c * size * sizeof(float), cudaMemcpyDeviceToHost, stream);
src_offset += c;
}
},
R"DOC(
This api provides a way to write pieces of source tensor to destination tensor
inplacely and asynchronously. In which, we use `offset` and `count` to determine
where to copy. `offset` means the begin points of the copy pieces of `src`, and
`count` means the lengths of the copy pieces of `src`. To be noted, the copy process
will run asynchronously from cuda to pin memory. We can simply remember this as
"gpu async_write to pin_memory".
Arguments:
src (Tensor): The source tensor, and the data type should be `float32` currently.
Besides, `src` should be placed on CUDAPlace.
dst (Tensor): The destination tensor, and the data type should be `float32` currently.
Besides, `dst` should be placed on CUDAPinnedPlace. The shape of `dst`
should be the same with `src` except for the first dimension.
offset (Tensor): The offset tensor, and the data type should be `int64` currently.
Besides, `offset` should be placed on CPUPlace. The shape of `offset`
should be one-dimensional.
count (Tensor): The count tensor, and the data type should be `int64` currently.
Besides, `count` should be placed on CPUPlace. The shape of `count`
should be one-dimensinal.
Examples:
.. code-block:: python
import numpy as np
import paddle
from paddle.fluid import core
from paddle.device import cuda
if core.is_compiled_with_cuda():
src = paddle.rand(shape=[100, 50, 50])
dst = paddle.emtpy(shape=[200, 50, 50]).pin_memory()
offset = paddle.to_tensor(
np.array([0, 60], dtype="int64"), place=paddle.CPUPlace())
count = paddle.to_tensor(
np.array([40, 60], dtype="int64"), place=paddle.CPUPlace())
stream = cuda.Stream()
with cuda.stream_guard(stream):
core.async_write(src, dst, offset, count)
offset_a = paddle.gather(dst, paddle.to_tensor(np.arange(0, 40)))
offset_b = paddle.gather(dst, paddle.to_tensor(np.arange(60, 120)))
offset_array = paddle.concat([offset_a, offset_b], axis=0)
print(np.allclose(src.numpy(), offset_array.numpy())) # True
)DOC");
m.def(
"async_read",
[](const imperative::VarBase &src, imperative::VarBase &dst,
const imperative::VarBase &index, imperative::VarBase &buffer,
const imperative::VarBase &offset, const imperative::VarBase &count) {
PADDLE_ENFORCE_EQ(platform::is_cuda_pinned_place(src.Place()), true,
platform::errors::InvalidArgument(
"Required `src` device should be "
"CUDAPinnedPlace, but received %d.",
src.Place()));
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(dst.Place()), true,
platform::errors::InvalidArgument(
"Required `dst` device should be CUDAPlace, but received %d.",
dst.Place()));
PADDLE_ENFORCE_EQ(
platform::is_cpu_place(index.Place()), true,
platform::errors::InvalidArgument(
"Required `index` device should be CPUPlace, but received %d.",
index.Place()));
PADDLE_ENFORCE_EQ(
platform::is_cuda_pinned_place(buffer.Place()), true,
platform::errors::InvalidArgument(
"Required `buffer` device should be CUDAPinnedPlace, "
"but received %d.",
buffer.Place()));
PADDLE_ENFORCE_EQ(
platform::is_cpu_place(offset.Place()), true,
platform::errors::InvalidArgument(
"Required `offset` device should be CPUPlace, but received %d.",
offset.Place()));
PADDLE_ENFORCE_EQ(
platform::is_cpu_place(count.Place()), true,
platform::errors::InvalidArgument(
"Required `count` device should be CPUPlace, but received %d.",
count.Place()));
auto &src_tensor = src.Var().Get<framework::LoDTensor>();
auto *dst_tensor = dst.MutableVar()->GetMutable<framework::LoDTensor>();
auto &index_tensor = index.Var().Get<framework::LoDTensor>();
auto *buffer_tensor =
buffer.MutableVar()->GetMutable<framework::LoDTensor>();
auto &offset_tensor = offset.Var().Get<framework::LoDTensor>();
auto &count_tensor = count.Var().Get<framework::LoDTensor>();
auto *dst_data = dst_tensor->mutable_data<float>(dst.Place());
const auto &deviceId = paddle::platform::GetCurrentDeviceId();
PADDLE_ENFORCE_EQ(src_tensor.dims().size(), dst_tensor->dims().size(),
platform::errors::InvalidArgument(
"`src` and `dst` should have same tensor shape, "
"except for the first dimension."));
PADDLE_ENFORCE_EQ(
src_tensor.dims().size(), buffer_tensor->dims().size(),
platform::errors::InvalidArgument(
"`src` and `buffer` should have same tensor shape, "
"except for the first dimension."));
for (int i = 1; i < src_tensor.dims().size(); i++) {
PADDLE_ENFORCE_EQ(
src_tensor.dims()[i], dst_tensor->dims()[i],
platform::errors::InvalidArgument(
"`src` and `dst` should have the same tensor shape, "
"except for the first dimension."));
PADDLE_ENFORCE_EQ(
src_tensor.dims()[i], buffer_tensor->dims()[i],
platform::errors::InvalidArgument(
"`src` and `buffer` should have the same tensor shape, "
"except for the first dimension."));
}
PADDLE_ENFORCE_EQ(index_tensor.dims().size(), 1,
platform::errors::InvalidArgument(
"`index` tensor should be one-dimensional."));
auto stream = paddle::platform::stream::get_current_stream(deviceId)
->raw_stream();
int64_t numel = 0; // total copy length
int64_t copy_flag = offset_tensor.dims()[0];
int64_t size = src_tensor.numel() / src_tensor.dims()[0];
if (copy_flag != 0) {
PADDLE_ENFORCE_EQ(offset_tensor.dims().size(), 1,
platform::errors::InvalidArgument(
"`offset` tensor should be one-dimensional."));
PADDLE_ENFORCE_EQ(count_tensor.dims().size(), 1,
platform::errors::InvalidArgument(
"`count` tensor should be one-dimensional."));
PADDLE_ENFORCE_EQ(offset_tensor.numel(), count_tensor.numel(),
platform::errors::InvalidArgument(
"`offset` and `count` tensor size dismatch."));
auto *offset_data = offset_tensor.data<int64_t>();
auto *count_data = count_tensor.data<int64_t>();
for (int64_t i = 0; i < count_tensor.numel(); i++) {
numel += count_data[i];
}
PADDLE_ENFORCE_LE(numel + index_tensor.numel(),
buffer_tensor->dims()[0],
platform::errors::InvalidArgument(
"Buffer tensor size is too small."));
PADDLE_ENFORCE_LE(numel + index_tensor.numel(), dst_tensor->dims()[0],
platform::errors::InvalidArgument(
"Target tensor size is too small."));
int64_t src_offset, dst_offset = 0, c;
auto *src_data = src_tensor.data<float>();
for (int64_t i = 0; i < offset_tensor.numel(); i++) {
src_offset = offset_data[i], c = count_data[i];
PADDLE_ENFORCE_LE(src_offset + c, src_tensor.dims()[0],
platform::errors::InvalidArgument(
"Invalid offset or count index."));
PADDLE_ENFORCE_LE(dst_offset + c, dst_tensor->dims()[0],
platform::errors::InvalidArgument(
"Invalid offset or count index."));
cudaMemcpyAsync(
dst_data + (dst_offset * size), src_data + (src_offset * size),
c * size * sizeof(float), cudaMemcpyHostToDevice, stream);
dst_offset += c;
}
} else {
PADDLE_ENFORCE_LE(index_tensor.numel(), buffer_tensor->dims()[0],
platform::errors::InvalidArgument(
"Buffer tensor size is too small."));
}
// Select the index data to the buffer
auto index_select = [](const framework::Tensor &src_tensor,
const framework::Tensor &index_tensor,
framework::Tensor *buffer_tensor) {
auto *src_data = src_tensor.data<float>();
auto *index_data = index_tensor.data<int64_t>();
auto *buffer_data =
buffer_tensor->mutable_data<float>(buffer_tensor->place());
const int &slice_size = src_tensor.numel() / src_tensor.dims()[0];
const int &copy_bytes = slice_size * sizeof(float);
int64_t c = 0;
for (int64_t i = 0; i < index_tensor.numel(); i++) {
std::memcpy(buffer_data + c * slice_size,
src_data + index_data[i] * slice_size, copy_bytes);
c += 1;
}
};
index_select(src_tensor, index_tensor, buffer_tensor);
// Copy the data to device memory
cudaMemcpyAsync(dst_data + (numel * size), buffer_tensor->data<float>(),
index_tensor.numel() * size * sizeof(float),
cudaMemcpyHostToDevice, stream);
},
R"DOC(
This api provides a way to read from pieces of source tensor to destination tensor
asynchronously. In which, we use `index`, `offset` and `count` to determine where
to read. `index` means the index position of src tensor we want to read. `offset`
and count means the begin points and length of pieces of src tensor we want to read.
To be noted, the copy process will run asynchronously from pin memory to cuda place.
We can simply remember this as "cuda async_read from pin_memory".
Arguments:
src (Tensor): The source tensor, and the data type should be `float32` currently.
Besides, `src` should be placed on CUDAPinnedPlace.
dst (Tensor): The destination tensor, and the data type should be `float32` currently.
Besides, `dst` should be placed on CUDAPlace. The shape of `dst` should
be the same with `src` except for the first dimension.
index (Tensor): The index tensor, and the data type should be `int64` currently.
Besides, `index` should be on CPUplace. The shape of `index` should
be one-dimensional.
buffer (Tensor): The buffer tensor, used to buffer index copy tensor temporarily.
The data type should be `float32` currently, and should be placed
on CUDAPinnedPlace. The shape of `buffer` should be the same with `src` except for the first dimension.
offset (Tensor): The offset tensor, and the data type should be `int64` currently.
Besides, `offset` should be placed on CPUPlace. The shape of `offset`
should be one-dimensional.
count (Tensor): The count tensor, and the data type should be `int64` currently.
Besides, `count` should be placed on CPUPlace. The shape of `count`
should be one-dimensinal.
Examples:
.. code-block:: python
import numpy as np
import paddle
from paddle.fluid import core
from paddle.device import cuda
if core.is_compiled_with_cuda():
src = paddle.rand(shape=[100, 50, 50], dtype="float32").pin_memory()
dst = paddle.empty(shape=[100, 50, 50], dtype="float32")
offset = paddle.to_tensor(
np.array([0, 60], dtype="int64"), place=paddle.CPUPlace())
count = paddle.to_tensor(
np.array([40, 60], dtype="int64"), place=paddle.CPUPlace())
buffer = paddle.empty(shape=[50, 50, 50], dtype="float32").pin_memory()
index = paddle.to_tensor(
np.array([1, 3, 5, 7, 9], dtype="int64")).cpu()
stream = cuda.Stream()
with cuda.stream_guard(stream):
core.async_read(src, dst, index, buffer, offset, count)
)DOC");
#endif
}
} // namespace pybind
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle.fluid import core
from paddle.device import cuda
class TestAsyncRead(unittest.TestCase):
def setUp(self):
self.empty = paddle.to_tensor(
np.array(
[], dtype="int64"), place=paddle.CPUPlace())
data = np.random.randn(100, 50, 50).astype("float32")
self.src = paddle.to_tensor(data, place=paddle.CUDAPinnedPlace())
self.dst = paddle.empty(shape=[100, 50, 50], dtype="float32")
self.index = paddle.to_tensor(
np.array(
[1, 3, 5, 7, 9], dtype="int64")).cpu()
self.buffer = paddle.empty(
shape=[50, 50, 50], dtype="float32").pin_memory()
self.stream = cuda.Stream()
def test_async_read_empty_offset_and_count(self):
with cuda.stream_guard(self.stream):
core.async_read(self.src, self.dst, self.index, self.buffer,
self.empty, self.empty)
array1 = paddle.gather(self.src, self.index)
array2 = self.dst[:len(self.index)]
self.assertTrue(np.allclose(array1.numpy(), array2.numpy()))
def test_async_read_success(self):
offset = paddle.to_tensor(
np.array(
[10, 20], dtype="int64"), place=paddle.CPUPlace())
count = paddle.to_tensor(
np.array(
[5, 10], dtype="int64"), place=paddle.CPUPlace())
with cuda.stream_guard(self.stream):
core.async_read(self.src, self.dst, self.index, self.buffer, offset,
count)
# index data
index_array1 = paddle.gather(self.src, self.index)
count_numel = paddle.sum(count).numpy()[0]
index_array2 = self.dst[count_numel:count_numel + len(self.index)]
self.assertTrue(np.allclose(index_array1.numpy(), index_array2.numpy()))
# offset, count
offset_a = paddle.gather(self.src, paddle.to_tensor(np.arange(10, 15)))
offset_b = paddle.gather(self.src, paddle.to_tensor(np.arange(20, 30)))
offset_array1 = paddle.concat([offset_a, offset_b], axis=0)
offset_array2 = self.dst[:count_numel]
self.assertTrue(
np.allclose(offset_array1.numpy(), offset_array2.numpy()))
def test_async_read_only_1dim(self):
src = paddle.rand([40], dtype="float32").pin_memory()
dst = paddle.empty([40], dtype="float32")
buffer_ = paddle.empty([20]).pin_memory()
with cuda.stream_guard(self.stream):
core.async_read(src, dst, self.index, buffer_, self.empty,
self.empty)
array1 = paddle.gather(src, self.index)
array2 = dst[:len(self.index)]
self.assertTrue(np.allclose(array1.numpy(), array2.numpy()))
class TestAsyncWrite(unittest.TestCase):
def setUp(self):
self.src = paddle.rand(shape=[100, 50, 50, 5], dtype="float32")
self.dst = paddle.empty(
shape=[200, 50, 50, 5], dtype="float32").pin_memory()
self.stream = cuda.Stream()
def test_async_write_success(self):
offset = paddle.to_tensor(
np.array(
[0, 60], dtype="int64"), place=paddle.CPUPlace())
count = paddle.to_tensor(
np.array(
[40, 60], dtype="int64"), place=paddle.CPUPlace())
with cuda.stream_guard(self.stream):
core.async_write(self.src, self.dst, offset, count)
offset_a = paddle.gather(self.dst, paddle.to_tensor(np.arange(0, 40)))
offset_b = paddle.gather(self.dst, paddle.to_tensor(np.arange(60, 120)))
offset_array = paddle.concat([offset_a, offset_b], axis=0)
self.assertTrue(np.allclose(self.src.numpy(), offset_array.numpy()))
if __name__ == "__main__":
if core.is_compiled_with_cuda():
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册