提交 d5ff79e5 编写于 作者: Y Youwei Song 提交者: hong

Support numpy bridge (enabled by default in dygraph mode) (#20983)

* add numpy bridge

* fix template compile

* add unittest, add default
test=develop

* fix unittest
test=develop

* fix unittest
test=develop

* zero_copy=True for to_variable,
test=develop

* bug fix
test=develop

* disable deprecated NumPy API
test=develop

* use better design of NumpyAllocator
test=develop

* fix Py_None check
test=develop

* reset c++ tracer when jump out dygraph guard
test=develop

* refine PADDLE_ENFORCE_xx format
test=develop

* bug fix of tracer switch
test=develop

* update decref
test=develop
上级 8493f20e
......@@ -114,5 +114,11 @@ void Tensor::ResetHolder(std::shared_ptr<memory::Allocation> holder) {
holder_ = holder;
}
void Tensor::ResetHolderWithType(std::shared_ptr<memory::Allocation> holder,
const proto::VarType::Type type) {
ResetHolder(holder);
type_ = type;
}
} // namespace framework
} // namespace paddle
......@@ -169,6 +169,9 @@ class Tensor {
void ResetHolder(std::shared_ptr<memory::Allocation> holder);
void ResetHolderWithType(std::shared_ptr<memory::Allocation> holder,
const proto::VarType::Type type);
private:
/*! holds the memory block if allocated. */
std::shared_ptr<memory::Allocation> holder_;
......
......@@ -459,17 +459,20 @@ PYBIND11_MODULE(core_noavx, m) {
})
.def("_clear", &Tensor::clear)
.def("set", SetTensorFromPyArray<paddle::platform::CPUPlace>,
py::arg("array"), py::arg("place"))
py::arg("array"), py::arg("place"), py::arg("zero_copy") = false)
.def("set", SetTensorFromPyArray<paddle::platform::CUDAPlace>,
py::arg("array"), py::arg("place"))
py::arg("array"), py::arg("place"), py::arg("zero_copy") = false)
.def("set", SetTensorFromPyArray<paddle::platform::CUDAPinnedPlace>,
py::arg("array"), py::arg("place"), R"DOC(
py::arg("array"), py::arg("place"), py::arg("zero_copy") = false,
R"DOC(
Set the data of LoDTensor on place with given numpy array.
Args:
lod (numpy.ndarray): The data to set.
place (CPUPlace|CUDAPlace|CUDAPinnedPlace): The place where the
LoDTensor is to be set.
zero_copy (bool, optional): Whether to share memory with the input numpy array.
This parameter only works with CPUPlace. Default: False.
Returns:
None.
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <Python.h>
#include <algorithm>
#include <memory>
......@@ -64,6 +65,31 @@ namespace pybind {
namespace details {
template <typename T>
class PYBIND11_HIDDEN NumpyAllocation : public memory::Allocation {
public:
explicit NumpyAllocation(const py::array &arr)
: Allocation(const_cast<void *>(arr.data()), sizeof(T) * (arr.size()),
paddle::platform::CPUPlace()),
arr_(arr.ptr()) {
PADDLE_ENFORCE_NOT_NULL(arr_, platform::errors::InvalidArgument(
"The underlying PyObject pointer of "
"numpy array cannot be nullptr"));
PADDLE_ENFORCE_NE(
arr_, Py_None,
platform::errors::PreconditionNotMet(
"The underlying PyObject pointer of numpy array cannot be None"));
Py_INCREF(arr_);
}
~NumpyAllocation() override {
py::gil_scoped_acquire gil;
Py_DECREF(arr_);
}
private:
PyObject *arr_;
};
template <typename T>
struct ValidDTypeToPyArrayChecker {
static constexpr bool kValue = false;
......@@ -141,19 +167,26 @@ template <typename T, typename P>
void SetTensorFromPyArrayT(
framework::Tensor *self,
const py::array_t<T, py::array::c_style | py::array::forcecast> &array,
const P &place) {
const P &place, bool zero_copy) {
std::vector<int64_t> dims;
dims.reserve(array.ndim());
for (decltype(array.ndim()) i = 0; i < array.ndim(); ++i) {
dims.push_back(static_cast<int>(array.shape()[i]));
}
self->Resize(framework::make_ddim(dims));
auto dst = self->mutable_data<T>(place);
if (paddle::platform::is_cpu_place(place)) {
std::memcpy(dst, array.data(), array.nbytes());
if (zero_copy) {
auto holder = std::make_shared<details::NumpyAllocation<T>>(array);
auto type = framework::ToDataType(std::type_index(typeid(T)));
self->ResetHolderWithType(holder, type);
} else {
auto dst = self->mutable_data<T>(place);
std::memcpy(dst, array.data(), array.nbytes());
}
} else {
#ifdef PADDLE_WITH_CUDA
auto dst = self->mutable_data<T>(place);
if (paddle::platform::is_cuda_pinned_place(place)) {
std::memcpy(dst, array.data(), array.nbytes());
} else if (paddle::platform::is_gpu_place(place)) {
......@@ -173,27 +206,29 @@ void SetTensorFromPyArrayT(
template <typename P>
void SetTensorFromPyArray(framework::Tensor *self, const py::object &obj,
const P &place) {
const P &place, bool zero_copy) {
auto array = obj.cast<py::array>();
if (py::isinstance<py::array_t<float>>(array)) {
SetTensorFromPyArrayT<float, P>(self, array, place);
SetTensorFromPyArrayT<float, P>(self, array, place, zero_copy);
} else if (py::isinstance<py::array_t<int>>(array)) {
SetTensorFromPyArrayT<int, P>(self, array, place);
SetTensorFromPyArrayT<int, P>(self, array, place, zero_copy);
} else if (py::isinstance<py::array_t<int64_t>>(array)) {
SetTensorFromPyArrayT<int64_t, P>(self, array, place);
SetTensorFromPyArrayT<int64_t, P>(self, array, place, zero_copy);
} else if (py::isinstance<py::array_t<double>>(array)) {
SetTensorFromPyArrayT<double, P>(self, array, place);
SetTensorFromPyArrayT<double, P>(self, array, place, zero_copy);
} else if (py::isinstance<py::array_t<int8_t>>(array)) {
SetTensorFromPyArrayT<int8_t, P>(self, array, place);
SetTensorFromPyArrayT<int8_t, P>(self, array, place, zero_copy);
} else if (py::isinstance<py::array_t<uint8_t>>(array)) {
SetTensorFromPyArrayT<uint8_t, P>(self, array, place);
SetTensorFromPyArrayT<uint8_t, P>(self, array, place, zero_copy);
} else if (py::isinstance<py::array_t<paddle::platform::float16>>(array)) {
SetTensorFromPyArrayT<paddle::platform::float16, P>(self, array, place);
SetTensorFromPyArrayT<paddle::platform::float16, P>(self, array, place,
zero_copy);
} else if (py::isinstance<py::array_t<uint16_t>>(array)) {
// TODO(cql): temporary keeping uint16, should be depracated later
SetTensorFromPyArrayT<paddle::platform::float16, P>(self, array, place);
SetTensorFromPyArrayT<paddle::platform::float16, P>(self, array, place,
zero_copy);
} else if (py::isinstance<py::array_t<bool>>(array)) {
SetTensorFromPyArrayT<bool, P>(self, array, place);
SetTensorFromPyArrayT<bool, P>(self, array, place, zero_copy);
} else {
PADDLE_THROW(
"Incompatible data or style type: tensor.set() supports bool, float16, "
......
......@@ -138,7 +138,6 @@ def guard(place=None):
train = framework.Program()
startup = framework.Program()
tracer = Tracer()
core._switch_tracer(tracer)
if place is None:
if core.is_compiled_with_cuda():
......@@ -173,7 +172,7 @@ def _print_debug_msg(limit=5, is_test=False):
@framework.dygraph_only
def to_variable(value, block=None, name=None):
def to_variable(value, block=None, name=None, zero_copy=None):
"""
The API will create a ``Variable`` object from numpy\.ndarray or Variable object.
......@@ -181,6 +180,7 @@ def to_variable(value, block=None, name=None):
value(ndarray): The numpy\.ndarray object that needs to be converted, it can be multi-dimension, and the data type is one of numpy\.{float16, float32, float64, int16, int32, int64, uint8, uint16}.
block(fluid.Block, optional): Which block this variable will be in. Default: None.
name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`
zero_copy(bool, optional): Whether to share memory with the input numpy array. This parameter only works with CPUPlace and will be set to True when it is None. Default: None.
Returns:
Variable: ``Tensor`` created from the specified numpy\.ndarray object, data type and shape is the same as ``value`` .
......@@ -192,9 +192,14 @@ def to_variable(value, block=None, name=None):
import numpy as np
import paddle.fluid as fluid
with fluid.dygraph.guard():
with fluid.dygraph.guard(fluid.CPUPlace()):
x = np.ones([2, 2], np.float32)
y = fluid.dygraph.to_variable(x, zero_copy=False)
x[0][0] = -1
y[0][0].numpy() # array([1.], dtype=float32)
y = fluid.dygraph.to_variable(x)
x[0][0] = 0
y[0][0].numpy() # array([0.], dtype=float32)
"""
if isinstance(value, np.ndarray):
......@@ -212,7 +217,14 @@ def to_variable(value, block=None, name=None):
stop_gradient=True)
var = py_var._ivar.value()
tensor = var.get_tensor()
tensor.set(value, framework._current_expected_place())
if isinstance(framework._current_expected_place(),
framework.core.CPUPlace):
if zero_copy is None:
zero_copy = True
tensor.set(value, framework._current_expected_place(), zero_copy)
else:
assert not zero_copy, "zero_copy mode can only be used with CPUPlace"
tensor.set(value, framework._current_expected_place(), False)
return py_var
elif isinstance(value, framework.Variable):
return value
......
......@@ -4781,9 +4781,11 @@ def _dygraph_guard(tracer):
global _dygraph_tracer_
tmp_trace = _dygraph_tracer_
_dygraph_tracer_ = tracer
core._switch_tracer(tracer)
yield
core._switch_tracer(tmp_trace)
_dygraph_tracer_ = tmp_trace
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle.fluid as fluid
class TestImperativeNumpyBridge(unittest.TestCase):
def test_tensor_from_numpy(self):
data_np = np.array([[2, 3, 1]]).astype('float32')
with fluid.dygraph.guard(fluid.CPUPlace()):
var = fluid.dygraph.to_variable(data_np, zero_copy=True)
self.assertTrue(np.array_equal(var.numpy(), data_np))
data_np[0][0] = 4
self.assertEqual(data_np[0][0], 4)
self.assertEqual(var[0][0].numpy()[0], 4)
self.assertTrue(np.array_equal(var.numpy(), data_np))
var2 = fluid.dygraph.to_variable(data_np, zero_copy=False)
self.assertTrue(np.array_equal(var2.numpy(), data_np))
data_np[0][0] = -1
self.assertEqual(data_np[0][0], -1)
self.assertNotEqual(var2[0][0].numpy()[0], -1)
self.assertFalse(np.array_equal(var2.numpy(), data_np))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册