未验证 提交 42075ddc 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] support tensor uva, test=windows_ci (#41310)

* [Eager] support tensor uva, test=windows_ci

* Add headers to fix CI, test=windows_ci

* Expose _uva python interface, Fix windows ci issue
上级 84b63a26
......@@ -772,6 +772,53 @@ static PyObject* eager_api_async_write(PyObject* self, PyObject* args,
return Py_None;
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* eager_api_to_uva_tensor(PyObject* self, PyObject* args,
PyObject* kwargs) {
EAGER_TRY
VLOG(4) << "Running in eager_api_to_uva_tensor.";
auto new_tensor = std::shared_ptr<paddle::experimental::Tensor>(
new paddle::experimental::Tensor(
egr::Controller::Instance().GenerateUniqueName()));
PyObject* obj = PyTuple_GET_ITEM(args, 0);
auto array = py::cast<py::array>(py::handle(obj));
int device_id = 0;
PyObject* Py_device_id = PyTuple_GET_ITEM(args, 1);
if (Py_device_id) {
device_id = CastPyArg2AttrLong(Py_device_id, 1);
}
if (py::isinstance<py::array_t<int32_t>>(array)) {
SetUVATensorFromPyArray<int32_t>(new_tensor, array, device_id);
} else if (py::isinstance<py::array_t<int64_t>>(array)) {
SetUVATensorFromPyArray<int64_t>(new_tensor, array, device_id);
} else if (py::isinstance<py::array_t<float>>(array)) {
SetUVATensorFromPyArray<float>(new_tensor, array, device_id);
} else if (py::isinstance<py::array_t<double>>(array)) {
SetUVATensorFromPyArray<double>(new_tensor, array, device_id);
} else if (py::isinstance<py::array_t<int8_t>>(array)) {
SetUVATensorFromPyArray<int8_t>(new_tensor, array, device_id);
} else if (py::isinstance<py::array_t<int16_t>>(array)) {
SetUVATensorFromPyArray<int16_t>(new_tensor, array, device_id);
} else if (py::isinstance<py::array_t<paddle::platform::float16>>(array)) {
SetUVATensorFromPyArray<paddle::platform::float16>(new_tensor, array,
device_id);
} else if (py::isinstance<py::array_t<bool>>(array)) {
SetUVATensorFromPyArray<bool>(new_tensor, array, device_id);
} else {
// obj may be any type, obj.cast<py::array>() may be failed,
// then the array.dtype will be string of unknown meaning.
PADDLE_THROW(platform::errors::InvalidArgument(
"Input object type error or incompatible array data type. "
"tensor.set() supports array with bool, float16, float32, "
"float64, int8, int16, int32, int64,"
"please check your input or input array data type."));
}
return ToPyObject(*(new_tensor.get()));
EAGER_CATCH_AND_THROW_RETURN_NULL
}
#endif
PyMethodDef variable_functions[] = {
......@@ -803,6 +850,8 @@ PyMethodDef variable_functions[] = {
METH_VARARGS | METH_KEYWORDS, NULL},
{"async_write", (PyCFunction)(void (*)(void))eager_api_async_write,
METH_VARARGS | METH_KEYWORDS, NULL},
{"to_uva_tensor", (PyCFunction)(void (*)(void))eager_api_to_uva_tensor,
METH_VARARGS | METH_KEYWORDS, NULL},
#endif
{NULL, NULL, 0, NULL}};
......
......@@ -32,6 +32,7 @@ limitations under the License. */
#include "paddle/fluid/pybind/eager_utils.h"
#include "paddle/fluid/pybind/exception.h"
#include "paddle/fluid/pybind/slice_utils.h"
#include "paddle/fluid/pybind/uva_utils.h"
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/compat/convert_utils.h"
......@@ -1343,6 +1344,26 @@ static PyObject* tensor__reset_grad_inplace_version(TensorObject* self,
EAGER_CATCH_AND_THROW_RETURN_NULL
}
#if defined(PADDLE_WITH_CUDA)
static PyObject* tensor_method__uva(TensorObject* self, PyObject* args,
PyObject* kwargs) {
EAGER_TRY
VLOG(4) << "Running in tensor_method__uva.";
PADDLE_ENFORCE_EQ(platform::is_cpu_place(self->tensor.inner_place()), true,
platform::errors::InvalidArgument(
"Unified virtual addressing only support "
"CPU Tensor currently."));
int device_id = pybind::CastPyArg2AttrLong(PyTuple_GET_ITEM(args, 0), 0);
auto* self_tensor =
static_cast<paddle::framework::LoDTensor*>(self->tensor.impl().get());
tensor_uva(self_tensor, device_id);
Py_INCREF(Py_None);
return Py_None;
EAGER_CATCH_AND_THROW_RETURN_NULL
}
#endif
PyMethodDef variable_methods[] = {
{"numpy", (PyCFunction)(void (*)(void))tensor_method_numpy,
METH_VARARGS | METH_KEYWORDS, NULL},
......@@ -1447,6 +1468,10 @@ PyMethodDef variable_methods[] = {
{"_reset_grad_inplace_version",
(PyCFunction)(void (*)(void))tensor__reset_grad_inplace_version,
METH_VARARGS | METH_KEYWORDS, NULL},
#if defined(PADDLE_WITH_CUDA)
{"_tensor_uva", (PyCFunction)(void (*)(void))tensor_method__uva,
METH_VARARGS | METH_KEYWORDS, NULL},
#endif
{NULL, NULL, 0, NULL}};
} // namespace pybind
......
......@@ -57,6 +57,7 @@ limitations under the License. */
#include "paddle/fluid/pybind/pybind_boost_headers.h"
#include "paddle/fluid/pybind/slice_utils.h"
#include "paddle/fluid/pybind/tensor_py.h"
#include "paddle/fluid/pybind/uva_utils.h"
#include "paddle/phi/core/compat/arg_map_context.h"
#include "paddle/phi/core/compat/type_defs.h"
......@@ -1629,39 +1630,9 @@ void BindImperative(py::module *m_ptr) {
platform::errors::InvalidArgument(
"Unified virtual addressing only support "
"CPU Tensor currently."));
platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance();
auto *dev_ctx = pool.Get(platform::CUDAPlace(device_id));
VLOG(4) << "Init the DeviceContext, and the place is "
<< dev_ctx->GetPlace();
auto *self_tensor =
self->MutableVar()->GetMutable<framework::LoDTensor>();
// Register the cpu memory as the cuda host memory
const auto &data_numel = self_tensor->numel();
const size_t &need_allocate_size =
data_numel *
framework::SizeOfType(
framework::TransToProtoVarType(self_tensor->dtype()));
void *data_ptr = self_tensor->data();
auto result = cudaHostRegister(data_ptr, need_allocate_size,
cudaHostRegisterDefault);
if (cudaSuccess != result) {
VLOG(4) << "UVA(unified virtual addressing) failed allocate:"
<< need_allocate_size << ", the error code:" << result;
}
// Get device pointer from the function of cudaHostGetDevicePointer
void *cuda_device_pointer = nullptr;
cudaHostGetDevicePointer(
reinterpret_cast<void **>(&cuda_device_pointer),
reinterpret_cast<void *>(data_ptr), 0);
// Reset the memory with device pointer
std::shared_ptr<memory::allocation::Allocation> holder =
std::make_shared<memory::allocation::Allocation>(
cuda_device_pointer, need_allocate_size,
platform::CUDAPlace(device_id));
self_tensor->ResetHolderWithType(holder, self_tensor->dtype());
tensor_uva(self_tensor, device_id);
},
py::arg("device_id") = 0, py::return_value_policy::reference, R"DOC(
Returns self tensor with the UVA(unified virtual addressing).
......
......@@ -529,11 +529,10 @@ void SetTensorFromPyArray(framework::Tensor *self, const py::object &obj,
}
template <typename T>
void SetUVATensorFromPyArray(
const std::shared_ptr<paddle::imperative::VarBase> &self,
void SetUVATensorFromPyArrayImpl(framework::LoDTensor *self_tensor,
const py::array_t<T> &array, int device_id) {
#if defined(PADDLE_WITH_CUDA)
auto *self_tensor = self->MutableVar()->GetMutable<framework::LoDTensor>();
VLOG(4) << "Running in SetUVATensorFromPyArrayImpl.";
std::vector<int64_t> dims;
dims.reserve(array.ndim());
int64_t numel = 1;
......@@ -562,6 +561,38 @@ void SetUVATensorFromPyArray(
#endif
}
template <typename T>
void SetUVATensorFromPyArray(
const std::shared_ptr<paddle::imperative::VarBase> &self,
const py::array_t<T> &array, int device_id) {
#if defined(PADDLE_WITH_CUDA)
VLOG(4) << "Running in SetUVATensorFromPyArray for VarBase.";
auto *self_tensor = self->MutableVar()->GetMutable<framework::LoDTensor>();
SetUVATensorFromPyArrayImpl<T>(self_tensor, array, device_id);
#endif
}
template <typename T>
void SetUVATensorFromPyArray(
const std::shared_ptr<paddle::experimental::Tensor> &self,
const py::array_t<T> &array, int device_id) {
#if defined(PADDLE_WITH_CUDA)
VLOG(4) << "Running in SetUVATensorFromPyArray for Phi::Tensor.";
phi::DenseTensorMeta meta =
phi::DenseTensorMeta(phi::DataType::FLOAT32, phi::make_ddim({1, 1}));
std::shared_ptr<phi::DenseTensor> tmp_t = std::make_shared<phi::DenseTensor>(
std::make_unique<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace())
.get(),
meta);
self.get()->set_impl(tmp_t);
auto *self_tensor =
static_cast<paddle::framework::LoDTensor *>(self.get()->impl().get());
SetUVATensorFromPyArrayImpl<T>(self_tensor, array, device_id);
#endif
}
template <typename T, size_t D>
void _sliceCompute(const framework::Tensor *in, framework::Tensor *out,
const platform::CPUDeviceContext &ctx,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <Python.h>
#include "paddle/fluid/operators/utils.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/dense_tensor.h"
namespace paddle {
namespace pybind {
static void tensor_uva(paddle::framework::LoDTensor *self_tensor,
int device_id) {
VLOG(4) << "Running in _uva interface.";
#if defined(PADDLE_WITH_CUDA)
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto *dev_ctx = pool.Get(platform::CUDAPlace(device_id));
VLOG(4) << "Init the DeviceContext, and the place is " << dev_ctx->GetPlace();
// Register the cpu memory as the cuda host memory
const auto &data_numel = self_tensor->numel();
const size_t &need_allocate_size =
data_numel * framework::SizeOfType(
framework::TransToProtoVarType(self_tensor->dtype()));
void *data_ptr = self_tensor->data();
auto result =
cudaHostRegister(data_ptr, need_allocate_size, cudaHostRegisterDefault);
if (cudaSuccess != result) {
VLOG(4) << "UVA(unified virtual addressing) failed allocate:"
<< need_allocate_size << ", the error code:" << result;
}
// Get device pointer from the function of cudaHostGetDevicePointer
void *cuda_device_pointer = nullptr;
cudaHostGetDevicePointer(reinterpret_cast<void **>(&cuda_device_pointer),
reinterpret_cast<void *>(data_ptr), 0);
// Reset the memory with device pointer
std::shared_ptr<memory::allocation::Allocation> holder =
std::make_shared<memory::allocation::Allocation>(
cuda_device_pointer, need_allocate_size,
platform::CUDAPlace(device_id));
self_tensor->ResetHolderWithType(holder, self_tensor->dtype());
#endif
}
} // namespace pybind
} // namespace paddle
......@@ -816,6 +816,10 @@ def monkey_patch_varbase():
def _numel(self):
return self.get_tensor()._numel()
@framework.dygraph_only
def _uva(self, device_id=0):
self._tensor_uva(device_id)
@framework.dygraph_only
def cpu(self):
if self.place.is_cpu_place():
......@@ -874,6 +878,7 @@ def monkey_patch_varbase():
setattr(core.eager.Tensor, "pin_memory", pin_memory)
setattr(core.eager.Tensor, "_slice", _slice)
setattr(core.eager.Tensor, "_numel", _numel)
setattr(core.eager.Tensor, "_uva", _uva)
else:
setattr(core.VarBase, "__name__", "Tensor")
setattr(core.VarBase, "grad", grad)
......
......@@ -15,10 +15,12 @@
import paddle
import unittest
import numpy as np
from paddle.fluid import core
from paddle.fluid.framework import _test_eager_guard, _in_legacy_dygraph
class TestTensorCopyFrom(unittest.TestCase):
def test_main(self):
def func_main(self):
if paddle.fluid.core.is_compiled_with_cuda():
place = paddle.CPUPlace()
np_value = np.random.random(size=[10, 30]).astype('float32')
......@@ -26,9 +28,14 @@ class TestTensorCopyFrom(unittest.TestCase):
tensor._uva()
self.assertTrue(tensor.place.is_gpu_place())
def test_main(self):
with _test_eager_guard():
self.func_main()
self.func_main()
class TestUVATensorFromNumpy(unittest.TestCase):
def test_uva_tensor_creation(self):
def func_uva_tensor_creation(self):
if paddle.fluid.core.is_compiled_with_cuda():
dtype_list = [
"int32", "int64", "float32", "float64", "float16", "int8",
......@@ -36,10 +43,18 @@ class TestUVATensorFromNumpy(unittest.TestCase):
]
for dtype in dtype_list:
data = np.random.randint(10, size=[4, 5]).astype(dtype)
if _in_legacy_dygraph():
tensor = paddle.fluid.core.to_uva_tensor(data, 0)
else:
tensor = core.eager.to_uva_tensor(data, 0)
self.assertTrue(tensor.place.is_gpu_place())
self.assertTrue(np.allclose(tensor.numpy(), data))
def test_uva_tensor_creation(self):
with _test_eager_guard():
self.func_uva_tensor_creation()
self.func_uva_tensor_creation()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册