From d21074cd68f7de9db44aeeaeb0c4a2a89a1a7b23 Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Thu, 20 Jan 2022 09:51:52 +0800 Subject: [PATCH] [Eager] Support Eager mode for some testcase (#38783) * Rearranged Eager AutoCodeGen directory structure * Removed USE_OP in Eager AutoCodeGen * Enabled generation for Operators without Grad/Inputs/Outputs * Resolved operators without input * Fixed merge conflicts * Enabled Eager AutoCodeGen for 10+ more operators * Refactored Eager AutoCodeGen with more organized helper objects * Enabled Eager AutoCodeGen for operators with multiple OpBases * Adjusted Eager AutoCodeGen to Enable Passing Output Tensor as Input Argument * Handled Dispensable Inputs/Outputs in Eager AutoCodeGen * Adjusted function generation/call between Python-C API & Dygraph API * Synchronized auto-generated Python-C API with Dygraph Forward Functions * support more eager tensor api * fix merge compile error * fix compile error and fit develop code * support pure CPU * fix some logic error in eager_mode * support _varbase_creator in eager mode * Added safe_initialized interface to EagerTensor for use in processing dispensable inputs * for eager mode * refine * support multiple constructor for eager tensor * add place related code * polish code * specific randint with dtype of int64 * Support pure cpu test * eager logic * refine test in pure cpu * eager logic * eager logic * eager logic, test=develop * skip core.eager when in inference, test=develop * refine, test=develop * refine, test=develop * call RetainGrad after run forward kernel, test=develop * refine, test=develop * support dygraph util, meta, guard test * eager test case * support inference test * refine test and fix initializer failed * modify eagertensor patch method * add eagertensor.clear_grandint, test=develop * refine, test=develop * refine, test=develop * refine, test=develop * support create varbase and fix retain grad error * call monkey_patch_varbase in _test_eager_guard, test=develop * fix windows error * split clear_gradient to clear_gradient and zero_grads, test=develop * refine, test=develop * refine, test=develop * support test_imperative_basic test in eager mode * remove additional log in variable.h * remove additional log in variable.h * remove additional code create in merge * eager * fix some eager logic, test=develop * refine, test=develop * refine, test=develop * refine, test=develop * patch_tensor_method_func, test=develop * refine, test=develop * eager test case, test=develop * refine, test=develop * eager, test=develop * eager, test=develop * eager optimizer, test=develop * eager optimizer, test=develop * eager test_imperative_optimizer_v2, test=develop * eager, test=develop * refine, test=develop * refine, test=develop * eager, test=develop * add resize in share buffer to, test=develop * eager, test=develop * fix _share_buffer_to, test=develop * refine, test=develop * refine, test=develop * support eager for dataloader,test=develop Co-authored-by: jim19930609 Co-authored-by: JiabinYang <360788950@qq.com> --- paddle/fluid/pybind/eager.cc | 89 ++-- paddle/fluid/pybind/eager_functions.cc | 30 ++ paddle/fluid/pybind/eager_method.cc | 56 ++- paddle/fluid/pybind/eager_utils.cc | 52 +++ paddle/fluid/pybind/eager_utils.h | 2 + paddle/fluid/pybind/imperative.cc | 23 + paddle/fluid/pybind/op_function_generator.h | 9 + paddle/fluid/pybind/pybind.cc | 8 +- .../fluid/dataloader/dataloader_iter.py | 20 +- .../fluid/dygraph/varbase_patch_methods.py | 10 +- python/paddle/fluid/initializer.py | 60 +-- python/paddle/fluid/optimizer.py | 395 +++++++++++------- python/paddle/fluid/reader.py | 8 +- .../tests/unittests/test_egr_python_api.py | 28 ++ .../tests/unittests/test_imperative_basic.py | 3 +- .../test_imperative_container_layerlist.py | 8 +- .../test_imperative_data_loader_base.py | 29 +- .../test_imperative_data_loader_exception.py | 29 +- .../test_imperative_data_loader_exit_func.py | 29 +- .../test_imperative_data_loader_fds_clear.py | 15 +- .../test_imperative_data_loader_process.py | 15 +- .../tests/unittests/test_imperative_layers.py | 8 +- .../test_imperative_named_members.py | 22 +- .../unittests/test_imperative_optimizer.py | 197 +++++++-- .../unittests/test_imperative_optimizer_v2.py | 232 ++++++++-- python/paddle/nn/layer/rnn.py | 7 + python/paddle/optimizer/lr.py | 8 +- python/paddle/optimizer/optimizer.py | 29 +- 28 files changed, 1074 insertions(+), 347 deletions(-) diff --git a/paddle/fluid/pybind/eager.cc b/paddle/fluid/pybind/eager.cc index 3439f96984d..3a7043809d9 100644 --- a/paddle/fluid/pybind/eager.cc +++ b/paddle/fluid/pybind/eager.cc @@ -26,6 +26,7 @@ limitations under the License. */ #include "paddle/pten/common/data_type.h" #include "paddle/pten/core/convert_utils.h" #include "paddle/pten/core/dense_tensor.h" +#include "pybind11/detail/internals.h" #include "pybind11/numpy.h" #include "pybind11/pybind11.h" #pragma GCC diagnostic ignored "-Wmissing-field-initializers" @@ -48,6 +49,7 @@ PyObject* EagerTensorNew(PyTypeObject* type, PyObject* args, PyObject* kwargs) { if (obj) { auto v = reinterpret_cast(obj); new (&(v->eager_tensor)) egr::EagerTensor(); + Py_INCREF(obj); } return obj; } @@ -726,7 +728,7 @@ int EagerTensorInit(PyObject* self, PyObject* args, PyObject* kwargs) { return 1; } -static void eagertensor_dealloc(EagerTensorObject* self) { +static void EagerTensorDealloc(EagerTensorObject* self) { self->eager_tensor.~EagerTensor(); Py_TYPE(self)->tp_free(reinterpret_cast(self)); } @@ -739,71 +741,44 @@ PyNumberMethods number_methods; PySequenceMethods sequence_methods; PyMappingMethods mapping_methods; -PyTypeObject eager_tensor_type = { - PyVarObject_HEAD_INIT(NULL, 0) "core_avx.eager.EagerTensor", /* tp_name */ - sizeof(EagerTensorObject), /* tp_basicsize */ - 0, /* tp_itemsize */ - (destructor)eagertensor_dealloc, /* tp_dealloc */ - 0, /* tp_vectorcall_offset */ - 0, /* tp_getattr */ - 0, /* tp_setattr */ - 0, /* tp_reserved */ - 0, /* tp_repr */ - &number_methods, /* tp_as_number */ - &sequence_methods, /* tp_as_sequence */ - &mapping_methods, /* tp_as_mapping */ - 0, /* tp_hash */ - 0, /* tp_call */ - 0, /* tp_str */ - 0, /* tp_getattro */ - 0, /* tp_setattro */ - 0, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | - Py_TPFLAGS_HEAPTYPE, /* tp_flags */ - 0, /* tp_doc */ - 0, /* tp_traverse */ - 0, /* tp_clear */ - 0, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - 0, /* tp_iter */ - 0, /* tp_iternext */ - variable_methods, /* tp_methods */ - 0, /* tp_members */ - variable_properties, /* tp_getset */ - 0, /* tp_base */ - 0, /* tp_dict */ - 0, /* tp_descr_get */ - 0, /* tp_descr_set */ - 0, /* tp_dictoffset */ - EagerTensorInit, /* tp_init */ - 0, /* tp_alloc */ - EagerTensorNew, /* tp_new */ - 0, /* tp_free */ - 0, /* tp_is_gc */ - 0, /* tp_bases */ - 0, /* tp_mro */ - 0, /* tp_cache */ - 0, /* tp_subclasses */ - 0, /* tp_weaklist */ - 0, /* tp_del */ - 0, /* tp_version_tag */ - 0 /* tp_finalize */ -}; - void BindEager(pybind11::module* module) { auto m = module->def_submodule("eager"); - p_eager_tensor_type = &eager_tensor_type; - if (PyType_Ready(&eager_tensor_type) < 0) { + auto& internals = pybind11::detail::get_internals(); + auto heap_type = reinterpret_cast( + internals.default_metaclass->tp_alloc(internals.default_metaclass, 0)); + heap_type->ht_name = ToPyObject("EagerTensor"); + heap_type->ht_qualname = ToPyObject("EagerTensor"); + auto type = &heap_type->ht_type; + type->tp_name = "EagerTensor"; + type->tp_basicsize = sizeof(EagerTensorObject); + type->tp_dealloc = (destructor)EagerTensorDealloc; + type->tp_as_number = &number_methods; + type->tp_as_sequence = &sequence_methods; + type->tp_as_mapping = &mapping_methods; + type->tp_methods = variable_methods; + type->tp_getset = variable_properties; + type->tp_init = EagerTensorInit; + type->tp_new = EagerTensorNew; + Py_INCREF(internals.instance_base); + type->tp_base = reinterpret_cast(internals.instance_base); + type->tp_flags |= + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE; +#if PY_VERSION_HEX >= 0x03050000 + type->tp_as_async = &heap_type->as_async; +#endif + p_eager_tensor_type = type; + + if (PyType_Ready(type) < 0) { PADDLE_THROW(platform::errors::Fatal( "Init Paddle error in BindEager(PyType_Ready).")); return; } - Py_INCREF(&eager_tensor_type); + Py_INCREF(type); if (PyModule_AddObject(m.ptr(), "EagerTensor", - reinterpret_cast(&eager_tensor_type)) < 0) { - Py_DECREF(&eager_tensor_type); + reinterpret_cast(type)) < 0) { + Py_DECREF(type); Py_DECREF(m.ptr()); PADDLE_THROW(platform::errors::Fatal( "Init Paddle error in BindEager(PyModule_AddObject).")); diff --git a/paddle/fluid/pybind/eager_functions.cc b/paddle/fluid/pybind/eager_functions.cc index aaf86bc41ae..35f091e1c88 100644 --- a/paddle/fluid/pybind/eager_functions.cc +++ b/paddle/fluid/pybind/eager_functions.cc @@ -41,6 +41,8 @@ namespace pybind { namespace py = ::pybind11; extern PyTypeObject* p_eager_tensor_type; +extern PyTypeObject* g_multidevicefeedreader_pytype; +extern PyTypeObject* g_orderedmultidevicefeedreader_pytype; size_t PyArray_Size_(PyObject* numpy_data) { size_t res = 1; @@ -146,6 +148,31 @@ static PyObject* eager_api_tensor_copy(PyObject* self, PyObject* args, EAGER_CATCH_AND_THROW_RETURN_NULL } +static PyObject* eager_api_read_next_eager_tensor_list(PyObject* self, + PyObject* args, + PyObject* kwargs) { + EAGER_TRY + auto tensor_list = CastPyArg2VectorOfTensor(PyTuple_GET_ITEM(args, 0), 0); + std::vector eager_tensor_list; + eager_tensor_list.reserve(tensor_list.size()); + auto func = [](framework::Tensor& tensor) { + egr::EagerTensor eager_tensor( + egr::Controller::Instance().GenerateUniqueName()); + auto autograd_meta = egr::EagerUtils::autograd_meta(&eager_tensor); + autograd_meta->SetPersistable(false); + autograd_meta->SetStopGradient(true); + auto tmp = std::move(tensor); + eager_tensor.set_impl( + std::move(paddle::experimental::MakePtenDenseTensor(tmp))); + return eager_tensor; + }; + for (auto& tensor : tensor_list) { + eager_tensor_list.emplace_back(func(tensor)); + } + return ToPyObject(eager_tensor_list); + EAGER_CATCH_AND_THROW_RETURN_NULL +} + PyMethodDef variable_functions[] = { {"scale", (PyCFunction)(void (*)(void))eager_api_scale, METH_VARARGS | METH_KEYWORDS, NULL}, @@ -159,6 +186,9 @@ PyMethodDef variable_functions[] = { METH_VARARGS | METH_KEYWORDS, NULL}, {"tensor_copy", (PyCFunction)(void (*)(void))eager_api_tensor_copy, METH_VARARGS | METH_KEYWORDS, NULL}, + {"read_next_eager_tensor_list", + (PyCFunction)(void (*)(void))eager_api_read_next_eager_tensor_list, + METH_VARARGS | METH_KEYWORDS, NULL}, {NULL, NULL, 0, NULL}}; void BindFunctions(PyObject* module) { diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index 4419640ccf3..b254b5d41d3 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -237,7 +237,7 @@ static PyObject* eager_tensor__share_buffer_to(EagerTensorObject* self, PyObject* args, PyObject* kwargs) { EAGER_SYNC_TRY - egr::EagerTensor* src_ptr = + egr::EagerTensor* dst_ptr = &(reinterpret_cast(PyTuple_GET_ITEM(args, 0)) ->eager_tensor); PADDLE_ENFORCE_EQ(self->eager_tensor.initialized(), true, @@ -245,7 +245,12 @@ static PyObject* eager_tensor__share_buffer_to(EagerTensorObject* self, "Tensor %s has not been initialized! please initialize " "src tensor before share_buffer_with to other.", self->eager_tensor.name())); - src_ptr->set_impl(self->eager_tensor.impl()); + auto* src_tensor = + static_cast(self->eager_tensor.impl().get()); + auto dst_tensor = + static_cast(dst_ptr->impl().get()); + dst_tensor->ShareDataWith(*src_tensor); + dst_tensor->ShareDataTypeWith(*src_tensor); Py_INCREF(Py_None); return Py_None; EAGER_CATCH_AND_THROW_RETURN_NULL @@ -255,6 +260,47 @@ static PyObject* eager_tensor__is_shared_buffer_with(EagerTensorObject* self, PyObject* args, PyObject* kwargs) { EAGER_SYNC_TRY + egr::EagerTensor* dst_ptr = + &(reinterpret_cast(PyTuple_GET_ITEM(args, 0)) + ->eager_tensor); + PADDLE_ENFORCE_EQ(self->eager_tensor.initialized(), true, + platform::errors::InvalidArgument( + "Tensor %s has not been initialized! please initialize " + "src tensor before share_buffer_with to other.", + self->eager_tensor.name())); + bool res = false; + if (!self->eager_tensor.defined() || !dst_ptr->defined()) { + return ToPyObject(res); + } + auto* self_ptr = + static_cast(self->eager_tensor.impl().get()); + auto dst_tensor = + static_cast(dst_ptr->impl().get()); + res = dst_tensor->IsSharedBufferWith(*self_ptr); + return ToPyObject(res); + EAGER_CATCH_AND_THROW_RETURN_NULL +} + +static PyObject* eager_tensor__share_underline_tensor_to( + EagerTensorObject* self, PyObject* args, PyObject* kwargs) { + EAGER_SYNC_TRY + egr::EagerTensor* src_ptr = + &(reinterpret_cast(PyTuple_GET_ITEM(args, 0)) + ->eager_tensor); + PADDLE_ENFORCE_EQ(self->eager_tensor.initialized(), true, + platform::errors::InvalidArgument( + "Tensor %s has not been initialized! please initialize " + "src tensor before share_buffer_with to other.", + self->eager_tensor.name())); + src_ptr->set_impl(self->eager_tensor.impl()); + Py_INCREF(Py_None); + return Py_None; + EAGER_CATCH_AND_THROW_RETURN_NULL +} + +static PyObject* eager_tensor__is_shared_underline_tensor_with( + EagerTensorObject* self, PyObject* args, PyObject* kwargs) { + EAGER_SYNC_TRY egr::EagerTensor src_tensor = CastPyArg2EagerTensor(PyTuple_GET_ITEM(args, 0), 0); PADDLE_ENFORCE_EQ(src_tensor.initialized(), true, @@ -336,6 +382,12 @@ PyMethodDef variable_methods[] = { {"_is_shared_buffer_with", (PyCFunction)(void (*)(void))eager_tensor__is_shared_buffer_with, METH_VARARGS | METH_KEYWORDS, NULL}, + {"_share_underline_tensor_to", + (PyCFunction)(void (*)(void))eager_tensor__share_underline_tensor_to, + METH_VARARGS | METH_KEYWORDS, NULL}, + {"_is_shared_underline_tensor_with", + (PyCFunction)(void (*)(void))eager_tensor__is_shared_underline_tensor_with, + METH_VARARGS | METH_KEYWORDS, NULL}, {"detach", (PyCFunction)(void (*)(void))eager_tensor_method_detach, METH_VARARGS | METH_KEYWORDS, NULL}, {"get_tensor", diff --git a/paddle/fluid/pybind/eager_utils.cc b/paddle/fluid/pybind/eager_utils.cc index 5c74653a719..1e0697246e9 100644 --- a/paddle/fluid/pybind/eager_utils.cc +++ b/paddle/fluid/pybind/eager_utils.cc @@ -40,6 +40,7 @@ extern PyTypeObject* g_xpuplace_pytype; extern PyTypeObject* g_npuplace_pytype; extern PyTypeObject* g_cudapinnedplace_pytype; extern PyTypeObject* g_framework_tensor_pytype; +extern PyTypeObject* g_framework_lodtensorarray_pytype; int TensorDtype2NumpyDtype(pten::DataType dtype) { switch (dtype) { @@ -316,6 +317,57 @@ framework::Tensor CastPyArg2FrameworkTensor(PyObject* obj, ssize_t arg_pos) { } } +std::vector CastPyArg2VectorOfTensor(PyObject* obj, + ssize_t arg_pos) { + std::vector result; + if (PyList_Check(obj)) { + Py_ssize_t len = PyList_Size(obj); + PyObject* item = nullptr; + for (Py_ssize_t i = 0; i < len; i++) { + item = PyList_GetItem(obj, i); + if (PyObject_IsInstance( + item, reinterpret_cast(g_framework_tensor_pytype))) { + result.emplace_back( + ::pybind11::handle(item).cast()); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "argument (position %d) must be " + "list of LoDTensor, but got %s at pos %d", + arg_pos + 1, + reinterpret_cast(item->ob_type)->tp_name, i)); + } + } + } else if (PyTuple_Check(obj)) { + Py_ssize_t len = PyTuple_Size(obj); + PyObject* item = nullptr; + for (Py_ssize_t i = 0; i < len; i++) { + item = PyTuple_GetItem(obj, i); + if (PyObject_IsInstance( + item, reinterpret_cast(g_framework_tensor_pytype))) { + result.emplace_back( + ::pybind11::handle(item).cast()); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "argument (position %d) must be " + "list of LoDTensor, but got %s at pos %d", + arg_pos + 1, + reinterpret_cast(item->ob_type)->tp_name, i)); + } + } + } else if (PyObject_IsInstance(obj, reinterpret_cast( + g_framework_lodtensorarray_pytype))) { + return ::pybind11::handle(obj).cast(); + } else if (obj == Py_None) { + return {}; + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "argument (position %d) must be " + "list or tuple, but got %s", + arg_pos + 1, reinterpret_cast(obj->ob_type)->tp_name)); + } + return result; +} + paddle::framework::proto::VarType::Type CastPyArg2ProtoType(PyObject* obj, ssize_t arg_pos) { paddle::framework::proto::VarType::Type dtype; diff --git a/paddle/fluid/pybind/eager_utils.h b/paddle/fluid/pybind/eager_utils.h index e1a7ed24150..8bec57bfefb 100644 --- a/paddle/fluid/pybind/eager_utils.h +++ b/paddle/fluid/pybind/eager_utils.h @@ -37,6 +37,8 @@ std::vector CastPyArg2VectorOfEagerTensor(PyObject* obj, ssize_t arg_pos); platform::Place CastPyArg2Place(PyObject* obj, ssize_t arg_pos); framework::Tensor CastPyArg2FrameworkTensor(PyObject* obj, ssize_t arg_pos); +std::vector CastPyArg2VectorOfTensor(PyObject* obj, + ssize_t arg_pos); std::vector CastPyArg2VectorOfInt(PyObject* obj, size_t arg_pos); framework::proto::VarType::Type CastPyArg2ProtoType(PyObject* obj, ssize_t arg_pos); diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 8957f5c0e7e..a3f0a0c87fd 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -2047,6 +2047,29 @@ void BindImperative(py::module *m_ptr) { } return dst_->IsSharedBufferWith(*src); }) + .def("_share_underline_tensor_to", + [](const std::shared_ptr &self, + std::shared_ptr &dst) { + auto *src = self->MutableVar()->GetMutable(); + auto *dst_ = dst->MutableVar()->GetMutable(); + PADDLE_ENFORCE_EQ( + src->IsInitialized(), true, + platform::errors::InvalidArgument( + "Tensor %s has not been initialized!", self->Name())); + dst_->ShareBufferWith(*src); + dst_->ShareDataTypeWith(*src); + dst_->Resize(src->dims()); + }) + .def("_is_shared_underline_tensor_with", + [](const std::shared_ptr &self, + std::shared_ptr &dst) { + auto *src = self->MutableVar()->GetMutable(); + auto *dst_ = dst->MutableVar()->GetMutable(); + if (!src->IsInitialized() || !dst_->IsInitialized()) { + return false; + } + return dst_->IsSharedBufferWith(*src); + }) .def("_slice", [](const std::shared_ptr &self, int64_t begin_idx, int64_t end_idx) { diff --git a/paddle/fluid/pybind/op_function_generator.h b/paddle/fluid/pybind/op_function_generator.h index f83997843f4..d916efe605a 100644 --- a/paddle/fluid/pybind/op_function_generator.h +++ b/paddle/fluid/pybind/op_function_generator.h @@ -151,6 +151,15 @@ std::map> op_outs_map = { // For those OPs, we need to manually specify the outs need to pass in this map. std::map> op_passing_outs_map = { {"sgd", {"ParamOut", "MasterParamOut"}}, + {"rmsprop", {"ParamOut", "MomentOut", "MeanSquareOut", "MeanGradOut"}}, + {"ftrl", {"ParamOut", "SquaredAccumOut", "LinearAccumOut"}}, + {"adadelta", {"ParamOut", "AvgSquaredGradOut", "AvgSquaredUpdateOut"}}, + {"adagrad", {"ParamOut", "MomentOut"}}, + {"adamax", {"ParamOut", "MomentOut", "InfNormOut"}}, + {"dpsgd", {"ParamOut"}}, + {"decayed_adagrad", {"ParamOut", "MomentOut"}}, + {"lars_momentum", {"ParamOut", "VelocityOut"}}, + {"coalesce_tensor", {"Output", "FusedOutput"}}, {"adam", {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut", "MasterParamOut"}}, diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 47f97944b2d..176db6b48c5 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -170,6 +170,7 @@ PyTypeObject *g_npuplace_pytype = nullptr; PyTypeObject *g_cudapinnedplace_pytype = nullptr; PyTypeObject *g_mluplace_pytype = nullptr; PyTypeObject *g_framework_tensor_pytype = nullptr; +PyTypeObject *g_framework_lodtensorarray_pytype = nullptr; bool IsCompiledWithCUDA() { #if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) @@ -2420,7 +2421,7 @@ All parameter, weight, gradient are variables in Paddle. return res; }); - py::class_(m, "LoDTensorArray", R"DOC( + py::class_ pylodtensorarray(m, "LoDTensorArray", R"DOC( LoDTensorArray is array of LoDTensor, it supports operator[], len() and for-loop iteration. Examples: @@ -2429,7 +2430,10 @@ All parameter, weight, gradient are variables in Paddle. import paddle.fluid as fluid arr = fluid.LoDTensorArray() -)DOC") +)DOC"); + g_framework_lodtensorarray_pytype = + reinterpret_cast(pylodtensorarray.ptr()); + pylodtensorarray .def("__init__", [](LoDTensorArray &instance) { new (&instance) LoDTensorArray(); }) .def("__getitem__", diff --git a/python/paddle/fluid/dataloader/dataloader_iter.py b/python/paddle/fluid/dataloader/dataloader_iter.py index a3e6ea6d1bc..f4ccd033aa5 100644 --- a/python/paddle/fluid/dataloader/dataloader_iter.py +++ b/python/paddle/fluid/dataloader/dataloader_iter.py @@ -31,7 +31,7 @@ import queue import paddle from .. import core, layers -from ..framework import in_dygraph_mode +from ..framework import in_dygraph_mode, _in_eager_mode from ..multiprocess_utils import _set_SIGCHLD_handler, MP_STATUS_CHECK_INTERVAL, CleanupFuncRegistrar from .fetcher import _IterableDatasetFetcher, _MapDatasetFetcher from .batch_sampler import _InfiniteIterableSampler @@ -252,7 +252,11 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase): def __next__(self): try: if in_dygraph_mode(): - data = self._reader.read_next_var_list() + if _in_eager_mode(): + data = core.eager.read_next_eager_tensor_list( + self._reader.read_next_list()[0]) + else: + data = self._reader.read_next_var_list() data = _restore_batch(data, self._structure_infos.pop(0)) else: if self._return_list: @@ -444,7 +448,11 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): # the blocking_queue cachees instead of recreating one while self._blocking_queue.size() >= len(self._places): if in_dygraph_mode(): - self._reader.read_next_var_list() + if _in_eager_mode(): + data = core.eager.read_next_eager_tensor_list( + self._reader.read_next_list()[0]) + else: + self._reader.read_next_var_list() elif self._return_list: self._reader.read_next_list() else: @@ -696,7 +704,11 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): self._blocking_queue.close() if in_dygraph_mode(): - data = self._reader.read_next_var_list() + if _in_eager_mode(): + data = core.eager.read_next_eager_tensor_list( + self._reader.read_next_list()[0]) + else: + data = self._reader.read_next_var_list() data = _restore_batch(data, self._structure_infos.pop(0)) else: if self._return_list: diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py index 3cccaceb8e6..8fc6bd818bc 100644 --- a/python/paddle/fluid/dygraph/varbase_patch_methods.py +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -635,9 +635,13 @@ def monkey_patch_varbase(): def __nonzero__(self): numel = np.prod(self.shape) assert numel == 1, "When Variable is used as the condition of if/while , Variable can only contain one element." - tensor = self.value().get_tensor() - assert tensor._is_initialized(), "tensor not initialized" - return bool(np.all(tensor.__array__() > 0)) + if core._in_eager_mode(): + assert self._is_initialized(), "tensor not initialized" + return bool(np.all(self.numpy() > 0)) + else: + tensor = self.value().get_tensor() + assert tensor._is_initialized(), "tensor not initialized" + return bool(np.all(tensor.__array__() > 0)) def __bool__(self): return self.__nonzero__() diff --git a/python/paddle/fluid/initializer.py b/python/paddle/fluid/initializer.py index 6ef3646a919..ea17d029b6c 100644 --- a/python/paddle/fluid/initializer.py +++ b/python/paddle/fluid/initializer.py @@ -252,9 +252,9 @@ class UniformInitializer(Initializer): if var.dtype == VarDesc.VarType.FP16: var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype, 'out_dtype', var.dtype) - var.copy_(var_tmp, False) + var_tmp._share_underline_tensor_to(var) else: - var.copy_(out_var, False) + out_var._share_underline_tensor_to(var) return None else: op = block.append_op( @@ -334,24 +334,28 @@ class NormalInitializer(Initializer): if self._seed == 0: self._seed = block.program.random_seed - op = block.append_op( - type="gaussian_random", - outputs={"Out": var}, - attrs={ - "shape": var.shape, - "dtype": var.dtype, - "mean": self._mean, - "std": self._std_dev, - "seed": self._seed, - "use_mkldnn": False - }, - stop_gradient=True) - - if not framework.in_dygraph_mode(): + if framework.in_dygraph_mode(): + out_var = _C_ops.gaussian_random( + 'shape', var.shape, 'dtype', var.dtype, 'mean', self._mean, + 'std', self._std_dev, 'seed', self._seed, 'use_mkldnn', False) + out_var._share_underline_tensor_to(var) + return None + else: + op = block.append_op( + type="gaussian_random", + outputs={"Out": var}, + attrs={ + "shape": var.shape, + "dtype": var.dtype, + "mean": self._mean, + "std": self._std_dev, + "seed": self._seed, + "use_mkldnn": False + }, + stop_gradient=True) + var.op = op return op - else: - return None class TruncatedNormalInitializer(Initializer): @@ -420,9 +424,9 @@ class TruncatedNormalInitializer(Initializer): if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]: var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype, 'out_dtype', var.dtype) - var.copy_(var_tmp, False) + var_tmp._share_underline_tensor_to(var) else: - var.copy_(out_var, False) + out_var._share_underline_tensor_to(var) return None else: op = block.append_op( @@ -560,9 +564,9 @@ class XavierInitializer(Initializer): var.dtype == VarDesc.VarType.BF16 and not self._uniform): var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype, 'out_dtype', var.dtype) - var.copy_(var_tmp, False) + var_tmp._share_underline_tensor_to(var) else: - var.copy_(out_var, False) + out_var._share_underline_tensor_to(var) return None else: if self._uniform: @@ -713,9 +717,9 @@ class MSRAInitializer(Initializer): var.dtype == VarDesc.VarType.BF16 and not self._uniform): var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype, 'out_dtype', var.dtype) - var.copy_(var_tmp, False) + var_tmp._share_underline_tensor_to(var) else: - var.copy_(out_var, False) + out_var._share_underline_tensor_to(var) return None else: if self._uniform: @@ -881,9 +885,9 @@ class BilinearInitializer(Initializer): ]: var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype, 'out_dtype', var.dtype) - var.copy_(var_tmp, False) + var_tmp._share_underline_tensor_to(var) else: - var.copy_(out_var, False) + out_var._share_underline_tensor_to(var) return None else: op = block.append_op( @@ -987,9 +991,9 @@ class NumpyArrayInitializer(Initializer): if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]: var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype, 'out_dtype', var.dtype) - var.copy_(var_tmp, False) + var_tmp._share_underline_tensor_to(var) else: - var.copy_(out_var, False) + out_var._share_underline_tensor_to(var) return None else: op = block.append_op( diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index ee01036ca93..42b1ea50ad0 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -439,16 +439,23 @@ class Optimizer(object): self._learning_rate = value current_lr = self._global_learning_rate() if current_lr is not None: - global_block = framework.default_main_program().global_block() - global_block.append_op( - type='fill_constant', - outputs={'Out': [current_lr]}, - attrs={ - 'dtype': current_lr.dtype, - 'shape': list(current_lr.shape), - 'value': float(value) - }, - stop_gradient=True) + if framework.in_dygraph_mode(): + _C_ops.fill_constant(current_lr, 'value', + float(value), 'dtype', + current_lr.dtype, 'shape', + list(current_lr.shape)) + else: + global_block = framework.default_main_program( + ).global_block() + global_block.append_op( + type='fill_constant', + outputs={'Out': [current_lr]}, + attrs={ + 'dtype': current_lr.dtype, + 'shape': list(current_lr.shape), + 'value': float(value) + }, + stop_gradient=True) else: assert len(value.shape) == 1 and value.shape[ 0] == 1, "optimizer's learning rate must be 1-D Tensor with shape[1]" @@ -606,7 +613,9 @@ class Optimizer(object): name=var_name, persistable=True, dtype=dtype or param.dtype, - type=param.type if type is None else type, + type=core.VarDesc.VarType.LOD_TENSOR + if framework._in_eager_mode() else (param.type + if type is None else type), shape=shape, belong_to_optimizer=True) if device is None: @@ -2146,15 +2155,34 @@ class LarsMomentumOptimizer(Optimizer): inputs["MasterParam"] = master_weight outputs["MasterParamOut"] = master_weight - # create the momentum optimize op - momentum_op = block.append_op( - type=self.type if _lars_weight_decay != 0.0 else 'momentum', - inputs=inputs, - outputs=outputs, - attrs=attrs, - stop_gradient=True) + if framework.in_dygraph_mode(): + if _lars_weight_decay != 0.0: + tmp, tmp2 = _C_ops.lars_momentum( + [param_and_grad[0]], [param_and_grad[1]], [velocity_acc], + [lr], [param_and_grad[0]], [velocity_acc], "mu", + self._momentum, "lars_coeff", self._lars_coeff, + "lars_weight_decay", [_lars_weight_decay], + "multi_precision", find_master, "epsilon", self._epsilon, + "rescale_grad", self._rescale_grad) + else: + _C_ops.momentum(param_and_grad[0], param_and_grad[1], + velocity_acc, lr, master_weight, + param_and_grad[0], velocity_acc, master_weight, + "mu", self._momentum, "lars_coeff", + self._lars_coeff, "lars_weight_decay", + [_lars_weight_decay], "multi_precision", + find_master, "epsilon", self._epsilon, + "rescale_grad", self._rescale_grad) + else: + # create the momentum optimize op + momentum_op = block.append_op( + type=self.type if _lars_weight_decay != 0.0 else 'momentum', + inputs=inputs, + outputs=outputs, + attrs=attrs, + stop_gradient=True) - return momentum_op + return momentum_op class AdagradOptimizer(Optimizer): @@ -2256,21 +2284,29 @@ class AdagradOptimizer(Optimizer): moment_acc = self._get_accumulator(self._moment_acc_str, param_and_grad[0]) - # Create the adagrad optimizer op - adagrad_op = block.append_op( - type=self.type, - inputs={ - "Param": param_and_grad[0], - "Grad": param_and_grad[1], - "Moment": moment_acc, - "LearningRate": self._create_param_lr(param_and_grad) - }, - outputs={"ParamOut": param_and_grad[0], - "MomentOut": moment_acc}, - attrs={"epsilon": self._epsilon}, - stop_gradient=True) + if framework.in_dygraph_mode(): + _C_ops.adagrad(param_and_grad[0], param_and_grad[1], moment_acc, + self._create_param_lr(param_and_grad), + param_and_grad[0], moment_acc, "epsilon", + self._epsilon) + else: + # Create the adagrad optimizer op + adagrad_op = block.append_op( + type=self.type, + inputs={ + "Param": param_and_grad[0], + "Grad": param_and_grad[1], + "Moment": moment_acc, + "LearningRate": self._create_param_lr(param_and_grad) + }, + outputs={ + "ParamOut": param_and_grad[0], + "MomentOut": moment_acc + }, + attrs={"epsilon": self._epsilon}, + stop_gradient=True) - return adagrad_op + return adagrad_op class AdamOptimizer(Optimizer): @@ -2774,30 +2810,37 @@ class AdamaxOptimizer(Optimizer): param_and_grad[0]) beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str, param_and_grad[0]) - # create the adamax optimize op - adamax_op = block.append_op( - type=self.type, - inputs={ - "Param": param_and_grad[0], - "Grad": param_and_grad[1], - "LearningRate": self._create_param_lr(param_and_grad), - "Moment": moment, - "InfNorm": inf_norm, - "Beta1Pow": beta1_pow_acc - }, - outputs={ - "ParamOut": param_and_grad[0], - "MomentOut": moment, - "InfNormOut": inf_norm - }, - attrs={ - "beta1": self._beta1, - "beta2": self._beta2, - "epsilon": self._epsilon - }, - stop_gradient=True) + if framework.in_dygraph_mode(): + _C_ops.adamax(param_and_grad[0], param_and_grad[1], + self._create_param_lr(param_and_grad), moment, + inf_norm, beta1_pow_acc, param_and_grad[0], moment, + inf_norm, "beta1", self._beta1, "beta2", self._beta2, + "epsilon", self._epsilon) + else: + # create the adamax optimize op + adamax_op = block.append_op( + type=self.type, + inputs={ + "Param": param_and_grad[0], + "Grad": param_and_grad[1], + "LearningRate": self._create_param_lr(param_and_grad), + "Moment": moment, + "InfNorm": inf_norm, + "Beta1Pow": beta1_pow_acc + }, + outputs={ + "ParamOut": param_and_grad[0], + "MomentOut": moment, + "InfNormOut": inf_norm + }, + attrs={ + "beta1": self._beta1, + "beta2": self._beta2, + "epsilon": self._epsilon + }, + stop_gradient=True) - return adamax_op + return adamax_op def _finish_update(self, block, parameters_and_grads): """Update Beta1 Power accumulator @@ -2810,12 +2853,16 @@ class AdamaxOptimizer(Optimizer): [param, grad]), name_scope('adamx'): beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str, param) - block.append_op( - type="scale", - inputs={"X": beta1_pow_acc}, - outputs={"Out": beta1_pow_acc}, - attrs={"scale": self._beta1}, - stop_gradient=True) + if framework.in_dygraph_mode(): + tmp = _C_ops.scale(beta1_pow_acc, "scale", self._beta1) + beta1_pow_acc.copy_(tmp, False) + else: + block.append_op( + type="scale", + inputs={"X": beta1_pow_acc}, + outputs={"Out": beta1_pow_acc}, + attrs={"scale": self._beta1}, + stop_gradient=True) class DpsgdOptimizer(Optimizer): @@ -2894,23 +2941,30 @@ class DpsgdOptimizer(Optimizer): if self._seed == None: self._seed = 0 - dpsgd_op = block.append_op( - type=self.type, - inputs={ - "Param": param_and_grad[0], - "Grad": param_and_grad[1], - "LearningRate": self._create_param_lr(param_and_grad) - }, - outputs={"ParamOut": param_and_grad[0]}, - attrs={ - "clip": self._clip, - "batch_size": self._batch_size, - "sigma": self._sigma, - "seed": self._seed - }, - stop_gradient=True) + if framework.in_dygraph_mode(): + _C_ops.dpsgd(param_and_grad[0], param_and_grad[1], + self._create_param_lr(param_and_grad), + param_and_grad[0], "clip", self._clip, "batch_size", + self._batch_size, "sigma", self._sigma, "seed", + self._seed) + else: + dpsgd_op = block.append_op( + type=self.type, + inputs={ + "Param": param_and_grad[0], + "Grad": param_and_grad[1], + "LearningRate": self._create_param_lr(param_and_grad) + }, + outputs={"ParamOut": param_and_grad[0]}, + attrs={ + "clip": self._clip, + "batch_size": self._batch_size, + "sigma": self._sigma, + "seed": self._seed + }, + stop_gradient=True) - return dpsgd_op + return dpsgd_op class DecayedAdagradOptimizer(Optimizer): @@ -3005,22 +3059,30 @@ class DecayedAdagradOptimizer(Optimizer): moment_acc = self._get_accumulator(self._moment_acc_str, param_and_grad[0]) - # Create the decayed adagrad optimizer op - decayed_adagrad_op = block.append_op( - type=self.type, - inputs={ - "Param": param_and_grad[0], - "Grad": param_and_grad[1], - "Moment": moment_acc, - "LearningRate": self._create_param_lr(param_and_grad) - }, - outputs={"ParamOut": param_and_grad[0], - "MomentOut": moment_acc}, - attrs={"epsilon": self._epsilon, - "decay": self._decay}, - stop_gradient=True) + if framework.in_dygraph_mode(): + _C_ops.decayed_adagrad( + param_and_grad[0], param_and_grad[1], moment_acc, + self._create_param_lr(param_and_grad), param_and_grad[0], + moment_acc, "epsilon", self._epsilon, "decay", self._decay) + else: + # Create the decayed adagrad optimizer op + decayed_adagrad_op = block.append_op( + type=self.type, + inputs={ + "Param": param_and_grad[0], + "Grad": param_and_grad[1], + "Moment": moment_acc, + "LearningRate": self._create_param_lr(param_and_grad) + }, + outputs={ + "ParamOut": param_and_grad[0], + "MomentOut": moment_acc + }, + attrs={"epsilon": self._epsilon, + "decay": self._decay}, + stop_gradient=True) - return decayed_adagrad_op + return decayed_adagrad_op class AdadeltaOptimizer(Optimizer): @@ -3121,25 +3183,32 @@ class AdadeltaOptimizer(Optimizer): avg_squared_update_acc = self._get_accumulator( self._avg_squared_update_acc_str, param_and_grad[0]) - # Create the adadelta optimizer op - adadelta_op = block.append_op( - type=self.type, - inputs={ - "Param": param_and_grad[0], - "Grad": param_and_grad[1], - "AvgSquaredGrad": avg_squared_grad_acc, - "AvgSquaredUpdate": avg_squared_update_acc - }, - outputs={ - "ParamOut": param_and_grad[0], - "AvgSquaredGradOut": avg_squared_grad_acc, - "AvgSquaredUpdateOut": avg_squared_update_acc - }, - attrs={"epsilon": self._epsilon, - "rho": self._rho}, - stop_gradient=True) + if framework.in_dygraph_mode(): + _C_ops.adadelta(param_and_grad[0], param_and_grad[1], + avg_squared_grad_acc, avg_squared_update_acc, + param_and_grad[0], avg_squared_grad_acc, + avg_squared_update_acc, "epsilon", self._epsilon, + "rho", self._rho) + else: + # Create the adadelta optimizer op + adadelta_op = block.append_op( + type=self.type, + inputs={ + "Param": param_and_grad[0], + "Grad": param_and_grad[1], + "AvgSquaredGrad": avg_squared_grad_acc, + "AvgSquaredUpdate": avg_squared_update_acc + }, + outputs={ + "ParamOut": param_and_grad[0], + "AvgSquaredGradOut": avg_squared_grad_acc, + "AvgSquaredUpdateOut": avg_squared_update_acc + }, + attrs={"epsilon": self._epsilon, + "rho": self._rho}, + stop_gradient=True) - return adadelta_op + return adadelta_op class RMSPropOptimizer(Optimizer): @@ -3303,31 +3372,39 @@ class RMSPropOptimizer(Optimizer): param_and_grad[0]) mean_grad_acc = self._get_accumulator(self._mean_grad_acc_str, param_and_grad[0]) - rmsprop_op = block.append_op( - type=self.type, - inputs={ - "Param": param_and_grad[0], - "Grad": param_and_grad[1], - "Moment": momentum_acc, - "MeanSquare": mean_square_acc, - "MeanGrad": mean_grad_acc, - "LearningRate": self._create_param_lr(param_and_grad), - }, - outputs={ - "ParamOut": param_and_grad[0], - "MomentOut": momentum_acc, - "MeanSquareOut": mean_square_acc, - "MeanGradOut": mean_grad_acc - }, - attrs={ - "epsilon": self._epsilon, - "decay": self._rho, - "momentum": self._momentum, - "centered": self._centered - }, - stop_gradient=True) + if framework.in_dygraph_mode(): + _C_ops.rmsprop( + param_and_grad[0], mean_square_acc, + self._create_param_lr(param_and_grad), param_and_grad[1], + momentum_acc, param_and_grad[0], momentum_acc, mean_square_acc, + mean_grad_acc, "epsilon", self._epsilon, "decay", self._rho, + "momentum", self._momentum, "centered", self._centered) + else: + rmsprop_op = block.append_op( + type=self.type, + inputs={ + "Param": param_and_grad[0], + "Grad": param_and_grad[1], + "Moment": momentum_acc, + "MeanSquare": mean_square_acc, + "MeanGrad": mean_grad_acc, + "LearningRate": self._create_param_lr(param_and_grad), + }, + outputs={ + "ParamOut": param_and_grad[0], + "MomentOut": momentum_acc, + "MeanSquareOut": mean_square_acc, + "MeanGradOut": mean_grad_acc + }, + attrs={ + "epsilon": self._epsilon, + "decay": self._rho, + "momentum": self._momentum, + "centered": self._centered + }, + stop_gradient=True) - return rmsprop_op + return rmsprop_op class FtrlOptimizer(Optimizer): @@ -3467,26 +3544,36 @@ class FtrlOptimizer(Optimizer): param_and_grad[0]) linear_acc = self._get_accumulator(self._linear_acc_str, param_and_grad[0]) - ftrl_op = block.append_op( - type=self.type, - inputs={ - "Param": param_and_grad[0], - "Grad": param_and_grad[1], - "SquaredAccumulator": squared_acc, - "LinearAccumulator": linear_acc, - "LearningRate": self._create_param_lr(param_and_grad), - }, - outputs={ - "ParamOut": param_and_grad[0], - "SquaredAccumOut": squared_acc, - "LinearAccumOut": linear_acc - }, - attrs={"l1": self._l1, - "l2": self._l2, - "lr_power": self._lr_power}, - stop_gradient=True) + if framework.in_dygraph_mode(): + _C_ops.ftrl(param_and_grad[0], squared_acc, linear_acc, + param_and_grad[1], + self._create_param_lr(param_and_grad), + param_and_grad[0], squared_acc, linear_acc, "l1", + self._l1, "l2", self._l2, "lr_power", self._lr_power) + + else: + ftrl_op = block.append_op( + type=self.type, + inputs={ + "Param": param_and_grad[0], + "Grad": param_and_grad[1], + "SquaredAccumulator": squared_acc, + "LinearAccumulator": linear_acc, + "LearningRate": self._create_param_lr(param_and_grad), + }, + outputs={ + "ParamOut": param_and_grad[0], + "SquaredAccumOut": squared_acc, + "LinearAccumOut": linear_acc + }, + attrs={ + "l1": self._l1, + "l2": self._l2, + "lr_power": self._lr_power + }, + stop_gradient=True) - return ftrl_op + return ftrl_op class LambOptimizer(AdamOptimizer): diff --git a/python/paddle/fluid/reader.py b/python/paddle/fluid/reader.py index 83ccd1051bb..dde39b2dfdb 100644 --- a/python/paddle/fluid/reader.py +++ b/python/paddle/fluid/reader.py @@ -18,7 +18,7 @@ import six import numpy as np import threading import paddle -from .framework import Program, Variable, program_guard, default_main_program, default_startup_program, in_dygraph_mode, cpu_places, _current_expected_place +from .framework import Program, Variable, program_guard, default_main_program, default_startup_program, in_dygraph_mode, cpu_places, _current_expected_place, _in_eager_mode from .executor import global_scope from .data_feeder import DataFeeder, BatchedTensorProvider from .multiprocess_utils import multiprocess_queue_set, CleanupFuncRegistrar, _cleanup_mmap, _cleanup, _set_SIGCHLD_handler @@ -971,7 +971,11 @@ class DygraphGeneratorLoader(DataLoaderBase): def __next__(self): try: - return self._reader.read_next_var_list() + if _in_eager_mode(): + return core.eager.read_next_eager_tensor_list( + self._reader.read_next_list()[0]) + else: + return self._reader.read_next_var_list() except StopIteration: self._reset() six.reraise(*sys.exc_info()) diff --git a/python/paddle/fluid/tests/unittests/test_egr_python_api.py b/python/paddle/fluid/tests/unittests/test_egr_python_api.py index e84c11e8601..ba0421d6eb3 100644 --- a/python/paddle/fluid/tests/unittests/test_egr_python_api.py +++ b/python/paddle/fluid/tests/unittests/test_egr_python_api.py @@ -673,6 +673,34 @@ class EagerTensorPropertiesAndMethodsTestCase(unittest.TestCase): self.assertTrue(np.array_equal(tensor3.numpy(), arr2)) self.assertTrue(tensor3._is_shared_buffer_with(tensor)) + def test_share_underline_tensor_to(self): + with _test_eager_guard(): + arr = np.ones([4, 16, 16, 32]).astype('float32') + arr1 = np.zeros([4, 16]).astype('float32') + arr2 = np.ones([4, 16, 16, 32]).astype('float32') + np.ones( + [4, 16, 16, 32]).astype('float32') + tensor = None + tensor2 = None + tensor = paddle.to_tensor(arr, core.VarDesc.VarType.FP32, + core.CPUPlace()) + tensor3 = core.eager.EagerTensor() + if core.is_compiled_with_cuda(): + tensor2 = paddle.to_tensor(arr2, core.VarDesc.VarType.FP32, + core.CUDAPlace(0)) + else: + tensor2 = paddle.to_tensor(arr2, core.VarDesc.VarType.FP32, + core.CPUPlace()) + self.assertTrue(np.array_equal(tensor.numpy(), arr)) + self.assertTrue(np.array_equal(tensor2.numpy(), arr2)) + tensor2._share_underline_tensor_to(tensor) + self.assertTrue(np.array_equal(tensor.numpy(), arr2)) + self.assertTrue(np.array_equal(tensor2.numpy(), arr2)) + self.assertTrue(tensor._is_shared_underline_tensor_with(tensor2)) + self.assertTrue(tensor2._is_shared_underline_tensor_with(tensor)) + tensor._share_underline_tensor_to(tensor3) + self.assertTrue(np.array_equal(tensor3.numpy(), arr2)) + self.assertTrue(tensor3._is_shared_underline_tensor_with(tensor)) + def test_properties(self): print("Test_properties") with _test_eager_guard(): diff --git a/python/paddle/fluid/tests/unittests/test_imperative_basic.py b/python/paddle/fluid/tests/unittests/test_imperative_basic.py index 07a8ae0ba0f..92d3dd7b605 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_basic.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_basic.py @@ -953,7 +953,8 @@ class TestMetaclass(unittest.TestCase): self.assertNotEqual(type(MyLayer).__name__, 'pybind11_type') if core._in_eager_mode(): self.assertEqual( - type(paddle.fluid.core.eager.EagerTensor).__name__, 'type') + type(paddle.fluid.core.eager.EagerTensor).__name__, + 'pybind11_type') else: self.assertEqual( type(paddle.fluid.core.VarBase).__name__, 'pybind11_type') diff --git a/python/paddle/fluid/tests/unittests/test_imperative_container_layerlist.py b/python/paddle/fluid/tests/unittests/test_imperative_container_layerlist.py index 2e722b69c3e..cf7fc9ba96b 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_container_layerlist.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_container_layerlist.py @@ -18,6 +18,7 @@ import unittest import paddle.fluid as fluid import numpy as np import paddle +from paddle.fluid.framework import _test_eager_guard class MyLayer(fluid.Layer): @@ -96,10 +97,15 @@ class TestImperativeContainer(unittest.TestCase): self.assertListEqual(res11.shape, [5, 4]) res11.backward() - def test_layer_list(self): + def func_test_layer_list(self): self.layer_list(True) self.layer_list(False) + def test_layer_list(self): + with _test_eager_guard(): + self.func_test_layer_list() + self.func_test_layer_list() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_data_loader_base.py b/python/paddle/fluid/tests/unittests/test_imperative_data_loader_base.py index 4c9061dd834..6f0876dcfc3 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_data_loader_base.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_data_loader_base.py @@ -18,6 +18,7 @@ import numpy as np import paddle.fluid as fluid from paddle.fluid import core from paddle.fluid.reader import use_pinned_memory +from paddle.fluid.framework import _test_eager_guard def get_random_images_and_labels(image_shape, label_shape): @@ -50,7 +51,7 @@ class TestDygraphDataLoader(unittest.TestCase): self.assertEqual(label.shape, [self.batch_size, 1]) self.assertEqual(relu.shape, [self.batch_size, 784]) - def test_single_process_loader(self): + def func_test_single_process_loader(self): with fluid.dygraph.guard(): loader = fluid.io.DataLoader.from_generator( capacity=self.capacity, iterable=False, use_multiprocess=False) @@ -60,7 +61,12 @@ class TestDygraphDataLoader(unittest.TestCase): places=fluid.CPUPlace()) self.iter_loader_data(loader) - def test_multi_process_loader(self): + def test_single_process_loader(self): + with _test_eager_guard(): + self.func_test_single_process_loader() + self.func_test_single_process_loader() + + def func_test_multi_process_loader(self): with fluid.dygraph.guard(): loader = fluid.io.DataLoader.from_generator( capacity=self.capacity, use_multiprocess=True) @@ -70,7 +76,12 @@ class TestDygraphDataLoader(unittest.TestCase): places=fluid.CPUPlace()) self.iter_loader_data(loader) - def test_generator_no_places(self): + def test_multi_process_loader(self): + with _test_eager_guard(): + self.func_test_multi_process_loader() + self.func_test_multi_process_loader() + + def func_test_generator_no_places(self): with fluid.dygraph.guard(): loader = fluid.io.DataLoader.from_generator(capacity=self.capacity) loader.set_sample_generator( @@ -78,7 +89,12 @@ class TestDygraphDataLoader(unittest.TestCase): batch_size=self.batch_size) self.iter_loader_data(loader) - def test_set_pin_memory(self): + def test_generator_no_places(self): + with _test_eager_guard(): + self.func_test_generator_no_places() + self.func_test_generator_no_places() + + def func_test_set_pin_memory(self): with fluid.dygraph.guard(): use_pinned_memory(False) loader = fluid.io.DataLoader.from_generator( @@ -90,6 +106,11 @@ class TestDygraphDataLoader(unittest.TestCase): self.iter_loader_data(loader) use_pinned_memory(True) + def test_set_pin_memory(self): + with _test_eager_guard(): + self.func_test_set_pin_memory() + self.func_test_set_pin_memory() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_data_loader_exception.py b/python/paddle/fluid/tests/unittests/test_imperative_data_loader_exception.py index d317d943dca..4ab58919fdb 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_data_loader_exception.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_data_loader_exception.py @@ -19,6 +19,7 @@ import numpy as np import paddle.fluid as fluid from paddle.fluid import core import paddle.compat as cpt +from paddle.fluid.framework import _test_eager_guard def get_random_images_and_labels(image_shape, label_shape): @@ -34,13 +35,18 @@ class TestDygraphDataLoaderWithException(unittest.TestCase): self.epoch_num = 1 self.capacity = 5 - def test_not_capacity(self): + def func_test_not_capacity(self): with fluid.dygraph.guard(): with self.assertRaisesRegexp(ValueError, "Please give value to capacity."): fluid.io.DataLoader.from_generator() - def test_single_process_with_thread_expection(self): + def test_not_capacity(self): + with _test_eager_guard(): + self.func_test_not_capacity() + self.func_test_not_capacity() + + def func_test_single_process_with_thread_expection(self): def error_sample_genarator(batch_num): def __reader__(): for _ in range(batch_num): @@ -63,7 +69,12 @@ class TestDygraphDataLoaderWithException(unittest.TestCase): exception = ex self.assertIsNotNone(exception) - def test_multi_process_with_process_expection(self): + def test_single_process_with_thread_expection(self): + with _test_eager_guard(): + self.func_test_single_process_with_thread_expection() + self.func_test_single_process_with_thread_expection() + + def func_test_multi_process_with_process_expection(self): def error_sample_genarator(batch_num): def __reader__(): for _ in range(batch_num): @@ -84,7 +95,12 @@ class TestDygraphDataLoaderWithException(unittest.TestCase): exception = ex self.assertIsNotNone(exception) - def test_multi_process_with_get_timeout(self): + def test_multi_process_with_process_expection(self): + with _test_eager_guard(): + self.func_test_multi_process_with_process_expection() + self.func_test_multi_process_with_process_expection() + + def func_test_multi_process_with_get_timeout(self): def slow_batch_generator_creator(batch_size, batch_num): def __reader__(): for _ in range(batch_num): @@ -112,6 +128,11 @@ class TestDygraphDataLoaderWithException(unittest.TestCase): exception = ex self.assertIsNotNone(exception) + def test_multi_process_with_get_timeout(self): + with _test_eager_guard(): + self.func_test_multi_process_with_get_timeout() + self.func_test_multi_process_with_get_timeout() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_data_loader_exit_func.py b/python/paddle/fluid/tests/unittests/test_imperative_data_loader_exit_func.py index ba98a343a45..e83d6210f84 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_data_loader_exit_func.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_data_loader_exit_func.py @@ -19,6 +19,7 @@ import multiprocessing import time import paddle.compat as cpt +from paddle.fluid.framework import _test_eager_guard if sys.version_info[0] == 2: import Queue as queue @@ -35,7 +36,7 @@ class TestDygraphDataLoaderCleanUpFunc(unittest.TestCase): def setUp(self): self.capacity = 10 - def test_clear_queue_set(self): + def func_test_clear_queue_set(self): test_queue = queue.Queue(self.capacity) global multiprocess_queue_set multiprocess_queue_set.add(test_queue) @@ -43,13 +44,18 @@ class TestDygraphDataLoaderCleanUpFunc(unittest.TestCase): test_queue.put(i) _cleanup() + def test_clear_queue_set(self): + with _test_eager_guard(): + self.func_test_clear_queue_set() + self.func_test_clear_queue_set() + class TestRegisterExitFunc(unittest.TestCase): # This function does not need to be implemented in this case def none_func(self): pass - def test_not_callable_func(self): + def func_test_not_callable_func(self): exception = None try: CleanupFuncRegistrar.register(5) @@ -58,11 +64,21 @@ class TestRegisterExitFunc(unittest.TestCase): exception = ex self.assertIsNotNone(exception) - def test_old_handler_for_sigint(self): + def test_not_callable_func(self): + with _test_eager_guard(): + self.func_test_not_callable_func() + self.func_test_not_callable_func() + + def func_test_old_handler_for_sigint(self): CleanupFuncRegistrar.register( function=self.none_func, signals=[signal.SIGINT]) - def test_signal_wrapper_by_sigchld(self): + def test_old_handler_for_sigint(self): + with _test_eager_guard(): + self.func_test_old_handler_for_sigint() + self.func_test_old_handler_for_sigint() + + def func_test_signal_wrapper_by_sigchld(self): # This function does not need to be implemented in this case def __test_process__(): pass @@ -79,6 +95,11 @@ class TestRegisterExitFunc(unittest.TestCase): exception = ex self.assertIsNotNone(exception) + def test_signal_wrapper_by_sigchld(self): + with _test_eager_guard(): + self.func_test_signal_wrapper_by_sigchld() + self.func_test_signal_wrapper_by_sigchld() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_data_loader_fds_clear.py b/python/paddle/fluid/tests/unittests/test_imperative_data_loader_fds_clear.py index e6b60873273..0ef2e19c44b 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_data_loader_fds_clear.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_data_loader_fds_clear.py @@ -18,6 +18,7 @@ import numpy as np import paddle.fluid as fluid from paddle.fluid import core from paddle.io import Dataset, DataLoader +from paddle.fluid.framework import _test_eager_guard def get_random_images_and_labels(image_shape, label_shape): @@ -75,19 +76,29 @@ class TestDygraphDataLoaderMmapFdsClear(unittest.TestCase): if step_id == 30: break - def test_data_loader_break(self): + def func_test_data_loader_break(self): with fluid.dygraph.guard(): loader = self.prepare_data_loader() for _ in range(self.epoch_num): self.run_one_epoch_with_break(loader) break - def test_data_loader_continue_break(self): + def test_data_loader_break(self): + with _test_eager_guard(): + self.func_test_data_loader_break() + self.func_test_data_loader_break() + + def func_test_data_loader_continue_break(self): with fluid.dygraph.guard(): loader = self.prepare_data_loader() for _ in range(self.epoch_num): self.run_one_epoch_with_break(loader) + def test_data_loader_continue_break(self): + with _test_eager_guard(): + self.func_test_data_loader_continue_break() + self.func_test_data_loader_continue_break() + class TestMultiProcessDataLoaderMmapFdsClear(TestDygraphDataLoaderMmapFdsClear): def prepare_data_loader(self): diff --git a/python/paddle/fluid/tests/unittests/test_imperative_data_loader_process.py b/python/paddle/fluid/tests/unittests/test_imperative_data_loader_process.py index 2a3a1e8b0a3..0eb5aa55eb3 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_data_loader_process.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_data_loader_process.py @@ -19,6 +19,7 @@ import numpy as np import paddle.fluid as fluid from paddle.fluid import core from paddle.fluid.reader import _reader_process_loop +from paddle.fluid.framework import _test_eager_guard if sys.version_info[0] == 2: import Queue as queue @@ -51,7 +52,7 @@ class TestDygraphDataLoaderProcess(unittest.TestCase): self.epoch_num = 2 self.capacity = 2 - def test_reader_process_loop(self): + def func_test_reader_process_loop(self): # This unittest's memory mapped files needs to be cleaned manually def __clear_process__(util_queue): while True: @@ -79,7 +80,12 @@ class TestDygraphDataLoaderProcess(unittest.TestCase): target=__clear_process__, args=(util_queue, )) clear_process.start() - def test_reader_process_loop_simple_none(self): + def test_reader_process_loop(self): + with _test_eager_guard(): + self.func_test_reader_process_loop() + self.func_test_reader_process_loop() + + def func_test_reader_process_loop_simple_none(self): def none_sample_genarator(batch_num): def __reader__(): for _ in range(batch_num): @@ -100,6 +106,11 @@ class TestDygraphDataLoaderProcess(unittest.TestCase): exception = ex self.assertIsNotNone(exception) + def test_reader_process_loop_simple_none(self): + with _test_eager_guard(): + self.func_test_reader_process_loop_simple_none() + self.func_test_reader_process_loop_simple_none() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_layers.py b/python/paddle/fluid/tests/unittests/test_imperative_layers.py index f69ed7a817f..82b5541b83e 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_layers.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_layers.py @@ -15,10 +15,11 @@ import unittest import paddle import paddle.nn as nn +from paddle.fluid.framework import _test_eager_guard, _in_eager_mode, in_dygraph_mode class TestLayerPrint(unittest.TestCase): - def test_layer_str(self): + def func_test_layer_str(self): module = nn.ELU(0.2) self.assertEqual(str(module), 'ELU(alpha=0.2)') @@ -352,6 +353,11 @@ class TestLayerPrint(unittest.TestCase): '(6): GELU(approximate=True)\n)' ) + def test_layer_str(self): + with _test_eager_guard(): + self.func_test_layer_str() + self.func_test_layer_str() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_named_members.py b/python/paddle/fluid/tests/unittests/test_imperative_named_members.py index dfcd6392b46..6e0866141af 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_named_members.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_named_members.py @@ -16,6 +16,7 @@ import unittest import numpy as np import paddle.fluid as fluid import paddle +from paddle.fluid.framework import _test_eager_guard, _in_eager_mode, in_dygraph_mode class MyLayer(fluid.Layer): @@ -31,7 +32,7 @@ class MyLayer(fluid.Layer): class TestImperativeNamedSubLayers(unittest.TestCase): - def test_named_sublayers(self): + def func_test_named_sublayers(self): with fluid.dygraph.guard(): fc1 = fluid.Linear(10, 3) fc2 = fluid.Linear(3, 10, bias_attr=False) @@ -56,9 +57,14 @@ class TestImperativeNamedSubLayers(unittest.TestCase): [l for _, l in list(model.named_sublayers(include_self=True))], [model] + expected_sublayers) + def test_named_sublayers(self): + with _test_eager_guard(): + self.func_test_named_sublayers() + self.func_test_named_sublayers() + class TestImperativeNamedParameters(unittest.TestCase): - def test_named_parameters(self): + def func_test_named_parameters(self): with fluid.dygraph.guard(): fc1 = fluid.Linear(10, 3) fc2 = fluid.Linear(3, 10, bias_attr=False) @@ -75,7 +81,12 @@ class TestImperativeNamedParameters(unittest.TestCase): self.assertListEqual(expected_named_parameters, named_parameters) - def test_dir_layer(self): + def test_named_parameters(self): + with _test_eager_guard(): + self.func_test_named_parameters() + self.func_test_named_parameters() + + def func_test_dir_layer(self): with fluid.dygraph.guard(): class Mymodel(fluid.dygraph.Layer): @@ -110,6 +121,11 @@ class TestImperativeNamedParameters(unittest.TestCase): self.assertTrue("weight" in expected_members, "model should contain parameter: weight") + def test_dir_layer(self): + with _test_eager_guard(): + self.func_test_dir_layer() + self.func_test_dir_layer() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py b/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py index 0b1f5c91119..3e4d1046d1f 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py @@ -28,6 +28,7 @@ from paddle.fluid.optimizer import ModelAverage, DGCMomentumOptimizer, Exponenti from paddle.fluid.dygraph import Linear from paddle.fluid.dygraph.base import to_variable from test_imperative_base import new_program_scope +from paddle.fluid.framework import _test_eager_guard, _in_eager_mode # Note(wangzhongpu) # In dygraph, don't support ModelAverage, DGCMomentumOptimizer, ExponentialMovingAverage, PipelineOptimizer, LookaheadOptimizer, RecomputeOptimizer. @@ -220,9 +221,14 @@ class TestImperativeOptimizerPiecewiseDecay(TestImperativeOptimizerBase): boundaries=bd, values=[0.1 * (0.1**i) for i in range(len(bd) + 1)])) return optimizer - def test_sgd(self): + def func_test_sgd(self): self._check_mlp() + def test_sgd(self): + with _test_eager_guard(): + self.func_test_sgd() + self.func_test_sgd() + class TestImperativeOptimizerNaturalExpDecay(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): @@ -243,9 +249,14 @@ class TestImperativeOptimizerNaturalExpDecay(TestImperativeOptimizerBase): staircase=True)) return optimizer - def test_sgd(self): + def func_test_sgd(self): self._check_mlp() + def test_sgd(self): + with _test_eager_guard(): + self.func_test_sgd() + self.func_test_sgd() + class TestImperativeOptimizerExponentialDecay(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): @@ -266,9 +277,14 @@ class TestImperativeOptimizerExponentialDecay(TestImperativeOptimizerBase): staircase=True)) return optimizer - def test_sgd(self): + def func_test_sgd(self): self._check_mlp() + def test_sgd(self): + with _test_eager_guard(): + self.func_test_sgd() + self.func_test_sgd() + class TestImperativeOptimizerInverseTimeDecay(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): @@ -289,9 +305,14 @@ class TestImperativeOptimizerInverseTimeDecay(TestImperativeOptimizerBase): staircase=True)) return optimizer - def test_adam(self): + def func_test_adam(self): self._check_mlp() + def test_adam(self): + with _test_eager_guard(): + self.func_test_adam() + self.func_test_adam() + class TestImperativeOptimizerPolynomialDecay(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): @@ -306,14 +327,24 @@ class TestImperativeOptimizerPolynomialDecay(TestImperativeOptimizerBase): learning_rate=0.1, decay_steps=5, cycle=self.cycle)) return optimizer - def test_sgd_cycle(self): + def func_test_sgd_cycle(self): self.cycle = True self._check_mlp() - def test_sgd(self): + def test_sgd_cycle(self): + with _test_eager_guard(): + self.func_test_sgd_cycle() + self.func_test_sgd_cycle() + + def func_test_sgd(self): self.cycle = False self._check_mlp() + def test_sgd(self): + with _test_eager_guard(): + self.func_test_sgd() + self.func_test_sgd() + class TestImperativeOptimizerCosineDecay(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): @@ -328,9 +359,14 @@ class TestImperativeOptimizerCosineDecay(TestImperativeOptimizerBase): learning_rate=0.1, step_each_epoch=10000, epochs=120)) return optimizer - def test_sgd(self): + def func_test_sgd(self): self._check_mlp() + def test_sgd(self): + with _test_eager_guard(): + self.func_test_sgd() + self.func_test_sgd() + class TestImperativeOptimizerNoamDecay(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): @@ -345,12 +381,17 @@ class TestImperativeOptimizerNoamDecay(TestImperativeOptimizerBase): d_model=512, warmup_steps=8000)) return optimizer - def test_sgd(self): + def func_test_sgd(self): self._check_mlp() + def test_sgd(self): + with _test_eager_guard(): + self.func_test_sgd() + self.func_test_sgd() + class TestOptimizerLearningRate(unittest.TestCase): - def test_constant_lr(self): + def func_test_constant_lr(self): with fluid.dygraph.guard(): a = np.random.uniform(-0.1, 0.1, [10, 10]).astype("float32") @@ -375,7 +416,12 @@ class TestOptimizerLearningRate(unittest.TestCase): self.assertTrue(np.allclose(lr, 0.001, rtol=1e-06, atol=0.0)) - def test_lr_decay(self): + def test_constant_lr(self): + with _test_eager_guard(): + self.func_test_constant_lr() + self.func_test_constant_lr() + + def func_test_lr_decay(self): with fluid.dygraph.guard(): a = np.random.uniform(-0.1, 0.1, [10, 10]).astype("float32") @@ -405,7 +451,12 @@ class TestOptimizerLearningRate(unittest.TestCase): self.assertTrue(np.allclose(lr, ret[i], rtol=1e-06, atol=0.0)) - def test_lr_decay_natural_exp(self): + def test_lr_decay(self): + with _test_eager_guard(): + self.func_test_lr_decay() + self.func_test_lr_decay() + + def func_test_lr_decay_natural_exp(self): with fluid.dygraph.guard(): a = np.random.uniform(-0.1, 0.1, [10, 10]).astype("float32") @@ -437,7 +488,12 @@ class TestOptimizerLearningRate(unittest.TestCase): self.assertTrue(np.allclose(lr, ret[i], rtol=1e-06, atol=0.0)) - def test_set_lr(self): + def test_lr_decay_natural_exp(self): + with _test_eager_guard(): + self.func_test_lr_decay_natural_exp() + self.func_test_lr_decay_natural_exp() + + def func_test_set_lr(self): with fluid.dygraph.guard(): a = np.random.uniform(-0.1, 0.1, [10, 10]).astype("float32") @@ -477,6 +533,11 @@ class TestOptimizerLearningRate(unittest.TestCase): parameter_list=linear.parameters()) adam.set_lr(0.01) + def test_set_lr(self): + with _test_eager_guard(): + self.func_test_set_lr() + self.func_test_set_lr() + class TestImperativeMomentumOptimizer(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): @@ -488,9 +549,14 @@ class TestImperativeMomentumOptimizer(TestImperativeOptimizerBase): optimizer = MomentumOptimizer(learning_rate=0.001, momentum=0.9) return optimizer - def test_momentum(self): + def func_test_momentum(self): self._check_mlp() + def test_momentum(self): + with _test_eager_guard(): + self.func_test_momentum() + self.func_test_momentum() + class TestImperativeLarsMomentumOptimizer(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): @@ -502,9 +568,14 @@ class TestImperativeLarsMomentumOptimizer(TestImperativeOptimizerBase): optimizer = LarsMomentumOptimizer(learning_rate=0.001, momentum=0.9) return optimizer - def test_larsmomentum(self): + def func_test_larsmomentum(self): self._check_mlp() + def test_larsmomentum(self): + with _test_eager_guard(): + self.func_test_larsmomentum() + self.func_test_larsmomentum() + class TestImperativeAdagradOptimizer(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): @@ -516,9 +587,14 @@ class TestImperativeAdagradOptimizer(TestImperativeOptimizerBase): optimizer = AdagradOptimizer(learning_rate=0.2) return optimizer - def test_adagrad(self): + def func_test_adagrad(self): self._check_mlp() + def test_adagrad(self): + with _test_eager_guard(): + self.func_test_adagrad() + self.func_test_adagrad() + class TestImperativeAdamaxOptimizer(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): @@ -530,9 +606,14 @@ class TestImperativeAdamaxOptimizer(TestImperativeOptimizerBase): optimizer = AdamaxOptimizer(learning_rate=0.2) return optimizer - def test_adamax(self): + def func_test_adamax(self): self._check_mlp() + def test_adamax(self): + with _test_eager_guard(): + self.func_test_adamax() + self.func_test_adamax() + class TestImperativeDpsgdOptimizer(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): @@ -551,9 +632,14 @@ class TestImperativeDpsgdOptimizer(TestImperativeOptimizerBase): optimizer._seed = 100 return optimizer - def test_dpsgd(self): + def func_test_dpsgd(self): self._check_mlp(place=fluid.CPUPlace()) + def test_dpsgd(self): + with _test_eager_guard(): + self.func_test_dpsgd() + self.func_test_dpsgd() + class TestImperativeDecayedAdagradOptimizer(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): @@ -565,9 +651,14 @@ class TestImperativeDecayedAdagradOptimizer(TestImperativeOptimizerBase): optimizer = DecayedAdagradOptimizer(learning_rate=0.2) return optimizer - def test_decayadagrad(self): + def func_test_decayadagrad(self): self._check_mlp() + def test_decayadagrad(self): + with _test_eager_guard(): + self.func_test_decayadagrad() + self.func_test_decayadagrad() + class TestImperativeAdadeltaOptimizer(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): @@ -583,9 +674,14 @@ class TestImperativeAdadeltaOptimizer(TestImperativeOptimizerBase): learning_rate=0.0003, epsilon=1.0e-6, rho=0.95) return optimizer - def test_adadelta(self): + def func_test_adadelta(self): self._check_mlp() + def test_adadelta(self): + with _test_eager_guard(): + self.func_test_adadelta() + self.func_test_adadelta() + class TestImperativeRMSPropOptimizer(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): @@ -597,9 +693,14 @@ class TestImperativeRMSPropOptimizer(TestImperativeOptimizerBase): optimizer = RMSPropOptimizer(learning_rate=0.1) return optimizer - def test_rmsprop(self): + def func_test_rmsprop(self): self._check_mlp() + def test_rmsprop(self): + with _test_eager_guard(): + self.func_test_rmsprop() + self.func_test_rmsprop() + class TestImperativeFtrlOptimizer(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): @@ -611,9 +712,14 @@ class TestImperativeFtrlOptimizer(TestImperativeOptimizerBase): optimizer = FtrlOptimizer(learning_rate=0.1) return optimizer - def test_ftrl(self): + def func_test_ftrl(self): self._check_mlp() + def test_ftrl(self): + with _test_eager_guard(): + self.func_test_ftrl() + self.func_test_ftrl() + def exclude_fn(param): return param.name.endswith('.b_0') @@ -643,10 +749,15 @@ class TestImperativeModelAverage(TestImperativeOptimizerBase): 0.15, min_average_window=10000, max_average_window=12500) return optimizer - def test_modelaverage(self): + def func_test_modelaverage(self): exception_message = "In dygraph, don't support ModelAverage." self._check_exception(exception_message) + def test_modelaverage(self): + with _test_eager_guard(): + self.func_test_modelaverage() + self.func_test_modelaverage() + class TestImperativeDGCMomentumOptimizer(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): @@ -658,20 +769,30 @@ class TestImperativeDGCMomentumOptimizer(TestImperativeOptimizerBase): sparsity=[0.999, 0.999]) return optimizer - def test_dgcmomentum(self): + def func_test_dgcmomentum(self): exception_message = "In dygraph, don't support DGCMomentumOptimizer." self._check_exception(exception_message) + def test_dgcmomentum(self): + with _test_eager_guard(): + self.func_test_dgcmomentum() + self.func_test_dgcmomentum() + class TestImperativeExponentialMovingAverage(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): optimizer = ExponentialMovingAverage(0.999) return optimizer - def test_exponentialmoving(self): + def func_test_exponentialmoving(self): exception_message = "In dygraph, don't support ExponentialMovingAverage." self._check_exception(exception_message) + def test_exponentialmoving(self): + with _test_eager_guard(): + self.func_test_exponentialmoving() + self.func_test_exponentialmoving() + class TestImperativePipelineOptimizer(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): @@ -680,10 +801,15 @@ class TestImperativePipelineOptimizer(TestImperativeOptimizerBase): optimizer = PipelineOptimizer(optimizer) return optimizer - def test_pipline(self): + def func_test_pipline(self): exception_message = "In dygraph, don't support PipelineOptimizer." self._check_exception(exception_message) + def test_pipline(self): + with _test_eager_guard(): + self.func_test_pipline() + self.func_test_pipline() + class TestImperativeLookaheadOptimizer(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): @@ -692,10 +818,15 @@ class TestImperativeLookaheadOptimizer(TestImperativeOptimizerBase): optimizer = LookaheadOptimizer(optimizer, alpha=0.5, k=5) return optimizer - def test_lookahead(self): + def func_test_lookahead(self): exception_message = "In dygraph, don't support LookaheadOptimizer." self._check_exception(exception_message) + def test_lookahead(self): + with _test_eager_guard(): + self.func_test_lookahead() + self.func_test_lookahead() + class TestImperativeRecomputeOptimizer(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): @@ -704,13 +835,18 @@ class TestImperativeRecomputeOptimizer(TestImperativeOptimizerBase): optimizer = RecomputeOptimizer(optimizer) return optimizer - def test_recompute(self): + def func_test_recompute(self): exception_message = "In dygraph, don't support RecomputeOptimizer." self._check_exception(exception_message) + def test_recompute(self): + with _test_eager_guard(): + self.func_test_recompute() + self.func_test_recompute() + class TestImperativeOptimizerList(unittest.TestCase): - def test_parameter_list(self): + def func_test_parameter_list(self): with fluid.dygraph.guard(): linear_1 = Linear(10, 10) linear_2 = Linear(10, 10) @@ -733,6 +869,11 @@ class TestImperativeOptimizerList(unittest.TestCase): len(sgd._parameter_list) == len(linear_1.parameters() + linear_2.parameters())) + def test_parameter_list(self): + with _test_eager_guard(): + self.func_test_parameter_list() + self.func_test_parameter_list() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_optimizer_v2.py b/python/paddle/fluid/tests/unittests/test_imperative_optimizer_v2.py index dfd1e4f97a8..b27ce6bb01f 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_optimizer_v2.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_optimizer_v2.py @@ -28,6 +28,7 @@ from paddle.fluid.optimizer import ModelAverage, DGCMomentumOptimizer, Exponenti from paddle.fluid.dygraph import Linear from paddle.fluid.dygraph.base import to_variable from test_imperative_base import new_program_scope +from paddle.fluid.framework import _test_eager_guard # Note(wangzhongpu) # In dygraph, don't support ModelAverage, DGCMomentumOptimizer, ExponentialMovingAverage, PipelineOptimizer, LookaheadOptimizer, RecomputeOptimizer. @@ -239,9 +240,14 @@ class TestImperativeOptimizerPiecewiseDecay(TestImperativeOptimizerBase): values=[0.1 * (0.1**i) for i in range(len(bd) + 1)])) return optimizer - def test_sgd(self): + def func_test_sgd(self): self._check_mlp() + def test_sgd(self): + with _test_eager_guard(): + self.func_test_sgd() + self.func_test_sgd() + class TestImperativeOptimizerNaturalExpDecay(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): @@ -257,9 +263,14 @@ class TestImperativeOptimizerNaturalExpDecay(TestImperativeOptimizerBase): learning_rate=0.5, gamma=0.9)) return optimizer - def test_sgd(self): + def func_test_sgd(self): self._check_mlp() + def test_sgd(self): + with _test_eager_guard(): + self.func_test_sgd() + self.func_test_sgd() + class TestImperativeOptimizerExponentialDecay(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): @@ -275,9 +286,14 @@ class TestImperativeOptimizerExponentialDecay(TestImperativeOptimizerBase): learning_rate=0.5, gamma=0.9)) return optimizer - def test_sgd(self): + def func_test_sgd(self): self._check_mlp() + def test_sgd(self): + with _test_eager_guard(): + self.func_test_sgd() + self.func_test_sgd() + class TestImperativeOptimizerInverseTimeDecay(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): @@ -293,9 +309,14 @@ class TestImperativeOptimizerInverseTimeDecay(TestImperativeOptimizerBase): learning_rate=0.5, gamma=0.9)) return optimizer - def test_adam(self): + def func_test_adam(self): self._check_mlp() + def test_adam(self): + with _test_eager_guard(): + self.func_test_adam() + self.func_test_adam() + class TestImperativeOptimizerPolynomialDecay(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): @@ -311,14 +332,24 @@ class TestImperativeOptimizerPolynomialDecay(TestImperativeOptimizerBase): learning_rate=0.5, decay_steps=5, cycle=self.cycle)) return optimizer - def test_sgd_cycle(self): + def func_test_sgd_cycle(self): self.cycle = True self._check_mlp() - def test_sgd(self): + def test_sgd_cycle(self): + with _test_eager_guard(): + self.func_test_sgd_cycle() + self.func_test_sgd_cycle() + + def func_test_sgd(self): self.cycle = False self._check_mlp() + def test_sgd(self): + with _test_eager_guard(): + self.func_test_sgd() + self.func_test_sgd() + class TestImperativeOptimizerCosineAnnealingDecay(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): @@ -334,9 +365,14 @@ class TestImperativeOptimizerCosineAnnealingDecay(TestImperativeOptimizerBase): learning_rate=0.5, T_max=5)) return optimizer - def test_sgd(self): + def func_test_sgd(self): self._check_mlp() + def test_sgd(self): + with _test_eager_guard(): + self.func_test_sgd() + self.func_test_sgd() + class TestImperativeOptimizerNoamDecay(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): @@ -352,9 +388,14 @@ class TestImperativeOptimizerNoamDecay(TestImperativeOptimizerBase): d_model=0.01, warmup_steps=100)) return optimizer - def test_sgd(self): + def func_test_sgd(self): self._check_mlp() + def test_sgd(self): + with _test_eager_guard(): + self.func_test_sgd() + self.func_test_sgd() + class TestImperativeOptimizerLambdaDecay(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): @@ -370,9 +411,14 @@ class TestImperativeOptimizerLambdaDecay(TestImperativeOptimizerBase): learning_rate=0.5, lr_lambda=lambda epoch: 0.9**epoch)) return optimizer - def test_sgd(self): + def func_test_sgd(self): self._check_mlp() + def test_sgd(self): + with _test_eager_guard(): + self.func_test_sgd() + self.func_test_sgd() + class TestImperativeOptimizerLinearWarmup(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): @@ -392,9 +438,14 @@ class TestImperativeOptimizerLinearWarmup(TestImperativeOptimizerBase): verbose=True)) return optimizer - def test_sgd(self): + def func_test_sgd(self): self._check_mlp() + def test_sgd(self): + with _test_eager_guard(): + self.func_test_sgd() + self.func_test_sgd() + class TestImperativeOptimizerMultiStepDecay(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): @@ -410,9 +461,14 @@ class TestImperativeOptimizerMultiStepDecay(TestImperativeOptimizerBase): learning_rate=0.5, milestones=[2, 4, 6], gamma=0.8)) return optimizer - def test_sgd(self): + def func_test_sgd(self): self._check_mlp() + def test_sgd(self): + with _test_eager_guard(): + self.func_test_sgd() + self.func_test_sgd() + class TestImperativeOptimizerStepLR(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): @@ -428,9 +484,14 @@ class TestImperativeOptimizerStepLR(TestImperativeOptimizerBase): learning_rate=0.5, step_size=5, gamma=0.8)) return optimizer - def test_sgd(self): + def func_test_sgd(self): self._check_mlp() + def test_sgd(self): + with _test_eager_guard(): + self.func_test_sgd() + self.func_test_sgd() + class TestImperativeOptimizerReduceOnPlateau(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): @@ -446,12 +507,17 @@ class TestImperativeOptimizerReduceOnPlateau(TestImperativeOptimizerBase): learning_rate=0.5)) return optimizer - def test_sgd(self): + def func_test_sgd(self): self._check_mlp() + def test_sgd(self): + with _test_eager_guard(): + self.func_test_sgd() + self.func_test_sgd() + class TestOptimizerLearningRate(unittest.TestCase): - def test_constant_lr(self): + def func_test_constant_lr(self): with fluid.dygraph.guard(): a = np.random.uniform(-0.1, 0.1, [10, 10]).astype("float32") @@ -475,7 +541,12 @@ class TestOptimizerLearningRate(unittest.TestCase): self.assertTrue(np.allclose(lr, 0.001, rtol=1e-06, atol=0.0)) - def test_lr_decay(self): + def test_constant_lr(self): + with _test_eager_guard(): + self.func_test_constant_lr() + self.func_test_constant_lr() + + def func_test_lr_decay(self): with fluid.dygraph.guard(): a = np.random.uniform(-0.1, 0.1, [10, 10]).astype("float32") @@ -505,7 +576,12 @@ class TestOptimizerLearningRate(unittest.TestCase): self.assertTrue(np.allclose(lr, ret[i], rtol=1e-06, atol=0.0)) scheduler.step() - def test_lr_scheduler_natural_exp(self): + def test_lr_decay(self): + with _test_eager_guard(): + self.func_test_lr_decay() + self.func_test_lr_decay() + + def func_test_lr_scheduler_natural_exp(self): with fluid.dygraph.guard(): a = np.random.uniform(-0.1, 0.1, [10, 10]).astype("float32") @@ -531,7 +607,12 @@ class TestOptimizerLearningRate(unittest.TestCase): self.assertTrue(np.allclose(lr, ret[i], rtol=1e-06, atol=0.0)) scheduler.step() - def test_set_lr(self): + def test_lr_scheduler_natural_exp(self): + with _test_eager_guard(): + self.func_test_lr_scheduler_natural_exp() + self.func_test_lr_scheduler_natural_exp() + + def func_test_set_lr(self): with fluid.dygraph.guard(): a = np.random.uniform(-0.1, 0.1, [10, 10]).astype("float32") @@ -566,6 +647,11 @@ class TestOptimizerLearningRate(unittest.TestCase): parameters=linear.parameters()) adam.set_lr(0.01) + def test_set_lr(self): + with _test_eager_guard(): + self.func_test_set_lr() + self.func_test_set_lr() + class TestImperativeMomentumOptimizer(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): @@ -577,9 +663,14 @@ class TestImperativeMomentumOptimizer(TestImperativeOptimizerBase): optimizer = MomentumOptimizer(learning_rate=0.001, momentum=0.9) return optimizer - def test_momentum(self): + def func_test_momentum(self): self._check_mlp() + def test_momentum(self): + with _test_eager_guard(): + self.func_test_momentum() + self.func_test_momentum() + class TestImperativeLarsMomentumOptimizer(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): @@ -591,9 +682,14 @@ class TestImperativeLarsMomentumOptimizer(TestImperativeOptimizerBase): optimizer = LarsMomentumOptimizer(learning_rate=0.001, momentum=0.9) return optimizer - def test_larsmomentum(self): + def func_test_larsmomentum(self): self._check_mlp() + def test_larsmomentum(self): + with _test_eager_guard(): + self.func_test_larsmomentum() + self.func_test_larsmomentum() + class TestImperativeAdagradOptimizer(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): @@ -605,9 +701,14 @@ class TestImperativeAdagradOptimizer(TestImperativeOptimizerBase): optimizer = AdagradOptimizer(learning_rate=0.2) return optimizer - def test_adagrad(self): + def func_test_adagrad(self): self._check_mlp() + def test_adagrad(self): + with _test_eager_guard(): + self.func_test_adagrad() + self.func_test_adagrad() + class TestImperativeAdamaxOptimizer(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): @@ -619,9 +720,14 @@ class TestImperativeAdamaxOptimizer(TestImperativeOptimizerBase): optimizer = AdamaxOptimizer(learning_rate=0.2) return optimizer - def test_adamax(self): + def func_test_adamax(self): self._check_mlp() + def test_adamax(self): + with _test_eager_guard(): + self.func_test_adamax() + self.func_test_adamax() + class TestImperativeDpsgdOptimizer(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): @@ -640,9 +746,14 @@ class TestImperativeDpsgdOptimizer(TestImperativeOptimizerBase): optimizer._seed = 100 return optimizer - def test_dpsgd(self): + def func_test_dpsgd(self): self._check_mlp(place=fluid.CPUPlace()) + def test_dpsgd(self): + with _test_eager_guard(): + self.func_test_dpsgd() + self.func_test_dpsgd() + class TestImperativeDecayedAdagradOptimizer(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): @@ -654,9 +765,14 @@ class TestImperativeDecayedAdagradOptimizer(TestImperativeOptimizerBase): optimizer = DecayedAdagradOptimizer(learning_rate=0.2) return optimizer - def test_decayadagrad(self): + def func_test_decayadagrad(self): self._check_mlp() + def test_decayadagrad(self): + with _test_eager_guard(): + self.func_test_decayadagrad() + self.func_test_decayadagrad() + class TestImperativeAdadeltaOptimizer(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): @@ -672,9 +788,14 @@ class TestImperativeAdadeltaOptimizer(TestImperativeOptimizerBase): learning_rate=0.0003, epsilon=1.0e-6, rho=0.95) return optimizer - def test_adadelta(self): + def func_test_adadelta(self): self._check_mlp() + def test_adadelta(self): + with _test_eager_guard(): + self.func_test_adadelta() + self.func_test_adadelta() + class TestImperativeRMSPropOptimizer(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): @@ -686,9 +807,14 @@ class TestImperativeRMSPropOptimizer(TestImperativeOptimizerBase): optimizer = RMSPropOptimizer(learning_rate=0.1) return optimizer - def test_rmsprop(self): + def func_test_rmsprop(self): self._check_mlp() + def test_rmsprop(self): + with _test_eager_guard(): + self.func_test_rmsprop() + self.func_test_rmsprop() + class TestImperativeFtrlOptimizer(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): @@ -700,9 +826,14 @@ class TestImperativeFtrlOptimizer(TestImperativeOptimizerBase): optimizer = FtrlOptimizer(learning_rate=0.1) return optimizer - def test_ftrl(self): + def func_test_ftrl(self): self._check_mlp() + def test_ftrl(self): + with _test_eager_guard(): + self.func_test_ftrl() + self.func_test_ftrl() + def exclude_fn(param): return param.name.endswith('.b_0') @@ -732,10 +863,15 @@ class TestImperativeModelAverage(TestImperativeOptimizerBase): 0.15, min_average_window=10000, max_average_window=12500) return optimizer - def test_modelaverage(self): + def func_test_modelaverage(self): exception_message = "In dygraph, don't support ModelAverage." self._check_exception(exception_message) + def test_modelaverage(self): + with _test_eager_guard(): + self.func_test_modelaverage() + self.func_test_modelaverage() + class TestImperativeDGCMomentumOptimizer(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): @@ -747,20 +883,30 @@ class TestImperativeDGCMomentumOptimizer(TestImperativeOptimizerBase): sparsity=[0.999, 0.999]) return optimizer - def test_dgcmomentum(self): + def func_test_dgcmomentum(self): exception_message = "In dygraph, don't support DGCMomentumOptimizer." self._check_exception(exception_message) + def test_dgcmomentum(self): + with _test_eager_guard(): + self.func_test_dgcmomentum() + self.func_test_dgcmomentum() + class TestImperativeExponentialMovingAverage(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): optimizer = ExponentialMovingAverage(0.999) return optimizer - def test_exponentialmoving(self): + def func_test_exponentialmoving(self): exception_message = "In dygraph, don't support ExponentialMovingAverage." self._check_exception(exception_message) + def test_exponentialmoving(self): + with _test_eager_guard(): + self.func_test_exponentialmoving() + self.func_test_exponentialmoving() + class TestImperativePipelineOptimizer(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): @@ -769,10 +915,15 @@ class TestImperativePipelineOptimizer(TestImperativeOptimizerBase): optimizer = PipelineOptimizer(optimizer) return optimizer - def test_pipline(self): + def func_test_pipline(self): exception_message = "In dygraph, don't support PipelineOptimizer." self._check_exception(exception_message) + def test_pipline(self): + with _test_eager_guard(): + self.func_test_pipline() + self.func_test_pipline() + class TestImperativeLookaheadOptimizer(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): @@ -781,10 +932,15 @@ class TestImperativeLookaheadOptimizer(TestImperativeOptimizerBase): optimizer = LookaheadOptimizer(optimizer, alpha=0.5, k=5) return optimizer - def test_lookahead(self): + def func_test_lookahead(self): exception_message = "In dygraph, don't support LookaheadOptimizer." self._check_exception(exception_message) + def test_lookahead(self): + with _test_eager_guard(): + self.func_test_lookahead() + self.func_test_lookahead() + class TestImperativeRecomputeOptimizer(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): @@ -793,13 +949,18 @@ class TestImperativeRecomputeOptimizer(TestImperativeOptimizerBase): optimizer = RecomputeOptimizer(optimizer) return optimizer - def test_recompute(self): + def func_test_recompute(self): exception_message = "In dygraph, don't support RecomputeOptimizer." self._check_exception(exception_message) + def test_recompute(self): + with _test_eager_guard(): + self.func_test_recompute() + self.func_test_recompute() + class TestImperativeOptimizerList(unittest.TestCase): - def test_parameter_list(self): + def func_test_parameter_list(self): with fluid.dygraph.guard(): linear_1 = Linear(10, 10) linear_2 = Linear(10, 10) @@ -822,6 +983,11 @@ class TestImperativeOptimizerList(unittest.TestCase): len(sgd._parameter_list) == len(linear_1.parameters() + linear_2.parameters())) + def test_parameter_list(self): + with _test_eager_guard(): + self.func_test_parameter_list() + self.func_test_parameter_list() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/nn/layer/rnn.py b/python/paddle/nn/layer/rnn.py index fbb648af42a..f7d5448d132 100644 --- a/python/paddle/nn/layer/rnn.py +++ b/python/paddle/nn/layer/rnn.py @@ -971,6 +971,13 @@ class RNNBase(LayerList): # should dropout state be persistable for static-graph self._dropout_state = self.create_variable( dtype=fluid.core.VarDesc.VarType.UINT8) + if fluid.framework.in_dygraph_mode(): + with paddle.no_grad(): + _C_ops.coalesce_tensor(self._all_weights, self._all_weights, + self._flat_weight[0], "copy_data", + True, "use_align", False, "dtype", + params[0].dtype) + return # for static-graph, append coalesce_tensor into startup program with fluid.program_guard(fluid.default_startup_program(), fluid.default_startup_program()): diff --git a/python/paddle/optimizer/lr.py b/python/paddle/optimizer/lr.py index 90117f99abc..79bacc0dfb6 100644 --- a/python/paddle/optimizer/lr.py +++ b/python/paddle/optimizer/lr.py @@ -16,6 +16,8 @@ import math import numpy import warnings from paddle import Tensor +import paddle.fluid.core as core +from ..fluid.framework import _in_eager_mode __all__ = [ # noqa 'LRScheduler', @@ -1355,8 +1357,12 @@ class ReduceOnPlateau(LRScheduler): else: self.last_epoch = epoch + if _in_eager_mode(): + tmp = core.eager.EagerTensor + else: + tmp = Tensor # loss must be float, numpy.ndarray or 1-D Tensor with shape [1] - if isinstance(metrics, (Tensor, numpy.ndarray)): + if isinstance(metrics, (tmp, numpy.ndarray)): assert len(metrics.shape) == 1 and metrics.shape[0] == 1, "the metrics.shape " \ "should be (1L,), but the current metrics.shape is {}. Maybe that " \ "you should call paddle.mean to process it first.".format( diff --git a/python/paddle/optimizer/optimizer.py b/python/paddle/optimizer/optimizer.py index 4206683a143..d433921e826 100644 --- a/python/paddle/optimizer/optimizer.py +++ b/python/paddle/optimizer/optimizer.py @@ -424,16 +424,21 @@ class Optimizer(object): self._learning_rate = float(value) current_lr = self._global_learning_rate() if current_lr is not None: - global_block = framework.default_main_program().global_block() - global_block.append_op( - type='fill_constant', - outputs={'Out': [current_lr]}, - attrs={ - 'dtype': current_lr.dtype, - 'shape': list(current_lr.shape), - 'value': float(value) - }, - stop_gradient=True) + if framework.in_dygraph_mode(): + _C_ops.fill_constant(current_lr, 'value', + float(value), 'dtype', current_lr.dtype, + 'shape', list(current_lr.shape)) + else: + global_block = framework.default_main_program().global_block() + global_block.append_op( + type='fill_constant', + outputs={'Out': [current_lr]}, + attrs={ + 'dtype': current_lr.dtype, + 'shape': list(current_lr.shape), + 'value': float(value) + }, + stop_gradient=True) def get_lr(self): """ @@ -590,7 +595,9 @@ class Optimizer(object): name=var_name, persistable=True, dtype=dtype or param.dtype, - type=param.type if type is None else type, + type=core.VarDesc.VarType.LOD_TENSOR + if framework._in_eager_mode() else (param.type + if type is None else type), shape=shape, belong_to_optimizer=True) if device is None: -- GitLab