“24509f4af942bb250564756ad636691c7921e1df”上不存在“paddle/legacy/gserver/layers/ConvBaseLayer.h”
未验证 提交 cdd46d7e 编写于 作者: L Leo Chen 提交者: GitHub

Split VarBase from Python Variable for Dygraph (#21359)

* test=develop, fix docker with paddle nccl problem

* don't expose numerous Tensor.set(), test=develop

* fix condition, test=develop

* fix float16 bug, test=develop

* feed should be Tensor or np.array, not Variable or number, test=develop

* use forcecast to copy numpy slice to new array, test=develop

* remove float16-uint16 hacking, test=develop

* add variable method to varbase and refactor to_variable to support return varbase

* support kwargs in varbase constructor

* add VarBase constructor to support default python args

* refine varbase initial method

* reset branch

* fix ut for change VarBase error info to PaddleEnforce

* cherry is parameter change before

* overload isinstance to replace too many change of is_variable

* rm useless files

* rm useless code merged by git

* test=develop, fix some ut failed error

* test=develop, fix test_graph_wrapper

* add some tests, test=develop

* refine __getitem__, test=develop

* add tests, test=develop

* fix err_msg, test=develop
上级 cdba41af
...@@ -203,6 +203,7 @@ elseif(${CBLAS_PROVIDER} STREQUAL EXTERN_OPENBLAS) ...@@ -203,6 +203,7 @@ elseif(${CBLAS_PROVIDER} STREQUAL EXTERN_OPENBLAS)
list(APPEND third_party_deps extern_openblas) list(APPEND third_party_deps extern_openblas)
endif() endif()
if(WITH_MKLDNN) if(WITH_MKLDNN)
include(external/mkldnn) # download, build, install mkldnn include(external/mkldnn) # download, build, install mkldnn
list(APPEND third_party_deps extern_mkldnn) list(APPEND third_party_deps extern_mkldnn)
......
...@@ -236,11 +236,13 @@ std::shared_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place, ...@@ -236,11 +236,13 @@ std::shared_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
// TODO(Jiabin): change this after move unique_name generator to CXX // TODO(Jiabin): change this after move unique_name generator to CXX
auto new_var = std::make_shared<VarBase>( auto new_var = std::make_shared<VarBase>(
false, "Itmp" + std::to_string(copied_counter_++)); true, Name() + std::to_string(copied_counter_++));
auto* dst_tensor = new_var->var_.GetMutable<framework::LoDTensor>(); auto* dst_tensor = new_var->var_.GetMutable<framework::LoDTensor>();
dst_tensor->set_lod(src_tensor.lod()); dst_tensor->set_lod(src_tensor.lod());
new_var->SetPersistable(Persistable());
new_var->SetDataType(DataType());
new_var->SetType(Type());
framework::TensorCopy(src_tensor, dst_place, dst_tensor); framework::TensorCopy(src_tensor, dst_place, dst_tensor);
if (blocking) { if (blocking) {
platform::DeviceContextPool::Instance().Get(dst_place)->Wait(); platform::DeviceContextPool::Instance().Get(dst_place)->Wait();
...@@ -253,7 +255,6 @@ std::shared_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place, ...@@ -253,7 +255,6 @@ std::shared_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
if (platform::is_gpu_place(dst_place)) { if (platform::is_gpu_place(dst_place)) {
VLOG(3) << "copy tensor " << Name() << " from gpu"; VLOG(3) << "copy tensor " << Name() << " from gpu";
} }
return new_var; return new_var;
} else { } else {
auto& src_selected_rows = var_.Get<framework::SelectedRows>(); auto& src_selected_rows = var_.Get<framework::SelectedRows>();
......
...@@ -158,7 +158,7 @@ TEST(test_layer, test_varbase_basic) { ...@@ -158,7 +158,7 @@ TEST(test_layer, test_varbase_basic) {
vin->MutableVar()->GetMutable<framework::LoDTensor>()->mutable_data<float>( vin->MutableVar()->GetMutable<framework::LoDTensor>()->mutable_data<float>(
place); place);
std::shared_ptr<imperative::VarBase> vout(vin->NewVarBase(place, false)); std::shared_ptr<imperative::VarBase> vout(vin->NewVarBase(place, false));
ASSERT_EQ(vout->Name(), "Itmp0"); ASSERT_EQ(vout->Name(), "vin0");
std::shared_ptr<imperative::VarBase> vin_with_grad( std::shared_ptr<imperative::VarBase> vin_with_grad(
new imperative::VarBase(true, "vin")); new imperative::VarBase(true, "vin"));
......
...@@ -30,7 +30,6 @@ limitations under the License. */ ...@@ -30,7 +30,6 @@ limitations under the License. */
#include "paddle/fluid/imperative/profiler.h" #include "paddle/fluid/imperative/profiler.h"
#include "paddle/fluid/imperative/tracer.h" #include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/imperative/type_defs.h" #include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/pybind/pybind_boost_headers.h" #include "paddle/fluid/pybind/pybind_boost_headers.h"
namespace paddle { namespace paddle {
...@@ -38,6 +37,12 @@ namespace pybind { ...@@ -38,6 +37,12 @@ namespace pybind {
namespace py = ::pybind11; namespace py = ::pybind11;
template <typename P>
extern void SetTensorFromPyArray(framework::Tensor *self, const py::object &obj,
const P &place, bool zero_copy);
extern py::array TensorToPyArray(const framework::Tensor &tensor,
bool need_deep_copy = false);
class Layer : public imperative::Layer { class Layer : public imperative::Layer {
public: public:
using imperative::Layer::Layer; // Inherit constructors using imperative::Layer::Layer; // Inherit constructors
...@@ -50,42 +55,99 @@ class Layer : public imperative::Layer { ...@@ -50,42 +55,99 @@ class Layer : public imperative::Layer {
} }
}; };
// warper for pyobject to avoid imperative module depend on python static void InitTensorForVarBase(imperative::VarBase *self, bool persistable,
// TODO(jiabin) Add OpBase's pybind interface back to enable backward hook bool is_default, const py::array &array,
class PYBIND11_HIDDEN PyCallableObject { const py::object &obj = py::object(),
public: bool zero_copy = false) {
PyCallableObject(std::shared_ptr<py::object> py_obj_ptr) new (self) imperative::VarBase(
: py_obj_ptr_(std::move(py_obj_ptr)) {} imperative::GetCurrentTracer()->GenerateUniqueName("generated_var_"));
~PyCallableObject() { self->SetPersistable(persistable);
py::call_guard<py::gil_scoped_acquire>(); auto *tensor = self->MutableVar()->GetMutable<framework::LoDTensor>();
py_obj_ptr_.reset(); if (is_default) {
auto place = imperative::GetCurrentTracer()->ExpectedPlace();
if (platform::is_cpu_place(place)) {
SetTensorFromPyArray<platform::CPUPlace>(
tensor, array, boost::get<platform::CPUPlace>(place), zero_copy);
} else if (platform::is_gpu_place(place)) {
SetTensorFromPyArray<platform::CUDAPlace>(
tensor, array, boost::get<platform::CUDAPlace>(place), zero_copy);
} else if (platform::is_cuda_pinned_place(place)) {
SetTensorFromPyArray<platform::CUDAPinnedPlace>(
tensor, array, boost::get<platform::CUDAPinnedPlace>(place),
zero_copy);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Place should be one of CPUPlace/CUDAPlace/CUDAPinnedPlace"));
}
} else {
if (py::isinstance<platform::CPUPlace>(obj)) {
SetTensorFromPyArray<platform::CPUPlace>(
tensor, array, obj.cast<platform::CPUPlace>(), zero_copy);
} else if (py::isinstance<platform::CUDAPlace>(obj)) {
SetTensorFromPyArray<platform::CUDAPlace>(
tensor, array, obj.cast<platform::CUDAPlace>(), zero_copy);
} else if (py::isinstance<platform::CUDAPinnedPlace>(obj)) {
SetTensorFromPyArray<platform::CUDAPinnedPlace>(
tensor, array, obj.cast<platform::CUDAPinnedPlace>(), zero_copy);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Place should be one of CPUPlace/CUDAPlace/CUDAPinnedPlace"));
}
} }
void operator()() { self->SetType(framework::proto::VarType::LOD_TENSOR);
py::call_guard<py::gil_scoped_acquire>(); self->SetDataType(tensor->type());
py_obj_ptr_->operator()(this); }
static void InitVarBaseFromNumpyWithKwargs(imperative::VarBase *self,
const py::kwargs &kwargs) {
PADDLE_ENFORCE_EQ(
kwargs.contains("value"), true,
platform::errors::InvalidArgument("Missing argument: value"));
if (kwargs.contains("place")) {
InitTensorForVarBase(self, kwargs.contains("persistable")
? kwargs["persistable"].cast<bool>()
: false,
false, kwargs["value"].cast<py::array>(),
kwargs["place"], kwargs["zero_copy"].cast<bool>());
} else {
InitTensorForVarBase(self, kwargs.contains("persistable")
? kwargs["persistable"].cast<bool>()
: false,
true, kwargs["value"].cast<py::array>(), py::object(),
kwargs["zero_copy"].cast<bool>());
} }
}
private: template <typename P>
std::shared_ptr<py::object> py_obj_ptr_; static void InitVarBaseFromNumpyWithArg(imperative::VarBase *self,
}; const py::array &array, const P &place,
bool persistable, bool zero_copy) {
// 0: value, 1: place, 2: name 3: persistable, 4: zero_copy
new (self) imperative::VarBase(
imperative::GetCurrentTracer()->GenerateUniqueName("generated_var_"));
self->SetPersistable(persistable);
auto *tensor = self->MutableVar()->GetMutable<framework::LoDTensor>();
SetTensorFromPyArray<P>(tensor, array, place, zero_copy);
self->SetType(framework::proto::VarType::LOD_TENSOR);
self->SetDataType(tensor->type());
}
static void InitVarBaseFromNumpyWithArgDefault(imperative::VarBase *self,
const py::array &array,
bool persistable) {
InitTensorForVarBase(self, persistable, true, array);
}
// Function like obj.attr_name in Python. static std::string GetTypeName(const imperative::VarBase &var) {
static PyObject *GetPythonAttribute(PyObject *obj, const char *attr_name) { if (var.Type() == framework::proto::VarType::RAW) {
// NOTE(zjl): PyObject_GetAttrString would return nullptr when attr_name return "RAW";
// is not inside obj, but it would also set the error flag of Python. } else if (!var.Var().IsInitialized()) {
// If the error flag is set in C++, C++ code would not raise Exception, return "nullptr";
// but Python would raise Exception once C++ call ends.
// To avoid unexpected Exception raised in Python, we check whether
// attribute exists before calling PyObject_GetAttrString.
//
// Caution: PyObject_GetAttrString would increase reference count of PyObject.
// Developer should call Py_DECREF manually after the attribute is not used.
if (PyObject_HasAttrString(obj, attr_name)) {
return PyObject_GetAttrString(obj, attr_name);
} else { } else {
return nullptr; return framework::ToTypeName(var.Var().Type());
} }
} }
using PyNameVarBaseMap = std::unordered_map<std::string, py::handle>;
template <typename T> template <typename T>
static T PyObjectCast(PyObject *obj) { static T PyObjectCast(PyObject *obj) {
...@@ -106,48 +168,36 @@ GetVarBaseListFromPyHandle(const py::handle &handle) { ...@@ -106,48 +168,36 @@ GetVarBaseListFromPyHandle(const py::handle &handle) {
return {}; return {};
} }
const char *kIVarField = "_ivar";
PyObject *py_ivar = GetPythonAttribute(py_obj, kIVarField);
std::vector<std::shared_ptr<imperative::VarBase>> result; std::vector<std::shared_ptr<imperative::VarBase>> result;
if (py_ivar) { // Variable if (PyList_Check(py_obj)) { // List of VarBase
result.emplace_back(
PyObjectCast<std::shared_ptr<imperative::VarBase>>(py_ivar));
Py_DECREF(py_ivar);
} else if (PyList_Check(py_obj)) { // List of Variable
size_t len = PyList_GET_SIZE(py_obj); size_t len = PyList_GET_SIZE(py_obj);
result.reserve(len); result.reserve(len);
for (size_t i = 0; i < len; ++i) { for (size_t i = 0; i < len; ++i) {
PyObject *py_ivar = PyObject *py_ivar = PyList_GET_ITEM(py_obj, i);
PyObject_GetAttrString(PyList_GET_ITEM(py_obj, i), kIVarField); PADDLE_ENFORCE_NOT_NULL(
PADDLE_ENFORCE_NOT_NULL(py_ivar); py_ivar, platform::errors::InvalidArgument("Python Object is NULL"));
result.emplace_back( result.emplace_back(
PyObjectCast<std::shared_ptr<imperative::VarBase>>(py_ivar)); PyObjectCast<std::shared_ptr<imperative::VarBase>>(py_ivar));
Py_DECREF(py_ivar);
} }
} else if (PyTuple_Check(py_obj)) { // Tuple of Variable } else if (PyTuple_Check(py_obj)) { // Tuple of VarBase
size_t len = PyTuple_GET_SIZE(py_obj); size_t len = PyTuple_GET_SIZE(py_obj);
result.reserve(len); result.reserve(len);
for (size_t i = 0; i < len; ++i) { for (size_t i = 0; i < len; ++i) {
PyObject *py_ivar = PyObject *py_ivar = PyTuple_GET_ITEM(py_obj, i);
PyObject_GetAttrString(PyTuple_GET_ITEM(py_obj, i), kIVarField); PADDLE_ENFORCE_NOT_NULL(
PADDLE_ENFORCE_NOT_NULL(py_ivar); py_ivar, platform::errors::InvalidArgument("Python Object is NULL"));
result.emplace_back( result.emplace_back(
PyObjectCast<std::shared_ptr<imperative::VarBase>>(py_ivar)); PyObjectCast<std::shared_ptr<imperative::VarBase>>(py_ivar));
Py_DECREF(py_ivar);
} }
} else { } else { // VarBase
PADDLE_THROW( result.emplace_back(
"unsupported type %s, must be Variable, list[Variable] or " PyObjectCast<std::shared_ptr<imperative::VarBase>>(py_obj));
"tuple[Variable]",
py::str(handle));
} }
return result; return result;
} }
using PyNameVarBaseMap = std::unordered_map<std::string, py::handle>;
static imperative::NameVarBaseMap ConvertToNameVarBaseMap( static imperative::NameVarBaseMap ConvertToNameVarBaseMap(
const PyNameVarBaseMap &map) { const PyNameVarBaseMap &map) {
imperative::NameVarBaseMap result; imperative::NameVarBaseMap result;
...@@ -163,16 +213,6 @@ static imperative::NameVarBaseMap ConvertToNameVarBaseMap( ...@@ -163,16 +213,6 @@ static imperative::NameVarBaseMap ConvertToNameVarBaseMap(
return result; return result;
} }
static std::string GetTypeName(const imperative::VarBase &var) {
if (var.Type() == framework::proto::VarType::RAW) {
return "RAW";
} else if (!var.Var().IsInitialized()) {
return "nullptr";
} else {
return framework::ToTypeName(var.Var().Type());
}
}
// Bind Methods // Bind Methods
void BindImperative(py::module *m_ptr) { void BindImperative(py::module *m_ptr) {
auto &m = *m_ptr; auto &m = *m_ptr;
...@@ -239,11 +279,17 @@ void BindImperative(py::module *m_ptr) { ...@@ -239,11 +279,17 @@ void BindImperative(py::module *m_ptr) {
R"DOC()DOC") R"DOC()DOC")
.def_static("_alive_vars", &imperative::VarBase::AliveVarNames) .def_static("_alive_vars", &imperative::VarBase::AliveVarNames)
.def("__init__", .def("__init__",
[](imperative::VarBase &self, const std::string &name, [](imperative::VarBase &self, framework::proto::VarType::Type dtype,
framework::proto::VarType::Type type, const std::vector<int> &dims, const py::handle &name,
framework::proto::VarType::Type dtype, framework::proto::VarType::Type type, bool persistable) {
const std::vector<int> &dims, bool persistable) { std::string act_name = "";
new (&self) imperative::VarBase(name); if (!name.ptr() || name.ptr() == Py_None) {
act_name = imperative::GetCurrentTracer()->GenerateUniqueName(
"generated_var");
} else {
act_name = name.cast<std::string>();
}
new (&self) imperative::VarBase(act_name);
self.SetPersistable(persistable); self.SetPersistable(persistable);
self.SetType(type); self.SetType(type);
self.SetDataType(dtype); self.SetDataType(dtype);
...@@ -253,6 +299,91 @@ void BindImperative(py::module *m_ptr) { ...@@ -253,6 +299,91 @@ void BindImperative(py::module *m_ptr) {
tensor->Resize(framework::make_ddim(dims)); tensor->Resize(framework::make_ddim(dims));
} }
}) })
.def("__init__", &InitVarBaseFromNumpyWithArg<platform::CPUPlace>,
py::arg("value"), py::arg("place"), py::arg("persistable") = false,
py::arg("zero_copy") = false)
.def("__init__", &InitVarBaseFromNumpyWithArg<platform::CUDAPlace>,
py::arg("value"), py::arg("place"), py::arg("persistable") = false,
py::arg("zero_copy") = false)
.def("__init__", &InitVarBaseFromNumpyWithArg<platform::CUDAPinnedPlace>,
py::arg("value"), py::arg("place"), py::arg("persistable") = false,
py::arg("zero_copy") = false)
.def("__init__", &InitVarBaseFromNumpyWithArgDefault, py::arg("value"),
py::arg("persistable") = false)
.def("__init__", &InitVarBaseFromNumpyWithKwargs)
.def("numpy",
[](imperative::VarBase &self) -> py::array {
const auto &tensor =
self.MutableVar()->Get<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(
tensor.IsInitialized(), true,
platform::errors::InvalidArgument(
"%s is Empty, Please check if it has no data in",
self.Name()));
return TensorToPyArray(tensor, true);
},
R"DOC(
**Notes**:
**This API is ONLY avaliable in Dygraph mode**
Returns a numpy array shows the value of current :ref:`api_guide_Variable_en`
Returns:
ndarray: The numpy value of current Variable.
Returns type:
ndarray: dtype is same as current Variable
Examples:
.. code-block:: python
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph import FC
import numpy as np
data = np.random.uniform(-1, 1, [30, 10, 32]).astype('float32')
with fluid.dygraph.guard():
fc = FC("fc", 64, num_flatten_dims=2)
data = to_variable(data)
x = fc(data)
print(x.numpy())
)DOC")
.def("detach",
[](const imperative::VarBase &self) {
const auto &tensor = self.Var().Get<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(tensor.IsInitialized(), true,
platform::errors::InvalidArgument(
"%s has not been initialized", self.Name()));
return self.NewVarBase(tensor.place(), false);
},
py::return_value_policy::copy, R"DOC(
**Notes**:
**This API is ONLY avaliable in Dygraph mode**
Returns a new Variable, detached from the current graph.
Returns:
( :ref:`api_guide_Variable_en` | dtype is same as current Variable): The detached Variable.
Examples:
.. code-block:: python
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph import FC
import numpy as np
data = np.random.uniform(-1, 1, [30, 10, 32]).astype('float32')
with fluid.dygraph.guard():
fc = FC("fc", 64, num_flatten_dims=2)
data = to_variable(data)
x = fc(data)
y = x.detach()
)DOC")
.def("_run_backward", .def("_run_backward",
[](imperative::VarBase &self, [](imperative::VarBase &self,
const imperative::detail::BackwardStrategy &bckst, const imperative::detail::BackwardStrategy &bckst,
...@@ -273,7 +404,39 @@ void BindImperative(py::module *m_ptr) { ...@@ -273,7 +404,39 @@ void BindImperative(py::module *m_ptr) {
return self.MutableGradVar()->Get<framework::LoDTensor>(); return self.MutableGradVar()->Get<framework::LoDTensor>();
}, },
py::return_value_policy::reference) py::return_value_policy::reference)
.def("_clear_gradient", &imperative::VarBase::ClearGradient) .def("clear_gradient", &imperative::VarBase::ClearGradient, R"DOC(
**Notes**:
**1. This API is ONLY avaliable in Dygraph mode**
**2. Use it only Variable has gradient, normally we use this for Parameters since other temporal Variable will be deleted by Python's GC**
Clear (set to ``0`` ) the Gradient of Current Variable
Returns: None
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
x = np.ones([2, 2], np.float32)
with fluid.dygraph.guard():
inputs2 = []
for _ in range(10):
tmp = fluid.dygraph.base.to_variable(x)
tmp.stop_gradient=False
inputs2.append(tmp)
ret2 = fluid.layers.sums(inputs2)
loss2 = fluid.layers.reduce_sum(ret2)
backward_strategy = fluid.dygraph.BackwardStrategy()
backward_strategy.sort_sum_gradient = True
loss2.backward(backward_strategy)
print(loss2.gradient())
loss2.clear_gradient()
print("After clear {}".format(loss2.gradient()))
)DOC")
.def("_grad_ivar", .def("_grad_ivar",
[](const imperative::VarBase &self) { [](const imperative::VarBase &self) {
auto &grad_var = self.GradVarBase(); auto &grad_var = self.GradVarBase();
......
...@@ -194,14 +194,8 @@ static std::vector<std::shared_ptr<imperative::VarBase>> GetVarBaseList( ...@@ -194,14 +194,8 @@ static std::vector<std::shared_ptr<imperative::VarBase>> GetVarBaseList(
if (!py_obj || py_obj == Py_None) { if (!py_obj || py_obj == Py_None) {
PADDLE_THROW("Save parameter [%s] is None", para.first); PADDLE_THROW("Save parameter [%s] is None", para.first);
} }
const char *kIVarField = "_ivar";
PyObject *py_ivar = GetPythonAttribute(py_obj, kIVarField);
PADDLE_ENFORCE_NOT_NULL(py_ivar, "Can not find ivar in Variable");
vec_res.emplace_back( vec_res.emplace_back(
PyObjectCast<std::shared_ptr<imperative::VarBase>>(py_ivar)); PyObjectCast<std::shared_ptr<imperative::VarBase>>(py_obj));
Py_DECREF(py_ivar);
} }
return vec_res; return vec_res;
......
...@@ -486,7 +486,8 @@ inline framework::Tensor *PySliceTensor(const framework::Tensor &self, ...@@ -486,7 +486,8 @@ inline framework::Tensor *PySliceTensor(const framework::Tensor &self,
} }
} }
inline py::array TensorToPyArray(const framework::Tensor &tensor) { inline py::array TensorToPyArray(const framework::Tensor &tensor,
bool need_deep_copy = false) {
if (!tensor.IsInitialized()) { if (!tensor.IsInitialized()) {
return py::array(); return py::array();
} }
...@@ -510,9 +511,26 @@ inline py::array TensorToPyArray(const framework::Tensor &tensor) { ...@@ -510,9 +511,26 @@ inline py::array TensorToPyArray(const framework::Tensor &tensor) {
std::string py_dtype_str = details::TensorDTypeToPyDTypeStr(tensor.type()); std::string py_dtype_str = details::TensorDTypeToPyDTypeStr(tensor.type());
if (!is_gpu_tensor) { if (!is_gpu_tensor) {
return py::array(py::buffer_info( if (!need_deep_copy) {
const_cast<void *>(tensor_buf_ptr), sizeof_dtype, py_dtype_str, return py::array(py::buffer_info(
static_cast<size_t>(tensor.dims().size()), py_dims, py_strides)); const_cast<void *>(tensor_buf_ptr), sizeof_dtype, py_dtype_str,
static_cast<size_t>(tensor.dims().size()), py_dims, py_strides));
} else {
py::array py_arr(py::dtype(py_dtype_str.c_str()), py_dims, py_strides);
PADDLE_ENFORCE_EQ(py_arr.writeable(), true,
platform::errors::InvalidArgument(
"PyArray must be writable, otherwise memory leak "
"or double free would occur"));
PADDLE_ENFORCE_EQ(py_arr.owndata(), true,
platform::errors::InvalidArgument(
"PyArray must own data, otherwise memory leak "
"or double free would occur"));
platform::CPUPlace place;
size_t copy_bytes = sizeof_dtype * numel;
paddle::memory::Copy(place, py_arr.mutable_data(), place, tensor_buf_ptr,
copy_bytes);
return py_arr;
}
} }
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
......
...@@ -88,13 +88,13 @@ from .dygraph.nn import * ...@@ -88,13 +88,13 @@ from .dygraph.nn import *
from .dygraph.layers import * from .dygraph.layers import *
from .io import save, load, load_program_state, set_program_state from .io import save, load, load_program_state, set_program_state
from .dygraph.checkpoint import save_dygraph, load_dygraph from .dygraph.checkpoint import save_dygraph, load_dygraph
from .dygraph.varbase_patch_methods import monkey_patch_varbase
Tensor = LoDTensor Tensor = LoDTensor
__all__ = framework.__all__ + executor.__all__ + \ __all__ = framework.__all__ + executor.__all__ + \
trainer_desc.__all__ + inferencer.__all__ + transpiler.__all__ + \ trainer_desc.__all__ + inferencer.__all__ + transpiler.__all__ + \
parallel_executor.__all__ + lod_tensor.__all__ + \ parallel_executor.__all__ + lod_tensor.__all__ + \
data_feed_desc.__all__ + compiler.__all__ + backward.__all__ + [ data_feed_desc.__all__ + compiler.__all__ + backward.__all__ + [
'io', 'io',
'initializer', 'initializer',
'embedding', 'embedding',
...@@ -126,6 +126,7 @@ __all__ = framework.__all__ + executor.__all__ + \ ...@@ -126,6 +126,7 @@ __all__ = framework.__all__ + executor.__all__ + \
'install_check', 'install_check',
'save', 'save',
'load', 'load',
'VarBase'
] ]
...@@ -234,3 +235,4 @@ def __bootstrap__(): ...@@ -234,3 +235,4 @@ def __bootstrap__():
# Consider paddle.init(args) or paddle.main(args) # Consider paddle.init(args) or paddle.main(args)
monkey_patch_variable() monkey_patch_variable()
__bootstrap__() __bootstrap__()
monkey_patch_varbase()
...@@ -138,6 +138,7 @@ def guard(place=None): ...@@ -138,6 +138,7 @@ def guard(place=None):
train = framework.Program() train = framework.Program()
startup = framework.Program() startup = framework.Program()
tracer = Tracer() tracer = Tracer()
VarBase = core.VarBase
if place is None: if place is None:
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
...@@ -205,28 +206,21 @@ def to_variable(value, block=None, name=None, zero_copy=None): ...@@ -205,28 +206,21 @@ def to_variable(value, block=None, name=None, zero_copy=None):
if isinstance(value, np.ndarray): if isinstance(value, np.ndarray):
assert framework.in_dygraph_mode( assert framework.in_dygraph_mode(
), "to_variable could only be called in dygraph mode" ), "to_variable could only be called in dygraph mode"
if not block:
block = framework.default_main_program().current_block()
py_var = framework.Variable(
block,
type=core.VarDesc.VarType.LOD_TENSOR,
name=name,
shape=value.shape,
dtype=value.dtype,
stop_gradient=True)
var = py_var._ivar.value()
tensor = var.get_tensor()
if isinstance(framework._current_expected_place(), if isinstance(framework._current_expected_place(),
framework.core.CPUPlace): framework.core.CPUPlace):
if zero_copy is None: if zero_copy is None:
zero_copy = True zero_copy = True
tensor.set(value, framework._current_expected_place(), zero_copy)
else: else:
assert not zero_copy, "zero_copy mode can only be used with CPUPlace" assert not zero_copy, "zero_copy mode can only be used with CPUPlace"
tensor.set(value, framework._current_expected_place(), False) zero_copy = False
py_var = core.VarBase(
value=value,
name=name,
persistable=False,
place=framework._current_expected_place(),
zero_copy=zero_copy)
return py_var return py_var
elif isinstance(value, framework.Variable): elif isinstance(value, (core.VarBase, framework.Variable)):
return value return value
else: else:
raise TypeError( raise TypeError(
......
...@@ -33,7 +33,7 @@ def create_program_from_desc(program_desc): ...@@ -33,7 +33,7 @@ def create_program_from_desc(program_desc):
def _extract_vars(inputs, result_list): def _extract_vars(inputs, result_list):
if isinstance(inputs, Variable): if isinstance(inputs, Variable):
result_list.append(inputs._ivar) result_list.append(inputs)
if isinstance(inputs, (list, tuple)): if isinstance(inputs, (list, tuple)):
for var in inputs: for var in inputs:
...@@ -67,7 +67,7 @@ def _trace(layer, ...@@ -67,7 +67,7 @@ def _trace(layer,
outputs = [original_outputs] outputs = [original_outputs]
else: else:
outputs = original_outputs outputs = original_outputs
out_vars = [var._ivar for var in outputs] out_vars = [var for var in outputs]
program_desc, feed_names, fetch_names = tracer.create_program_desc( program_desc, feed_names, fetch_names = tracer.create_program_desc(
var_list, feed_prefix, out_vars, fetch_prefix, tmp_prefix) var_list, feed_prefix, out_vars, fetch_prefix, tmp_prefix)
...@@ -104,7 +104,7 @@ class TracedLayer(object): ...@@ -104,7 +104,7 @@ class TracedLayer(object):
self._scope = core.Scope() self._scope = core.Scope()
for p in parameters: for p in parameters:
src_tensor = p._ivar.value().get_tensor() src_tensor = p.value().get_tensor()
dst_tensor = self._scope.var(p.name).get_tensor() dst_tensor = self._scope.var(p.name).get_tensor()
dst_tensor._share_data_with(src_tensor) dst_tensor._share_data_with(src_tensor)
...@@ -234,7 +234,7 @@ class TracedLayer(object): ...@@ -234,7 +234,7 @@ class TracedLayer(object):
feed_dict = {} feed_dict = {}
if in_dygraph_mode(): if in_dygraph_mode():
for x, name in zip(inputs, self._feed_names): for x, name in zip(inputs, self._feed_names):
feed_dict[name] = x._ivar.value().get_tensor() feed_dict[name] = x.value().get_tensor()
else: else:
for x, name in zip(inputs, self._feed_names): for x, name in zip(inputs, self._feed_names):
feed_dict[name] = x feed_dict[name] = x
......
...@@ -25,7 +25,6 @@ from .layer_object_helper import LayerObjectHelper ...@@ -25,7 +25,6 @@ from .layer_object_helper import LayerObjectHelper
from .base import program_desc_tracing_guard from .base import program_desc_tracing_guard
from paddle.fluid import framework from paddle.fluid import framework
from ..param_attr import ParamAttr from ..param_attr import ParamAttr
from paddle.fluid.framework import Variable
__all__ = ['Layer'] __all__ = ['Layer']
......
...@@ -219,12 +219,8 @@ class DataParallel(layers.Layer): ...@@ -219,12 +219,8 @@ class DataParallel(layers.Layer):
grad_vars = [] grad_vars = []
for param in self._layers.parameters(): for param in self._layers.parameters():
# NOTE(zcd): The grad_ivar maybe no generated. # NOTE(zcd): The grad_ivar maybe no generated.
if param.trainable and param._ivar._grad_ivar(): if param.trainable and param._grad_ivar():
g_var = framework.Variable( g_var = param._grad_ivar()
block=self._helper.main_program.current_block(),
name=param._ivar._grad_name(),
stop_gradient=True,
ivar=param._ivar._grad_ivar())
grad_vars.append(g_var) grad_vars.append(g_var)
assert g_var not in grad_var_set assert g_var not in grad_var_set
grad_var_set.add(g_var) grad_var_set.add(g_var)
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .. import framework
from .. import core
from . import BackwardStrategy
from ..framework import Variable, _getitem_impl_
from .. import unique_name
import numpy as np
def monkey_patch_varbase():
# TODO(jiabin): move this to cplusplus end if we find some performance issue on it
@framework.dygraph_only
def set_value(self, value):
"""
**Notes**:
**This API is ONLY avaliable in Dygraph mode**
Set a new value for this Variable.
Args:
value (Variable|np.ndarray): the new value.
Examples:
.. code-block:: python
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph import FC
import numpy as np
data = np.ones([3, 32, 32], dtype='float32')
with fluid.dygraph.guard():
fc = fluid.dygraph.FC("fc", 4)
t = to_variable(data)
fc(t) # call with default weight
custom_weight = np.random.randn(1024, 4).astype("float32")
fc.weight.set_value(custom_weight) # change existing weight
out = fc(t) # call with different weight
"""
assert isinstance(value, (np.ndarray, core.VarBase)), \
"Variable set_value function, arguments type only support Variable, numpy, VarBase"
value_np = value
if isinstance(value, core.VarBase):
value_np = value.numpy()
self_tensor_np = self.numpy()
assert self_tensor_np.shape == value_np.shape, \
"Variable Shape not match, Variable [ {} ] need tensor with shape {} but load set tensor with shape {}".format(
self.name, self_tensor_np.shape, value_np.shape)
assert self_tensor_np.dtype == value_np.dtype, \
"Variable dtype not match, Variable [ {} ] need tensor with dtype {} but load tensor with dtype {}".format(
self.name, self_tensor_np.dtype, value_np.dtype)
self.value().get_tensor().set(value_np,
framework._current_expected_place())
@framework.dygraph_only
def backward(self, backward_strategy=None):
"""
**Notes**:
**This API is ONLY avaliable in Dygraph mode**
Run backward of current Graph which starts from current Variable
Args:
backward_strategy( :ref:`api_fluid_dygraph_BackwardStrategy` ): The Backward Strategy to run backward
Returns:
NoneType: None
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
x = np.ones([2, 2], np.float32)
with fluid.dygraph.guard():
inputs2 = []
for _ in range(10):
tmp = fluid.dygraph.base.to_variable(x)
# if we don't set tmp's stop_gradient as False then, all path to loss will has no gradient since
# there is no one need gradient on it.
tmp.stop_gradient=False
inputs2.append(tmp)
ret2 = fluid.layers.sums(inputs2)
loss2 = fluid.layers.reduce_sum(ret2)
backward_strategy = fluid.dygraph.BackwardStrategy()
backward_strategy.sort_sum_gradient = True
loss2.backward(backward_strategy)
"""
if framework.in_dygraph_mode():
if backward_strategy is None:
backward_strategy = BackwardStrategy()
backward_strategy.sort_sum_gradient = False
self._run_backward(backward_strategy, framework._dygraph_tracer())
else:
raise ValueError(
"Variable.backward() is only avaliable in DyGraph mode")
@framework.dygraph_only
def gradient(self):
"""
**Notes**:
**This API is ONLY avaliable in Dygraph mode**
Get the Gradient of Current Variable
Returns:
ndarray: Numpy value of the gradient of current Variable
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
x = np.ones([2, 2], np.float32)
with fluid.dygraph.guard():
inputs2 = []
for _ in range(10):
tmp = fluid.dygraph.base.to_variable(x)
tmp.stop_gradient=False
inputs2.append(tmp)
ret2 = fluid.layers.sums(inputs2)
loss2 = fluid.layers.reduce_sum(ret2)
backward_strategy = fluid.dygraph.BackwardStrategy()
backward_strategy.sort_sum_gradient = True
loss2.backward(backward_strategy)
print(loss2.gradient())
"""
if self._grad_ivar() is None:
raise ValueError(
"%s has no grad, Please set Variable.stop_gradient=False, or "
"check if this is the first and only variable need grad, if so, please set its pre-Variable's "
"stop_gradient=False, to make sure it has gradient " %
self.name)
new_ivar = self._grad_ivar()._copy_to(core.CPUPlace(), True)
if self._grad_ivar().type == core.VarDesc.VarType.SELECTED_ROWS:
return (np.array(new_ivar.value().get_selected_rows().get_tensor()),
np.array(new_ivar.value().get_selected_rows().rows()))
else:
return np.array(new_ivar.value().get_tensor())
def __str__(self):
return self.to_string(True)
@property
def block(self):
return framework.default_main_program().global_block()
def to_string(self, throw_on_error, with_details=False):
"""
Get debug string.
Args:
throw_on_error (bool): True if raise an exception when self is not initialized.
with_details (bool): more details about variables and parameters (e.g. trainable, optimize_attr, ...) will be printed when with_details is True. Default value is False;
Returns:
str: The debug string.
Examples:
.. code-block:: python
import paddle.fluid as fluid
cur_program = fluid.Program()
cur_block = cur_program.current_block()
new_variable = cur_block.create_var(name="X",
shape=[-1, 23, 48],
dtype='float32')
print(new_variable.to_string(True))
print("=============with detail===============")
print(new_variable.to_string(True, True))
"""
if framework.in_dygraph_mode():
# TODO(panyx0718): add more dygraph debug info.
tensor = self.value().get_tensor()
if tensor._is_initialized():
return 'name %s, dtype: %s shape: %s %s' % (
self.name, self.dtype, self.shape, str(tensor))
else:
return 'name %s, shape: %s, not inited' % (self.name,
self.shape)
def __getitem__(self, item):
return _getitem_impl_(self, item)
for method_name, method in (("set_value", set_value), ("block", block),
("backward", backward), ("gradient", gradient),
("__str__", __str__), ("to_string", to_string),
("__getitem__", __getitem__)):
setattr(core.VarBase, method_name, method)
...@@ -264,7 +264,7 @@ class GradClipByGlobalNorm(GradClipBase): ...@@ -264,7 +264,7 @@ class GradClipByGlobalNorm(GradClipBase):
if g is None: if g is None:
continue continue
merge_grad = g merge_grad = g
if g._ivar.type == core.VarDesc.VarType.SELECTED_ROWS: if g.type == core.VarDesc.VarType.SELECTED_ROWS:
merge_grad = layers.merge_selected_rows(g) merge_grad = layers.merge_selected_rows(g)
merge_grad = layers.get_tensor_from_selected_rows(merge_grad) merge_grad = layers.get_tensor_from_selected_rows(merge_grad)
power = layers.square(merge_grad) power = layers.square(merge_grad)
......
...@@ -260,19 +260,6 @@ def is_compiled_with_cuda(): ...@@ -260,19 +260,6 @@ def is_compiled_with_cuda():
return core.is_compiled_with_cuda() return core.is_compiled_with_cuda()
def _var_base_to_np(var_base):
"""
convert VarBase tp numpy
Args:
var_base(VarBase) : the VarBase to convert
Returns (np.ndarray): the np.ndarray contain the value of VarBase
"""
var = var_base._copy_to(core.CPUPlace(), True)
return np.array(var.value().get_tensor())
def cuda_places(device_ids=None): def cuda_places(device_ids=None):
""" """
**Note**: **Note**:
...@@ -558,6 +545,241 @@ def _debug_string_(proto, throw_on_error=True): ...@@ -558,6 +545,241 @@ def _debug_string_(proto, throw_on_error=True):
return proto.__str__() return proto.__str__()
def _varbase_creator(type=core.VarDesc.VarType.LOD_TENSOR,
name=None,
shape=None,
dtype=None,
persistable=None,
**kwargs):
if dtype is not None:
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
return core.VarBase(dtype if dtype else core.VarDesc.VarType.FP32,
list(shape) if shape else [], name, type
if type else core.VarDesc.VarType.LOD_TENSOR, True
if persistable else False)
class VariableMetaClass(type):
@classmethod
def __instancecheck__(cls, instance):
t = type(instance)
if in_dygraph_mode():
return issubclass(t, core.VarBase)
else:
return issubclass(t, Variable)
class ParameterMetaClass(VariableMetaClass):
@classmethod
def __instancecheck__(cls, instance):
t = type(instance)
if in_dygraph_mode():
return issubclass(t, ParamBase)
else:
return issubclass(t, Parameter)
def _getitem_impl_(var, item):
"""
Slice the variable.
Args:
item(int/slice/tuple) : the index.
Returns:
Sliced variable
"""
if not isinstance(item, tuple):
item = [item]
decrease_axis = []
slice_axis = []
slice_start = []
slice_end = []
slice_step = []
use_strided_slice = False
reverse_axis = []
def fill_constant(shape, value, force_cpu=False, out=None):
var.block.append_op(
type='fill_constant',
inputs={},
outputs={'Out': [out]},
attrs={
'shape': shape,
'dtype': out.dtype,
'value': float(value),
'force_cpu': force_cpu
},
stop_gradient=True)
out.stop_gradient = True
return out
for dim, slice_item in enumerate(item):
if isinstance(slice_item, slice):
start = slice_item.start
end = slice_item.stop
step = slice_item.step
if start is None and end is None and step is None:
continue
if step is None:
step = 1
if start is None and end is None:
assert (step == -1)
reverse_axis.append(dim)
continue
if start is None:
start = 0
if end is None:
end = 10000000
if step != 1:
use_strided_slice = True
slice_axis.append(dim)
slice_start.append(start)
slice_end.append(end)
slice_step.append(step)
else:
decrease_axis.append(dim)
slice_axis.append(dim)
slice_start.append(slice_item)
slice_step.append(1)
if isinstance(slice_item, Variable):
temp_1 = var.block.create_var(dtype='int32')
fill_constant([1], 1, force_cpu=True, out=temp_1)
temp_end = var.block.create_var(dtype='int32')
var.block.append_op(
type='elementwise_add',
inputs={'X': slice_item,
'Y': temp_1},
outputs={'Out': temp_end},
attrs={'axis': -1})
slice_end.append(temp_end)
else:
slice_end.append(slice_item + 1
if slice_item != -1 else 10000000)
def contain_var(one_list):
for ele in one_list:
if isinstance(ele, Variable):
return True
return False
def get_new_list_tensor(old_list):
new_list_tensor = []
for dim in old_list:
if isinstance(dim, Variable):
dim.stop_gradient = True
new_list_tensor.append(dim)
else:
assert (isinstance(dim, int))
temp_out = var.block.create_var(dtype='int32')
fill_constant([1], dim, force_cpu=True, out=temp_out)
new_list_tensor.append(temp_out)
return new_list_tensor
inputs = {'Input': [var]}
attrs = {
'axes': slice_axis,
'starts': [],
'ends': [],
'decrease_axis': decrease_axis
}
if (use_strided_slice == True):
attrs['strides'] = []
infer_flags = list(1 for i in range(len(slice_axis)))
# starts
if not contain_var(slice_start):
attrs['starts'] = slice_start
else:
inputs['StartsTensorList'] = get_new_list_tensor(slice_start)
for i, dim in enumerate(slice_start):
if isinstance(dim, Variable):
attrs['starts'].append(-1)
infer_flags[i] = -1
else:
attrs['starts'].append(dim)
# ends
if not contain_var(slice_end):
attrs['ends'] = slice_end
else:
inputs['EndsTensorList'] = get_new_list_tensor(slice_end)
for i, dim in enumerate(slice_end):
if isinstance(dim, Variable):
attrs['ends'].append(-1)
infer_flags[i] = -1
else:
attrs['ends'].append(dim)
# strides
if use_strided_slice == True:
if not contain_var(slice_step):
attrs['strides'] = slice_step
else:
inputs['StridesTensorList'] = get_new_list_tensor(slice_step)
for i, dim in enumerate(slice_step):
if isinstance(dim, Variable):
attrs['strides'].append(-1)
infer_flags[i] = -1
else:
attrs['strides'].append(dim)
# infer_flags
attrs['infer_flags'] = infer_flags
out = var
if use_strided_slice == False and len(slice_axis) > 0:
# append slice_op here
slice_out_var = var.block.create_var(
name=unique_name.generate_with_ignorable_key(var.name + "_slice"),
dtype=var.dtype)
var.block.append_op(
type="slice",
inputs=inputs,
outputs={'Out': [slice_out_var]},
attrs=attrs)
out = slice_out_var
elif use_strided_slice == True and len(slice_axis) > 0:
strided_slice_out_var = var.block.create_var(
name=unique_name.generate_with_ignorable_key(var.name +
"_strided_slice"),
dtype=var.dtype)
var.block.append_op(
type="strided_slice",
inputs=inputs,
outputs={'Out': [strided_slice_out_var]},
attrs=attrs)
out = strided_slice_out_var
if len(reverse_axis) > 0:
reverse_out_var = var.block.create_var(
name=unique_name.generate_with_ignorable_key(var.name +
"_slice_reverse"),
dtype=var.dtype)
var.block.append_op(
type="reverse",
inputs={'X': out},
outputs={'Out': [reverse_out_var]},
attrs={'axis': reverse_axis})
out = reverse_out_var
return out
@six.add_metaclass(VariableMetaClass)
class Variable(object): class Variable(object):
""" """
**Notes**: **Notes**:
...@@ -626,100 +848,83 @@ class Variable(object): ...@@ -626,100 +848,83 @@ class Variable(object):
self.belong_to_optimizer = belong_to_optimizer self.belong_to_optimizer = belong_to_optimizer
if in_dygraph_mode(): self.error_clip = error_clip
# record vars in tracer rather than blocks
self._ivar = kwargs.get("ivar", None) is_new_var = False
self.stop_gradient_ = kwargs.get("stop_gradient", True) name = cpt.to_text(name)
if not self._ivar: self.desc = self.block.desc.find_var(cpt.to_bytes(name))
self._ivar = core.VarBase(
name, type
if type else core.VarDesc.VarType.LOD_TENSOR, dtype
if dtype else core.VarDesc.VarType.FP32,
list(shape) if shape else [], True
if persistable else False)
if persistable:
_dygraph_tracer().trace_var(name, self)
self.op = None
else:
self.error_clip = error_clip
is_new_var = False if self.desc is None:
name = cpt.to_text(name) self.desc = self.block.desc.var(cpt.to_bytes(name))
self.desc = self.block.desc.find_var(cpt.to_bytes(name)) is_new_var = True
if self.desc is None: if is_new_var:
self.desc = self.block.desc.var(cpt.to_bytes(name)) self.desc.set_type(type)
is_new_var = True elif self.desc.type() != type:
raise ValueError("Variable {0} has been created before. The "
"previous type is {1}; the new type is {2}. They"
" are not matched".format(self.name,
self.desc.type(), type))
if shape is not None:
if is_new_var: if is_new_var:
self.desc.set_type(type) self.desc.set_shape(shape)
elif self.desc.type() != type: else:
raise ValueError( old_shape = self.shape
"Variable {0} has been created before. The " shape = tuple(shape)
"previous type is {1}; the new type is {2}. They" if shape != old_shape:
" are not matched".format(self.name, self.desc.type(), raise ValueError(
type)) "Variable {0} has been created before. the previous "
"shape is {1}; the new shape is {2}. They are not "
if shape is not None: "matched.".format(self.name, old_shape, shape))
if is_new_var: if dtype is not None:
self.desc.set_shape(shape) if is_new_var:
else: self.desc.set_dtype(dtype)
old_shape = self.shape else:
shape = tuple(shape) old_dtype = self.dtype
if shape != old_shape: if dtype != old_dtype:
raise ValueError( raise ValueError("Variable {0} has been created before. "
"Variable {0} has been created before. the previous " "The previous data type is {1}; the new "
"shape is {1}; the new shape is {2}. They are not " "data type is {2}. They are not "
"matched.".format(self.name, old_shape, shape)) "matched.".format(self.name, old_dtype,
if dtype is not None: dtype))
if is_new_var:
self.desc.set_dtype(dtype) if lod_level is not None:
else: if is_new_var:
old_dtype = self.dtype self.desc.set_lod_level(lod_level)
if dtype != old_dtype: else:
raise ValueError( if lod_level != self.lod_level:
"Variable {0} has been created before. " raise ValueError("Variable {0} has been created before. "
"The previous data type is {1}; the new " "The previous lod_level is {1}; the new "
"data type is {2}. They are not " "lod_level is {2}. They are not "
"matched.".format(self.name, old_dtype, dtype)) "matched".format(self.name, self.lod_level,
lod_level))
if lod_level is not None: if persistable is not None:
if is_new_var: if is_new_var:
self.desc.set_lod_level(lod_level) self.desc.set_persistable(persistable)
else: else:
if lod_level != self.lod_level: if persistable != self.persistable:
raise ValueError( raise ValueError(
"Variable {0} has been created before. " "Variable {0} has been created before."
"The previous lod_level is {1}; the new " "The previous persistable is {1}; the new "
"lod_level is {2}. They are not " "persistable is {2}. They are not matched".format(
"matched".format(self.name, self.lod_level, self.name, self.persistable, persistable))
lod_level))
if persistable is not None:
if is_new_var:
self.desc.set_persistable(persistable)
else:
if persistable != self.persistable:
raise ValueError(
"Variable {0} has been created before."
"The previous persistable is {1}; the new "
"persistable is {2}. They are not matched".format(
self.name, self.persistable, persistable))
if need_check_feed and is_new_var: if need_check_feed and is_new_var:
self.desc.set_need_check_feed(need_check_feed) self.desc.set_need_check_feed(need_check_feed)
if capacity is not None: if capacity is not None:
if is_new_var: if is_new_var:
self.desc.set_capacity(capacity) self.desc.set_capacity(capacity)
else: else:
# TODO(abhinavarora) : Compare with set capacity once, # TODO(abhinavarora) : Compare with set capacity once,
# get_capacity is implemented # get_capacity is implemented
pass pass
self.block.vars[name] = self self.block.vars[name] = self
self.op = None self.op = None
self._stop_gradient = stop_gradient self._stop_gradient = stop_gradient
self.is_data = is_data self.is_data = is_data
@dygraph_only @dygraph_only
def detach(self): def detach(self):
...@@ -749,16 +954,7 @@ class Variable(object): ...@@ -749,16 +954,7 @@ class Variable(object):
y = x.detach() y = x.detach()
""" """
if in_dygraph_mode(): pass
new_var = self._cloneVar()
self.block.append_op(
type="assign",
inputs={'X': [self]},
outputs={'Out': [new_var]},
stop_gradient=True)
return new_var
else:
raise AttributeError("static graph model DO NOT supprt detach")
@dygraph_only @dygraph_only
def numpy(self): def numpy(self):
...@@ -790,12 +986,7 @@ class Variable(object): ...@@ -790,12 +986,7 @@ class Variable(object):
print(x.numpy()) print(x.numpy())
""" """
pass
if not self._ivar.value().get_tensor()._is_initialized():
raise ValueError("%s is Empty, Please check if it has no data in" %
self.name)
new_ivar = self._ivar._copy_to(core.CPUPlace(), True)
return np.array(new_ivar.value().get_tensor())
@dygraph_only @dygraph_only
def set_value(self, value): def set_value(self, value):
...@@ -826,25 +1017,7 @@ class Variable(object): ...@@ -826,25 +1017,7 @@ class Variable(object):
out = fc(t) # call with different weight out = fc(t) # call with different weight
""" """
assert isinstance(value, (Variable, np.ndarray, core.VarBase)), \ pass
"Variable set_value function, arguments type only support Variable, numpy, VarBase"
value_np = value
if isinstance(value, Variable):
value_np = value.numpy()
elif isinstance(value, core.VarBase):
value_np = _var_base_to_np(value)
self_tensor = self._ivar.value().get_tensor()
self_tensor_np = np.array(self_tensor)
assert self_tensor_np.shape == value_np.shape, \
"Variable Shape not match, Variable [ {} ] need tensor with shape {} but load set tensor with shape {}".format( self._ivar.name, self_tensor_np.shape, value_np.shape)
assert self_tensor_np.dtype == value_np.dtype, \
"Variable dtype not match, Variable [ {} ] need tensor with dtype {} but load tensor with dtype {}".format( self._ivar.name, self_tensor_np.dtype, value_np.dtype)
self_tensor.set(value_np, _current_expected_place())
@dygraph_only @dygraph_only
def backward(self, backward_strategy=None): def backward(self, backward_strategy=None):
...@@ -882,16 +1055,7 @@ class Variable(object): ...@@ -882,16 +1055,7 @@ class Variable(object):
loss2.backward(backward_strategy) loss2.backward(backward_strategy)
""" """
if in_dygraph_mode(): pass
from .dygraph import BackwardStrategy
if backward_strategy is None:
backward_strategy = BackwardStrategy()
backward_strategy.sort_sum_gradient = False
self._ivar._run_backward(backward_strategy, _dygraph_tracer())
else:
raise ValueError(
"Variable.backward() is only avaliable in DyGraph mode")
@dygraph_only @dygraph_only
def gradient(self): def gradient(self):
...@@ -925,16 +1089,7 @@ class Variable(object): ...@@ -925,16 +1089,7 @@ class Variable(object):
print(loss2.gradient()) print(loss2.gradient())
""" """
if self._ivar._grad_ivar() is None: pass
raise ValueError("%s has no grad, Please set Variable.stop_gradient=False, or " \
"check if this is the first and only variable need grad, if so, please set its pre-Variable's " \
"stop_gradient=False, to make sure it has gradient " % self.name)
new_ivar = self._ivar._grad_ivar()._copy_to(core.CPUPlace(), True)
if self._ivar._grad_ivar().type == core.VarDesc.VarType.SELECTED_ROWS:
return (np.array(new_ivar.value().get_selected_rows().get_tensor()),
np.array(new_ivar.value().get_selected_rows().rows()))
else:
return np.array(new_ivar.value().get_tensor())
@dygraph_only @dygraph_only
def clear_gradient(self): def clear_gradient(self):
...@@ -971,7 +1126,7 @@ class Variable(object): ...@@ -971,7 +1126,7 @@ class Variable(object):
print("After clear {}".format(loss2.gradient())) print("After clear {}".format(loss2.gradient()))
""" """
self._ivar._clear_gradient() pass
def __str__(self): def __str__(self):
return self.to_string(True) return self.to_string(True)
...@@ -1004,14 +1159,7 @@ class Variable(object): ...@@ -1004,14 +1159,7 @@ class Variable(object):
print(new_variable.to_string(True, True)) print(new_variable.to_string(True, True))
""" """
if in_dygraph_mode(): if in_dygraph_mode():
# TODO(panyx0718): add more dygraph debug info. return
tensor = self._ivar.value().get_tensor()
if tensor._is_initialized():
return 'name %s, dtype: %s shape: %s %s' % (
self.name, self.dtype, self.shape, str(tensor))
else:
return 'name %s, shape: %s, not inited' % (self.name,
self.shape)
assert isinstance(throw_on_error, bool) and isinstance(with_details, assert isinstance(throw_on_error, bool) and isinstance(with_details,
bool) bool)
...@@ -1060,14 +1208,14 @@ class Variable(object): ...@@ -1060,14 +1208,14 @@ class Variable(object):
assert (out1.gradient() == 0).all() assert (out1.gradient() == 0).all()
""" """
if in_dygraph_mode(): if in_dygraph_mode():
return self._ivar.stop_gradient pass
else: else:
return self._stop_gradient return self._stop_gradient
@stop_gradient.setter @stop_gradient.setter
def stop_gradient(self, s): def stop_gradient(self, s):
if in_dygraph_mode(): if in_dygraph_mode():
self._ivar.stop_gradient = s pass
else: else:
self._stop_gradient = s self._stop_gradient = s
...@@ -1095,7 +1243,7 @@ class Variable(object): ...@@ -1095,7 +1243,7 @@ class Variable(object):
print("persistable of current Var is: {}".format(new_variable.persistable)) print("persistable of current Var is: {}".format(new_variable.persistable))
""" """
if in_dygraph_mode(): if in_dygraph_mode():
return self._ivar.persistable pass
else: else:
return self.desc.persistable() return self.desc.persistable()
...@@ -1127,7 +1275,7 @@ class Variable(object): ...@@ -1127,7 +1275,7 @@ class Variable(object):
print("name of current Var is: {}".format(new_variable.name)) print("name of current Var is: {}".format(new_variable.name))
""" """
if in_dygraph_mode(): if in_dygraph_mode():
return self._ivar.name pass
else: else:
return cpt.to_text(self.desc.name()) return cpt.to_text(self.desc.name())
...@@ -1154,7 +1302,7 @@ class Variable(object): ...@@ -1154,7 +1302,7 @@ class Variable(object):
@name.setter @name.setter
def name(self, new_name): def name(self, new_name):
if in_dygraph_mode(): if in_dygraph_mode():
self._ivar.name = new_name pass
else: else:
self.desc.set_name(new_name) self.desc.set_name(new_name)
...@@ -1179,7 +1327,7 @@ class Variable(object): ...@@ -1179,7 +1327,7 @@ class Variable(object):
""" """
# convert to tuple, make it as same as numpy API. # convert to tuple, make it as same as numpy API.
if in_dygraph_mode(): if in_dygraph_mode():
return self._ivar.shape pass
else: else:
return tuple(self.desc.shape()) return tuple(self.desc.shape())
...@@ -1202,7 +1350,7 @@ class Variable(object): ...@@ -1202,7 +1350,7 @@ class Variable(object):
print("Dtype of current Var is: {}".format(new_variable.dtype)) print("Dtype of current Var is: {}".format(new_variable.dtype))
""" """
if in_dygraph_mode(): if in_dygraph_mode():
return self._ivar.dtype pass
else: else:
return self.desc.dtype() return self.desc.dtype()
...@@ -1254,7 +1402,7 @@ class Variable(object): ...@@ -1254,7 +1402,7 @@ class Variable(object):
print("Type of current Var is: {}".format(new_variable.type)) print("Type of current Var is: {}".format(new_variable.type))
""" """
if in_dygraph_mode(): if in_dygraph_mode():
return self._ivar.type pass
else: else:
return self.desc.type() return self.desc.type()
...@@ -1446,200 +1594,7 @@ class Variable(object): ...@@ -1446,200 +1594,7 @@ class Variable(object):
raise IndexError("Valid index accept int or slice or tuple") raise IndexError("Valid index accept int or slice or tuple")
def __getitem__(self, item): def __getitem__(self, item):
""" return _getitem_impl_(self, item)
Slice the variable.
Args:
item(int/slice/tuple) : the index.
Returns:
Sliced variable
"""
if not isinstance(item, tuple):
item = [item]
decrease_axis = []
slice_axis = []
slice_start = []
slice_end = []
slice_step = []
use_strided_slice = False
reverse_axis = []
def fill_constant(shape, value, force_cpu=False, out=None):
self.block.append_op(
type='fill_constant',
inputs={},
outputs={'Out': [out]},
attrs={
'shape': shape,
'dtype': out.dtype,
'value': float(value),
'force_cpu': force_cpu
},
stop_gradient=True)
out.stop_gradient = True
return out
for dim, slice_item in enumerate(item):
if isinstance(slice_item, slice):
start = slice_item.start
end = slice_item.stop
step = slice_item.step
if start is None and end is None and step is None:
continue
if step is None:
step = 1
if start is None and end is None:
assert (step == -1)
reverse_axis.append(dim)
continue
if start is None:
start = 0
if end is None:
end = 10000000
if step != 1:
use_strided_slice = True
slice_axis.append(dim)
slice_start.append(start)
slice_end.append(end)
slice_step.append(step)
else:
decrease_axis.append(dim)
slice_axis.append(dim)
slice_start.append(slice_item)
slice_step.append(1)
if isinstance(slice_item, Variable):
temp_1 = self.block.create_var(dtype='int32')
fill_constant([1], 1, force_cpu=True, out=temp_1)
temp_end = self.block.create_var(dtype='int32')
self.block.append_op(
type='elementwise_add',
inputs={'X': slice_item,
'Y': temp_1},
outputs={'Out': temp_end},
attrs={'axis': -1})
slice_end.append(temp_end)
else:
slice_end.append(slice_item + 1
if slice_item != -1 else 10000000)
def contain_var(one_list):
for ele in one_list:
if isinstance(ele, Variable):
return True
return False
def get_new_list_tensor(old_list):
new_list_tensor = []
for dim in old_list:
if isinstance(dim, Variable):
dim.stop_gradient = True
new_list_tensor.append(dim)
else:
assert (isinstance(dim, int))
temp_out = self.block.create_var(dtype='int32')
fill_constant([1], dim, force_cpu=True, out=temp_out)
new_list_tensor.append(temp_out)
return new_list_tensor
inputs = {'Input': [self]}
attrs = {
'axes': slice_axis,
'starts': [],
'ends': [],
'decrease_axis': decrease_axis
}
if (use_strided_slice == True):
attrs['strides'] = []
infer_flags = list(1 for i in range(len(slice_axis)))
# starts
if not contain_var(slice_start):
attrs['starts'] = slice_start
else:
inputs['StartsTensorList'] = get_new_list_tensor(slice_start)
for i, dim in enumerate(slice_start):
if isinstance(dim, Variable):
attrs['starts'].append(-1)
infer_flags[i] = -1
else:
attrs['starts'].append(dim)
# ends
if not contain_var(slice_end):
attrs['ends'] = slice_end
else:
inputs['EndsTensorList'] = get_new_list_tensor(slice_end)
for i, dim in enumerate(slice_end):
if isinstance(dim, Variable):
attrs['ends'].append(-1)
infer_flags[i] = -1
else:
attrs['ends'].append(dim)
# strides
if use_strided_slice == True:
if not contain_var(slice_step):
attrs['strides'] = slice_step
else:
inputs['StridesTensorList'] = get_new_list_tensor(slice_step)
for i, dim in enumerate(slice_step):
if isinstance(dim, Variable):
attrs['strides'].append(-1)
infer_flags[i] = -1
else:
attrs['strides'].append(dim)
# infer_flags
attrs['infer_flags'] = infer_flags
out = self
if use_strided_slice == False and len(slice_axis) > 0:
# append slice_op here
slice_out_var = self.block.create_var(
name=unique_name.generate_with_ignorable_key(self.name +
"_slice"),
dtype=self.dtype)
self.block.append_op(
type="slice",
inputs=inputs,
outputs={'Out': [slice_out_var]},
attrs=attrs)
out = slice_out_var
elif use_strided_slice == True and len(slice_axis) > 0:
strided_slice_out_var = self.block.create_var(
name=unique_name.generate_with_ignorable_key(self.name +
"_strided_slice"),
dtype=self.dtype)
self.block.append_op(
type="strided_slice",
inputs=inputs,
outputs={'Out': [strided_slice_out_var]},
attrs=attrs)
out = strided_slice_out_var
if len(reverse_axis) > 0:
reverse_out_var = self.block.create_var(
name=unique_name.generate_with_ignorable_key(self.name +
"_slice_reverse"),
dtype=self.dtype)
self.block.append_op(
type="reverse",
inputs={'X': out},
outputs={'Out': [reverse_out_var]},
attrs={'axis': reverse_axis})
out = reverse_out_var
return out
def get_all_op_protos(): def get_all_op_protos():
...@@ -2347,9 +2302,12 @@ class Block(object): ...@@ -2347,9 +2302,12 @@ class Block(object):
if isinstance(item[1], Parameter)) if isinstance(item[1], Parameter))
def create_var(self, *args, **kwargs): def create_var(self, *args, **kwargs):
var = Variable(block=self, *args, **kwargs) if not in_dygraph_mode():
if 'initializer' in kwargs: var = Variable(block=self, *args, **kwargs)
kwargs['initializer'](var, self) if 'initializer' in kwargs:
kwargs['initializer'](var, self)
else:
var = _varbase_creator(*args, **kwargs)
return var return var
def has_var(self, name): def has_var(self, name):
...@@ -2396,18 +2354,31 @@ class Block(object): ...@@ -2396,18 +2354,31 @@ class Block(object):
# NOTE: v is destroyed by C++ after calling _rename_var. # NOTE: v is destroyed by C++ after calling _rename_var.
d = self.desc.find_var(cpt.to_bytes(new_name)) d = self.desc.find_var(cpt.to_bytes(new_name))
if var_type == "Parameter": if var_type == "Parameter":
var = Parameter( if not in_dygraph_mode():
self, var = Parameter(
d.shape(), self,
d.dtype(), d.shape(),
type=orig_var_type, d.dtype(),
name=new_name, type=orig_var_type,
stop_gradient=stop_gradient, name=new_name,
trainable=trainable, stop_gradient=stop_gradient,
optimize_attr=optimize_attr, trainable=trainable,
regularizer=regularizer, optimize_attr=optimize_attr,
gradient_clip_attr=gradient_clip_attr, regularizer=regularizer,
error_clip=error_clip) gradient_clip_attr=gradient_clip_attr,
error_clip=error_clip)
else:
var = ParamBase(
d.shape(),
d.dtype(),
type=orig_var_type,
name=new_name,
stop_gradient=stop_gradient,
trainable=trainable,
optimize_attr=optimize_attr,
regularizer=regularizer,
gradient_clip_attr=gradient_clip_attr,
error_clip=error_clip)
elif var_type == "Variable": elif var_type == "Variable":
var = Variable( var = Variable(
self, self,
...@@ -2430,7 +2401,11 @@ class Block(object): ...@@ -2430,7 +2401,11 @@ class Block(object):
def create_parameter(self, *args, **kwargs): def create_parameter(self, *args, **kwargs):
global_block = self.program.global_block() global_block = self.program.global_block()
param = Parameter(global_block, *args, **kwargs) param = None
if not in_dygraph_mode():
param = Parameter(global_block, *args, **kwargs)
else:
param = ParamBase(*args, **kwargs)
if 'initializer' in kwargs: if 'initializer' in kwargs:
def _is_inited_by(block, var): def _is_inited_by(block, var):
...@@ -2669,19 +2644,34 @@ class Block(object): ...@@ -2669,19 +2644,34 @@ class Block(object):
raise ValueError("_copy_param_info_from should be invoked with " raise ValueError("_copy_param_info_from should be invoked with "
"same topology") "same topology")
assert isinstance(v, Variable) assert isinstance(v, Variable)
new_p = Parameter( new_p = None
block=self, if not in_dygraph_mode():
shape=v.shape, new_p = Parameter(
dtype=v.dtype, block=self,
type=v.type, shape=v.shape,
lod_level=v.lod_level, dtype=v.dtype,
stop_gradient=p.stop_gradient, type=v.type,
trainable=p.trainable, lod_level=v.lod_level,
optimize_attr=p.optimize_attr, stop_gradient=p.stop_gradient,
regularizer=p.regularizer, trainable=p.trainable,
gradient_clip_attr=p.gradient_clip_attr, optimize_attr=p.optimize_attr,
error_clip=p.error_clip, regularizer=p.regularizer,
name=v.name) gradient_clip_attr=p.gradient_clip_attr,
error_clip=p.error_clip,
name=v.name)
else:
new_p = ParamBase(
shape=v.shape,
dtype=v.dtype,
type=v.type,
lod_level=v.lod_level,
stop_gradient=p.stop_gradient,
trainable=p.trainable,
optimize_attr=p.optimize_attr,
regularizer=p.regularizer,
gradient_clip_attr=p.gradient_clip_attr,
error_clip=p.error_clip,
name=v.name)
self.vars[new_p.name] = new_p self.vars[new_p.name] = new_p
def _clone_variable(self, var, force_persistable=True): def _clone_variable(self, var, force_persistable=True):
...@@ -4485,6 +4475,7 @@ class Program(object): ...@@ -4485,6 +4475,7 @@ class Program(object):
yield each_var yield each_var
@six.add_metaclass(ParameterMetaClass)
class Parameter(Variable): class Parameter(Variable):
""" """
Parameter is derived from Variable. A parameter is a persistable Parameter is derived from Variable. A parameter is a persistable
...@@ -4580,6 +4571,111 @@ class Parameter(Variable): ...@@ -4580,6 +4571,111 @@ class Parameter(Variable):
__repr__ = __str__ __repr__ = __str__
class ParamBase(core.VarBase):
"""
ParamBase is derived from VarBase( Which is the Variable in Dygraph Mode ). A ParamBase is a persistable
VarBase, and will be updated by optimizers after each iteration.
The training of a neural network is essentially the updating of
its ParamBase.
Relative to a general Variable, a ParamBase has several its own
member variables:
Args:
trainable(bool): True if the ParamBase need to be updated after
iterations.
optimize_attr(map): ParamBase attributes related with optimizing.
Currently, it only contains 'learning_rate'.
Default: {'learning_rate': 1.0}
regularizer(WeightDecayRegularizer): The Regularizer which will
be applied on the ParamBase. Default: None
gradient_clip_attr(BaseGradientClipAttr): The gradint clip strategy
which will be applied on the ParamBase. Default: None
do_model_average(bool): True if the model average strategy will
be applied on this ParamBase.
"""
@dygraph_only
def __init__(self, shape, dtype, **kwargs):
if shape is None:
raise ValueError("The shape of Parameter should not be None")
if dtype is None:
raise ValueError("The dtype of Parameter should not be None")
if len(shape) == 0:
raise ValueError(
"The dimensions of shape for Parameter must be greater than 0")
for each in shape:
if each < 0:
raise ValueError(
"Each dimension of shape for Parameter must be greater than 0, but received %s"
% list(shape))
if dtype is not None:
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
name = kwargs.get('name', unique_name.generate('_param_base'))
super(ParamBase, self).__init__(dtype
if dtype else core.VarDesc.VarType.FP32,
list(shape) if shape else [], name,
core.VarDesc.VarType.LOD_TENSOR, True)
self.trainable = kwargs.get('trainable', True)
self.optimize_attr = kwargs.get('optimize_attr', {'learning_rate': 1.0})
self.regularizer = kwargs.get('regularizer', None)
self.gradient_clip_attr = kwargs.get('gradient_clip_attr', None)
self.do_model_average = kwargs.get('do_model_average', None)
self.is_distributed = False
#self.block = default_main_program().global_block()
_dygraph_tracer().trace_var(name, self)
def __str__(self):
return self.to_string(True)
def to_string(self, throw_on_error, with_details=False):
"""
To debug string.
Args:
throw_on_error(bool): raise exception when self is not initialized
when throw_on_error is True
with_details(bool): more details about variables and parameters
(e.g. trainable, optimize_attr, ...) will be printed when with_details is True
Returns(str): The debug string.
Examples:
.. code-block:: python
import paddle.fluid as fluid
prog = fluid.default_main_program()
rlt = fluid.layers.data("fake_data", shape=[1,1], dtype='float32')
debug_str = prog.to_string(throw_on_error=True, with_details=False)
print(debug_str)
"""
assert isinstance(throw_on_error, bool) and isinstance(with_details,
bool)
tensor = self.value().get_tensor()
if tensor._is_initialized():
return 'name %s, dtype: %s shape: %s %s' % (self.name, self.dtype,
self.shape, str(tensor))
else:
return 'name %s, shape: %s, not inited' % (self.name, self.shape)
__repr__ = __str__
# program is a global instance. # program is a global instance.
_main_program_ = Program() _main_program_ = Program()
_startup_program_ = Program() _startup_program_ = Program()
......
...@@ -44,33 +44,46 @@ class LayerHelperBase(object): ...@@ -44,33 +44,46 @@ class LayerHelperBase(object):
def startup_program(self): def startup_program(self):
return default_startup_program() return default_startup_program()
def to_variable(self, value, block=None): def to_variable(self, value, name=None):
"""convert value to variable """
The API will create a ``Variable`` object from numpy\.ndarray or Variable object.
Parameters:
value(ndarray): The numpy\.ndarray object that needs to be converted, it can be multi-dimension, and the data type is one of numpy\.{float16, float32, float64, int16, int32, int64, uint8, uint16}.
block(fluid.Block, optional): Which block this variable will be in. Default: None.
name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`
Returns:
Variable: ``Tensor`` created from the specified numpy\.ndarray object, data type and shape is the same as ``value`` .
Examples:
Args: .. code-block:: python
value: value to be convert
block: the block of the variable import numpy as np
import paddle.fluid as fluid
with fluid.dygraph.guard():
x = np.ones([2, 2], np.float32)
y = fluid.dygraph.to_variable(x)
Return Variable construct from value
""" """
if isinstance(value, np.ndarray): if isinstance(value, np.ndarray):
assert in_dygraph_mode( assert in_dygraph_mode(
), "to_variable could only be called in dygraph mode" ), "to_variable could only be called in dygraph mode"
py_var = core.VarBase(
if not block: value=value,
block = default_main_program().current_block() name=name,
py_var = Variable( persistable=False,
block, place=_current_expected_place(),
type=core.VarDesc.VarType.LOD_TENSOR, zero_copy=False)
name=None,
shape=value.shape,
dtype=value.dtype)
var = py_var._ivar.value()
tensor = var.get_tensor()
tensor.set(value, _current_expected_place())
return py_var return py_var
elif isinstance(value, Variable): elif isinstance(value, (core.VarBase, Variable)):
return value return value
else:
raise TypeError(
"to_variable only accepts 'ndarray' or 'Variable' or 'VarBase' as value's input"
)
def _create_weight_normalize(self, attr, shape, dtype): def _create_weight_normalize(self, attr, shape, dtype):
from .layers import elementwise_mul, elementwise_div, reshape from .layers import elementwise_mul, elementwise_div, reshape
...@@ -386,7 +399,7 @@ class LayerHelperBase(object): ...@@ -386,7 +399,7 @@ class LayerHelperBase(object):
""" """
assert isinstance(var, Variable) assert isinstance(var, Variable)
if in_dygraph_mode(): if in_dygraph_mode():
initializer(var, var.block) initializer(var, self.main_program.global_block())
else: else:
self.startup_program.global_block().create_var( self.startup_program.global_block().create_var(
name=var.name, name=var.name,
......
...@@ -233,6 +233,7 @@ def cross_entropy(input, label, soft_label=False, ignore_index=kIgnoreIndex): ...@@ -233,6 +233,7 @@ def cross_entropy(input, label, soft_label=False, ignore_index=kIgnoreIndex):
predict = fluid.layers.fc(input=x, size=class_num, act='softmax') predict = fluid.layers.fc(input=x, size=class_num, act='softmax')
cost = fluid.layers.cross_entropy(input=predict, label=label) cost = fluid.layers.cross_entropy(input=predict, label=label)
""" """
check_type_and_dtype(input, 'input', Variable, check_type_and_dtype(input, 'input', Variable,
['float16', 'float32', 'float64'], 'cross_entropy') ['float16', 'float32', 'float64'], 'cross_entropy')
if not soft_label: if not soft_label:
...@@ -729,7 +730,6 @@ def nce(input, ...@@ -729,7 +730,6 @@ def nce(input,
sampler = 1 sampler = 1
elif sampler == "custom_dist": elif sampler == "custom_dist":
assert custom_dist is not None assert custom_dist is not None
# assert isinstance(custom_dist, Variable)
custom_dist_len = num_total_classes custom_dist_len = num_total_classes
alias_probs_ = [0] * custom_dist_len alias_probs_ = [0] * custom_dist_len
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
from __future__ import print_function from __future__ import print_function
from .. import core from .. import core
from ..framework import Variable, unique_name from ..framework import Variable, unique_name, in_dygraph_mode, default_main_program
from .layer_function_generator import OpProtoHolder from .layer_function_generator import OpProtoHolder
from ..initializer import force_init_on_cpu from ..initializer import force_init_on_cpu
...@@ -40,7 +40,10 @@ def monkey_patch_variable(): ...@@ -40,7 +40,10 @@ def monkey_patch_variable():
return dtype return dtype
def current_block(var): def current_block(var):
return var.block if in_dygraph_mode():
return default_main_program().global_block()
else:
return var.block
def create_new_tmp_var(block, dtype): def create_new_tmp_var(block, dtype):
tmp_name = unique_tmp_name() tmp_name = unique_tmp_name()
...@@ -281,5 +284,9 @@ def monkey_patch_variable(): ...@@ -281,5 +284,9 @@ def monkey_patch_variable():
setattr(Variable, method_name, setattr(Variable, method_name,
_elemwise_method_creator_(method_name, op_type, reverse, _elemwise_method_creator_(method_name, op_type, reverse,
scalar_method)) scalar_method))
setattr(core.VarBase, method_name,
_elemwise_method_creator_(method_name, op_type, reverse,
scalar_method))
Variable.astype = astype Variable.astype = astype
setattr(core.VarBase, "astype", astype)
...@@ -32,7 +32,6 @@ from .layers import ops ...@@ -32,7 +32,6 @@ from .layers import ops
from .regularizer import append_regularization_ops from .regularizer import append_regularization_ops
from .dygraph import base as imperative_base from .dygraph import base as imperative_base
from .dygraph.learning_rate_scheduler import LearningRateDecay from .dygraph.learning_rate_scheduler import LearningRateDecay
from .framework import _var_base_to_np
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.layers import tensor from paddle.fluid.layers import tensor
from functools import reduce from functools import reduce
...@@ -122,7 +121,13 @@ class Optimizer(object): ...@@ -122,7 +121,13 @@ class Optimizer(object):
state_dict[var_tmp.name] = var_tmp state_dict[var_tmp.name] = var_tmp
# global step if use lr decay # global step if use lr decay
if isinstance(self._learning_rate, LearningRateDecay): if isinstance(self._learning_rate, LearningRateDecay):
var_temp = Variable(None, name='global_step', dtype='int32') var_tmp = None
if not framework.in_dygraph_mode():
var_temp = Variable(None, name='global_step', dtype='int32')
else:
var_temp = framework._varbase_creator(
None, name='global_step', dtype='int32')
tensor.fill_constant( tensor.fill_constant(
[1], "int32", self._learning_rate.step_num, out=var_temp) [1], "int32", self._learning_rate.step_num, out=var_temp)
...@@ -164,7 +169,7 @@ class Optimizer(object): ...@@ -164,7 +169,7 @@ class Optimizer(object):
global_step = state_dict['global_step'] global_step = state_dict['global_step']
if isinstance(global_step, core.VarBase): if isinstance(global_step, core.VarBase):
step_np = global_step._copy_to(core.CPUPlace(), True) step_np = global_step
step_np = np.array(step_np.value().get_tensor()) step_np = np.array(step_np.value().get_tensor())
assert step_np.shape == (1,), \ assert step_np.shape == (1,), \
"global step shape is (1,), the shape is {}".format( step_np.shape ) "global step shape is (1,), the shape is {}".format( step_np.shape )
...@@ -189,7 +194,7 @@ class Optimizer(object): ...@@ -189,7 +194,7 @@ class Optimizer(object):
for para_name, var_tmp in v.items(): for para_name, var_tmp in v.items():
assert var_tmp.name in state_dict, \ assert var_tmp.name in state_dict, \
"optimizer variable {} not found".format( var_tmp.name ) "optimizer variable {} not found".format( var_tmp.name )
var = var_tmp._ivar.value() var = var_tmp.value()
tensor = var.get_tensor() tensor = var.get_tensor()
model_np = np.array(tensor) model_np = np.array(tensor)
...@@ -198,7 +203,7 @@ class Optimizer(object): ...@@ -198,7 +203,7 @@ class Optimizer(object):
if isinstance(load_para, Variable): if isinstance(load_para, Variable):
load_para_np = load_para.numpy() load_para_np = load_para.numpy()
elif isinstance(load_para, core.VarBase): elif isinstance(load_para, core.VarBase):
load_para_np = _var_base_to_np(load_para) load_para_np = load_para.numpy()
elif isinstance(load_para, np.ndarray): elif isinstance(load_para, np.ndarray):
load_para_np = load_para load_para_np = load_para
else: else:
...@@ -515,7 +520,11 @@ class Optimizer(object): ...@@ -515,7 +520,11 @@ class Optimizer(object):
Examples: Examples:
See examples in ``apply_gradients``. See examples in ``apply_gradients``.
""" """
no_grad_set = self._get_no_grad_set(loss, no_grad_set) act_no_grad_set = None
if not framework.in_dygraph_mode():
act_no_grad_set = self._get_no_grad_set(loss, no_grad_set)
else:
pass
self._dtype = loss.dtype self._dtype = loss.dtype
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
...@@ -528,15 +537,9 @@ class Optimizer(object): ...@@ -528,15 +537,9 @@ class Optimizer(object):
for param in parameters: for param in parameters:
if not param.trainable: if not param.trainable:
continue continue
if param._ivar._grad_ivar() is not None: if param._grad_ivar() is not None:
ivar_type = param._ivar._grad_ivar().type
# create gradient variable # create gradient variable
grad_var = Variable( grad_var = param._grad_ivar()
block=loss.block,
type=ivar_type,
name=param._ivar._grad_name(),
stop_gradient=True,
ivar=param._ivar._grad_ivar())
params_grads.append((param, grad_var)) params_grads.append((param, grad_var))
else: else:
if callbacks is None: if callbacks is None:
...@@ -550,7 +553,7 @@ class Optimizer(object): ...@@ -550,7 +553,7 @@ class Optimizer(object):
loss.shape) loss.shape)
with program_guard(program, startup_program): with program_guard(program, startup_program):
params_grads = append_backward(loss, parameter_list, params_grads = append_backward(loss, parameter_list,
no_grad_set, callbacks) act_no_grad_set, callbacks)
# Note: since we can't use all_reduce_op now, # Note: since we can't use all_reduce_op now,
# dgc_op should be the last op of one grad. # dgc_op should be the last op of one grad.
self._append_dgc_ops(params_grads) self._append_dgc_ops(params_grads)
......
...@@ -268,7 +268,7 @@ class OpTest(unittest.TestCase): ...@@ -268,7 +268,7 @@ class OpTest(unittest.TestCase):
data = value[0] data = value[0]
lod = value[1] lod = value[1]
v = fluid.dygraph.base.to_variable(value=data) v = fluid.dygraph.base.to_variable(value=data)
v._ivar.value().get_tensor().set_recursive_sequence_lengths(lod) v.value().get_tensor().set_recursive_sequence_lengths(lod)
return v return v
else: else:
return fluid.dygraph.base.to_variable(value) return fluid.dygraph.base.to_variable(value)
...@@ -289,7 +289,7 @@ class OpTest(unittest.TestCase): ...@@ -289,7 +289,7 @@ class OpTest(unittest.TestCase):
if if_return_inputs_grad_dict: if if_return_inputs_grad_dict:
v.stop_gradient = False v.stop_gradient = False
if has_lod: if has_lod:
v._ivar.value().get_tensor().set_recursive_sequence_lengths( v.value().get_tensor().set_recursive_sequence_lengths(
lod_temp) lod_temp)
else: else:
v = block.create_var( v = block.create_var(
...@@ -840,8 +840,8 @@ class OpTest(unittest.TestCase): ...@@ -840,8 +840,8 @@ class OpTest(unittest.TestCase):
if check_dygraph: if check_dygraph:
imperative_actual = find_imperative_actual( imperative_actual = find_imperative_actual(
sub_out_name, dygraph_outs, place) sub_out_name, dygraph_outs, place)
imperative_actual_t = np.array( imperative_actual_t = np.array(imperative_actual.value()
imperative_actual._ivar.value().get_tensor()) .get_tensor())
idx = find_actual(sub_out_name, fetch_list) idx = find_actual(sub_out_name, fetch_list)
actual = outs[idx] actual = outs[idx]
actual_t = np.array(actual) actual_t = np.array(actual)
...@@ -868,7 +868,7 @@ class OpTest(unittest.TestCase): ...@@ -868,7 +868,7 @@ class OpTest(unittest.TestCase):
") has different lod at " + str(place)) ") has different lod at " + str(place))
if check_dygraph: if check_dygraph:
self.assertListEqual( self.assertListEqual(
imperative_actual._ivar.value().get_tensor() imperative_actual.value().get_tensor()
.recursive_sequence_lengths(), expect[1], .recursive_sequence_lengths(), expect[1],
"Output (" + out_name + "Output (" + out_name +
") has different lod at " + str(place) + ") has different lod at " + str(place) +
...@@ -877,8 +877,8 @@ class OpTest(unittest.TestCase): ...@@ -877,8 +877,8 @@ class OpTest(unittest.TestCase):
if check_dygraph: if check_dygraph:
imperative_actual = find_imperative_actual( imperative_actual = find_imperative_actual(
out_name, dygraph_outs, place) out_name, dygraph_outs, place)
imperative_actual_t = np.array( imperative_actual_t = np.array(imperative_actual.value()
imperative_actual._ivar.value().get_tensor()) .get_tensor())
idx = find_actual(out_name, fetch_list) idx = find_actual(out_name, fetch_list)
actual = outs[idx] actual = outs[idx]
actual_t = np.array(actual) actual_t = np.array(actual)
...@@ -913,7 +913,7 @@ class OpTest(unittest.TestCase): ...@@ -913,7 +913,7 @@ class OpTest(unittest.TestCase):
") has different lod at " + str(place)) ") has different lod at " + str(place))
if check_dygraph: if check_dygraph:
self.assertListEqual( self.assertListEqual(
imperative_actual._ivar.value().get_tensor() imperative_actual.value().get_tensor()
.recursive_sequence_lengths(), expect[1], .recursive_sequence_lengths(), expect[1],
"Output (" + out_name + ") has different lod at " + "Output (" + out_name + ") has different lod at " +
str(place) + " in dygraph mode") str(place) + " in dygraph mode")
......
...@@ -153,8 +153,8 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -153,8 +153,8 @@ class TestImperativeAutoPrune(unittest.TestCase):
v2 = fluid.dygraph.to_variable(value2) v2 = fluid.dygraph.to_variable(value2)
loss = case1(v1, v2) loss = case1(v1, v2)
loss.backward() loss.backward()
self.assertTrue(case1.fc2._w._ivar._grad_ivar() is not None) self.assertTrue(case1.fc2._w._grad_ivar() is not None)
self.assertTrue(case1.fc1._w._ivar._grad_ivar() is not None) self.assertTrue(case1.fc1._w._grad_ivar() is not None)
def test_auto_prune2(self): def test_auto_prune2(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
...@@ -166,8 +166,8 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -166,8 +166,8 @@ class TestImperativeAutoPrune(unittest.TestCase):
loss = case2(v1, v2) loss = case2(v1, v2)
loss.backward() loss.backward()
self.assertTrue(case2.fc2._w._ivar._grad_ivar() is None) self.assertTrue(case2.fc2._w._grad_ivar() is None)
self.assertTrue(case2.fc1._w._ivar._grad_ivar() is not None) self.assertTrue(case2.fc1._w._grad_ivar() is not None)
def test_auto_prune3(self): def test_auto_prune3(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
...@@ -178,7 +178,7 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -178,7 +178,7 @@ class TestImperativeAutoPrune(unittest.TestCase):
v2 = fluid.dygraph.to_variable(value2) v2 = fluid.dygraph.to_variable(value2)
loss, part2 = case3(v1, v2, 1) loss, part2 = case3(v1, v2, 1)
loss.backward() loss.backward()
self.assertTrue(case3.fc._w._ivar._grad_ivar() is not None) self.assertTrue(case3.fc._w._grad_ivar() is not None)
self.assertTrue((part2.gradient() == 0).all()) self.assertTrue((part2.gradient() == 0).all())
def test_auto_prune4(self): def test_auto_prune4(self):
...@@ -190,7 +190,7 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -190,7 +190,7 @@ class TestImperativeAutoPrune(unittest.TestCase):
v2 = fluid.dygraph.to_variable(value2) v2 = fluid.dygraph.to_variable(value2)
loss, part2 = case4(v1, v2, 1) loss, part2 = case4(v1, v2, 1)
part2.backward() part2.backward()
self.assertTrue(case4.fc._w._ivar._grad_ivar() is not None) self.assertTrue(case4.fc._w._grad_ivar() is not None)
self.assertTrue((part2.gradient() == 1).all()) self.assertTrue((part2.gradient() == 1).all())
def test_auto_prune5(self): def test_auto_prune5(self):
...@@ -202,7 +202,7 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -202,7 +202,7 @@ class TestImperativeAutoPrune(unittest.TestCase):
v2 = fluid.dygraph.to_variable(value2) v2 = fluid.dygraph.to_variable(value2)
loss, part1, part2 = case4(v1, v2, 2) loss, part1, part2 = case4(v1, v2, 2)
part1.backward() part1.backward()
self.assertTrue(case4.fc._w._ivar._grad_ivar() is not None) self.assertTrue(case4.fc._w._grad_ivar() is not None)
self.assertTrue((part2.gradient() == 0).all()) self.assertTrue((part2.gradient() == 0).all())
def test_auto_prune6(self): def test_auto_prune6(self):
...@@ -333,8 +333,8 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -333,8 +333,8 @@ class TestImperativeAutoPrune(unittest.TestCase):
for items in params_grads: for items in params_grads:
assert items[0].name is not model.embed1._w.name assert items[0].name is not model.embed1._w.name
assert items[0].name is not model.fc1._w.name assert items[0].name is not model.fc1._w.name
assert model.embed1._w._ivar._grad_ivar() is None assert model.embed1._w._grad_ivar() is None
assert model.fc1._w._ivar._grad_ivar() is None assert model.fc1._w._grad_ivar() is None
with fluid.dygraph.guard(place): with fluid.dygraph.guard(place):
model = MyLayer2("mylayer", vocab_size, size) model = MyLayer2("mylayer", vocab_size, size)
...@@ -351,8 +351,8 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -351,8 +351,8 @@ class TestImperativeAutoPrune(unittest.TestCase):
for items in params_grads: for items in params_grads:
assert items[0].name is not model.embed1._w.name assert items[0].name is not model.embed1._w.name
assert items[0].name is not model.fc1._w.name assert items[0].name is not model.fc1._w.name
assert model.embed1._w._ivar._grad_ivar() is None assert model.embed1._w._grad_ivar() is None
assert model.fc1._w._ivar._grad_ivar() is None assert model.fc1._w._grad_ivar() is None
def test_case2_prune_no_grad_branch(self): def test_case2_prune_no_grad_branch(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
...@@ -363,8 +363,8 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -363,8 +363,8 @@ class TestImperativeAutoPrune(unittest.TestCase):
case3 = AutoPruneLayer2("l2") case3 = AutoPruneLayer2("l2")
loss = case3(v1, v2) loss = case3(v1, v2)
loss.backward() loss.backward()
self.assertTrue(case3.fc2._w._ivar._grad_ivar() is None) self.assertTrue(case3.fc2._w._grad_ivar() is None)
self.assertTrue(case3.fc._w._ivar._grad_ivar() is not None) self.assertTrue(case3.fc._w._grad_ivar() is not None)
def test_case2_prune_no_grad_branch(self): def test_case2_prune_no_grad_branch(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
...@@ -375,8 +375,8 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -375,8 +375,8 @@ class TestImperativeAutoPrune(unittest.TestCase):
case3 = AutoPruneLayer2("l2") case3 = AutoPruneLayer2("l2")
loss = case3(v1, v2) loss = case3(v1, v2)
loss.backward() loss.backward()
self.assertTrue(case3.fc2._w._ivar._grad_ivar() is None) self.assertTrue(case3.fc2._w._grad_ivar() is None)
self.assertTrue(case3.fc._w._ivar._grad_ivar() is not None) self.assertTrue(case3.fc._w._grad_ivar() is not None)
def test_case3_prune_no_grad_branch2(self): def test_case3_prune_no_grad_branch2(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
...@@ -389,14 +389,14 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -389,14 +389,14 @@ class TestImperativeAutoPrune(unittest.TestCase):
out = fluid.layers.one_hot(input=label, depth=100) out = fluid.layers.one_hot(input=label, depth=100)
loss = fluid.layers.mean(out) loss = fluid.layers.mean(out)
loss.backward() loss.backward()
self.assertTrue(fc._w._ivar._grad_ivar() is None) self.assertTrue(fc._w._grad_ivar() is None)
def test_case4_with_no_grad_op_maker(self): def test_case4_with_no_grad_op_maker(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
out = fluid.layers.gaussian_random(shape=[20, 30]) out = fluid.layers.gaussian_random(shape=[20, 30])
loss = fluid.layers.mean(out) loss = fluid.layers.mean(out)
loss.backward() loss.backward()
self.assertTrue(out._ivar._grad_ivar() is None) self.assertTrue(out._grad_ivar() is None)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -177,6 +177,30 @@ class SimpleRNN(fluid.Layer): ...@@ -177,6 +177,30 @@ class SimpleRNN(fluid.Layer):
class TestImperative(unittest.TestCase): class TestImperative(unittest.TestCase):
def test_isinstance(self):
var = fluid.layers.data(shape=[1], name='x', dtype='float32')
self.assertTrue(isinstance(var, fluid.Variable))
with fluid.dygraph.guard():
var_base = fluid.dygraph.base.to_variable(np.array([3, 4, 5]))
self.assertTrue(isinstance(var_base, core.VarBase))
self.assertTrue(isinstance(var_base, fluid.Variable))
def test_create_VarBase(self):
x = np.ones([2, 2], np.float32)
y = np.zeros([3, 3], np.float32)
with fluid.dygraph.guard():
tmp = fluid.core.VarBase(value=x, place=fluid.core.CPUPlace())
tmp2 = fluid.core.VarBase(y, fluid.core.CPUPlace())
tmp3 = fluid.dygraph.base.to_variable(x)
tmp4 = fluid.core.VarBase(y)
tmp5 = fluid.core.VarBase(value=x)
self.assertTrue(np.array_equal(x, tmp.numpy()))
self.assertTrue(np.array_equal(y, tmp2.numpy()))
self.assertTrue(np.array_equal(x, tmp3.numpy()))
self.assertTrue(np.array_equal(y, tmp4.numpy()))
self.assertTrue(np.array_equal(x, tmp5.numpy()))
def test_sum_op(self): def test_sum_op(self):
x = np.ones([2, 2], np.float32) x = np.ones([2, 2], np.float32)
with fluid.dygraph.guard(): with fluid.dygraph.guard():
...@@ -215,17 +239,17 @@ class TestImperative(unittest.TestCase): ...@@ -215,17 +239,17 @@ class TestImperative(unittest.TestCase):
try: try:
new_variable.numpy() new_variable.numpy()
except Exception as e: except Exception as e:
assert type(e) == ValueError assert type(e) == core.EnforceNotMet
try: try:
new_variable.backward() new_variable.backward()
except Exception as e: except Exception as e:
assert type(e) == ValueError assert type(e) == core.EnforceNotMet
try: try:
new_variable.clear_gradient() new_variable.clear_gradient()
except Exception as e: except Exception as e:
assert type(e) == ValueError assert type(e) == core.EnforceNotMet
def test_empty_grad(self): def test_empty_grad(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
...@@ -239,7 +263,7 @@ class TestImperative(unittest.TestCase): ...@@ -239,7 +263,7 @@ class TestImperative(unittest.TestCase):
try: try:
new_var.clear_gradient() new_var.clear_gradient()
except Exception as e: except Exception as e:
assert type(e) == ValueError assert type(e) == core.EnforceNotMet
with fluid.dygraph.guard(): with fluid.dygraph.guard():
cur_program = fluid.Program() cur_program = fluid.Program()
...@@ -257,7 +281,7 @@ class TestImperative(unittest.TestCase): ...@@ -257,7 +281,7 @@ class TestImperative(unittest.TestCase):
new_var = fluid.dygraph.base.to_variable(x) new_var = fluid.dygraph.base.to_variable(x)
self.assertFalse(new_var.persistable) self.assertFalse(new_var.persistable)
new_var.persistable = True new_var.persistable = True
self.assertFalse(new_var.persistable) self.assertTrue(new_var.persistable)
def test_layer(self): def test_layer(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
......
...@@ -70,7 +70,6 @@ class SimpleNet(fluid.Layer): ...@@ -70,7 +70,6 @@ class SimpleNet(fluid.Layer):
loss = fluid.layers.reshape(loss, shape=[-1, self.num_steps]) loss = fluid.layers.reshape(loss, shape=[-1, self.num_steps])
loss = fluid.layers.reduce_mean(loss, dim=[0]) loss = fluid.layers.reduce_mean(loss, dim=[0])
loss = fluid.layers.reduce_sum(loss) loss = fluid.layers.reduce_sum(loss)
loss.permissions = True
return loss return loss
......
...@@ -459,8 +459,7 @@ class TestDygraphOCRAttention(unittest.TestCase): ...@@ -459,8 +459,7 @@ class TestDygraphOCRAttention(unittest.TestCase):
for batch_id in range(batch_num): for batch_id in range(batch_num):
label_in = to_variable(label_in_np) label_in = to_variable(label_in_np)
label_out = to_variable(label_out_np) label_out = to_variable(label_out_np)
label_out._stop_gradient = True label_out.stop_gradient = True
label_out.trainable = False
img = to_variable(image_np) img = to_variable(image_np)
dy_prediction = ocr_attention(img, label_in) dy_prediction = ocr_attention(img, label_in)
label_out = fluid.layers.reshape( label_out = fluid.layers.reshape(
...@@ -481,7 +480,7 @@ class TestDygraphOCRAttention(unittest.TestCase): ...@@ -481,7 +480,7 @@ class TestDygraphOCRAttention(unittest.TestCase):
dy_grad_value = {} dy_grad_value = {}
for param in ocr_attention.parameters(): for param in ocr_attention.parameters():
if param.trainable: if param.trainable:
np_array = np.array(param._ivar._grad_ivar().value() np_array = np.array(param._grad_ivar().value()
.get_tensor()) .get_tensor())
dy_grad_value[param.name + core.grad_var_suffix( dy_grad_value[param.name + core.grad_var_suffix(
)] = np_array )] = np_array
...@@ -514,7 +513,7 @@ class TestDygraphOCRAttention(unittest.TestCase): ...@@ -514,7 +513,7 @@ class TestDygraphOCRAttention(unittest.TestCase):
name='label_in', shape=[1], dtype='int64', lod_level=0) name='label_in', shape=[1], dtype='int64', lod_level=0)
static_label_out = fluid.layers.data( static_label_out = fluid.layers.data(
name='label_out', shape=[1], dtype='int64', lod_level=0) name='label_out', shape=[1], dtype='int64', lod_level=0)
static_label_out._stop_gradient = True static_label_out.stop_gradient = True
static_label_out.trainable = False static_label_out.trainable = False
static_prediction = ocr_attention(images, static_label_in) static_prediction = ocr_attention(images, static_label_in)
......
...@@ -83,7 +83,7 @@ class TestImperativeOptimizerBase(unittest.TestCase): ...@@ -83,7 +83,7 @@ class TestImperativeOptimizerBase(unittest.TestCase):
img = data[0] img = data[0]
label = data[1] label = data[1]
label._stop_gradient = True label.stop_gradient = True
cost = mlp(img) cost = mlp(img)
avg_loss = fluid.layers.reduce_mean(cost) avg_loss = fluid.layers.reduce_mean(cost)
......
...@@ -33,10 +33,10 @@ class TestImperativePartitialBackward(unittest.TestCase): ...@@ -33,10 +33,10 @@ class TestImperativePartitialBackward(unittest.TestCase):
loss.backward() loss.backward()
for param in fc1.parameters(): for param in fc1.parameters():
self.assertIsNotNone(param._ivar._grad_ivar()) self.assertIsNotNone(param._grad_ivar())
for param in fc2.parameters(): for param in fc2.parameters():
self.assertIsNone(param._ivar._grad_ivar()) self.assertIsNone(param._grad_ivar())
optimizer = fluid.optimizer.AdamOptimizer() optimizer = fluid.optimizer.AdamOptimizer()
_, params_grads = optimizer.minimize(loss) _, params_grads = optimizer.minimize(loss)
......
...@@ -207,7 +207,6 @@ class PtbModel(fluid.Layer): ...@@ -207,7 +207,6 @@ class PtbModel(fluid.Layer):
loss = fluid.layers.reshape(loss, shape=[-1, self.num_steps]) loss = fluid.layers.reshape(loss, shape=[-1, self.num_steps])
loss = fluid.layers.reduce_mean(loss, dim=[0]) loss = fluid.layers.reduce_mean(loss, dim=[0])
loss = fluid.layers.reduce_sum(loss) loss = fluid.layers.reduce_sum(loss)
loss.permissions = True
return loss, last_hidden, last_cell return loss, last_hidden, last_cell
......
...@@ -302,7 +302,7 @@ class TestDygraphResnet(unittest.TestCase): ...@@ -302,7 +302,7 @@ class TestDygraphResnet(unittest.TestCase):
dy_grad_value = {} dy_grad_value = {}
for param in resnet.parameters(): for param in resnet.parameters():
if param.trainable: if param.trainable:
np_array = np.array(param._ivar._grad_ivar().value() np_array = np.array(param._grad_ivar().value()
.get_tensor()) .get_tensor())
dy_grad_value[param.name + core.grad_var_suffix( dy_grad_value[param.name + core.grad_var_suffix(
)] = np_array )] = np_array
......
...@@ -119,7 +119,7 @@ class TestDygraphResnetSortGradient(unittest.TestCase): ...@@ -119,7 +119,7 @@ class TestDygraphResnetSortGradient(unittest.TestCase):
dy_grad_value = {} dy_grad_value = {}
for param in resnet.parameters(): for param in resnet.parameters():
if param.trainable: if param.trainable:
np_array = np.array(param._ivar._grad_ivar().value() np_array = np.array(param._grad_ivar().value()
.get_tensor()) .get_tensor())
dy_grad_value[param.name + core.grad_var_suffix( dy_grad_value[param.name + core.grad_var_suffix(
)] = np_array )] = np_array
......
...@@ -197,7 +197,6 @@ class PtbModel(fluid.Layer): ...@@ -197,7 +197,6 @@ class PtbModel(fluid.Layer):
loss = fluid.layers.reshape(loss, shape=[-1, self.num_steps]) loss = fluid.layers.reshape(loss, shape=[-1, self.num_steps])
loss = fluid.layers.reduce_mean(loss, dim=[0]) loss = fluid.layers.reduce_mean(loss, dim=[0])
loss = fluid.layers.reduce_sum(loss) loss = fluid.layers.reduce_sum(loss)
loss.permissions = True
return loss, last_hidden, last_cell return loss, last_hidden, last_cell
...@@ -353,7 +352,7 @@ class TestDygraphPtbRnn(unittest.TestCase): ...@@ -353,7 +352,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
# set to zero # set to zero
for k, v in opti_dict.items(): for k, v in opti_dict.items():
np_t = v.numpy() np_t = v.numpy()
var = v._ivar.value().get_tensor() var = v.value().get_tensor()
var.set(np.zeros_like(np_t), place) var.set(np.zeros_like(np_t), place)
self.assertTrue(np.sum(np.abs(v.numpy())) == 0) self.assertTrue(np.sum(np.abs(v.numpy())) == 0)
...@@ -373,7 +372,7 @@ class TestDygraphPtbRnn(unittest.TestCase): ...@@ -373,7 +372,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
state_dict = ptb_model.state_dict() state_dict = ptb_model.state_dict()
for k, v in state_dict.items(): for k, v in state_dict.items():
np_t = v.numpy() np_t = v.numpy()
var = v._ivar.value().get_tensor() var = v.value().get_tensor()
var.set(np.zeros_like(np_t), place) var.set(np.zeros_like(np_t), place)
...@@ -457,7 +456,7 @@ class TestDygraphPtbRnn(unittest.TestCase): ...@@ -457,7 +456,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
# set to zero # set to zero
for k, v in opti_dict.items(): for k, v in opti_dict.items():
np_t = v.numpy() np_t = v.numpy()
var = v._ivar.value().get_tensor() var = v.value().get_tensor()
var.set(np.zeros_like(np_t), place) var.set(np.zeros_like(np_t), place)
self.assertTrue(np.sum(np.abs(v.numpy())) == 0) self.assertTrue(np.sum(np.abs(v.numpy())) == 0)
...@@ -476,7 +475,7 @@ class TestDygraphPtbRnn(unittest.TestCase): ...@@ -476,7 +475,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
state_dict = ptb_model.state_dict() state_dict = ptb_model.state_dict()
for k, v in state_dict.items(): for k, v in state_dict.items():
np_t = v.numpy() np_t = v.numpy()
var = v._ivar.value().get_tensor() var = v.value().get_tensor()
var.set(np.zeros_like(np_t), place) var.set(np.zeros_like(np_t), place)
...@@ -562,7 +561,7 @@ class TestDygraphPtbRnn(unittest.TestCase): ...@@ -562,7 +561,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
for k, v in opti_dict.items(): for k, v in opti_dict.items():
np_t = v.numpy() np_t = v.numpy()
np_opti_dict[v.name] = np_t np_opti_dict[v.name] = np_t
var = v._ivar.value().get_tensor() var = v.value().get_tensor()
var.set(np.zeros_like(np_t), place) var.set(np.zeros_like(np_t), place)
self.assertTrue(np.sum(np.abs(v.numpy())) == 0) self.assertTrue(np.sum(np.abs(v.numpy())) == 0)
...@@ -583,7 +582,7 @@ class TestDygraphPtbRnn(unittest.TestCase): ...@@ -583,7 +582,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
for k, v in state_dict.items(): for k, v in state_dict.items():
np_t = v.numpy() np_t = v.numpy()
np_state_dict[v.name] = np_t np_state_dict[v.name] = np_t
var = v._ivar.value().get_tensor() var = v.value().get_tensor()
var.set(np.zeros_like(np_t), place) var.set(np.zeros_like(np_t), place)
......
...@@ -361,7 +361,7 @@ class TestImperativeResneXt(unittest.TestCase): ...@@ -361,7 +361,7 @@ class TestImperativeResneXt(unittest.TestCase):
#dy_grad_value = {} #dy_grad_value = {}
#for param in se_resnext.parameters(): #for param in se_resnext.parameters():
# if param.trainable: # if param.trainable:
# np_array = np.array(param._ivar._grad_ivar().value() # np_array = np.array(param._grad_ivar().value()
# .get_tensor()) # .get_tensor())
# dy_grad_value[param.name + core.grad_var_suffix()] = np_array # dy_grad_value[param.name + core.grad_var_suffix()] = np_array
......
...@@ -78,7 +78,6 @@ class SimpleNet(fluid.Layer): ...@@ -78,7 +78,6 @@ class SimpleNet(fluid.Layer):
loss = fluid.layers.reshape(loss, shape=[-1, self.num_steps]) loss = fluid.layers.reshape(loss, shape=[-1, self.num_steps])
loss = fluid.layers.reduce_mean(loss, dim=[0]) loss = fluid.layers.reduce_mean(loss, dim=[0])
loss = fluid.layers.reduce_sum(loss) loss = fluid.layers.reduce_sum(loss)
loss.permissions = True
return loss return loss
......
...@@ -25,8 +25,8 @@ import numpy as np ...@@ -25,8 +25,8 @@ import numpy as np
main_program = default_main_program() main_program = default_main_program()
class TestParameter(unittest.TestCase): class ParameterChecks(unittest.TestCase):
def test_param(self): def check_param(self):
shape = [784, 100] shape = [784, 100]
val = 1.0625 val = 1.0625
b = main_program.global_block() b = main_program.global_block()
...@@ -46,7 +46,7 @@ class TestParameter(unittest.TestCase): ...@@ -46,7 +46,7 @@ class TestParameter(unittest.TestCase):
p = io.get_parameter_value_by_name('fc.w', exe, main_program) p = io.get_parameter_value_by_name('fc.w', exe, main_program)
self.assertTrue(np.allclose(np.array(p), np.ones(shape) * val)) self.assertTrue(np.allclose(np.array(p), np.ones(shape) * val))
def test_exceptions(self): def check_exceptions(self):
b = main_program.global_block() b = main_program.global_block()
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
b.create_parameter( b.create_parameter(
...@@ -62,5 +62,13 @@ class TestParameter(unittest.TestCase): ...@@ -62,5 +62,13 @@ class TestParameter(unittest.TestCase):
name='test', shape=[-1], dtype='float32', initializer=None) name='test', shape=[-1], dtype='float32', initializer=None)
class TestParameter(ParameterChecks):
def test_param(self):
self.check_param()
def test_exceptions(self):
self.check_exceptions()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -208,7 +208,6 @@ class PtbModel(fluid.Layer): ...@@ -208,7 +208,6 @@ class PtbModel(fluid.Layer):
loss = fluid.layers.reshape(loss, shape=[-1, self.num_steps]) loss = fluid.layers.reshape(loss, shape=[-1, self.num_steps])
loss = fluid.layers.reduce_mean(loss, dim=[0]) loss = fluid.layers.reduce_mean(loss, dim=[0])
loss = fluid.layers.reduce_sum(loss) loss = fluid.layers.reduce_sum(loss)
loss.permissions = True
return loss, last_hidden, last_cell return loss, last_hidden, last_cell
......
...@@ -184,6 +184,28 @@ class TestVariable(unittest.TestCase): ...@@ -184,6 +184,28 @@ class TestVariable(unittest.TestCase):
with fluid.program_guard(default_main_program()): with fluid.program_guard(default_main_program()):
self._tostring() self._tostring()
# NOTE(zhiqiu): for coverage CI
# TODO(zhiqiu): code clean for dygraph
def test_dygraph_deprecated_api(self):
b = default_main_program().current_block()
var = b.create_var(dtype="float64", lod_level=0)
with fluid.dygraph.guard():
self.assertIsNone(var.detach())
self.assertIsNone(var.numpy())
self.assertIsNone(var.set_value(None))
self.assertIsNone(var.backward())
self.assertIsNone(var.gradient())
self.assertIsNone(var.clear_gradient())
self.assertIsNone(var.to_string(True))
self.assertIsNone(var.persistable)
var.stop_gradient = True
self.assertIsNone(var.stop_gradient)
var.stop_gradient = 'tmp'
self.assertIsNone(var.name)
self.assertIsNone(var.shape)
self.assertIsNone(var.dtype)
self.assertIsNone(var.type)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册