diff --git a/paddle/fluid/imperative/tests/test_tracer.cc b/paddle/fluid/imperative/tests/test_tracer.cc index 5bb2cbc887da502860cfe1456cb49c548290b9da..0bdf8fa7fe63f8c54cb02c48f4af559f186765e8 100644 --- a/paddle/fluid/imperative/tests/test_tracer.cc +++ b/paddle/fluid/imperative/tests/test_tracer.cc @@ -280,8 +280,8 @@ TEST(test_tracer, test_unique_name_generator) { imperative::Tracer tracer; auto fc_1 = tracer.GenerateUniqueName("fc"); auto fc_2 = tracer.GenerateUniqueName("fc"); - ASSERT_STREQ("fc_1", fc_1.c_str()); - ASSERT_STREQ("fc_2", fc_2.c_str()); + ASSERT_STREQ("fc_0", fc_1.c_str()); + ASSERT_STREQ("fc_1", fc_2.c_str()); } TEST(test_tracer, test_current_tracer) { diff --git a/paddle/fluid/imperative/tracer.h b/paddle/fluid/imperative/tracer.h index 5d5988981bb37b93054863221c04d03823af919e..2bbf0caf40499e0a17803dbff21c4cd875b1d7d2 100644 --- a/paddle/fluid/imperative/tracer.h +++ b/paddle/fluid/imperative/tracer.h @@ -33,7 +33,7 @@ class UniqueNameGenerator { public: explicit UniqueNameGenerator(std::string prefix = "") : prefix_(prefix) {} std::string Generate(std::string key = "tmp") { - return prefix_ + key + "_" + std::to_string(++id_); + return prefix_ + key + "_" + std::to_string(id_++); } private: diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 5a8506973dd09be58a8d3fdf285aa713a5b82496..82d6931ccdcf9596cb06b0f7d9566510118d6d9a 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -31,18 +31,13 @@ limitations under the License. */ #include "paddle/fluid/imperative/tracer.h" #include "paddle/fluid/imperative/type_defs.h" #include "paddle/fluid/pybind/pybind_boost_headers.h" +#include "paddle/fluid/pybind/tensor_py.h" namespace paddle { namespace pybind { namespace py = ::pybind11; -template -extern void SetTensorFromPyArray(framework::Tensor *self, const py::object &obj, - const P &place, bool zero_copy); -extern py::array TensorToPyArray(const framework::Tensor &tensor, - bool need_deep_copy = false); - class Layer : public imperative::Layer { public: using imperative::Layer::Layer; // Inherit constructors @@ -55,45 +50,44 @@ class Layer : public imperative::Layer { } }; -static void InitTensorForVarBase(imperative::VarBase *self, bool persistable, - bool is_default, const py::array &array, - const py::object &obj = py::object(), - bool zero_copy = false) { - new (self) imperative::VarBase( - imperative::GetCurrentTracer()->GenerateUniqueName("generated_var_")); - self->SetPersistable(persistable); +static const platform::Place PyObjectToPlace(const py::object &place_obj) { + if (py::isinstance(place_obj)) { + return place_obj.cast(); + } else if (py::isinstance(place_obj)) { + return place_obj.cast(); + } else if (py::isinstance(place_obj)) { + return place_obj.cast(); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Place should be one of CPUPlace/CUDAPlace/CUDAPinnedPlace")); + } +} + +static void InitTensorForVarBase(imperative::VarBase *self, + const py::array &array, + const platform::Place place, + bool persistable = false, + bool zero_copy = false, + std::string name = "") { + if (name == "") { + name = imperative::GetCurrentTracer()->GenerateUniqueName("generated_var"); + } + new (self) imperative::VarBase(name); auto *tensor = self->MutableVar()->GetMutable(); - if (is_default) { - auto place = imperative::GetCurrentTracer()->ExpectedPlace(); - if (platform::is_cpu_place(place)) { - SetTensorFromPyArray( - tensor, array, boost::get(place), zero_copy); - } else if (platform::is_gpu_place(place)) { - SetTensorFromPyArray( - tensor, array, boost::get(place), zero_copy); - } else if (platform::is_cuda_pinned_place(place)) { - SetTensorFromPyArray( - tensor, array, boost::get(place), - zero_copy); - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "Place should be one of CPUPlace/CUDAPlace/CUDAPinnedPlace")); - } + if (platform::is_cpu_place(place)) { + SetTensorFromPyArray( + tensor, array, boost::get(place), zero_copy); + } else if (platform::is_gpu_place(place)) { + SetTensorFromPyArray( + tensor, array, boost::get(place), zero_copy); + } else if (platform::is_cuda_pinned_place(place)) { + SetTensorFromPyArray( + tensor, array, boost::get(place), zero_copy); } else { - if (py::isinstance(obj)) { - SetTensorFromPyArray( - tensor, array, obj.cast(), zero_copy); - } else if (py::isinstance(obj)) { - SetTensorFromPyArray( - tensor, array, obj.cast(), zero_copy); - } else if (py::isinstance(obj)) { - SetTensorFromPyArray( - tensor, array, obj.cast(), zero_copy); - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "Place should be one of CPUPlace/CUDAPlace/CUDAPinnedPlace")); - } + PADDLE_THROW(platform::errors::InvalidArgument( + "Place should be one of CPUPlace/CUDAPlace/CUDAPinnedPlace")); } + self->SetPersistable(persistable); self->SetType(framework::proto::VarType::LOD_TENSOR); self->SetDataType(tensor->type()); } @@ -103,28 +97,32 @@ static void InitVarBaseFromNumpyWithKwargs(imperative::VarBase *self, PADDLE_ENFORCE_EQ( kwargs.contains("value"), true, platform::errors::InvalidArgument("Missing argument: value")); - if (kwargs.contains("place")) { - InitTensorForVarBase(self, kwargs.contains("persistable") - ? kwargs["persistable"].cast() - : false, - false, kwargs["value"].cast(), - kwargs["place"], kwargs["zero_copy"].cast()); - } else { - InitTensorForVarBase(self, kwargs.contains("persistable") - ? kwargs["persistable"].cast() - : false, - true, kwargs["value"].cast(), py::object(), - kwargs["zero_copy"].cast()); - } + + auto persistable = kwargs.contains("persistable") + ? kwargs["persistable"].cast() + : false; + auto array = kwargs.contains("value") ? kwargs["value"].cast() + : py::array(); + auto zero_copy = + kwargs.contains("zero_copy") ? kwargs["zero_copy"].cast() : false; + auto name = kwargs.contains("name") ? kwargs["name"].cast() : ""; + auto default_place = imperative::GetCurrentTracer()->ExpectedPlace(); + auto place = kwargs.contains("place") ? PyObjectToPlace(kwargs["place"]) + : default_place; + InitTensorForVarBase(self, array, place, persistable, zero_copy, name); } template static void InitVarBaseFromNumpyWithArg(imperative::VarBase *self, const py::array &array, const P &place, - bool persistable, bool zero_copy) { - // 0: value, 1: place, 2: name 3: persistable, 4: zero_copy - new (self) imperative::VarBase( - imperative::GetCurrentTracer()->GenerateUniqueName("generated_var_")); + bool persistable = false, + bool zero_copy = false, + std::string name = "") { + // 0: self, 1: value, 2: place, 3: persistable, 4: zero_copy, 5: name + if (name == "") { + name = imperative::GetCurrentTracer()->GenerateUniqueName("generated_var"); + } + new (self) imperative::VarBase(name); self->SetPersistable(persistable); auto *tensor = self->MutableVar()->GetMutable(); SetTensorFromPyArray

(tensor, array, place, zero_copy); @@ -133,9 +131,9 @@ static void InitVarBaseFromNumpyWithArg(imperative::VarBase *self, } static void InitVarBaseFromNumpyWithArgDefault(imperative::VarBase *self, - const py::array &array, - bool persistable) { - InitTensorForVarBase(self, persistable, true, array); + const py::array &array) { + auto place = imperative::GetCurrentTracer()->ExpectedPlace(); + InitTensorForVarBase(self, array, place); } static std::string GetTypeName(const imperative::VarBase &var) { @@ -147,6 +145,7 @@ static std::string GetTypeName(const imperative::VarBase &var) { return framework::ToTypeName(var.Var().Type()); } } + using PyNameVarBaseMap = std::unordered_map; template @@ -301,15 +300,14 @@ void BindImperative(py::module *m_ptr) { }) .def("__init__", &InitVarBaseFromNumpyWithArg, py::arg("value"), py::arg("place"), py::arg("persistable") = false, - py::arg("zero_copy") = false) + py::arg("zero_copy") = false, py::arg("name") = "") .def("__init__", &InitVarBaseFromNumpyWithArg, py::arg("value"), py::arg("place"), py::arg("persistable") = false, - py::arg("zero_copy") = false) + py::arg("zero_copy") = false, py::arg("name") = "") .def("__init__", &InitVarBaseFromNumpyWithArg, py::arg("value"), py::arg("place"), py::arg("persistable") = false, - py::arg("zero_copy") = false) - .def("__init__", &InitVarBaseFromNumpyWithArgDefault, py::arg("value"), - py::arg("persistable") = false) + py::arg("zero_copy") = false, py::arg("name") = "") + .def("__init__", &InitVarBaseFromNumpyWithArgDefault, py::arg("value")) .def("__init__", &InitVarBaseFromNumpyWithKwargs) .def("numpy", [](imperative::VarBase &self) -> py::array { @@ -384,26 +382,6 @@ void BindImperative(py::module *m_ptr) { y = x.detach() )DOC") - .def("_run_backward", - [](imperative::VarBase &self, - const imperative::detail::BackwardStrategy &bckst, - const imperative::Tracer &tracer) { - // TODO(jiabin): when we impl more backward execution we can select - // them - - imperative::Engine *engine = tracer.GetDefaultEngine(); - VLOG(3) << "Start backward"; - engine->Init(&self, bckst); - engine->Execute(); - VLOG(3) << "Finish backward"; - }, - py::call_guard()) - .def("_grad_name", &imperative::VarBase::GradVarName) - .def("_grad_value", - [](imperative::VarBase &self) { - return self.MutableGradVar()->Get(); - }, - py::return_value_policy::reference) .def("clear_gradient", &imperative::VarBase::ClearGradient, R"DOC( **Notes**: @@ -437,6 +415,26 @@ void BindImperative(py::module *m_ptr) { loss2.clear_gradient() print("After clear {}".format(loss2.gradient())) )DOC") + .def("_run_backward", + [](imperative::VarBase &self, + const imperative::detail::BackwardStrategy &bckst, + const imperative::Tracer &tracer) { + // TODO(jiabin): when we impl more backward execution we can select + // them + + imperative::Engine *engine = tracer.GetDefaultEngine(); + VLOG(3) << "Start backward"; + engine->Init(&self, bckst); + engine->Execute(); + VLOG(3) << "Finish backward"; + }, + py::call_guard()) + .def("_grad_name", &imperative::VarBase::GradVarName) + .def("_grad_value", + [](imperative::VarBase &self) { + return self.MutableGradVar()->Get(); + }, + py::return_value_policy::reference) .def("_grad_ivar", [](const imperative::VarBase &self) { auto &grad_var = self.GradVarBase(); @@ -467,6 +465,11 @@ void BindImperative(py::module *m_ptr) { py::return_value_policy::reference) .def_property("name", &imperative::VarBase::Name, &imperative::VarBase::SetName) + .def_property("stop_gradient", + &imperative::VarBase::OverridedStopGradient, + &imperative::VarBase::SetOverridedStopGradient) + .def_property("persistable", &imperative::VarBase::Persistable, + &imperative::VarBase::SetPersistable) .def_property_readonly( "shape", [](imperative::VarBase &self) { @@ -483,12 +486,7 @@ void BindImperative(py::module *m_ptr) { } }) .def_property_readonly("type", &imperative::VarBase::Type) - .def_property_readonly("dtype", &imperative::VarBase::DataType) - .def_property("persistable", &imperative::VarBase::Persistable, - &imperative::VarBase::SetPersistable) - .def_property("stop_gradient", - &imperative::VarBase::OverridedStopGradient, - &imperative::VarBase::SetOverridedStopGradient); + .def_property_readonly("dtype", &imperative::VarBase::DataType); py::class_ layer(m, "Layer"); layer.def(py::init<>()) diff --git a/paddle/fluid/pybind/tensor_py.h b/paddle/fluid/pybind/tensor_py.h index 7185cf3fc48798f562da78616a12c8f7ba145b3b..cc44ad9a2deb556d2d9e275caa029025e1c5f533 100644 --- a/paddle/fluid/pybind/tensor_py.h +++ b/paddle/fluid/pybind/tensor_py.h @@ -294,9 +294,9 @@ void _concatCompute(const std::vector &ins, } } -void _getSliceinfo(const framework::Tensor &self, py::object obj, - const int64_t dim, int64_t *pstart, int64_t *pstop, - int64_t *pstep, int64_t *pslicelength) { +inline void _getSliceinfo(const framework::Tensor &self, py::object obj, + const int64_t dim, int64_t *pstart, int64_t *pstop, + int64_t *pstep, int64_t *pslicelength) { auto &start = *pstart; auto &stop = *pstop; auto &step = *pstep; diff --git a/python/paddle/fluid/dygraph/base.py b/python/paddle/fluid/dygraph/base.py index b1cbd399ab379b2c48a7fd43a6bea354e6c6af65..2a67a5ddd3f860d6702fbb3762b8badf178ceefc 100644 --- a/python/paddle/fluid/dygraph/base.py +++ b/python/paddle/fluid/dygraph/base.py @@ -172,6 +172,7 @@ def _print_debug_msg(limit=5, is_test=False): return unique_name_size, tracer_var_size, alive_cpp_var_size +# TODO(zhiqiu): Param 'block' should be deprecated, since block is meaningless in dygraph @framework.dygraph_only def to_variable(value, block=None, name=None, zero_copy=None): """ @@ -215,10 +216,10 @@ def to_variable(value, block=None, name=None, zero_copy=None): zero_copy = False py_var = core.VarBase( value=value, - name=name, - persistable=False, place=framework._current_expected_place(), - zero_copy=zero_copy) + persistable=False, + zero_copy=zero_copy, + name=name if name else '') return py_var elif isinstance(value, (core.VarBase, framework.Variable)): return value diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index fe88bb4766f0189aeee5b651f2f8c99b45245280..f54ceb62c8ff320b5c9a98f34faf479477c166a3 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -221,6 +221,23 @@ def _current_expected_place(): return _dygraph_current_expected_place_ +# TODO(zhiqiu): remove this function. +def _var_base_to_np(var_base): + """ + convert VarBase tp numpy + + Args: + var_base(VarBase) : the VarBase to convert + Returns (np.ndarray): the np.ndarray contain the value of VarBase + """ + + warnings.warn( + "paddle.fluid.framework._var_base_to_np is deprecated, please use var_base.numpy() instead of _var_base_to_np(var_base)." + ) + + return var_base.numpy() + + def _cpu_num(): if "CPU_NUM" not in os.environ.keys(): if multiprocessing.cpu_count() > 1: diff --git a/python/paddle/fluid/layer_helper_base.py b/python/paddle/fluid/layer_helper_base.py index ba528f3f48162b3322d0ef27b326bf899ab17a35..7e904ff31e247aa922fb918a6191f251ecde4342 100644 --- a/python/paddle/fluid/layer_helper_base.py +++ b/python/paddle/fluid/layer_helper_base.py @@ -73,7 +73,7 @@ class LayerHelperBase(object): ), "to_variable could only be called in dygraph mode" py_var = core.VarBase( value=value, - name=name, + name=name if name else '', persistable=False, place=_current_expected_place(), zero_copy=False) diff --git a/python/paddle/fluid/tests/unittests/test_var_base.py b/python/paddle/fluid/tests/unittests/test_var_base.py new file mode 100644 index 0000000000000000000000000000000000000000..3f2903cb5d432015f3d16ec4a20e6a7dbe351e63 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_var_base.py @@ -0,0 +1,113 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +from paddle.fluid.framework import default_main_program, Program, convert_np_dtype_to_dtype_, in_dygraph_mode +import paddle.fluid as fluid +import paddle.fluid.layers as layers +import paddle.fluid.core as core +import numpy as np + + +class TestVarBase(unittest.TestCase): + def setUp(self): + self.shape = [512, 1234] + self.dtype = np.float32 + self.array = np.random.uniform(0.1, 1, self.shape).astype(self.dtype) + + def test_to_variable(self): + with fluid.dygraph.guard(): + var = fluid.dygraph.to_variable(self.array, name="abc") + self.assertTrue(np.array_equal(var.numpy(), self.array)) + self.assertEqual(var.name, 'abc') + # default value + self.assertEqual(var.persistable, False) + self.assertEqual(var.stop_gradient, True) + self.assertEqual(var.shape, self.shape) + self.assertEqual(var.dtype, core.VarDesc.VarType.FP32) + self.assertEqual(var.type, core.VarDesc.VarType.LOD_TENSOR) + + def test_write_property(self): + with fluid.dygraph.guard(): + var = fluid.dygraph.to_variable(self.array) + + self.assertEqual(var.name, 'generated_var_0') + var.name = 'test' + self.assertEqual(var.name, 'test') + + self.assertEqual(var.persistable, False) + var.persistable = True + self.assertEqual(var.persistable, True) + + self.assertEqual(var.stop_gradient, True) + var.stop_gradient = False + self.assertEqual(var.stop_gradient, False) + + # test some patched methods + def test_set_value(self): + with fluid.dygraph.guard(): + var = fluid.dygraph.to_variable(self.array) + tmp1 = np.random.uniform(0.1, 1, [2, 2, 3]).astype(self.dtype) + self.assertRaises(AssertionError, var.set_value, tmp1) + + tmp2 = np.random.uniform(0.1, 1, self.shape).astype(self.dtype) + var.set_value(tmp2) + self.assertTrue(np.array_equal(var.numpy(), tmp2)) + + def test_to_string(self): + with fluid.dygraph.guard(): + var = fluid.dygraph.to_variable(self.array) + self.assertTrue(isinstance(str(var.to_string(True)), str)) + + def test_backward(self): + with fluid.dygraph.guard(): + var = fluid.dygraph.to_variable(self.array) + var.stop_gradient = False + loss = fluid.layers.relu(var) + loss.backward() + grad_var = var._grad_ivar() + self.assertEqual(grad_var.shape, self.shape) + + def test_gradient(self): + with fluid.dygraph.guard(): + var = fluid.dygraph.to_variable(self.array) + var.stop_gradient = False + loss = fluid.layers.relu(var) + loss.backward() + grad_var = var.gradient() + self.assertEqual(grad_var.shape, self.array.shape) + + def test_block(self): + with fluid.dygraph.guard(): + var = fluid.dygraph.to_variable(self.array) + self.assertEqual(var.block, + fluid.default_main_program().global_block()) + + def test_slice(self): + with fluid.dygraph.guard(): + var = fluid.dygraph.to_variable(self.array) + self.assertTrue(np.array_equal(var[1, :].numpy(), self.array[1, :])) + + def test_var_base_to_np(self): + with fluid.dygraph.guard(): + var = fluid.dygraph.to_variable(self.array) + self.assertTrue( + np.array_equal(var.numpy(), + fluid.framework._var_base_to_np(var))) + + +if __name__ == '__main__': + unittest.main()