diff --git a/cmake/third_party.cmake b/cmake/third_party.cmake index 7b9d9ad891f71dd5206fc1fbeb59b7bc8b52007a..5c27e5f951e0e20e55be6497446af937daede055 100644 --- a/cmake/third_party.cmake +++ b/cmake/third_party.cmake @@ -203,6 +203,7 @@ elseif(${CBLAS_PROVIDER} STREQUAL EXTERN_OPENBLAS) list(APPEND third_party_deps extern_openblas) endif() + if(WITH_MKLDNN) include(external/mkldnn) # download, build, install mkldnn list(APPEND third_party_deps extern_mkldnn) diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc index 180c29c6559484889cdccec1df1fe3f42201a7b6..3cce46f346d269e1cb1e458df1b71a314ad975a8 100644 --- a/paddle/fluid/imperative/layer.cc +++ b/paddle/fluid/imperative/layer.cc @@ -236,11 +236,13 @@ std::shared_ptr VarBase::NewVarBase(const platform::Place& dst_place, // TODO(Jiabin): change this after move unique_name generator to CXX auto new_var = std::make_shared( - false, "Itmp" + std::to_string(copied_counter_++)); + true, Name() + std::to_string(copied_counter_++)); auto* dst_tensor = new_var->var_.GetMutable(); dst_tensor->set_lod(src_tensor.lod()); - + new_var->SetPersistable(Persistable()); + new_var->SetDataType(DataType()); + new_var->SetType(Type()); framework::TensorCopy(src_tensor, dst_place, dst_tensor); if (blocking) { platform::DeviceContextPool::Instance().Get(dst_place)->Wait(); @@ -253,7 +255,6 @@ std::shared_ptr VarBase::NewVarBase(const platform::Place& dst_place, if (platform::is_gpu_place(dst_place)) { VLOG(3) << "copy tensor " << Name() << " from gpu"; } - return new_var; } else { auto& src_selected_rows = var_.Get(); diff --git a/paddle/fluid/imperative/tests/test_layer.cc b/paddle/fluid/imperative/tests/test_layer.cc index e793f3da64654ca875deb594ad6c6a3c14929a0d..f220e64811fe8fdea2c7a0d858f9b0cfbc460a67 100644 --- a/paddle/fluid/imperative/tests/test_layer.cc +++ b/paddle/fluid/imperative/tests/test_layer.cc @@ -158,7 +158,7 @@ TEST(test_layer, test_varbase_basic) { vin->MutableVar()->GetMutable()->mutable_data( place); std::shared_ptr vout(vin->NewVarBase(place, false)); - ASSERT_EQ(vout->Name(), "Itmp0"); + ASSERT_EQ(vout->Name(), "vin0"); std::shared_ptr vin_with_grad( new imperative::VarBase(true, "vin")); diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index b9e876b9b8a0d4bf7c549cd565b1829ca7ce19a6..5a8506973dd09be58a8d3fdf285aa713a5b82496 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -30,7 +30,6 @@ limitations under the License. */ #include "paddle/fluid/imperative/profiler.h" #include "paddle/fluid/imperative/tracer.h" #include "paddle/fluid/imperative/type_defs.h" - #include "paddle/fluid/pybind/pybind_boost_headers.h" namespace paddle { @@ -38,6 +37,12 @@ namespace pybind { namespace py = ::pybind11; +template +extern void SetTensorFromPyArray(framework::Tensor *self, const py::object &obj, + const P &place, bool zero_copy); +extern py::array TensorToPyArray(const framework::Tensor &tensor, + bool need_deep_copy = false); + class Layer : public imperative::Layer { public: using imperative::Layer::Layer; // Inherit constructors @@ -50,42 +55,99 @@ class Layer : public imperative::Layer { } }; -// warper for pyobject to avoid imperative module depend on python -// TODO(jiabin) Add OpBase's pybind interface back to enable backward hook -class PYBIND11_HIDDEN PyCallableObject { - public: - PyCallableObject(std::shared_ptr py_obj_ptr) - : py_obj_ptr_(std::move(py_obj_ptr)) {} - ~PyCallableObject() { - py::call_guard(); - py_obj_ptr_.reset(); +static void InitTensorForVarBase(imperative::VarBase *self, bool persistable, + bool is_default, const py::array &array, + const py::object &obj = py::object(), + bool zero_copy = false) { + new (self) imperative::VarBase( + imperative::GetCurrentTracer()->GenerateUniqueName("generated_var_")); + self->SetPersistable(persistable); + auto *tensor = self->MutableVar()->GetMutable(); + if (is_default) { + auto place = imperative::GetCurrentTracer()->ExpectedPlace(); + if (platform::is_cpu_place(place)) { + SetTensorFromPyArray( + tensor, array, boost::get(place), zero_copy); + } else if (platform::is_gpu_place(place)) { + SetTensorFromPyArray( + tensor, array, boost::get(place), zero_copy); + } else if (platform::is_cuda_pinned_place(place)) { + SetTensorFromPyArray( + tensor, array, boost::get(place), + zero_copy); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Place should be one of CPUPlace/CUDAPlace/CUDAPinnedPlace")); + } + } else { + if (py::isinstance(obj)) { + SetTensorFromPyArray( + tensor, array, obj.cast(), zero_copy); + } else if (py::isinstance(obj)) { + SetTensorFromPyArray( + tensor, array, obj.cast(), zero_copy); + } else if (py::isinstance(obj)) { + SetTensorFromPyArray( + tensor, array, obj.cast(), zero_copy); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Place should be one of CPUPlace/CUDAPlace/CUDAPinnedPlace")); + } } - void operator()() { - py::call_guard(); - py_obj_ptr_->operator()(this); + self->SetType(framework::proto::VarType::LOD_TENSOR); + self->SetDataType(tensor->type()); +} + +static void InitVarBaseFromNumpyWithKwargs(imperative::VarBase *self, + const py::kwargs &kwargs) { + PADDLE_ENFORCE_EQ( + kwargs.contains("value"), true, + platform::errors::InvalidArgument("Missing argument: value")); + if (kwargs.contains("place")) { + InitTensorForVarBase(self, kwargs.contains("persistable") + ? kwargs["persistable"].cast() + : false, + false, kwargs["value"].cast(), + kwargs["place"], kwargs["zero_copy"].cast()); + } else { + InitTensorForVarBase(self, kwargs.contains("persistable") + ? kwargs["persistable"].cast() + : false, + true, kwargs["value"].cast(), py::object(), + kwargs["zero_copy"].cast()); } +} - private: - std::shared_ptr py_obj_ptr_; -}; +template +static void InitVarBaseFromNumpyWithArg(imperative::VarBase *self, + const py::array &array, const P &place, + bool persistable, bool zero_copy) { + // 0: value, 1: place, 2: name 3: persistable, 4: zero_copy + new (self) imperative::VarBase( + imperative::GetCurrentTracer()->GenerateUniqueName("generated_var_")); + self->SetPersistable(persistable); + auto *tensor = self->MutableVar()->GetMutable(); + SetTensorFromPyArray

(tensor, array, place, zero_copy); + self->SetType(framework::proto::VarType::LOD_TENSOR); + self->SetDataType(tensor->type()); +} + +static void InitVarBaseFromNumpyWithArgDefault(imperative::VarBase *self, + const py::array &array, + bool persistable) { + InitTensorForVarBase(self, persistable, true, array); +} -// Function like obj.attr_name in Python. -static PyObject *GetPythonAttribute(PyObject *obj, const char *attr_name) { - // NOTE(zjl): PyObject_GetAttrString would return nullptr when attr_name - // is not inside obj, but it would also set the error flag of Python. - // If the error flag is set in C++, C++ code would not raise Exception, - // but Python would raise Exception once C++ call ends. - // To avoid unexpected Exception raised in Python, we check whether - // attribute exists before calling PyObject_GetAttrString. - // - // Caution: PyObject_GetAttrString would increase reference count of PyObject. - // Developer should call Py_DECREF manually after the attribute is not used. - if (PyObject_HasAttrString(obj, attr_name)) { - return PyObject_GetAttrString(obj, attr_name); +static std::string GetTypeName(const imperative::VarBase &var) { + if (var.Type() == framework::proto::VarType::RAW) { + return "RAW"; + } else if (!var.Var().IsInitialized()) { + return "nullptr"; } else { - return nullptr; + return framework::ToTypeName(var.Var().Type()); } } +using PyNameVarBaseMap = std::unordered_map; template static T PyObjectCast(PyObject *obj) { @@ -106,48 +168,36 @@ GetVarBaseListFromPyHandle(const py::handle &handle) { return {}; } - const char *kIVarField = "_ivar"; - PyObject *py_ivar = GetPythonAttribute(py_obj, kIVarField); std::vector> result; - if (py_ivar) { // Variable - result.emplace_back( - PyObjectCast>(py_ivar)); - Py_DECREF(py_ivar); - } else if (PyList_Check(py_obj)) { // List of Variable + if (PyList_Check(py_obj)) { // List of VarBase size_t len = PyList_GET_SIZE(py_obj); result.reserve(len); for (size_t i = 0; i < len; ++i) { - PyObject *py_ivar = - PyObject_GetAttrString(PyList_GET_ITEM(py_obj, i), kIVarField); - PADDLE_ENFORCE_NOT_NULL(py_ivar); + PyObject *py_ivar = PyList_GET_ITEM(py_obj, i); + PADDLE_ENFORCE_NOT_NULL( + py_ivar, platform::errors::InvalidArgument("Python Object is NULL")); result.emplace_back( PyObjectCast>(py_ivar)); - Py_DECREF(py_ivar); } - } else if (PyTuple_Check(py_obj)) { // Tuple of Variable + } else if (PyTuple_Check(py_obj)) { // Tuple of VarBase size_t len = PyTuple_GET_SIZE(py_obj); result.reserve(len); for (size_t i = 0; i < len; ++i) { - PyObject *py_ivar = - PyObject_GetAttrString(PyTuple_GET_ITEM(py_obj, i), kIVarField); - PADDLE_ENFORCE_NOT_NULL(py_ivar); + PyObject *py_ivar = PyTuple_GET_ITEM(py_obj, i); + PADDLE_ENFORCE_NOT_NULL( + py_ivar, platform::errors::InvalidArgument("Python Object is NULL")); result.emplace_back( PyObjectCast>(py_ivar)); - Py_DECREF(py_ivar); } - } else { - PADDLE_THROW( - "unsupported type %s, must be Variable, list[Variable] or " - "tuple[Variable]", - py::str(handle)); + } else { // VarBase + result.emplace_back( + PyObjectCast>(py_obj)); } return result; } -using PyNameVarBaseMap = std::unordered_map; - static imperative::NameVarBaseMap ConvertToNameVarBaseMap( const PyNameVarBaseMap &map) { imperative::NameVarBaseMap result; @@ -163,16 +213,6 @@ static imperative::NameVarBaseMap ConvertToNameVarBaseMap( return result; } -static std::string GetTypeName(const imperative::VarBase &var) { - if (var.Type() == framework::proto::VarType::RAW) { - return "RAW"; - } else if (!var.Var().IsInitialized()) { - return "nullptr"; - } else { - return framework::ToTypeName(var.Var().Type()); - } -} - // Bind Methods void BindImperative(py::module *m_ptr) { auto &m = *m_ptr; @@ -239,11 +279,17 @@ void BindImperative(py::module *m_ptr) { R"DOC()DOC") .def_static("_alive_vars", &imperative::VarBase::AliveVarNames) .def("__init__", - [](imperative::VarBase &self, const std::string &name, - framework::proto::VarType::Type type, - framework::proto::VarType::Type dtype, - const std::vector &dims, bool persistable) { - new (&self) imperative::VarBase(name); + [](imperative::VarBase &self, framework::proto::VarType::Type dtype, + const std::vector &dims, const py::handle &name, + framework::proto::VarType::Type type, bool persistable) { + std::string act_name = ""; + if (!name.ptr() || name.ptr() == Py_None) { + act_name = imperative::GetCurrentTracer()->GenerateUniqueName( + "generated_var"); + } else { + act_name = name.cast(); + } + new (&self) imperative::VarBase(act_name); self.SetPersistable(persistable); self.SetType(type); self.SetDataType(dtype); @@ -253,6 +299,91 @@ void BindImperative(py::module *m_ptr) { tensor->Resize(framework::make_ddim(dims)); } }) + .def("__init__", &InitVarBaseFromNumpyWithArg, + py::arg("value"), py::arg("place"), py::arg("persistable") = false, + py::arg("zero_copy") = false) + .def("__init__", &InitVarBaseFromNumpyWithArg, + py::arg("value"), py::arg("place"), py::arg("persistable") = false, + py::arg("zero_copy") = false) + .def("__init__", &InitVarBaseFromNumpyWithArg, + py::arg("value"), py::arg("place"), py::arg("persistable") = false, + py::arg("zero_copy") = false) + .def("__init__", &InitVarBaseFromNumpyWithArgDefault, py::arg("value"), + py::arg("persistable") = false) + .def("__init__", &InitVarBaseFromNumpyWithKwargs) + .def("numpy", + [](imperative::VarBase &self) -> py::array { + const auto &tensor = + self.MutableVar()->Get(); + PADDLE_ENFORCE_EQ( + tensor.IsInitialized(), true, + platform::errors::InvalidArgument( + "%s is Empty, Please check if it has no data in", + self.Name())); + return TensorToPyArray(tensor, true); + }, + R"DOC( + **Notes**: + **This API is ONLY avaliable in Dygraph mode** + + Returns a numpy array shows the value of current :ref:`api_guide_Variable_en` + + Returns: + ndarray: The numpy value of current Variable. + + Returns type: + ndarray: dtype is same as current Variable + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + from paddle.fluid.dygraph.base import to_variable + from paddle.fluid.dygraph import FC + import numpy as np + + data = np.random.uniform(-1, 1, [30, 10, 32]).astype('float32') + with fluid.dygraph.guard(): + fc = FC("fc", 64, num_flatten_dims=2) + data = to_variable(data) + x = fc(data) + print(x.numpy()) + + )DOC") + .def("detach", + [](const imperative::VarBase &self) { + const auto &tensor = self.Var().Get(); + PADDLE_ENFORCE_EQ(tensor.IsInitialized(), true, + platform::errors::InvalidArgument( + "%s has not been initialized", self.Name())); + return self.NewVarBase(tensor.place(), false); + }, + py::return_value_policy::copy, R"DOC( + **Notes**: + **This API is ONLY avaliable in Dygraph mode** + + Returns a new Variable, detached from the current graph. + + Returns: + ( :ref:`api_guide_Variable_en` | dtype is same as current Variable): The detached Variable. + + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + from paddle.fluid.dygraph.base import to_variable + from paddle.fluid.dygraph import FC + import numpy as np + + data = np.random.uniform(-1, 1, [30, 10, 32]).astype('float32') + with fluid.dygraph.guard(): + fc = FC("fc", 64, num_flatten_dims=2) + data = to_variable(data) + x = fc(data) + y = x.detach() + + )DOC") .def("_run_backward", [](imperative::VarBase &self, const imperative::detail::BackwardStrategy &bckst, @@ -273,7 +404,39 @@ void BindImperative(py::module *m_ptr) { return self.MutableGradVar()->Get(); }, py::return_value_policy::reference) - .def("_clear_gradient", &imperative::VarBase::ClearGradient) + .def("clear_gradient", &imperative::VarBase::ClearGradient, R"DOC( + + **Notes**: + **1. This API is ONLY avaliable in Dygraph mode** + + **2. Use it only Variable has gradient, normally we use this for Parameters since other temporal Variable will be deleted by Python's GC** + + Clear (set to ``0`` ) the Gradient of Current Variable + + Returns: None + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + import numpy as np + + x = np.ones([2, 2], np.float32) + with fluid.dygraph.guard(): + inputs2 = [] + for _ in range(10): + tmp = fluid.dygraph.base.to_variable(x) + tmp.stop_gradient=False + inputs2.append(tmp) + ret2 = fluid.layers.sums(inputs2) + loss2 = fluid.layers.reduce_sum(ret2) + backward_strategy = fluid.dygraph.BackwardStrategy() + backward_strategy.sort_sum_gradient = True + loss2.backward(backward_strategy) + print(loss2.gradient()) + loss2.clear_gradient() + print("After clear {}".format(loss2.gradient())) + )DOC") .def("_grad_ivar", [](const imperative::VarBase &self) { auto &grad_var = self.GradVarBase(); diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 0abaf2d5f0562fea253085ac243d99018589b442..2d2438e7cdb1305634578644cb8c9936a4b5908d 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -194,14 +194,8 @@ static std::vector> GetVarBaseList( if (!py_obj || py_obj == Py_None) { PADDLE_THROW("Save parameter [%s] is None", para.first); } - - const char *kIVarField = "_ivar"; - PyObject *py_ivar = GetPythonAttribute(py_obj, kIVarField); - PADDLE_ENFORCE_NOT_NULL(py_ivar, "Can not find ivar in Variable"); - vec_res.emplace_back( - PyObjectCast>(py_ivar)); - Py_DECREF(py_ivar); + PyObjectCast>(py_obj)); } return vec_res; diff --git a/paddle/fluid/pybind/tensor_py.h b/paddle/fluid/pybind/tensor_py.h index 0e04b20a10edc9376ed1837a6b2b651ee4921076..7185cf3fc48798f562da78616a12c8f7ba145b3b 100644 --- a/paddle/fluid/pybind/tensor_py.h +++ b/paddle/fluid/pybind/tensor_py.h @@ -486,7 +486,8 @@ inline framework::Tensor *PySliceTensor(const framework::Tensor &self, } } -inline py::array TensorToPyArray(const framework::Tensor &tensor) { +inline py::array TensorToPyArray(const framework::Tensor &tensor, + bool need_deep_copy = false) { if (!tensor.IsInitialized()) { return py::array(); } @@ -510,9 +511,26 @@ inline py::array TensorToPyArray(const framework::Tensor &tensor) { std::string py_dtype_str = details::TensorDTypeToPyDTypeStr(tensor.type()); if (!is_gpu_tensor) { - return py::array(py::buffer_info( - const_cast(tensor_buf_ptr), sizeof_dtype, py_dtype_str, - static_cast(tensor.dims().size()), py_dims, py_strides)); + if (!need_deep_copy) { + return py::array(py::buffer_info( + const_cast(tensor_buf_ptr), sizeof_dtype, py_dtype_str, + static_cast(tensor.dims().size()), py_dims, py_strides)); + } else { + py::array py_arr(py::dtype(py_dtype_str.c_str()), py_dims, py_strides); + PADDLE_ENFORCE_EQ(py_arr.writeable(), true, + platform::errors::InvalidArgument( + "PyArray must be writable, otherwise memory leak " + "or double free would occur")); + PADDLE_ENFORCE_EQ(py_arr.owndata(), true, + platform::errors::InvalidArgument( + "PyArray must own data, otherwise memory leak " + "or double free would occur")); + platform::CPUPlace place; + size_t copy_bytes = sizeof_dtype * numel; + paddle::memory::Copy(place, py_arr.mutable_data(), place, tensor_buf_ptr, + copy_bytes); + return py_arr; + } } #ifdef PADDLE_WITH_CUDA diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 4fae1d9a2a56018bfc878f11e7ddbe94c60aee67..71c47eb4e0857dd234efc55b3be6c0c30803fb7d 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -88,13 +88,13 @@ from .dygraph.nn import * from .dygraph.layers import * from .io import save, load, load_program_state, set_program_state from .dygraph.checkpoint import save_dygraph, load_dygraph - +from .dygraph.varbase_patch_methods import monkey_patch_varbase Tensor = LoDTensor __all__ = framework.__all__ + executor.__all__ + \ trainer_desc.__all__ + inferencer.__all__ + transpiler.__all__ + \ parallel_executor.__all__ + lod_tensor.__all__ + \ - data_feed_desc.__all__ + compiler.__all__ + backward.__all__ + [ + data_feed_desc.__all__ + compiler.__all__ + backward.__all__ + [ 'io', 'initializer', 'embedding', @@ -126,6 +126,7 @@ __all__ = framework.__all__ + executor.__all__ + \ 'install_check', 'save', 'load', + 'VarBase' ] @@ -234,3 +235,4 @@ def __bootstrap__(): # Consider paddle.init(args) or paddle.main(args) monkey_patch_variable() __bootstrap__() +monkey_patch_varbase() diff --git a/python/paddle/fluid/dygraph/base.py b/python/paddle/fluid/dygraph/base.py index ed9e9abb28f58faeb2c80335610359bc7821b02b..b1cbd399ab379b2c48a7fd43a6bea354e6c6af65 100644 --- a/python/paddle/fluid/dygraph/base.py +++ b/python/paddle/fluid/dygraph/base.py @@ -138,6 +138,7 @@ def guard(place=None): train = framework.Program() startup = framework.Program() tracer = Tracer() + VarBase = core.VarBase if place is None: if core.is_compiled_with_cuda(): @@ -205,28 +206,21 @@ def to_variable(value, block=None, name=None, zero_copy=None): if isinstance(value, np.ndarray): assert framework.in_dygraph_mode( ), "to_variable could only be called in dygraph mode" - - if not block: - block = framework.default_main_program().current_block() - py_var = framework.Variable( - block, - type=core.VarDesc.VarType.LOD_TENSOR, - name=name, - shape=value.shape, - dtype=value.dtype, - stop_gradient=True) - var = py_var._ivar.value() - tensor = var.get_tensor() if isinstance(framework._current_expected_place(), framework.core.CPUPlace): if zero_copy is None: zero_copy = True - tensor.set(value, framework._current_expected_place(), zero_copy) else: assert not zero_copy, "zero_copy mode can only be used with CPUPlace" - tensor.set(value, framework._current_expected_place(), False) + zero_copy = False + py_var = core.VarBase( + value=value, + name=name, + persistable=False, + place=framework._current_expected_place(), + zero_copy=zero_copy) return py_var - elif isinstance(value, framework.Variable): + elif isinstance(value, (core.VarBase, framework.Variable)): return value else: raise TypeError( diff --git a/python/paddle/fluid/dygraph/jit.py b/python/paddle/fluid/dygraph/jit.py index 4ea82e2ca0c4b241f6c4eb6bba3214f44b106588..a64620c6493dd5caa199c94456432ed82e58ef8f 100644 --- a/python/paddle/fluid/dygraph/jit.py +++ b/python/paddle/fluid/dygraph/jit.py @@ -33,7 +33,7 @@ def create_program_from_desc(program_desc): def _extract_vars(inputs, result_list): if isinstance(inputs, Variable): - result_list.append(inputs._ivar) + result_list.append(inputs) if isinstance(inputs, (list, tuple)): for var in inputs: @@ -67,7 +67,7 @@ def _trace(layer, outputs = [original_outputs] else: outputs = original_outputs - out_vars = [var._ivar for var in outputs] + out_vars = [var for var in outputs] program_desc, feed_names, fetch_names = tracer.create_program_desc( var_list, feed_prefix, out_vars, fetch_prefix, tmp_prefix) @@ -104,7 +104,7 @@ class TracedLayer(object): self._scope = core.Scope() for p in parameters: - src_tensor = p._ivar.value().get_tensor() + src_tensor = p.value().get_tensor() dst_tensor = self._scope.var(p.name).get_tensor() dst_tensor._share_data_with(src_tensor) @@ -234,7 +234,7 @@ class TracedLayer(object): feed_dict = {} if in_dygraph_mode(): for x, name in zip(inputs, self._feed_names): - feed_dict[name] = x._ivar.value().get_tensor() + feed_dict[name] = x.value().get_tensor() else: for x, name in zip(inputs, self._feed_names): feed_dict[name] = x diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index 8d61e5d5ea377b989b0e977ba64e95e353c378b7..728655593088dd869b874467939c3235b155a5d5 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -25,7 +25,6 @@ from .layer_object_helper import LayerObjectHelper from .base import program_desc_tracing_guard from paddle.fluid import framework from ..param_attr import ParamAttr -from paddle.fluid.framework import Variable __all__ = ['Layer'] diff --git a/python/paddle/fluid/dygraph/parallel.py b/python/paddle/fluid/dygraph/parallel.py index 9ed278bf517390dc901c8d2ee33790aceec3a4be..e8746a860c56312b24cb578cc508add931d64361 100644 --- a/python/paddle/fluid/dygraph/parallel.py +++ b/python/paddle/fluid/dygraph/parallel.py @@ -219,12 +219,8 @@ class DataParallel(layers.Layer): grad_vars = [] for param in self._layers.parameters(): # NOTE(zcd): The grad_ivar maybe no generated. - if param.trainable and param._ivar._grad_ivar(): - g_var = framework.Variable( - block=self._helper.main_program.current_block(), - name=param._ivar._grad_name(), - stop_gradient=True, - ivar=param._ivar._grad_ivar()) + if param.trainable and param._grad_ivar(): + g_var = param._grad_ivar() grad_vars.append(g_var) assert g_var not in grad_var_set grad_var_set.add(g_var) diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py new file mode 100644 index 0000000000000000000000000000000000000000..118afd93e4c8d34aa40e1ae95a4e9b8020f73b49 --- /dev/null +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -0,0 +1,216 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .. import framework +from .. import core +from . import BackwardStrategy +from ..framework import Variable, _getitem_impl_ +from .. import unique_name +import numpy as np + + +def monkey_patch_varbase(): + # TODO(jiabin): move this to cplusplus end if we find some performance issue on it + @framework.dygraph_only + def set_value(self, value): + """ + **Notes**: + **This API is ONLY avaliable in Dygraph mode** + + Set a new value for this Variable. + + Args: + value (Variable|np.ndarray): the new value. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + from paddle.fluid.dygraph.base import to_variable + from paddle.fluid.dygraph import FC + import numpy as np + + data = np.ones([3, 32, 32], dtype='float32') + with fluid.dygraph.guard(): + fc = fluid.dygraph.FC("fc", 4) + t = to_variable(data) + fc(t) # call with default weight + custom_weight = np.random.randn(1024, 4).astype("float32") + fc.weight.set_value(custom_weight) # change existing weight + out = fc(t) # call with different weight + + """ + assert isinstance(value, (np.ndarray, core.VarBase)), \ + "Variable set_value function, arguments type only support Variable, numpy, VarBase" + + value_np = value + if isinstance(value, core.VarBase): + value_np = value.numpy() + + self_tensor_np = self.numpy() + + assert self_tensor_np.shape == value_np.shape, \ + "Variable Shape not match, Variable [ {} ] need tensor with shape {} but load set tensor with shape {}".format( + self.name, self_tensor_np.shape, value_np.shape) + + assert self_tensor_np.dtype == value_np.dtype, \ + "Variable dtype not match, Variable [ {} ] need tensor with dtype {} but load tensor with dtype {}".format( + self.name, self_tensor_np.dtype, value_np.dtype) + + self.value().get_tensor().set(value_np, + framework._current_expected_place()) + + @framework.dygraph_only + def backward(self, backward_strategy=None): + """ + **Notes**: + **This API is ONLY avaliable in Dygraph mode** + + Run backward of current Graph which starts from current Variable + + Args: + backward_strategy( :ref:`api_fluid_dygraph_BackwardStrategy` ): The Backward Strategy to run backward + + Returns: + NoneType: None + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + import numpy as np + + x = np.ones([2, 2], np.float32) + with fluid.dygraph.guard(): + inputs2 = [] + for _ in range(10): + tmp = fluid.dygraph.base.to_variable(x) + # if we don't set tmp's stop_gradient as False then, all path to loss will has no gradient since + # there is no one need gradient on it. + tmp.stop_gradient=False + inputs2.append(tmp) + ret2 = fluid.layers.sums(inputs2) + loss2 = fluid.layers.reduce_sum(ret2) + backward_strategy = fluid.dygraph.BackwardStrategy() + backward_strategy.sort_sum_gradient = True + loss2.backward(backward_strategy) + + """ + if framework.in_dygraph_mode(): + if backward_strategy is None: + backward_strategy = BackwardStrategy() + backward_strategy.sort_sum_gradient = False + + self._run_backward(backward_strategy, framework._dygraph_tracer()) + else: + raise ValueError( + "Variable.backward() is only avaliable in DyGraph mode") + + @framework.dygraph_only + def gradient(self): + """ + **Notes**: + **This API is ONLY avaliable in Dygraph mode** + + Get the Gradient of Current Variable + + Returns: + ndarray: Numpy value of the gradient of current Variable + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + import numpy as np + + x = np.ones([2, 2], np.float32) + with fluid.dygraph.guard(): + inputs2 = [] + for _ in range(10): + tmp = fluid.dygraph.base.to_variable(x) + tmp.stop_gradient=False + inputs2.append(tmp) + ret2 = fluid.layers.sums(inputs2) + loss2 = fluid.layers.reduce_sum(ret2) + backward_strategy = fluid.dygraph.BackwardStrategy() + backward_strategy.sort_sum_gradient = True + loss2.backward(backward_strategy) + print(loss2.gradient()) + + """ + if self._grad_ivar() is None: + raise ValueError( + "%s has no grad, Please set Variable.stop_gradient=False, or " + "check if this is the first and only variable need grad, if so, please set its pre-Variable's " + "stop_gradient=False, to make sure it has gradient " % + self.name) + new_ivar = self._grad_ivar()._copy_to(core.CPUPlace(), True) + if self._grad_ivar().type == core.VarDesc.VarType.SELECTED_ROWS: + return (np.array(new_ivar.value().get_selected_rows().get_tensor()), + np.array(new_ivar.value().get_selected_rows().rows())) + else: + return np.array(new_ivar.value().get_tensor()) + + def __str__(self): + return self.to_string(True) + + @property + def block(self): + return framework.default_main_program().global_block() + + def to_string(self, throw_on_error, with_details=False): + """ + Get debug string. + + Args: + + throw_on_error (bool): True if raise an exception when self is not initialized. + + with_details (bool): more details about variables and parameters (e.g. trainable, optimize_attr, ...) will be printed when with_details is True. Default value is False; + + Returns: + str: The debug string. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + + cur_program = fluid.Program() + cur_block = cur_program.current_block() + new_variable = cur_block.create_var(name="X", + shape=[-1, 23, 48], + dtype='float32') + print(new_variable.to_string(True)) + print("=============with detail===============") + print(new_variable.to_string(True, True)) + """ + if framework.in_dygraph_mode(): + # TODO(panyx0718): add more dygraph debug info. + tensor = self.value().get_tensor() + if tensor._is_initialized(): + return 'name %s, dtype: %s shape: %s %s' % ( + self.name, self.dtype, self.shape, str(tensor)) + else: + return 'name %s, shape: %s, not inited' % (self.name, + self.shape) + + def __getitem__(self, item): + return _getitem_impl_(self, item) + + for method_name, method in (("set_value", set_value), ("block", block), + ("backward", backward), ("gradient", gradient), + ("__str__", __str__), ("to_string", to_string), + ("__getitem__", __getitem__)): + setattr(core.VarBase, method_name, method) diff --git a/python/paddle/fluid/dygraph_grad_clip.py b/python/paddle/fluid/dygraph_grad_clip.py index 4fdfc0bc9ded771f695923a7d3e33ca8eb94a1b7..ad052694648beacff0dcdb5ea7ff82ba618d9435 100644 --- a/python/paddle/fluid/dygraph_grad_clip.py +++ b/python/paddle/fluid/dygraph_grad_clip.py @@ -264,7 +264,7 @@ class GradClipByGlobalNorm(GradClipBase): if g is None: continue merge_grad = g - if g._ivar.type == core.VarDesc.VarType.SELECTED_ROWS: + if g.type == core.VarDesc.VarType.SELECTED_ROWS: merge_grad = layers.merge_selected_rows(g) merge_grad = layers.get_tensor_from_selected_rows(merge_grad) power = layers.square(merge_grad) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 4a840c04aa3e5e4fd39ac580837a0e0b7e2c8519..fe88bb4766f0189aeee5b651f2f8c99b45245280 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -260,19 +260,6 @@ def is_compiled_with_cuda(): return core.is_compiled_with_cuda() -def _var_base_to_np(var_base): - """ - convert VarBase tp numpy - - Args: - var_base(VarBase) : the VarBase to convert - Returns (np.ndarray): the np.ndarray contain the value of VarBase - - """ - var = var_base._copy_to(core.CPUPlace(), True) - return np.array(var.value().get_tensor()) - - def cuda_places(device_ids=None): """ **Note**: @@ -558,6 +545,241 @@ def _debug_string_(proto, throw_on_error=True): return proto.__str__() +def _varbase_creator(type=core.VarDesc.VarType.LOD_TENSOR, + name=None, + shape=None, + dtype=None, + persistable=None, + **kwargs): + + if dtype is not None: + if not isinstance(dtype, core.VarDesc.VarType): + dtype = convert_np_dtype_to_dtype_(dtype) + + return core.VarBase(dtype if dtype else core.VarDesc.VarType.FP32, + list(shape) if shape else [], name, type + if type else core.VarDesc.VarType.LOD_TENSOR, True + if persistable else False) + + +class VariableMetaClass(type): + @classmethod + def __instancecheck__(cls, instance): + t = type(instance) + if in_dygraph_mode(): + + return issubclass(t, core.VarBase) + else: + return issubclass(t, Variable) + + +class ParameterMetaClass(VariableMetaClass): + @classmethod + def __instancecheck__(cls, instance): + t = type(instance) + if in_dygraph_mode(): + return issubclass(t, ParamBase) + else: + return issubclass(t, Parameter) + + +def _getitem_impl_(var, item): + """ + Slice the variable. + + Args: + item(int/slice/tuple) : the index. + + Returns: + Sliced variable + """ + + if not isinstance(item, tuple): + item = [item] + + decrease_axis = [] + slice_axis = [] + slice_start = [] + slice_end = [] + slice_step = [] + use_strided_slice = False + reverse_axis = [] + + def fill_constant(shape, value, force_cpu=False, out=None): + var.block.append_op( + type='fill_constant', + inputs={}, + outputs={'Out': [out]}, + attrs={ + 'shape': shape, + 'dtype': out.dtype, + 'value': float(value), + 'force_cpu': force_cpu + }, + stop_gradient=True) + out.stop_gradient = True + return out + + for dim, slice_item in enumerate(item): + if isinstance(slice_item, slice): + start = slice_item.start + end = slice_item.stop + step = slice_item.step + + if start is None and end is None and step is None: + continue + + if step is None: + step = 1 + + if start is None and end is None: + assert (step == -1) + reverse_axis.append(dim) + continue + + if start is None: + start = 0 + + if end is None: + end = 10000000 + + if step != 1: + use_strided_slice = True + + slice_axis.append(dim) + slice_start.append(start) + slice_end.append(end) + slice_step.append(step) + else: + decrease_axis.append(dim) + slice_axis.append(dim) + slice_start.append(slice_item) + slice_step.append(1) + if isinstance(slice_item, Variable): + temp_1 = var.block.create_var(dtype='int32') + fill_constant([1], 1, force_cpu=True, out=temp_1) + temp_end = var.block.create_var(dtype='int32') + var.block.append_op( + type='elementwise_add', + inputs={'X': slice_item, + 'Y': temp_1}, + outputs={'Out': temp_end}, + attrs={'axis': -1}) + slice_end.append(temp_end) + else: + slice_end.append(slice_item + 1 + if slice_item != -1 else 10000000) + + def contain_var(one_list): + for ele in one_list: + if isinstance(ele, Variable): + return True + return False + + def get_new_list_tensor(old_list): + new_list_tensor = [] + for dim in old_list: + if isinstance(dim, Variable): + dim.stop_gradient = True + new_list_tensor.append(dim) + else: + assert (isinstance(dim, int)) + temp_out = var.block.create_var(dtype='int32') + fill_constant([1], dim, force_cpu=True, out=temp_out) + new_list_tensor.append(temp_out) + return new_list_tensor + + inputs = {'Input': [var]} + attrs = { + 'axes': slice_axis, + 'starts': [], + 'ends': [], + 'decrease_axis': decrease_axis + } + if (use_strided_slice == True): + attrs['strides'] = [] + infer_flags = list(1 for i in range(len(slice_axis))) + # starts + if not contain_var(slice_start): + attrs['starts'] = slice_start + else: + inputs['StartsTensorList'] = get_new_list_tensor(slice_start) + for i, dim in enumerate(slice_start): + if isinstance(dim, Variable): + attrs['starts'].append(-1) + infer_flags[i] = -1 + else: + attrs['starts'].append(dim) + # ends + if not contain_var(slice_end): + attrs['ends'] = slice_end + else: + inputs['EndsTensorList'] = get_new_list_tensor(slice_end) + for i, dim in enumerate(slice_end): + if isinstance(dim, Variable): + attrs['ends'].append(-1) + infer_flags[i] = -1 + else: + attrs['ends'].append(dim) + # strides + if use_strided_slice == True: + if not contain_var(slice_step): + attrs['strides'] = slice_step + else: + inputs['StridesTensorList'] = get_new_list_tensor(slice_step) + for i, dim in enumerate(slice_step): + if isinstance(dim, Variable): + attrs['strides'].append(-1) + infer_flags[i] = -1 + else: + attrs['strides'].append(dim) + # infer_flags + attrs['infer_flags'] = infer_flags + + out = var + if use_strided_slice == False and len(slice_axis) > 0: + # append slice_op here + slice_out_var = var.block.create_var( + name=unique_name.generate_with_ignorable_key(var.name + "_slice"), + dtype=var.dtype) + + var.block.append_op( + type="slice", + inputs=inputs, + outputs={'Out': [slice_out_var]}, + attrs=attrs) + + out = slice_out_var + elif use_strided_slice == True and len(slice_axis) > 0: + strided_slice_out_var = var.block.create_var( + name=unique_name.generate_with_ignorable_key(var.name + + "_strided_slice"), + dtype=var.dtype) + var.block.append_op( + type="strided_slice", + inputs=inputs, + outputs={'Out': [strided_slice_out_var]}, + attrs=attrs) + + out = strided_slice_out_var + + if len(reverse_axis) > 0: + reverse_out_var = var.block.create_var( + name=unique_name.generate_with_ignorable_key(var.name + + "_slice_reverse"), + dtype=var.dtype) + var.block.append_op( + type="reverse", + inputs={'X': out}, + outputs={'Out': [reverse_out_var]}, + attrs={'axis': reverse_axis}) + + out = reverse_out_var + + return out + + +@six.add_metaclass(VariableMetaClass) class Variable(object): """ **Notes**: @@ -626,100 +848,83 @@ class Variable(object): self.belong_to_optimizer = belong_to_optimizer - if in_dygraph_mode(): - # record vars in tracer rather than blocks - self._ivar = kwargs.get("ivar", None) - self.stop_gradient_ = kwargs.get("stop_gradient", True) - if not self._ivar: - self._ivar = core.VarBase( - name, type - if type else core.VarDesc.VarType.LOD_TENSOR, dtype - if dtype else core.VarDesc.VarType.FP32, - list(shape) if shape else [], True - if persistable else False) - if persistable: - _dygraph_tracer().trace_var(name, self) - self.op = None - else: - self.error_clip = error_clip + self.error_clip = error_clip + + is_new_var = False + name = cpt.to_text(name) + self.desc = self.block.desc.find_var(cpt.to_bytes(name)) - is_new_var = False - name = cpt.to_text(name) - self.desc = self.block.desc.find_var(cpt.to_bytes(name)) + if self.desc is None: + self.desc = self.block.desc.var(cpt.to_bytes(name)) + is_new_var = True - if self.desc is None: - self.desc = self.block.desc.var(cpt.to_bytes(name)) - is_new_var = True + if is_new_var: + self.desc.set_type(type) + elif self.desc.type() != type: + raise ValueError("Variable {0} has been created before. The " + "previous type is {1}; the new type is {2}. They" + " are not matched".format(self.name, + self.desc.type(), type)) + if shape is not None: if is_new_var: - self.desc.set_type(type) - elif self.desc.type() != type: - raise ValueError( - "Variable {0} has been created before. The " - "previous type is {1}; the new type is {2}. They" - " are not matched".format(self.name, self.desc.type(), - type)) - - if shape is not None: - if is_new_var: - self.desc.set_shape(shape) - else: - old_shape = self.shape - shape = tuple(shape) - if shape != old_shape: - raise ValueError( - "Variable {0} has been created before. the previous " - "shape is {1}; the new shape is {2}. They are not " - "matched.".format(self.name, old_shape, shape)) - if dtype is not None: - if is_new_var: - self.desc.set_dtype(dtype) - else: - old_dtype = self.dtype - if dtype != old_dtype: - raise ValueError( - "Variable {0} has been created before. " - "The previous data type is {1}; the new " - "data type is {2}. They are not " - "matched.".format(self.name, old_dtype, dtype)) - - if lod_level is not None: - if is_new_var: - self.desc.set_lod_level(lod_level) - else: - if lod_level != self.lod_level: - raise ValueError( - "Variable {0} has been created before. " - "The previous lod_level is {1}; the new " - "lod_level is {2}. They are not " - "matched".format(self.name, self.lod_level, - lod_level)) - if persistable is not None: - if is_new_var: - self.desc.set_persistable(persistable) - else: - if persistable != self.persistable: - raise ValueError( - "Variable {0} has been created before." - "The previous persistable is {1}; the new " - "persistable is {2}. They are not matched".format( - self.name, self.persistable, persistable)) + self.desc.set_shape(shape) + else: + old_shape = self.shape + shape = tuple(shape) + if shape != old_shape: + raise ValueError( + "Variable {0} has been created before. the previous " + "shape is {1}; the new shape is {2}. They are not " + "matched.".format(self.name, old_shape, shape)) + if dtype is not None: + if is_new_var: + self.desc.set_dtype(dtype) + else: + old_dtype = self.dtype + if dtype != old_dtype: + raise ValueError("Variable {0} has been created before. " + "The previous data type is {1}; the new " + "data type is {2}. They are not " + "matched.".format(self.name, old_dtype, + dtype)) + + if lod_level is not None: + if is_new_var: + self.desc.set_lod_level(lod_level) + else: + if lod_level != self.lod_level: + raise ValueError("Variable {0} has been created before. " + "The previous lod_level is {1}; the new " + "lod_level is {2}. They are not " + "matched".format(self.name, self.lod_level, + lod_level)) + if persistable is not None: + if is_new_var: + self.desc.set_persistable(persistable) + else: + if persistable != self.persistable: + raise ValueError( + "Variable {0} has been created before." + "The previous persistable is {1}; the new " + "persistable is {2}. They are not matched".format( + self.name, self.persistable, persistable)) - if need_check_feed and is_new_var: - self.desc.set_need_check_feed(need_check_feed) + if need_check_feed and is_new_var: + self.desc.set_need_check_feed(need_check_feed) - if capacity is not None: - if is_new_var: - self.desc.set_capacity(capacity) - else: - # TODO(abhinavarora) : Compare with set capacity once, - # get_capacity is implemented - pass + if capacity is not None: + if is_new_var: + self.desc.set_capacity(capacity) + else: + # TODO(abhinavarora) : Compare with set capacity once, + # get_capacity is implemented + pass - self.block.vars[name] = self - self.op = None - self._stop_gradient = stop_gradient - self.is_data = is_data + self.block.vars[name] = self + self.op = None + self._stop_gradient = stop_gradient + self.is_data = is_data @dygraph_only def detach(self): @@ -749,16 +954,7 @@ class Variable(object): y = x.detach() """ - if in_dygraph_mode(): - new_var = self._cloneVar() - self.block.append_op( - type="assign", - inputs={'X': [self]}, - outputs={'Out': [new_var]}, - stop_gradient=True) - return new_var - else: - raise AttributeError("static graph model DO NOT supprt detach") + pass @dygraph_only def numpy(self): @@ -790,12 +986,7 @@ class Variable(object): print(x.numpy()) """ - - if not self._ivar.value().get_tensor()._is_initialized(): - raise ValueError("%s is Empty, Please check if it has no data in" % - self.name) - new_ivar = self._ivar._copy_to(core.CPUPlace(), True) - return np.array(new_ivar.value().get_tensor()) + pass @dygraph_only def set_value(self, value): @@ -826,25 +1017,7 @@ class Variable(object): out = fc(t) # call with different weight """ - assert isinstance(value, (Variable, np.ndarray, core.VarBase)), \ - "Variable set_value function, arguments type only support Variable, numpy, VarBase" - - value_np = value - if isinstance(value, Variable): - value_np = value.numpy() - elif isinstance(value, core.VarBase): - value_np = _var_base_to_np(value) - self_tensor = self._ivar.value().get_tensor() - - self_tensor_np = np.array(self_tensor) - - assert self_tensor_np.shape == value_np.shape, \ - "Variable Shape not match, Variable [ {} ] need tensor with shape {} but load set tensor with shape {}".format( self._ivar.name, self_tensor_np.shape, value_np.shape) - - assert self_tensor_np.dtype == value_np.dtype, \ - "Variable dtype not match, Variable [ {} ] need tensor with dtype {} but load tensor with dtype {}".format( self._ivar.name, self_tensor_np.dtype, value_np.dtype) - - self_tensor.set(value_np, _current_expected_place()) + pass @dygraph_only def backward(self, backward_strategy=None): @@ -882,16 +1055,7 @@ class Variable(object): loss2.backward(backward_strategy) """ - if in_dygraph_mode(): - from .dygraph import BackwardStrategy - if backward_strategy is None: - backward_strategy = BackwardStrategy() - backward_strategy.sort_sum_gradient = False - - self._ivar._run_backward(backward_strategy, _dygraph_tracer()) - else: - raise ValueError( - "Variable.backward() is only avaliable in DyGraph mode") + pass @dygraph_only def gradient(self): @@ -925,16 +1089,7 @@ class Variable(object): print(loss2.gradient()) """ - if self._ivar._grad_ivar() is None: - raise ValueError("%s has no grad, Please set Variable.stop_gradient=False, or " \ - "check if this is the first and only variable need grad, if so, please set its pre-Variable's " \ - "stop_gradient=False, to make sure it has gradient " % self.name) - new_ivar = self._ivar._grad_ivar()._copy_to(core.CPUPlace(), True) - if self._ivar._grad_ivar().type == core.VarDesc.VarType.SELECTED_ROWS: - return (np.array(new_ivar.value().get_selected_rows().get_tensor()), - np.array(new_ivar.value().get_selected_rows().rows())) - else: - return np.array(new_ivar.value().get_tensor()) + pass @dygraph_only def clear_gradient(self): @@ -971,7 +1126,7 @@ class Variable(object): print("After clear {}".format(loss2.gradient())) """ - self._ivar._clear_gradient() + pass def __str__(self): return self.to_string(True) @@ -1004,14 +1159,7 @@ class Variable(object): print(new_variable.to_string(True, True)) """ if in_dygraph_mode(): - # TODO(panyx0718): add more dygraph debug info. - tensor = self._ivar.value().get_tensor() - if tensor._is_initialized(): - return 'name %s, dtype: %s shape: %s %s' % ( - self.name, self.dtype, self.shape, str(tensor)) - else: - return 'name %s, shape: %s, not inited' % (self.name, - self.shape) + return assert isinstance(throw_on_error, bool) and isinstance(with_details, bool) @@ -1060,14 +1208,14 @@ class Variable(object): assert (out1.gradient() == 0).all() """ if in_dygraph_mode(): - return self._ivar.stop_gradient + pass else: return self._stop_gradient @stop_gradient.setter def stop_gradient(self, s): if in_dygraph_mode(): - self._ivar.stop_gradient = s + pass else: self._stop_gradient = s @@ -1095,7 +1243,7 @@ class Variable(object): print("persistable of current Var is: {}".format(new_variable.persistable)) """ if in_dygraph_mode(): - return self._ivar.persistable + pass else: return self.desc.persistable() @@ -1127,7 +1275,7 @@ class Variable(object): print("name of current Var is: {}".format(new_variable.name)) """ if in_dygraph_mode(): - return self._ivar.name + pass else: return cpt.to_text(self.desc.name()) @@ -1154,7 +1302,7 @@ class Variable(object): @name.setter def name(self, new_name): if in_dygraph_mode(): - self._ivar.name = new_name + pass else: self.desc.set_name(new_name) @@ -1179,7 +1327,7 @@ class Variable(object): """ # convert to tuple, make it as same as numpy API. if in_dygraph_mode(): - return self._ivar.shape + pass else: return tuple(self.desc.shape()) @@ -1202,7 +1350,7 @@ class Variable(object): print("Dtype of current Var is: {}".format(new_variable.dtype)) """ if in_dygraph_mode(): - return self._ivar.dtype + pass else: return self.desc.dtype() @@ -1254,7 +1402,7 @@ class Variable(object): print("Type of current Var is: {}".format(new_variable.type)) """ if in_dygraph_mode(): - return self._ivar.type + pass else: return self.desc.type() @@ -1446,200 +1594,7 @@ class Variable(object): raise IndexError("Valid index accept int or slice or tuple") def __getitem__(self, item): - """ - Slice the variable. - - Args: - item(int/slice/tuple) : the index. - - Returns: - Sliced variable - """ - - if not isinstance(item, tuple): - item = [item] - - decrease_axis = [] - slice_axis = [] - slice_start = [] - slice_end = [] - slice_step = [] - use_strided_slice = False - reverse_axis = [] - - def fill_constant(shape, value, force_cpu=False, out=None): - self.block.append_op( - type='fill_constant', - inputs={}, - outputs={'Out': [out]}, - attrs={ - 'shape': shape, - 'dtype': out.dtype, - 'value': float(value), - 'force_cpu': force_cpu - }, - stop_gradient=True) - out.stop_gradient = True - return out - - for dim, slice_item in enumerate(item): - if isinstance(slice_item, slice): - start = slice_item.start - end = slice_item.stop - step = slice_item.step - - if start is None and end is None and step is None: - continue - - if step is None: - step = 1 - - if start is None and end is None: - assert (step == -1) - reverse_axis.append(dim) - continue - - if start is None: - start = 0 - - if end is None: - end = 10000000 - - if step != 1: - use_strided_slice = True - - slice_axis.append(dim) - slice_start.append(start) - slice_end.append(end) - slice_step.append(step) - else: - decrease_axis.append(dim) - slice_axis.append(dim) - slice_start.append(slice_item) - slice_step.append(1) - if isinstance(slice_item, Variable): - temp_1 = self.block.create_var(dtype='int32') - fill_constant([1], 1, force_cpu=True, out=temp_1) - temp_end = self.block.create_var(dtype='int32') - self.block.append_op( - type='elementwise_add', - inputs={'X': slice_item, - 'Y': temp_1}, - outputs={'Out': temp_end}, - attrs={'axis': -1}) - slice_end.append(temp_end) - else: - slice_end.append(slice_item + 1 - if slice_item != -1 else 10000000) - - def contain_var(one_list): - for ele in one_list: - if isinstance(ele, Variable): - return True - return False - - def get_new_list_tensor(old_list): - new_list_tensor = [] - for dim in old_list: - if isinstance(dim, Variable): - dim.stop_gradient = True - new_list_tensor.append(dim) - else: - assert (isinstance(dim, int)) - temp_out = self.block.create_var(dtype='int32') - fill_constant([1], dim, force_cpu=True, out=temp_out) - new_list_tensor.append(temp_out) - return new_list_tensor - - inputs = {'Input': [self]} - attrs = { - 'axes': slice_axis, - 'starts': [], - 'ends': [], - 'decrease_axis': decrease_axis - } - if (use_strided_slice == True): - attrs['strides'] = [] - infer_flags = list(1 for i in range(len(slice_axis))) - # starts - if not contain_var(slice_start): - attrs['starts'] = slice_start - else: - inputs['StartsTensorList'] = get_new_list_tensor(slice_start) - for i, dim in enumerate(slice_start): - if isinstance(dim, Variable): - attrs['starts'].append(-1) - infer_flags[i] = -1 - else: - attrs['starts'].append(dim) - # ends - if not contain_var(slice_end): - attrs['ends'] = slice_end - else: - inputs['EndsTensorList'] = get_new_list_tensor(slice_end) - for i, dim in enumerate(slice_end): - if isinstance(dim, Variable): - attrs['ends'].append(-1) - infer_flags[i] = -1 - else: - attrs['ends'].append(dim) - # strides - if use_strided_slice == True: - if not contain_var(slice_step): - attrs['strides'] = slice_step - else: - inputs['StridesTensorList'] = get_new_list_tensor(slice_step) - for i, dim in enumerate(slice_step): - if isinstance(dim, Variable): - attrs['strides'].append(-1) - infer_flags[i] = -1 - else: - attrs['strides'].append(dim) - # infer_flags - attrs['infer_flags'] = infer_flags - - out = self - if use_strided_slice == False and len(slice_axis) > 0: - # append slice_op here - slice_out_var = self.block.create_var( - name=unique_name.generate_with_ignorable_key(self.name + - "_slice"), - dtype=self.dtype) - - self.block.append_op( - type="slice", - inputs=inputs, - outputs={'Out': [slice_out_var]}, - attrs=attrs) - - out = slice_out_var - elif use_strided_slice == True and len(slice_axis) > 0: - strided_slice_out_var = self.block.create_var( - name=unique_name.generate_with_ignorable_key(self.name + - "_strided_slice"), - dtype=self.dtype) - self.block.append_op( - type="strided_slice", - inputs=inputs, - outputs={'Out': [strided_slice_out_var]}, - attrs=attrs) - - out = strided_slice_out_var - - if len(reverse_axis) > 0: - reverse_out_var = self.block.create_var( - name=unique_name.generate_with_ignorable_key(self.name + - "_slice_reverse"), - dtype=self.dtype) - self.block.append_op( - type="reverse", - inputs={'X': out}, - outputs={'Out': [reverse_out_var]}, - attrs={'axis': reverse_axis}) - - out = reverse_out_var - - return out + return _getitem_impl_(self, item) def get_all_op_protos(): @@ -2347,9 +2302,12 @@ class Block(object): if isinstance(item[1], Parameter)) def create_var(self, *args, **kwargs): - var = Variable(block=self, *args, **kwargs) - if 'initializer' in kwargs: - kwargs['initializer'](var, self) + if not in_dygraph_mode(): + var = Variable(block=self, *args, **kwargs) + if 'initializer' in kwargs: + kwargs['initializer'](var, self) + else: + var = _varbase_creator(*args, **kwargs) return var def has_var(self, name): @@ -2396,18 +2354,31 @@ class Block(object): # NOTE: v is destroyed by C++ after calling _rename_var. d = self.desc.find_var(cpt.to_bytes(new_name)) if var_type == "Parameter": - var = Parameter( - self, - d.shape(), - d.dtype(), - type=orig_var_type, - name=new_name, - stop_gradient=stop_gradient, - trainable=trainable, - optimize_attr=optimize_attr, - regularizer=regularizer, - gradient_clip_attr=gradient_clip_attr, - error_clip=error_clip) + if not in_dygraph_mode(): + var = Parameter( + self, + d.shape(), + d.dtype(), + type=orig_var_type, + name=new_name, + stop_gradient=stop_gradient, + trainable=trainable, + optimize_attr=optimize_attr, + regularizer=regularizer, + gradient_clip_attr=gradient_clip_attr, + error_clip=error_clip) + else: + var = ParamBase( + d.shape(), + d.dtype(), + type=orig_var_type, + name=new_name, + stop_gradient=stop_gradient, + trainable=trainable, + optimize_attr=optimize_attr, + regularizer=regularizer, + gradient_clip_attr=gradient_clip_attr, + error_clip=error_clip) elif var_type == "Variable": var = Variable( self, @@ -2430,7 +2401,11 @@ class Block(object): def create_parameter(self, *args, **kwargs): global_block = self.program.global_block() - param = Parameter(global_block, *args, **kwargs) + param = None + if not in_dygraph_mode(): + param = Parameter(global_block, *args, **kwargs) + else: + param = ParamBase(*args, **kwargs) if 'initializer' in kwargs: def _is_inited_by(block, var): @@ -2669,19 +2644,34 @@ class Block(object): raise ValueError("_copy_param_info_from should be invoked with " "same topology") assert isinstance(v, Variable) - new_p = Parameter( - block=self, - shape=v.shape, - dtype=v.dtype, - type=v.type, - lod_level=v.lod_level, - stop_gradient=p.stop_gradient, - trainable=p.trainable, - optimize_attr=p.optimize_attr, - regularizer=p.regularizer, - gradient_clip_attr=p.gradient_clip_attr, - error_clip=p.error_clip, - name=v.name) + new_p = None + if not in_dygraph_mode(): + new_p = Parameter( + block=self, + shape=v.shape, + dtype=v.dtype, + type=v.type, + lod_level=v.lod_level, + stop_gradient=p.stop_gradient, + trainable=p.trainable, + optimize_attr=p.optimize_attr, + regularizer=p.regularizer, + gradient_clip_attr=p.gradient_clip_attr, + error_clip=p.error_clip, + name=v.name) + else: + new_p = ParamBase( + shape=v.shape, + dtype=v.dtype, + type=v.type, + lod_level=v.lod_level, + stop_gradient=p.stop_gradient, + trainable=p.trainable, + optimize_attr=p.optimize_attr, + regularizer=p.regularizer, + gradient_clip_attr=p.gradient_clip_attr, + error_clip=p.error_clip, + name=v.name) self.vars[new_p.name] = new_p def _clone_variable(self, var, force_persistable=True): @@ -4485,6 +4475,7 @@ class Program(object): yield each_var +@six.add_metaclass(ParameterMetaClass) class Parameter(Variable): """ Parameter is derived from Variable. A parameter is a persistable @@ -4580,6 +4571,111 @@ class Parameter(Variable): __repr__ = __str__ +class ParamBase(core.VarBase): + """ + ParamBase is derived from VarBase( Which is the Variable in Dygraph Mode ). A ParamBase is a persistable + VarBase, and will be updated by optimizers after each iteration. + The training of a neural network is essentially the updating of + its ParamBase. + + Relative to a general Variable, a ParamBase has several its own + member variables: + + Args: + trainable(bool): True if the ParamBase need to be updated after + iterations. + optimize_attr(map): ParamBase attributes related with optimizing. + Currently, it only contains 'learning_rate'. + Default: {'learning_rate': 1.0} + regularizer(WeightDecayRegularizer): The Regularizer which will + be applied on the ParamBase. Default: None + gradient_clip_attr(BaseGradientClipAttr): The gradint clip strategy + which will be applied on the ParamBase. Default: None + do_model_average(bool): True if the model average strategy will + be applied on this ParamBase. + """ + + @dygraph_only + def __init__(self, shape, dtype, **kwargs): + if shape is None: + raise ValueError("The shape of Parameter should not be None") + if dtype is None: + raise ValueError("The dtype of Parameter should not be None") + + if len(shape) == 0: + raise ValueError( + "The dimensions of shape for Parameter must be greater than 0") + + for each in shape: + if each < 0: + raise ValueError( + "Each dimension of shape for Parameter must be greater than 0, but received %s" + % list(shape)) + + if dtype is not None: + if not isinstance(dtype, core.VarDesc.VarType): + dtype = convert_np_dtype_to_dtype_(dtype) + + name = kwargs.get('name', unique_name.generate('_param_base')) + + super(ParamBase, self).__init__(dtype + if dtype else core.VarDesc.VarType.FP32, + list(shape) if shape else [], name, + core.VarDesc.VarType.LOD_TENSOR, True) + + self.trainable = kwargs.get('trainable', True) + + self.optimize_attr = kwargs.get('optimize_attr', {'learning_rate': 1.0}) + + self.regularizer = kwargs.get('regularizer', None) + + self.gradient_clip_attr = kwargs.get('gradient_clip_attr', None) + + self.do_model_average = kwargs.get('do_model_average', None) + + self.is_distributed = False + + #self.block = default_main_program().global_block() + + _dygraph_tracer().trace_var(name, self) + + def __str__(self): + return self.to_string(True) + + def to_string(self, throw_on_error, with_details=False): + """ + To debug string. + + Args: + throw_on_error(bool): raise exception when self is not initialized + when throw_on_error is True + with_details(bool): more details about variables and parameters + (e.g. trainable, optimize_attr, ...) will be printed when with_details is True + + Returns(str): The debug string. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + + prog = fluid.default_main_program() + rlt = fluid.layers.data("fake_data", shape=[1,1], dtype='float32') + debug_str = prog.to_string(throw_on_error=True, with_details=False) + print(debug_str) + """ + assert isinstance(throw_on_error, bool) and isinstance(with_details, + bool) + tensor = self.value().get_tensor() + if tensor._is_initialized(): + return 'name %s, dtype: %s shape: %s %s' % (self.name, self.dtype, + self.shape, str(tensor)) + else: + return 'name %s, shape: %s, not inited' % (self.name, self.shape) + + __repr__ = __str__ + + # program is a global instance. _main_program_ = Program() _startup_program_ = Program() diff --git a/python/paddle/fluid/layer_helper_base.py b/python/paddle/fluid/layer_helper_base.py index 5fa7fef381061a5480d133acbc5729acf9e63d5f..ba528f3f48162b3322d0ef27b326bf899ab17a35 100644 --- a/python/paddle/fluid/layer_helper_base.py +++ b/python/paddle/fluid/layer_helper_base.py @@ -44,33 +44,46 @@ class LayerHelperBase(object): def startup_program(self): return default_startup_program() - def to_variable(self, value, block=None): - """convert value to variable + def to_variable(self, value, name=None): + """ + The API will create a ``Variable`` object from numpy\.ndarray or Variable object. + + Parameters: + value(ndarray): The numpy\.ndarray object that needs to be converted, it can be multi-dimension, and the data type is one of numpy\.{float16, float32, float64, int16, int32, int64, uint8, uint16}. + block(fluid.Block, optional): Which block this variable will be in. Default: None. + name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` + + Returns: + Variable: ``Tensor`` created from the specified numpy\.ndarray object, data type and shape is the same as ``value`` . + + Examples: - Args: - value: value to be convert - block: the block of the variable + .. code-block:: python + + import numpy as np + import paddle.fluid as fluid + + with fluid.dygraph.guard(): + x = np.ones([2, 2], np.float32) + y = fluid.dygraph.to_variable(x) - Return Variable construct from value """ if isinstance(value, np.ndarray): assert in_dygraph_mode( ), "to_variable could only be called in dygraph mode" - - if not block: - block = default_main_program().current_block() - py_var = Variable( - block, - type=core.VarDesc.VarType.LOD_TENSOR, - name=None, - shape=value.shape, - dtype=value.dtype) - var = py_var._ivar.value() - tensor = var.get_tensor() - tensor.set(value, _current_expected_place()) + py_var = core.VarBase( + value=value, + name=name, + persistable=False, + place=_current_expected_place(), + zero_copy=False) return py_var - elif isinstance(value, Variable): + elif isinstance(value, (core.VarBase, Variable)): return value + else: + raise TypeError( + "to_variable only accepts 'ndarray' or 'Variable' or 'VarBase' as value's input" + ) def _create_weight_normalize(self, attr, shape, dtype): from .layers import elementwise_mul, elementwise_div, reshape @@ -386,7 +399,7 @@ class LayerHelperBase(object): """ assert isinstance(var, Variable) if in_dygraph_mode(): - initializer(var, var.block) + initializer(var, self.main_program.global_block()) else: self.startup_program.global_block().create_var( name=var.name, diff --git a/python/paddle/fluid/layers/loss.py b/python/paddle/fluid/layers/loss.py index e5eaa270afec60bd62047e5c607d5e5a1c67cbf2..00d73d8eb9ac293de88f2154029442279e1b37e0 100644 --- a/python/paddle/fluid/layers/loss.py +++ b/python/paddle/fluid/layers/loss.py @@ -233,6 +233,7 @@ def cross_entropy(input, label, soft_label=False, ignore_index=kIgnoreIndex): predict = fluid.layers.fc(input=x, size=class_num, act='softmax') cost = fluid.layers.cross_entropy(input=predict, label=label) """ + check_type_and_dtype(input, 'input', Variable, ['float16', 'float32', 'float64'], 'cross_entropy') if not soft_label: @@ -729,7 +730,6 @@ def nce(input, sampler = 1 elif sampler == "custom_dist": assert custom_dist is not None - # assert isinstance(custom_dist, Variable) custom_dist_len = num_total_classes alias_probs_ = [0] * custom_dist_len diff --git a/python/paddle/fluid/layers/math_op_patch.py b/python/paddle/fluid/layers/math_op_patch.py index d14a39ae58fc48b93e81a7461f84d8627c185017..a81f3085b1f6688222ec8a2766071d4abcdd042e 100644 --- a/python/paddle/fluid/layers/math_op_patch.py +++ b/python/paddle/fluid/layers/math_op_patch.py @@ -15,7 +15,7 @@ from __future__ import print_function from .. import core -from ..framework import Variable, unique_name +from ..framework import Variable, unique_name, in_dygraph_mode, default_main_program from .layer_function_generator import OpProtoHolder from ..initializer import force_init_on_cpu @@ -40,7 +40,10 @@ def monkey_patch_variable(): return dtype def current_block(var): - return var.block + if in_dygraph_mode(): + return default_main_program().global_block() + else: + return var.block def create_new_tmp_var(block, dtype): tmp_name = unique_tmp_name() @@ -281,5 +284,9 @@ def monkey_patch_variable(): setattr(Variable, method_name, _elemwise_method_creator_(method_name, op_type, reverse, scalar_method)) + setattr(core.VarBase, method_name, + _elemwise_method_creator_(method_name, op_type, reverse, + scalar_method)) Variable.astype = astype + setattr(core.VarBase, "astype", astype) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 23b5a8e501dffa71359cb4b3764b8930eb6996fc..e524fb66f08377b93cec8b1b59285bd12275e03a 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -32,7 +32,6 @@ from .layers import ops from .regularizer import append_regularization_ops from .dygraph import base as imperative_base from .dygraph.learning_rate_scheduler import LearningRateDecay -from .framework import _var_base_to_np from paddle.fluid import core from paddle.fluid.layers import tensor from functools import reduce @@ -122,7 +121,13 @@ class Optimizer(object): state_dict[var_tmp.name] = var_tmp # global step if use lr decay if isinstance(self._learning_rate, LearningRateDecay): - var_temp = Variable(None, name='global_step', dtype='int32') + var_tmp = None + if not framework.in_dygraph_mode(): + var_temp = Variable(None, name='global_step', dtype='int32') + else: + var_temp = framework._varbase_creator( + None, name='global_step', dtype='int32') + tensor.fill_constant( [1], "int32", self._learning_rate.step_num, out=var_temp) @@ -164,7 +169,7 @@ class Optimizer(object): global_step = state_dict['global_step'] if isinstance(global_step, core.VarBase): - step_np = global_step._copy_to(core.CPUPlace(), True) + step_np = global_step step_np = np.array(step_np.value().get_tensor()) assert step_np.shape == (1,), \ "global step shape is (1,), the shape is {}".format( step_np.shape ) @@ -189,7 +194,7 @@ class Optimizer(object): for para_name, var_tmp in v.items(): assert var_tmp.name in state_dict, \ "optimizer variable {} not found".format( var_tmp.name ) - var = var_tmp._ivar.value() + var = var_tmp.value() tensor = var.get_tensor() model_np = np.array(tensor) @@ -198,7 +203,7 @@ class Optimizer(object): if isinstance(load_para, Variable): load_para_np = load_para.numpy() elif isinstance(load_para, core.VarBase): - load_para_np = _var_base_to_np(load_para) + load_para_np = load_para.numpy() elif isinstance(load_para, np.ndarray): load_para_np = load_para else: @@ -515,7 +520,11 @@ class Optimizer(object): Examples: See examples in ``apply_gradients``. """ - no_grad_set = self._get_no_grad_set(loss, no_grad_set) + act_no_grad_set = None + if not framework.in_dygraph_mode(): + act_no_grad_set = self._get_no_grad_set(loss, no_grad_set) + else: + pass self._dtype = loss.dtype if framework.in_dygraph_mode(): @@ -528,15 +537,9 @@ class Optimizer(object): for param in parameters: if not param.trainable: continue - if param._ivar._grad_ivar() is not None: - ivar_type = param._ivar._grad_ivar().type + if param._grad_ivar() is not None: # create gradient variable - grad_var = Variable( - block=loss.block, - type=ivar_type, - name=param._ivar._grad_name(), - stop_gradient=True, - ivar=param._ivar._grad_ivar()) + grad_var = param._grad_ivar() params_grads.append((param, grad_var)) else: if callbacks is None: @@ -550,7 +553,7 @@ class Optimizer(object): loss.shape) with program_guard(program, startup_program): params_grads = append_backward(loss, parameter_list, - no_grad_set, callbacks) + act_no_grad_set, callbacks) # Note: since we can't use all_reduce_op now, # dgc_op should be the last op of one grad. self._append_dgc_ops(params_grads) diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index c0dd0809c4b0e39f1194ba955ac536cd359a12c2..87eceb4690d4ff63540e56b99fa760c71052976b 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -268,7 +268,7 @@ class OpTest(unittest.TestCase): data = value[0] lod = value[1] v = fluid.dygraph.base.to_variable(value=data) - v._ivar.value().get_tensor().set_recursive_sequence_lengths(lod) + v.value().get_tensor().set_recursive_sequence_lengths(lod) return v else: return fluid.dygraph.base.to_variable(value) @@ -289,7 +289,7 @@ class OpTest(unittest.TestCase): if if_return_inputs_grad_dict: v.stop_gradient = False if has_lod: - v._ivar.value().get_tensor().set_recursive_sequence_lengths( + v.value().get_tensor().set_recursive_sequence_lengths( lod_temp) else: v = block.create_var( @@ -840,8 +840,8 @@ class OpTest(unittest.TestCase): if check_dygraph: imperative_actual = find_imperative_actual( sub_out_name, dygraph_outs, place) - imperative_actual_t = np.array( - imperative_actual._ivar.value().get_tensor()) + imperative_actual_t = np.array(imperative_actual.value() + .get_tensor()) idx = find_actual(sub_out_name, fetch_list) actual = outs[idx] actual_t = np.array(actual) @@ -868,7 +868,7 @@ class OpTest(unittest.TestCase): ") has different lod at " + str(place)) if check_dygraph: self.assertListEqual( - imperative_actual._ivar.value().get_tensor() + imperative_actual.value().get_tensor() .recursive_sequence_lengths(), expect[1], "Output (" + out_name + ") has different lod at " + str(place) + @@ -877,8 +877,8 @@ class OpTest(unittest.TestCase): if check_dygraph: imperative_actual = find_imperative_actual( out_name, dygraph_outs, place) - imperative_actual_t = np.array( - imperative_actual._ivar.value().get_tensor()) + imperative_actual_t = np.array(imperative_actual.value() + .get_tensor()) idx = find_actual(out_name, fetch_list) actual = outs[idx] actual_t = np.array(actual) @@ -913,7 +913,7 @@ class OpTest(unittest.TestCase): ") has different lod at " + str(place)) if check_dygraph: self.assertListEqual( - imperative_actual._ivar.value().get_tensor() + imperative_actual.value().get_tensor() .recursive_sequence_lengths(), expect[1], "Output (" + out_name + ") has different lod at " + str(place) + " in dygraph mode") diff --git a/python/paddle/fluid/tests/unittests/test_imperative_auto_prune.py b/python/paddle/fluid/tests/unittests/test_imperative_auto_prune.py index 609336e281f2bfe3ccbb6a4f41e7b6e8f4a3137a..a83532e3768beb5117cf76e0625c37fb3a106492 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_auto_prune.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_auto_prune.py @@ -153,8 +153,8 @@ class TestImperativeAutoPrune(unittest.TestCase): v2 = fluid.dygraph.to_variable(value2) loss = case1(v1, v2) loss.backward() - self.assertTrue(case1.fc2._w._ivar._grad_ivar() is not None) - self.assertTrue(case1.fc1._w._ivar._grad_ivar() is not None) + self.assertTrue(case1.fc2._w._grad_ivar() is not None) + self.assertTrue(case1.fc1._w._grad_ivar() is not None) def test_auto_prune2(self): with fluid.dygraph.guard(): @@ -166,8 +166,8 @@ class TestImperativeAutoPrune(unittest.TestCase): loss = case2(v1, v2) loss.backward() - self.assertTrue(case2.fc2._w._ivar._grad_ivar() is None) - self.assertTrue(case2.fc1._w._ivar._grad_ivar() is not None) + self.assertTrue(case2.fc2._w._grad_ivar() is None) + self.assertTrue(case2.fc1._w._grad_ivar() is not None) def test_auto_prune3(self): with fluid.dygraph.guard(): @@ -178,7 +178,7 @@ class TestImperativeAutoPrune(unittest.TestCase): v2 = fluid.dygraph.to_variable(value2) loss, part2 = case3(v1, v2, 1) loss.backward() - self.assertTrue(case3.fc._w._ivar._grad_ivar() is not None) + self.assertTrue(case3.fc._w._grad_ivar() is not None) self.assertTrue((part2.gradient() == 0).all()) def test_auto_prune4(self): @@ -190,7 +190,7 @@ class TestImperativeAutoPrune(unittest.TestCase): v2 = fluid.dygraph.to_variable(value2) loss, part2 = case4(v1, v2, 1) part2.backward() - self.assertTrue(case4.fc._w._ivar._grad_ivar() is not None) + self.assertTrue(case4.fc._w._grad_ivar() is not None) self.assertTrue((part2.gradient() == 1).all()) def test_auto_prune5(self): @@ -202,7 +202,7 @@ class TestImperativeAutoPrune(unittest.TestCase): v2 = fluid.dygraph.to_variable(value2) loss, part1, part2 = case4(v1, v2, 2) part1.backward() - self.assertTrue(case4.fc._w._ivar._grad_ivar() is not None) + self.assertTrue(case4.fc._w._grad_ivar() is not None) self.assertTrue((part2.gradient() == 0).all()) def test_auto_prune6(self): @@ -333,8 +333,8 @@ class TestImperativeAutoPrune(unittest.TestCase): for items in params_grads: assert items[0].name is not model.embed1._w.name assert items[0].name is not model.fc1._w.name - assert model.embed1._w._ivar._grad_ivar() is None - assert model.fc1._w._ivar._grad_ivar() is None + assert model.embed1._w._grad_ivar() is None + assert model.fc1._w._grad_ivar() is None with fluid.dygraph.guard(place): model = MyLayer2("mylayer", vocab_size, size) @@ -351,8 +351,8 @@ class TestImperativeAutoPrune(unittest.TestCase): for items in params_grads: assert items[0].name is not model.embed1._w.name assert items[0].name is not model.fc1._w.name - assert model.embed1._w._ivar._grad_ivar() is None - assert model.fc1._w._ivar._grad_ivar() is None + assert model.embed1._w._grad_ivar() is None + assert model.fc1._w._grad_ivar() is None def test_case2_prune_no_grad_branch(self): with fluid.dygraph.guard(): @@ -363,8 +363,8 @@ class TestImperativeAutoPrune(unittest.TestCase): case3 = AutoPruneLayer2("l2") loss = case3(v1, v2) loss.backward() - self.assertTrue(case3.fc2._w._ivar._grad_ivar() is None) - self.assertTrue(case3.fc._w._ivar._grad_ivar() is not None) + self.assertTrue(case3.fc2._w._grad_ivar() is None) + self.assertTrue(case3.fc._w._grad_ivar() is not None) def test_case2_prune_no_grad_branch(self): with fluid.dygraph.guard(): @@ -375,8 +375,8 @@ class TestImperativeAutoPrune(unittest.TestCase): case3 = AutoPruneLayer2("l2") loss = case3(v1, v2) loss.backward() - self.assertTrue(case3.fc2._w._ivar._grad_ivar() is None) - self.assertTrue(case3.fc._w._ivar._grad_ivar() is not None) + self.assertTrue(case3.fc2._w._grad_ivar() is None) + self.assertTrue(case3.fc._w._grad_ivar() is not None) def test_case3_prune_no_grad_branch2(self): with fluid.dygraph.guard(): @@ -389,14 +389,14 @@ class TestImperativeAutoPrune(unittest.TestCase): out = fluid.layers.one_hot(input=label, depth=100) loss = fluid.layers.mean(out) loss.backward() - self.assertTrue(fc._w._ivar._grad_ivar() is None) + self.assertTrue(fc._w._grad_ivar() is None) def test_case4_with_no_grad_op_maker(self): with fluid.dygraph.guard(): out = fluid.layers.gaussian_random(shape=[20, 30]) loss = fluid.layers.mean(out) loss.backward() - self.assertTrue(out._ivar._grad_ivar() is None) + self.assertTrue(out._grad_ivar() is None) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_imperative_basic.py b/python/paddle/fluid/tests/unittests/test_imperative_basic.py index d1b5642406dbbad9edb5d0b590a4060c10309487..fc3378ce283316176f604e87f8979edf2f398fd1 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_basic.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_basic.py @@ -177,6 +177,30 @@ class SimpleRNN(fluid.Layer): class TestImperative(unittest.TestCase): + def test_isinstance(self): + var = fluid.layers.data(shape=[1], name='x', dtype='float32') + self.assertTrue(isinstance(var, fluid.Variable)) + with fluid.dygraph.guard(): + var_base = fluid.dygraph.base.to_variable(np.array([3, 4, 5])) + self.assertTrue(isinstance(var_base, core.VarBase)) + self.assertTrue(isinstance(var_base, fluid.Variable)) + + def test_create_VarBase(self): + x = np.ones([2, 2], np.float32) + y = np.zeros([3, 3], np.float32) + with fluid.dygraph.guard(): + tmp = fluid.core.VarBase(value=x, place=fluid.core.CPUPlace()) + tmp2 = fluid.core.VarBase(y, fluid.core.CPUPlace()) + tmp3 = fluid.dygraph.base.to_variable(x) + tmp4 = fluid.core.VarBase(y) + tmp5 = fluid.core.VarBase(value=x) + + self.assertTrue(np.array_equal(x, tmp.numpy())) + self.assertTrue(np.array_equal(y, tmp2.numpy())) + self.assertTrue(np.array_equal(x, tmp3.numpy())) + self.assertTrue(np.array_equal(y, tmp4.numpy())) + self.assertTrue(np.array_equal(x, tmp5.numpy())) + def test_sum_op(self): x = np.ones([2, 2], np.float32) with fluid.dygraph.guard(): @@ -215,17 +239,17 @@ class TestImperative(unittest.TestCase): try: new_variable.numpy() except Exception as e: - assert type(e) == ValueError + assert type(e) == core.EnforceNotMet try: new_variable.backward() except Exception as e: - assert type(e) == ValueError + assert type(e) == core.EnforceNotMet try: new_variable.clear_gradient() except Exception as e: - assert type(e) == ValueError + assert type(e) == core.EnforceNotMet def test_empty_grad(self): with fluid.dygraph.guard(): @@ -239,7 +263,7 @@ class TestImperative(unittest.TestCase): try: new_var.clear_gradient() except Exception as e: - assert type(e) == ValueError + assert type(e) == core.EnforceNotMet with fluid.dygraph.guard(): cur_program = fluid.Program() @@ -257,7 +281,7 @@ class TestImperative(unittest.TestCase): new_var = fluid.dygraph.base.to_variable(x) self.assertFalse(new_var.persistable) new_var.persistable = True - self.assertFalse(new_var.persistable) + self.assertTrue(new_var.persistable) def test_layer(self): with fluid.dygraph.guard(): diff --git a/python/paddle/fluid/tests/unittests/test_imperative_lod_tensor_to_selected_rows.py b/python/paddle/fluid/tests/unittests/test_imperative_lod_tensor_to_selected_rows.py index eb06daa0c532b117992a2faa2bd6418e0f678df1..fce587c5921b0ffeb72010a3ed8774d79d2d51be 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_lod_tensor_to_selected_rows.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_lod_tensor_to_selected_rows.py @@ -70,7 +70,6 @@ class SimpleNet(fluid.Layer): loss = fluid.layers.reshape(loss, shape=[-1, self.num_steps]) loss = fluid.layers.reduce_mean(loss, dim=[0]) loss = fluid.layers.reduce_sum(loss) - loss.permissions = True return loss diff --git a/python/paddle/fluid/tests/unittests/test_imperative_ocr_attention_model.py b/python/paddle/fluid/tests/unittests/test_imperative_ocr_attention_model.py index 73768340003d95c1a220854dd584d5c85787ea90..2eac3507f853372bcd327404f3fd2409a469ecec 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_ocr_attention_model.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_ocr_attention_model.py @@ -459,8 +459,7 @@ class TestDygraphOCRAttention(unittest.TestCase): for batch_id in range(batch_num): label_in = to_variable(label_in_np) label_out = to_variable(label_out_np) - label_out._stop_gradient = True - label_out.trainable = False + label_out.stop_gradient = True img = to_variable(image_np) dy_prediction = ocr_attention(img, label_in) label_out = fluid.layers.reshape( @@ -481,7 +480,7 @@ class TestDygraphOCRAttention(unittest.TestCase): dy_grad_value = {} for param in ocr_attention.parameters(): if param.trainable: - np_array = np.array(param._ivar._grad_ivar().value() + np_array = np.array(param._grad_ivar().value() .get_tensor()) dy_grad_value[param.name + core.grad_var_suffix( )] = np_array @@ -514,7 +513,7 @@ class TestDygraphOCRAttention(unittest.TestCase): name='label_in', shape=[1], dtype='int64', lod_level=0) static_label_out = fluid.layers.data( name='label_out', shape=[1], dtype='int64', lod_level=0) - static_label_out._stop_gradient = True + static_label_out.stop_gradient = True static_label_out.trainable = False static_prediction = ocr_attention(images, static_label_in) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py b/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py index a7c39f7ff2ad8e9dedc99bb37fc0f997853da572..1d232ba7f9891618fc9d67341cc5745cd5b3107a 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py @@ -83,7 +83,7 @@ class TestImperativeOptimizerBase(unittest.TestCase): img = data[0] label = data[1] - label._stop_gradient = True + label.stop_gradient = True cost = mlp(img) avg_loss = fluid.layers.reduce_mean(cost) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_partitial_backward.py b/python/paddle/fluid/tests/unittests/test_imperative_partitial_backward.py index c6a2ad9e3d5ce79298160bdca2506c989f356ce0..890e088f84197b114a442cfac3c3e6db64da4272 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_partitial_backward.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_partitial_backward.py @@ -33,10 +33,10 @@ class TestImperativePartitialBackward(unittest.TestCase): loss.backward() for param in fc1.parameters(): - self.assertIsNotNone(param._ivar._grad_ivar()) + self.assertIsNotNone(param._grad_ivar()) for param in fc2.parameters(): - self.assertIsNone(param._ivar._grad_ivar()) + self.assertIsNone(param._grad_ivar()) optimizer = fluid.optimizer.AdamOptimizer() _, params_grads = optimizer.minimize(loss) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn.py b/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn.py index 472fae6e49b727721f40f2bf1d26518ab61262c8..2f5db94cf6dec0d1ca628bb9182a7dd278c17c08 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn.py @@ -207,7 +207,6 @@ class PtbModel(fluid.Layer): loss = fluid.layers.reshape(loss, shape=[-1, self.num_steps]) loss = fluid.layers.reduce_mean(loss, dim=[0]) loss = fluid.layers.reduce_sum(loss) - loss.permissions = True return loss, last_hidden, last_cell diff --git a/python/paddle/fluid/tests/unittests/test_imperative_resnet.py b/python/paddle/fluid/tests/unittests/test_imperative_resnet.py index 4c9a65d15d872b4b57b99ac9760ace9c7e3ad38c..8a7a8338ea21ff494903d71d2c349dc4a29e947f 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_resnet.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_resnet.py @@ -302,7 +302,7 @@ class TestDygraphResnet(unittest.TestCase): dy_grad_value = {} for param in resnet.parameters(): if param.trainable: - np_array = np.array(param._ivar._grad_ivar().value() + np_array = np.array(param._grad_ivar().value() .get_tensor()) dy_grad_value[param.name + core.grad_var_suffix( )] = np_array diff --git a/python/paddle/fluid/tests/unittests/test_imperative_resnet_sorted_gradient.py b/python/paddle/fluid/tests/unittests/test_imperative_resnet_sorted_gradient.py index 74560535074f2550429de79385415162987dab7d..44e147e317cb97db054f37222b8e4a314b3cb0a8 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_resnet_sorted_gradient.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_resnet_sorted_gradient.py @@ -119,7 +119,7 @@ class TestDygraphResnetSortGradient(unittest.TestCase): dy_grad_value = {} for param in resnet.parameters(): if param.trainable: - np_array = np.array(param._ivar._grad_ivar().value() + np_array = np.array(param._grad_ivar().value() .get_tensor()) dy_grad_value[param.name + core.grad_var_suffix( )] = np_array diff --git a/python/paddle/fluid/tests/unittests/test_imperative_save_load.py b/python/paddle/fluid/tests/unittests/test_imperative_save_load.py index c15e893c7d51cc27c85db06ba5db02baa02e03a6..a68b7735b03a94a64a6166fe1581546d273f232d 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_save_load.py @@ -197,7 +197,6 @@ class PtbModel(fluid.Layer): loss = fluid.layers.reshape(loss, shape=[-1, self.num_steps]) loss = fluid.layers.reduce_mean(loss, dim=[0]) loss = fluid.layers.reduce_sum(loss) - loss.permissions = True return loss, last_hidden, last_cell @@ -353,7 +352,7 @@ class TestDygraphPtbRnn(unittest.TestCase): # set to zero for k, v in opti_dict.items(): np_t = v.numpy() - var = v._ivar.value().get_tensor() + var = v.value().get_tensor() var.set(np.zeros_like(np_t), place) self.assertTrue(np.sum(np.abs(v.numpy())) == 0) @@ -373,7 +372,7 @@ class TestDygraphPtbRnn(unittest.TestCase): state_dict = ptb_model.state_dict() for k, v in state_dict.items(): np_t = v.numpy() - var = v._ivar.value().get_tensor() + var = v.value().get_tensor() var.set(np.zeros_like(np_t), place) @@ -457,7 +456,7 @@ class TestDygraphPtbRnn(unittest.TestCase): # set to zero for k, v in opti_dict.items(): np_t = v.numpy() - var = v._ivar.value().get_tensor() + var = v.value().get_tensor() var.set(np.zeros_like(np_t), place) self.assertTrue(np.sum(np.abs(v.numpy())) == 0) @@ -476,7 +475,7 @@ class TestDygraphPtbRnn(unittest.TestCase): state_dict = ptb_model.state_dict() for k, v in state_dict.items(): np_t = v.numpy() - var = v._ivar.value().get_tensor() + var = v.value().get_tensor() var.set(np.zeros_like(np_t), place) @@ -562,7 +561,7 @@ class TestDygraphPtbRnn(unittest.TestCase): for k, v in opti_dict.items(): np_t = v.numpy() np_opti_dict[v.name] = np_t - var = v._ivar.value().get_tensor() + var = v.value().get_tensor() var.set(np.zeros_like(np_t), place) self.assertTrue(np.sum(np.abs(v.numpy())) == 0) @@ -583,7 +582,7 @@ class TestDygraphPtbRnn(unittest.TestCase): for k, v in state_dict.items(): np_t = v.numpy() np_state_dict[v.name] = np_t - var = v._ivar.value().get_tensor() + var = v.value().get_tensor() var.set(np.zeros_like(np_t), place) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_se_resnext.py b/python/paddle/fluid/tests/unittests/test_imperative_se_resnext.py index f6585d1b30dacc5a54e38455e8db82980057f1a0..2c11933dcb6352056a49f99166bf03cc47f23427 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_se_resnext.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_se_resnext.py @@ -361,7 +361,7 @@ class TestImperativeResneXt(unittest.TestCase): #dy_grad_value = {} #for param in se_resnext.parameters(): # if param.trainable: - # np_array = np.array(param._ivar._grad_ivar().value() + # np_array = np.array(param._grad_ivar().value() # .get_tensor()) # dy_grad_value[param.name + core.grad_var_suffix()] = np_array diff --git a/python/paddle/fluid/tests/unittests/test_imperative_selected_rows_to_lod_tensor.py b/python/paddle/fluid/tests/unittests/test_imperative_selected_rows_to_lod_tensor.py index c7054c15017d48650267899005cd3a634f03aa2e..04776a3838904c9268c12788edfac417500cd081 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_selected_rows_to_lod_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_selected_rows_to_lod_tensor.py @@ -78,7 +78,6 @@ class SimpleNet(fluid.Layer): loss = fluid.layers.reshape(loss, shape=[-1, self.num_steps]) loss = fluid.layers.reduce_mean(loss, dim=[0]) loss = fluid.layers.reduce_sum(loss) - loss.permissions = True return loss diff --git a/python/paddle/fluid/tests/unittests/test_parameter.py b/python/paddle/fluid/tests/unittests/test_parameter.py index fc7427dcbfd6998598ad95b70d245f6c8c1b28ae..05c19776a37f160a38b13698e61ae1aeec3f8f71 100644 --- a/python/paddle/fluid/tests/unittests/test_parameter.py +++ b/python/paddle/fluid/tests/unittests/test_parameter.py @@ -25,8 +25,8 @@ import numpy as np main_program = default_main_program() -class TestParameter(unittest.TestCase): - def test_param(self): +class ParameterChecks(unittest.TestCase): + def check_param(self): shape = [784, 100] val = 1.0625 b = main_program.global_block() @@ -46,7 +46,7 @@ class TestParameter(unittest.TestCase): p = io.get_parameter_value_by_name('fc.w', exe, main_program) self.assertTrue(np.allclose(np.array(p), np.ones(shape) * val)) - def test_exceptions(self): + def check_exceptions(self): b = main_program.global_block() with self.assertRaises(ValueError): b.create_parameter( @@ -62,5 +62,13 @@ class TestParameter(unittest.TestCase): name='test', shape=[-1], dtype='float32', initializer=None) +class TestParameter(ParameterChecks): + def test_param(self): + self.check_param() + + def test_exceptions(self): + self.check_exceptions() + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_static_save_load.py b/python/paddle/fluid/tests/unittests/test_static_save_load.py index 5fc2701ee2dccc54624512c5c8777e3f32c647ae..47241c28cd2039693a5e30e9640bb9468c1fd928 100644 --- a/python/paddle/fluid/tests/unittests/test_static_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_static_save_load.py @@ -208,7 +208,6 @@ class PtbModel(fluid.Layer): loss = fluid.layers.reshape(loss, shape=[-1, self.num_steps]) loss = fluid.layers.reduce_mean(loss, dim=[0]) loss = fluid.layers.reduce_sum(loss) - loss.permissions = True return loss, last_hidden, last_cell diff --git a/python/paddle/fluid/tests/unittests/test_variable.py b/python/paddle/fluid/tests/unittests/test_variable.py index 79e73aa8ea5a1d96803c2e12cb62fe872a775250..f4f4749e0c0a562ee91564af68b74554f765b2e7 100644 --- a/python/paddle/fluid/tests/unittests/test_variable.py +++ b/python/paddle/fluid/tests/unittests/test_variable.py @@ -184,6 +184,28 @@ class TestVariable(unittest.TestCase): with fluid.program_guard(default_main_program()): self._tostring() + # NOTE(zhiqiu): for coverage CI + # TODO(zhiqiu): code clean for dygraph + def test_dygraph_deprecated_api(self): + b = default_main_program().current_block() + var = b.create_var(dtype="float64", lod_level=0) + with fluid.dygraph.guard(): + self.assertIsNone(var.detach()) + self.assertIsNone(var.numpy()) + self.assertIsNone(var.set_value(None)) + self.assertIsNone(var.backward()) + self.assertIsNone(var.gradient()) + self.assertIsNone(var.clear_gradient()) + self.assertIsNone(var.to_string(True)) + self.assertIsNone(var.persistable) + var.stop_gradient = True + self.assertIsNone(var.stop_gradient) + var.stop_gradient = 'tmp' + self.assertIsNone(var.name) + self.assertIsNone(var.shape) + self.assertIsNone(var.dtype) + self.assertIsNone(var.type) + if __name__ == '__main__': unittest.main()