未验证 提交 9c9e635c 编写于 作者: L Leo Chen 提交者: GitHub

support tensor to varbase, test=develop (#24660)

上级 fdbe114b
......@@ -101,6 +101,7 @@ static void InitTensorForVarBase(imperative::VarBase *self,
static void InitVarBaseFromNumpyWithKwargs(imperative::VarBase *self,
const py::kwargs &kwargs) {
VLOG(4) << "Init VarBase";
PADDLE_ENFORCE_EQ(
kwargs.contains("value"), true,
platform::errors::NotFound(
......@@ -126,6 +127,7 @@ static void InitVarBaseFromNumpyWithArg(imperative::VarBase *self,
bool persistable = false,
bool zero_copy = false,
std::string name = "") {
VLOG(4) << "Init VarBase";
// 0: self, 1: value, 2: place, 3: persistable, 4: zero_copy, 5: name
if (name == "") {
name = imperative::GetCurrentTracer()->GenerateUniqueName("generated_var");
......@@ -140,10 +142,31 @@ static void InitVarBaseFromNumpyWithArg(imperative::VarBase *self,
static void InitVarBaseFromNumpyWithArgDefault(imperative::VarBase *self,
const py::array &array) {
VLOG(4) << "Init VarBase";
auto place = imperative::GetCurrentTracer()->ExpectedPlace();
InitTensorForVarBase(self, array, place);
}
static void InitVarBaseFromTensorWithArgDefault(
imperative::VarBase *self, const framework::LoDTensor &tensor) {
VLOG(4) << "Init VarBase";
auto place = imperative::GetCurrentTracer()->ExpectedPlace();
new (self) imperative::VarBase(
imperative::GetCurrentTracer()->GenerateUniqueName("generated_var"));
self->SetPersistable(false);
self->SetType(framework::proto::VarType::LOD_TENSOR);
self->SetDataType(tensor.type());
auto *new_tensor = self->MutableVar()->GetMutable<framework::LoDTensor>();
// Same place,share data directly
if (place == tensor.place()) {
new_tensor->ShareDataWith(tensor);
VLOG(4) << "Same place, do ShareDataWith";
} else {
framework::TensorCopy(tensor, place, new_tensor);
VLOG(4) << "Different place, do TensorCopy";
}
}
static std::string GetTypeName(const imperative::VarBase &var) {
if (var.Type() == framework::proto::VarType::RAW) {
return "RAW";
......@@ -520,6 +543,7 @@ void BindImperative(py::module *m_ptr) {
[](imperative::VarBase &self, framework::proto::VarType::Type dtype,
const std::vector<int> &dims, const py::handle &name,
framework::proto::VarType::Type type, bool persistable) {
VLOG(4) << "Init VarBase";
std::string act_name = "";
if (!name.ptr() || name.ptr() == Py_None) {
act_name = imperative::GetCurrentTracer()->GenerateUniqueName(
......@@ -547,6 +571,7 @@ void BindImperative(py::module *m_ptr) {
py::arg("value"), py::arg("place"), py::arg("persistable") = false,
py::arg("zero_copy") = false, py::arg("name") = "")
.def("__init__", &InitVarBaseFromNumpyWithArgDefault, py::arg("value"))
.def("__init__", &InitVarBaseFromTensorWithArgDefault, py::arg("tensor"))
.def("__init__", &InitVarBaseFromNumpyWithKwargs)
.def("__getitem__",
[](std::shared_ptr<imperative::VarBase> &self, py::handle _index) {
......
......@@ -538,8 +538,8 @@ def to_variable(value, name=None, zero_copy=None):
numpy\.ndarray, Variable or ComplexVariable object.
Parameters:
value(ndarray|Variable|ComplexVariable): The numpy\.ndarray, Variable
or ComplexVariable object that needs to be converted, it can be
value(ndarray|Variable|Tensor|ComplexVariable): The numpy\.ndarray, Variable
Tensor or ComplexVariable object that needs to be converted, it can be
multi-dimension, and the data type is one of numpy\.{float16,
float32, float64, int16, int32, int64, uint8, uint16, complex64,
complex128}.
......@@ -611,6 +611,8 @@ def to_variable(value, name=None, zero_copy=None):
elif isinstance(value, (core.VarBase, framework.Variable,
framework.ComplexVariable)):
return value
elif isinstance(value, (core.Tensor, core.LoDTensor)):
return core.VarBase(value)
else:
raise TypeError(
"The type of input value is invalid, expected type is 'ndarray', "
......
......@@ -240,18 +240,22 @@ class TestImperative(unittest.TestCase):
def test_create_VarBase(self):
x = np.ones([2, 2], np.float32)
y = np.zeros([3, 3], np.float32)
t = fluid.Tensor()
t.set(x, fluid.CPUPlace())
with fluid.dygraph.guard():
tmp = fluid.core.VarBase(value=x, place=fluid.core.CPUPlace())
tmp2 = fluid.core.VarBase(y, fluid.core.CPUPlace())
tmp3 = fluid.dygraph.base.to_variable(x)
tmp4 = fluid.core.VarBase(y)
tmp5 = fluid.core.VarBase(value=x)
tmp6 = fluid.core.VarBase(t)
self.assertTrue(np.array_equal(x, tmp.numpy()))
self.assertTrue(np.array_equal(y, tmp2.numpy()))
self.assertTrue(np.array_equal(x, tmp3.numpy()))
self.assertTrue(np.array_equal(y, tmp4.numpy()))
self.assertTrue(np.array_equal(x, tmp5.numpy()))
self.assertTrue(np.array_equal(x, tmp6.numpy()))
def test_no_grad_guard(self):
data = np.array([[2, 3], [4, 5]]).astype('float32')
......@@ -384,7 +388,6 @@ class TestImperative(unittest.TestCase):
var_inp = fluid.dygraph.base.to_variable(np_inp)
var_inp.stop_gradient = False
l = MyLayer()
print(var_inp)
x = l(var_inp)[0]
self.assertIsNotNone(x)
dy_out = x.numpy()
......
......@@ -47,6 +47,13 @@ class TestVarBase(unittest.TestCase):
linear = fluid.dygraph.Linear(32, 64)
var = linear._helper.to_variable("test", name="abc")
def test_tensor_to_variable(self):
with fluid.dygraph.guard():
t = fluid.Tensor()
t.set(np.ndarray([5, 30]), fluid.CPUPlace())
var = fluid.dygraph.to_variable(t)
self.assertTrue(np.array_equal(t, var.numpy()))
def test_write_property(self):
with fluid.dygraph.guard():
var = fluid.dygraph.to_variable(self.array)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册