未验证 提交 cdd46d7e 编写于 作者: L Leo Chen 提交者: GitHub

Split VarBase from Python Variable for Dygraph (#21359)

* test=develop, fix docker with paddle nccl problem

* don't expose numerous Tensor.set(), test=develop

* fix condition, test=develop

* fix float16 bug, test=develop

* feed should be Tensor or np.array, not Variable or number, test=develop

* use forcecast to copy numpy slice to new array, test=develop

* remove float16-uint16 hacking, test=develop

* add variable method to varbase and refactor to_variable to support return varbase

* support kwargs in varbase constructor

* add VarBase constructor to support default python args

* refine varbase initial method

* reset branch

* fix ut for change VarBase error info to PaddleEnforce

* cherry is parameter change before

* overload isinstance to replace too many change of is_variable

* rm useless files

* rm useless code merged by git

* test=develop, fix some ut failed error

* test=develop, fix test_graph_wrapper

* add some tests, test=develop

* refine __getitem__, test=develop

* add tests, test=develop

* fix err_msg, test=develop
上级 cdba41af
......@@ -203,6 +203,7 @@ elseif(${CBLAS_PROVIDER} STREQUAL EXTERN_OPENBLAS)
list(APPEND third_party_deps extern_openblas)
endif()
if(WITH_MKLDNN)
include(external/mkldnn) # download, build, install mkldnn
list(APPEND third_party_deps extern_mkldnn)
......
......@@ -236,11 +236,13 @@ std::shared_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
// TODO(Jiabin): change this after move unique_name generator to CXX
auto new_var = std::make_shared<VarBase>(
false, "Itmp" + std::to_string(copied_counter_++));
true, Name() + std::to_string(copied_counter_++));
auto* dst_tensor = new_var->var_.GetMutable<framework::LoDTensor>();
dst_tensor->set_lod(src_tensor.lod());
new_var->SetPersistable(Persistable());
new_var->SetDataType(DataType());
new_var->SetType(Type());
framework::TensorCopy(src_tensor, dst_place, dst_tensor);
if (blocking) {
platform::DeviceContextPool::Instance().Get(dst_place)->Wait();
......@@ -253,7 +255,6 @@ std::shared_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
if (platform::is_gpu_place(dst_place)) {
VLOG(3) << "copy tensor " << Name() << " from gpu";
}
return new_var;
} else {
auto& src_selected_rows = var_.Get<framework::SelectedRows>();
......
......@@ -158,7 +158,7 @@ TEST(test_layer, test_varbase_basic) {
vin->MutableVar()->GetMutable<framework::LoDTensor>()->mutable_data<float>(
place);
std::shared_ptr<imperative::VarBase> vout(vin->NewVarBase(place, false));
ASSERT_EQ(vout->Name(), "Itmp0");
ASSERT_EQ(vout->Name(), "vin0");
std::shared_ptr<imperative::VarBase> vin_with_grad(
new imperative::VarBase(true, "vin"));
......
......@@ -30,7 +30,6 @@ limitations under the License. */
#include "paddle/fluid/imperative/profiler.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/pybind/pybind_boost_headers.h"
namespace paddle {
......@@ -38,6 +37,12 @@ namespace pybind {
namespace py = ::pybind11;
template <typename P>
extern void SetTensorFromPyArray(framework::Tensor *self, const py::object &obj,
const P &place, bool zero_copy);
extern py::array TensorToPyArray(const framework::Tensor &tensor,
bool need_deep_copy = false);
class Layer : public imperative::Layer {
public:
using imperative::Layer::Layer; // Inherit constructors
......@@ -50,42 +55,99 @@ class Layer : public imperative::Layer {
}
};
// warper for pyobject to avoid imperative module depend on python
// TODO(jiabin) Add OpBase's pybind interface back to enable backward hook
class PYBIND11_HIDDEN PyCallableObject {
public:
PyCallableObject(std::shared_ptr<py::object> py_obj_ptr)
: py_obj_ptr_(std::move(py_obj_ptr)) {}
~PyCallableObject() {
py::call_guard<py::gil_scoped_acquire>();
py_obj_ptr_.reset();
static void InitTensorForVarBase(imperative::VarBase *self, bool persistable,
bool is_default, const py::array &array,
const py::object &obj = py::object(),
bool zero_copy = false) {
new (self) imperative::VarBase(
imperative::GetCurrentTracer()->GenerateUniqueName("generated_var_"));
self->SetPersistable(persistable);
auto *tensor = self->MutableVar()->GetMutable<framework::LoDTensor>();
if (is_default) {
auto place = imperative::GetCurrentTracer()->ExpectedPlace();
if (platform::is_cpu_place(place)) {
SetTensorFromPyArray<platform::CPUPlace>(
tensor, array, boost::get<platform::CPUPlace>(place), zero_copy);
} else if (platform::is_gpu_place(place)) {
SetTensorFromPyArray<platform::CUDAPlace>(
tensor, array, boost::get<platform::CUDAPlace>(place), zero_copy);
} else if (platform::is_cuda_pinned_place(place)) {
SetTensorFromPyArray<platform::CUDAPinnedPlace>(
tensor, array, boost::get<platform::CUDAPinnedPlace>(place),
zero_copy);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Place should be one of CPUPlace/CUDAPlace/CUDAPinnedPlace"));
}
} else {
if (py::isinstance<platform::CPUPlace>(obj)) {
SetTensorFromPyArray<platform::CPUPlace>(
tensor, array, obj.cast<platform::CPUPlace>(), zero_copy);
} else if (py::isinstance<platform::CUDAPlace>(obj)) {
SetTensorFromPyArray<platform::CUDAPlace>(
tensor, array, obj.cast<platform::CUDAPlace>(), zero_copy);
} else if (py::isinstance<platform::CUDAPinnedPlace>(obj)) {
SetTensorFromPyArray<platform::CUDAPinnedPlace>(
tensor, array, obj.cast<platform::CUDAPinnedPlace>(), zero_copy);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Place should be one of CPUPlace/CUDAPlace/CUDAPinnedPlace"));
}
}
void operator()() {
py::call_guard<py::gil_scoped_acquire>();
py_obj_ptr_->operator()(this);
self->SetType(framework::proto::VarType::LOD_TENSOR);
self->SetDataType(tensor->type());
}
static void InitVarBaseFromNumpyWithKwargs(imperative::VarBase *self,
const py::kwargs &kwargs) {
PADDLE_ENFORCE_EQ(
kwargs.contains("value"), true,
platform::errors::InvalidArgument("Missing argument: value"));
if (kwargs.contains("place")) {
InitTensorForVarBase(self, kwargs.contains("persistable")
? kwargs["persistable"].cast<bool>()
: false,
false, kwargs["value"].cast<py::array>(),
kwargs["place"], kwargs["zero_copy"].cast<bool>());
} else {
InitTensorForVarBase(self, kwargs.contains("persistable")
? kwargs["persistable"].cast<bool>()
: false,
true, kwargs["value"].cast<py::array>(), py::object(),
kwargs["zero_copy"].cast<bool>());
}
}
private:
std::shared_ptr<py::object> py_obj_ptr_;
};
template <typename P>
static void InitVarBaseFromNumpyWithArg(imperative::VarBase *self,
const py::array &array, const P &place,
bool persistable, bool zero_copy) {
// 0: value, 1: place, 2: name 3: persistable, 4: zero_copy
new (self) imperative::VarBase(
imperative::GetCurrentTracer()->GenerateUniqueName("generated_var_"));
self->SetPersistable(persistable);
auto *tensor = self->MutableVar()->GetMutable<framework::LoDTensor>();
SetTensorFromPyArray<P>(tensor, array, place, zero_copy);
self->SetType(framework::proto::VarType::LOD_TENSOR);
self->SetDataType(tensor->type());
}
static void InitVarBaseFromNumpyWithArgDefault(imperative::VarBase *self,
const py::array &array,
bool persistable) {
InitTensorForVarBase(self, persistable, true, array);
}
// Function like obj.attr_name in Python.
static PyObject *GetPythonAttribute(PyObject *obj, const char *attr_name) {
// NOTE(zjl): PyObject_GetAttrString would return nullptr when attr_name
// is not inside obj, but it would also set the error flag of Python.
// If the error flag is set in C++, C++ code would not raise Exception,
// but Python would raise Exception once C++ call ends.
// To avoid unexpected Exception raised in Python, we check whether
// attribute exists before calling PyObject_GetAttrString.
//
// Caution: PyObject_GetAttrString would increase reference count of PyObject.
// Developer should call Py_DECREF manually after the attribute is not used.
if (PyObject_HasAttrString(obj, attr_name)) {
return PyObject_GetAttrString(obj, attr_name);
static std::string GetTypeName(const imperative::VarBase &var) {
if (var.Type() == framework::proto::VarType::RAW) {
return "RAW";
} else if (!var.Var().IsInitialized()) {
return "nullptr";
} else {
return nullptr;
return framework::ToTypeName(var.Var().Type());
}
}
using PyNameVarBaseMap = std::unordered_map<std::string, py::handle>;
template <typename T>
static T PyObjectCast(PyObject *obj) {
......@@ -106,48 +168,36 @@ GetVarBaseListFromPyHandle(const py::handle &handle) {
return {};
}
const char *kIVarField = "_ivar";
PyObject *py_ivar = GetPythonAttribute(py_obj, kIVarField);
std::vector<std::shared_ptr<imperative::VarBase>> result;
if (py_ivar) { // Variable
result.emplace_back(
PyObjectCast<std::shared_ptr<imperative::VarBase>>(py_ivar));
Py_DECREF(py_ivar);
} else if (PyList_Check(py_obj)) { // List of Variable
if (PyList_Check(py_obj)) { // List of VarBase
size_t len = PyList_GET_SIZE(py_obj);
result.reserve(len);
for (size_t i = 0; i < len; ++i) {
PyObject *py_ivar =
PyObject_GetAttrString(PyList_GET_ITEM(py_obj, i), kIVarField);
PADDLE_ENFORCE_NOT_NULL(py_ivar);
PyObject *py_ivar = PyList_GET_ITEM(py_obj, i);
PADDLE_ENFORCE_NOT_NULL(
py_ivar, platform::errors::InvalidArgument("Python Object is NULL"));
result.emplace_back(
PyObjectCast<std::shared_ptr<imperative::VarBase>>(py_ivar));
Py_DECREF(py_ivar);
}
} else if (PyTuple_Check(py_obj)) { // Tuple of Variable
} else if (PyTuple_Check(py_obj)) { // Tuple of VarBase
size_t len = PyTuple_GET_SIZE(py_obj);
result.reserve(len);
for (size_t i = 0; i < len; ++i) {
PyObject *py_ivar =
PyObject_GetAttrString(PyTuple_GET_ITEM(py_obj, i), kIVarField);
PADDLE_ENFORCE_NOT_NULL(py_ivar);
PyObject *py_ivar = PyTuple_GET_ITEM(py_obj, i);
PADDLE_ENFORCE_NOT_NULL(
py_ivar, platform::errors::InvalidArgument("Python Object is NULL"));
result.emplace_back(
PyObjectCast<std::shared_ptr<imperative::VarBase>>(py_ivar));
Py_DECREF(py_ivar);
}
} else {
PADDLE_THROW(
"unsupported type %s, must be Variable, list[Variable] or "
"tuple[Variable]",
py::str(handle));
} else { // VarBase
result.emplace_back(
PyObjectCast<std::shared_ptr<imperative::VarBase>>(py_obj));
}
return result;
}
using PyNameVarBaseMap = std::unordered_map<std::string, py::handle>;
static imperative::NameVarBaseMap ConvertToNameVarBaseMap(
const PyNameVarBaseMap &map) {
imperative::NameVarBaseMap result;
......@@ -163,16 +213,6 @@ static imperative::NameVarBaseMap ConvertToNameVarBaseMap(
return result;
}
static std::string GetTypeName(const imperative::VarBase &var) {
if (var.Type() == framework::proto::VarType::RAW) {
return "RAW";
} else if (!var.Var().IsInitialized()) {
return "nullptr";
} else {
return framework::ToTypeName(var.Var().Type());
}
}
// Bind Methods
void BindImperative(py::module *m_ptr) {
auto &m = *m_ptr;
......@@ -239,11 +279,17 @@ void BindImperative(py::module *m_ptr) {
R"DOC()DOC")
.def_static("_alive_vars", &imperative::VarBase::AliveVarNames)
.def("__init__",
[](imperative::VarBase &self, const std::string &name,
framework::proto::VarType::Type type,
framework::proto::VarType::Type dtype,
const std::vector<int> &dims, bool persistable) {
new (&self) imperative::VarBase(name);
[](imperative::VarBase &self, framework::proto::VarType::Type dtype,
const std::vector<int> &dims, const py::handle &name,
framework::proto::VarType::Type type, bool persistable) {
std::string act_name = "";
if (!name.ptr() || name.ptr() == Py_None) {
act_name = imperative::GetCurrentTracer()->GenerateUniqueName(
"generated_var");
} else {
act_name = name.cast<std::string>();
}
new (&self) imperative::VarBase(act_name);
self.SetPersistable(persistable);
self.SetType(type);
self.SetDataType(dtype);
......@@ -253,6 +299,91 @@ void BindImperative(py::module *m_ptr) {
tensor->Resize(framework::make_ddim(dims));
}
})
.def("__init__", &InitVarBaseFromNumpyWithArg<platform::CPUPlace>,
py::arg("value"), py::arg("place"), py::arg("persistable") = false,
py::arg("zero_copy") = false)
.def("__init__", &InitVarBaseFromNumpyWithArg<platform::CUDAPlace>,
py::arg("value"), py::arg("place"), py::arg("persistable") = false,
py::arg("zero_copy") = false)
.def("__init__", &InitVarBaseFromNumpyWithArg<platform::CUDAPinnedPlace>,
py::arg("value"), py::arg("place"), py::arg("persistable") = false,
py::arg("zero_copy") = false)
.def("__init__", &InitVarBaseFromNumpyWithArgDefault, py::arg("value"),
py::arg("persistable") = false)
.def("__init__", &InitVarBaseFromNumpyWithKwargs)
.def("numpy",
[](imperative::VarBase &self) -> py::array {
const auto &tensor =
self.MutableVar()->Get<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(
tensor.IsInitialized(), true,
platform::errors::InvalidArgument(
"%s is Empty, Please check if it has no data in",
self.Name()));
return TensorToPyArray(tensor, true);
},
R"DOC(
**Notes**:
**This API is ONLY avaliable in Dygraph mode**
Returns a numpy array shows the value of current :ref:`api_guide_Variable_en`
Returns:
ndarray: The numpy value of current Variable.
Returns type:
ndarray: dtype is same as current Variable
Examples:
.. code-block:: python
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph import FC
import numpy as np
data = np.random.uniform(-1, 1, [30, 10, 32]).astype('float32')
with fluid.dygraph.guard():
fc = FC("fc", 64, num_flatten_dims=2)
data = to_variable(data)
x = fc(data)
print(x.numpy())
)DOC")
.def("detach",
[](const imperative::VarBase &self) {
const auto &tensor = self.Var().Get<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(tensor.IsInitialized(), true,
platform::errors::InvalidArgument(
"%s has not been initialized", self.Name()));
return self.NewVarBase(tensor.place(), false);
},
py::return_value_policy::copy, R"DOC(
**Notes**:
**This API is ONLY avaliable in Dygraph mode**
Returns a new Variable, detached from the current graph.
Returns:
( :ref:`api_guide_Variable_en` | dtype is same as current Variable): The detached Variable.
Examples:
.. code-block:: python
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph import FC
import numpy as np
data = np.random.uniform(-1, 1, [30, 10, 32]).astype('float32')
with fluid.dygraph.guard():
fc = FC("fc", 64, num_flatten_dims=2)
data = to_variable(data)
x = fc(data)
y = x.detach()
)DOC")
.def("_run_backward",
[](imperative::VarBase &self,
const imperative::detail::BackwardStrategy &bckst,
......@@ -273,7 +404,39 @@ void BindImperative(py::module *m_ptr) {
return self.MutableGradVar()->Get<framework::LoDTensor>();
},
py::return_value_policy::reference)
.def("_clear_gradient", &imperative::VarBase::ClearGradient)
.def("clear_gradient", &imperative::VarBase::ClearGradient, R"DOC(
**Notes**:
**1. This API is ONLY avaliable in Dygraph mode**
**2. Use it only Variable has gradient, normally we use this for Parameters since other temporal Variable will be deleted by Python's GC**
Clear (set to ``0`` ) the Gradient of Current Variable
Returns: None
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
x = np.ones([2, 2], np.float32)
with fluid.dygraph.guard():
inputs2 = []
for _ in range(10):
tmp = fluid.dygraph.base.to_variable(x)
tmp.stop_gradient=False
inputs2.append(tmp)
ret2 = fluid.layers.sums(inputs2)
loss2 = fluid.layers.reduce_sum(ret2)
backward_strategy = fluid.dygraph.BackwardStrategy()
backward_strategy.sort_sum_gradient = True
loss2.backward(backward_strategy)
print(loss2.gradient())
loss2.clear_gradient()
print("After clear {}".format(loss2.gradient()))
)DOC")
.def("_grad_ivar",
[](const imperative::VarBase &self) {
auto &grad_var = self.GradVarBase();
......
......@@ -194,14 +194,8 @@ static std::vector<std::shared_ptr<imperative::VarBase>> GetVarBaseList(
if (!py_obj || py_obj == Py_None) {
PADDLE_THROW("Save parameter [%s] is None", para.first);
}
const char *kIVarField = "_ivar";
PyObject *py_ivar = GetPythonAttribute(py_obj, kIVarField);
PADDLE_ENFORCE_NOT_NULL(py_ivar, "Can not find ivar in Variable");
vec_res.emplace_back(
PyObjectCast<std::shared_ptr<imperative::VarBase>>(py_ivar));
Py_DECREF(py_ivar);
PyObjectCast<std::shared_ptr<imperative::VarBase>>(py_obj));
}
return vec_res;
......
......@@ -486,7 +486,8 @@ inline framework::Tensor *PySliceTensor(const framework::Tensor &self,
}
}
inline py::array TensorToPyArray(const framework::Tensor &tensor) {
inline py::array TensorToPyArray(const framework::Tensor &tensor,
bool need_deep_copy = false) {
if (!tensor.IsInitialized()) {
return py::array();
}
......@@ -510,9 +511,26 @@ inline py::array TensorToPyArray(const framework::Tensor &tensor) {
std::string py_dtype_str = details::TensorDTypeToPyDTypeStr(tensor.type());
if (!is_gpu_tensor) {
return py::array(py::buffer_info(
const_cast<void *>(tensor_buf_ptr), sizeof_dtype, py_dtype_str,
static_cast<size_t>(tensor.dims().size()), py_dims, py_strides));
if (!need_deep_copy) {
return py::array(py::buffer_info(
const_cast<void *>(tensor_buf_ptr), sizeof_dtype, py_dtype_str,
static_cast<size_t>(tensor.dims().size()), py_dims, py_strides));
} else {
py::array py_arr(py::dtype(py_dtype_str.c_str()), py_dims, py_strides);
PADDLE_ENFORCE_EQ(py_arr.writeable(), true,
platform::errors::InvalidArgument(
"PyArray must be writable, otherwise memory leak "
"or double free would occur"));
PADDLE_ENFORCE_EQ(py_arr.owndata(), true,
platform::errors::InvalidArgument(
"PyArray must own data, otherwise memory leak "
"or double free would occur"));
platform::CPUPlace place;
size_t copy_bytes = sizeof_dtype * numel;
paddle::memory::Copy(place, py_arr.mutable_data(), place, tensor_buf_ptr,
copy_bytes);
return py_arr;
}
}
#ifdef PADDLE_WITH_CUDA
......
......@@ -88,13 +88,13 @@ from .dygraph.nn import *
from .dygraph.layers import *
from .io import save, load, load_program_state, set_program_state
from .dygraph.checkpoint import save_dygraph, load_dygraph
from .dygraph.varbase_patch_methods import monkey_patch_varbase
Tensor = LoDTensor
__all__ = framework.__all__ + executor.__all__ + \
trainer_desc.__all__ + inferencer.__all__ + transpiler.__all__ + \
parallel_executor.__all__ + lod_tensor.__all__ + \
data_feed_desc.__all__ + compiler.__all__ + backward.__all__ + [
data_feed_desc.__all__ + compiler.__all__ + backward.__all__ + [
'io',
'initializer',
'embedding',
......@@ -126,6 +126,7 @@ __all__ = framework.__all__ + executor.__all__ + \
'install_check',
'save',
'load',
'VarBase'
]
......@@ -234,3 +235,4 @@ def __bootstrap__():
# Consider paddle.init(args) or paddle.main(args)
monkey_patch_variable()
__bootstrap__()
monkey_patch_varbase()
......@@ -138,6 +138,7 @@ def guard(place=None):
train = framework.Program()
startup = framework.Program()
tracer = Tracer()
VarBase = core.VarBase
if place is None:
if core.is_compiled_with_cuda():
......@@ -205,28 +206,21 @@ def to_variable(value, block=None, name=None, zero_copy=None):
if isinstance(value, np.ndarray):
assert framework.in_dygraph_mode(
), "to_variable could only be called in dygraph mode"
if not block:
block = framework.default_main_program().current_block()
py_var = framework.Variable(
block,
type=core.VarDesc.VarType.LOD_TENSOR,
name=name,
shape=value.shape,
dtype=value.dtype,
stop_gradient=True)
var = py_var._ivar.value()
tensor = var.get_tensor()
if isinstance(framework._current_expected_place(),
framework.core.CPUPlace):
if zero_copy is None:
zero_copy = True
tensor.set(value, framework._current_expected_place(), zero_copy)
else:
assert not zero_copy, "zero_copy mode can only be used with CPUPlace"
tensor.set(value, framework._current_expected_place(), False)
zero_copy = False
py_var = core.VarBase(
value=value,
name=name,
persistable=False,
place=framework._current_expected_place(),
zero_copy=zero_copy)
return py_var
elif isinstance(value, framework.Variable):
elif isinstance(value, (core.VarBase, framework.Variable)):
return value
else:
raise TypeError(
......
......@@ -33,7 +33,7 @@ def create_program_from_desc(program_desc):
def _extract_vars(inputs, result_list):
if isinstance(inputs, Variable):
result_list.append(inputs._ivar)
result_list.append(inputs)
if isinstance(inputs, (list, tuple)):
for var in inputs:
......@@ -67,7 +67,7 @@ def _trace(layer,
outputs = [original_outputs]
else:
outputs = original_outputs
out_vars = [var._ivar for var in outputs]
out_vars = [var for var in outputs]
program_desc, feed_names, fetch_names = tracer.create_program_desc(
var_list, feed_prefix, out_vars, fetch_prefix, tmp_prefix)
......@@ -104,7 +104,7 @@ class TracedLayer(object):
self._scope = core.Scope()
for p in parameters:
src_tensor = p._ivar.value().get_tensor()
src_tensor = p.value().get_tensor()
dst_tensor = self._scope.var(p.name).get_tensor()
dst_tensor._share_data_with(src_tensor)
......@@ -234,7 +234,7 @@ class TracedLayer(object):
feed_dict = {}
if in_dygraph_mode():
for x, name in zip(inputs, self._feed_names):
feed_dict[name] = x._ivar.value().get_tensor()
feed_dict[name] = x.value().get_tensor()
else:
for x, name in zip(inputs, self._feed_names):
feed_dict[name] = x
......
......@@ -25,7 +25,6 @@ from .layer_object_helper import LayerObjectHelper
from .base import program_desc_tracing_guard
from paddle.fluid import framework
from ..param_attr import ParamAttr
from paddle.fluid.framework import Variable
__all__ = ['Layer']
......
......@@ -219,12 +219,8 @@ class DataParallel(layers.Layer):
grad_vars = []
for param in self._layers.parameters():
# NOTE(zcd): The grad_ivar maybe no generated.
if param.trainable and param._ivar._grad_ivar():
g_var = framework.Variable(
block=self._helper.main_program.current_block(),
name=param._ivar._grad_name(),
stop_gradient=True,
ivar=param._ivar._grad_ivar())
if param.trainable and param._grad_ivar():
g_var = param._grad_ivar()
grad_vars.append(g_var)
assert g_var not in grad_var_set
grad_var_set.add(g_var)
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .. import framework
from .. import core
from . import BackwardStrategy
from ..framework import Variable, _getitem_impl_
from .. import unique_name
import numpy as np
def monkey_patch_varbase():
# TODO(jiabin): move this to cplusplus end if we find some performance issue on it
@framework.dygraph_only
def set_value(self, value):
"""
**Notes**:
**This API is ONLY avaliable in Dygraph mode**
Set a new value for this Variable.
Args:
value (Variable|np.ndarray): the new value.
Examples:
.. code-block:: python
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph import FC
import numpy as np
data = np.ones([3, 32, 32], dtype='float32')
with fluid.dygraph.guard():
fc = fluid.dygraph.FC("fc", 4)
t = to_variable(data)
fc(t) # call with default weight
custom_weight = np.random.randn(1024, 4).astype("float32")
fc.weight.set_value(custom_weight) # change existing weight
out = fc(t) # call with different weight
"""
assert isinstance(value, (np.ndarray, core.VarBase)), \
"Variable set_value function, arguments type only support Variable, numpy, VarBase"
value_np = value
if isinstance(value, core.VarBase):
value_np = value.numpy()
self_tensor_np = self.numpy()
assert self_tensor_np.shape == value_np.shape, \
"Variable Shape not match, Variable [ {} ] need tensor with shape {} but load set tensor with shape {}".format(
self.name, self_tensor_np.shape, value_np.shape)
assert self_tensor_np.dtype == value_np.dtype, \
"Variable dtype not match, Variable [ {} ] need tensor with dtype {} but load tensor with dtype {}".format(
self.name, self_tensor_np.dtype, value_np.dtype)
self.value().get_tensor().set(value_np,
framework._current_expected_place())
@framework.dygraph_only
def backward(self, backward_strategy=None):
"""
**Notes**:
**This API is ONLY avaliable in Dygraph mode**
Run backward of current Graph which starts from current Variable
Args:
backward_strategy( :ref:`api_fluid_dygraph_BackwardStrategy` ): The Backward Strategy to run backward
Returns:
NoneType: None
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
x = np.ones([2, 2], np.float32)
with fluid.dygraph.guard():
inputs2 = []
for _ in range(10):
tmp = fluid.dygraph.base.to_variable(x)
# if we don't set tmp's stop_gradient as False then, all path to loss will has no gradient since
# there is no one need gradient on it.
tmp.stop_gradient=False
inputs2.append(tmp)
ret2 = fluid.layers.sums(inputs2)
loss2 = fluid.layers.reduce_sum(ret2)
backward_strategy = fluid.dygraph.BackwardStrategy()
backward_strategy.sort_sum_gradient = True
loss2.backward(backward_strategy)
"""
if framework.in_dygraph_mode():
if backward_strategy is None:
backward_strategy = BackwardStrategy()
backward_strategy.sort_sum_gradient = False
self._run_backward(backward_strategy, framework._dygraph_tracer())
else:
raise ValueError(
"Variable.backward() is only avaliable in DyGraph mode")
@framework.dygraph_only
def gradient(self):
"""
**Notes**:
**This API is ONLY avaliable in Dygraph mode**
Get the Gradient of Current Variable
Returns:
ndarray: Numpy value of the gradient of current Variable
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
x = np.ones([2, 2], np.float32)
with fluid.dygraph.guard():
inputs2 = []
for _ in range(10):
tmp = fluid.dygraph.base.to_variable(x)
tmp.stop_gradient=False
inputs2.append(tmp)
ret2 = fluid.layers.sums(inputs2)
loss2 = fluid.layers.reduce_sum(ret2)
backward_strategy = fluid.dygraph.BackwardStrategy()
backward_strategy.sort_sum_gradient = True
loss2.backward(backward_strategy)
print(loss2.gradient())
"""
if self._grad_ivar() is None:
raise ValueError(
"%s has no grad, Please set Variable.stop_gradient=False, or "
"check if this is the first and only variable need grad, if so, please set its pre-Variable's "
"stop_gradient=False, to make sure it has gradient " %
self.name)
new_ivar = self._grad_ivar()._copy_to(core.CPUPlace(), True)
if self._grad_ivar().type == core.VarDesc.VarType.SELECTED_ROWS:
return (np.array(new_ivar.value().get_selected_rows().get_tensor()),
np.array(new_ivar.value().get_selected_rows().rows()))
else:
return np.array(new_ivar.value().get_tensor())
def __str__(self):
return self.to_string(True)
@property
def block(self):
return framework.default_main_program().global_block()
def to_string(self, throw_on_error, with_details=False):
"""
Get debug string.
Args:
throw_on_error (bool): True if raise an exception when self is not initialized.
with_details (bool): more details about variables and parameters (e.g. trainable, optimize_attr, ...) will be printed when with_details is True. Default value is False;
Returns:
str: The debug string.
Examples:
.. code-block:: python
import paddle.fluid as fluid
cur_program = fluid.Program()
cur_block = cur_program.current_block()
new_variable = cur_block.create_var(name="X",
shape=[-1, 23, 48],
dtype='float32')
print(new_variable.to_string(True))
print("=============with detail===============")
print(new_variable.to_string(True, True))
"""
if framework.in_dygraph_mode():
# TODO(panyx0718): add more dygraph debug info.
tensor = self.value().get_tensor()
if tensor._is_initialized():
return 'name %s, dtype: %s shape: %s %s' % (
self.name, self.dtype, self.shape, str(tensor))
else:
return 'name %s, shape: %s, not inited' % (self.name,
self.shape)
def __getitem__(self, item):
return _getitem_impl_(self, item)
for method_name, method in (("set_value", set_value), ("block", block),
("backward", backward), ("gradient", gradient),
("__str__", __str__), ("to_string", to_string),
("__getitem__", __getitem__)):
setattr(core.VarBase, method_name, method)
......@@ -264,7 +264,7 @@ class GradClipByGlobalNorm(GradClipBase):
if g is None:
continue
merge_grad = g
if g._ivar.type == core.VarDesc.VarType.SELECTED_ROWS:
if g.type == core.VarDesc.VarType.SELECTED_ROWS:
merge_grad = layers.merge_selected_rows(g)
merge_grad = layers.get_tensor_from_selected_rows(merge_grad)
power = layers.square(merge_grad)
......
此差异已折叠。
......@@ -44,33 +44,46 @@ class LayerHelperBase(object):
def startup_program(self):
return default_startup_program()
def to_variable(self, value, block=None):
"""convert value to variable
def to_variable(self, value, name=None):
"""
The API will create a ``Variable`` object from numpy\.ndarray or Variable object.
Parameters:
value(ndarray): The numpy\.ndarray object that needs to be converted, it can be multi-dimension, and the data type is one of numpy\.{float16, float32, float64, int16, int32, int64, uint8, uint16}.
block(fluid.Block, optional): Which block this variable will be in. Default: None.
name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`
Returns:
Variable: ``Tensor`` created from the specified numpy\.ndarray object, data type and shape is the same as ``value`` .
Examples:
Args:
value: value to be convert
block: the block of the variable
.. code-block:: python
import numpy as np
import paddle.fluid as fluid
with fluid.dygraph.guard():
x = np.ones([2, 2], np.float32)
y = fluid.dygraph.to_variable(x)
Return Variable construct from value
"""
if isinstance(value, np.ndarray):
assert in_dygraph_mode(
), "to_variable could only be called in dygraph mode"
if not block:
block = default_main_program().current_block()
py_var = Variable(
block,
type=core.VarDesc.VarType.LOD_TENSOR,
name=None,
shape=value.shape,
dtype=value.dtype)
var = py_var._ivar.value()
tensor = var.get_tensor()
tensor.set(value, _current_expected_place())
py_var = core.VarBase(
value=value,
name=name,
persistable=False,
place=_current_expected_place(),
zero_copy=False)
return py_var
elif isinstance(value, Variable):
elif isinstance(value, (core.VarBase, Variable)):
return value
else:
raise TypeError(
"to_variable only accepts 'ndarray' or 'Variable' or 'VarBase' as value's input"
)
def _create_weight_normalize(self, attr, shape, dtype):
from .layers import elementwise_mul, elementwise_div, reshape
......@@ -386,7 +399,7 @@ class LayerHelperBase(object):
"""
assert isinstance(var, Variable)
if in_dygraph_mode():
initializer(var, var.block)
initializer(var, self.main_program.global_block())
else:
self.startup_program.global_block().create_var(
name=var.name,
......
......@@ -233,6 +233,7 @@ def cross_entropy(input, label, soft_label=False, ignore_index=kIgnoreIndex):
predict = fluid.layers.fc(input=x, size=class_num, act='softmax')
cost = fluid.layers.cross_entropy(input=predict, label=label)
"""
check_type_and_dtype(input, 'input', Variable,
['float16', 'float32', 'float64'], 'cross_entropy')
if not soft_label:
......@@ -729,7 +730,6 @@ def nce(input,
sampler = 1
elif sampler == "custom_dist":
assert custom_dist is not None
# assert isinstance(custom_dist, Variable)
custom_dist_len = num_total_classes
alias_probs_ = [0] * custom_dist_len
......
......@@ -15,7 +15,7 @@
from __future__ import print_function
from .. import core
from ..framework import Variable, unique_name
from ..framework import Variable, unique_name, in_dygraph_mode, default_main_program
from .layer_function_generator import OpProtoHolder
from ..initializer import force_init_on_cpu
......@@ -40,7 +40,10 @@ def monkey_patch_variable():
return dtype
def current_block(var):
return var.block
if in_dygraph_mode():
return default_main_program().global_block()
else:
return var.block
def create_new_tmp_var(block, dtype):
tmp_name = unique_tmp_name()
......@@ -281,5 +284,9 @@ def monkey_patch_variable():
setattr(Variable, method_name,
_elemwise_method_creator_(method_name, op_type, reverse,
scalar_method))
setattr(core.VarBase, method_name,
_elemwise_method_creator_(method_name, op_type, reverse,
scalar_method))
Variable.astype = astype
setattr(core.VarBase, "astype", astype)
......@@ -32,7 +32,6 @@ from .layers import ops
from .regularizer import append_regularization_ops
from .dygraph import base as imperative_base
from .dygraph.learning_rate_scheduler import LearningRateDecay
from .framework import _var_base_to_np
from paddle.fluid import core
from paddle.fluid.layers import tensor
from functools import reduce
......@@ -122,7 +121,13 @@ class Optimizer(object):
state_dict[var_tmp.name] = var_tmp
# global step if use lr decay
if isinstance(self._learning_rate, LearningRateDecay):
var_temp = Variable(None, name='global_step', dtype='int32')
var_tmp = None
if not framework.in_dygraph_mode():
var_temp = Variable(None, name='global_step', dtype='int32')
else:
var_temp = framework._varbase_creator(
None, name='global_step', dtype='int32')
tensor.fill_constant(
[1], "int32", self._learning_rate.step_num, out=var_temp)
......@@ -164,7 +169,7 @@ class Optimizer(object):
global_step = state_dict['global_step']
if isinstance(global_step, core.VarBase):
step_np = global_step._copy_to(core.CPUPlace(), True)
step_np = global_step
step_np = np.array(step_np.value().get_tensor())
assert step_np.shape == (1,), \
"global step shape is (1,), the shape is {}".format( step_np.shape )
......@@ -189,7 +194,7 @@ class Optimizer(object):
for para_name, var_tmp in v.items():
assert var_tmp.name in state_dict, \
"optimizer variable {} not found".format( var_tmp.name )
var = var_tmp._ivar.value()
var = var_tmp.value()
tensor = var.get_tensor()
model_np = np.array(tensor)
......@@ -198,7 +203,7 @@ class Optimizer(object):
if isinstance(load_para, Variable):
load_para_np = load_para.numpy()
elif isinstance(load_para, core.VarBase):
load_para_np = _var_base_to_np(load_para)
load_para_np = load_para.numpy()
elif isinstance(load_para, np.ndarray):
load_para_np = load_para
else:
......@@ -515,7 +520,11 @@ class Optimizer(object):
Examples:
See examples in ``apply_gradients``.
"""
no_grad_set = self._get_no_grad_set(loss, no_grad_set)
act_no_grad_set = None
if not framework.in_dygraph_mode():
act_no_grad_set = self._get_no_grad_set(loss, no_grad_set)
else:
pass
self._dtype = loss.dtype
if framework.in_dygraph_mode():
......@@ -528,15 +537,9 @@ class Optimizer(object):
for param in parameters:
if not param.trainable:
continue
if param._ivar._grad_ivar() is not None:
ivar_type = param._ivar._grad_ivar().type
if param._grad_ivar() is not None:
# create gradient variable
grad_var = Variable(
block=loss.block,
type=ivar_type,
name=param._ivar._grad_name(),
stop_gradient=True,
ivar=param._ivar._grad_ivar())
grad_var = param._grad_ivar()
params_grads.append((param, grad_var))
else:
if callbacks is None:
......@@ -550,7 +553,7 @@ class Optimizer(object):
loss.shape)
with program_guard(program, startup_program):
params_grads = append_backward(loss, parameter_list,
no_grad_set, callbacks)
act_no_grad_set, callbacks)
# Note: since we can't use all_reduce_op now,
# dgc_op should be the last op of one grad.
self._append_dgc_ops(params_grads)
......
......@@ -268,7 +268,7 @@ class OpTest(unittest.TestCase):
data = value[0]
lod = value[1]
v = fluid.dygraph.base.to_variable(value=data)
v._ivar.value().get_tensor().set_recursive_sequence_lengths(lod)
v.value().get_tensor().set_recursive_sequence_lengths(lod)
return v
else:
return fluid.dygraph.base.to_variable(value)
......@@ -289,7 +289,7 @@ class OpTest(unittest.TestCase):
if if_return_inputs_grad_dict:
v.stop_gradient = False
if has_lod:
v._ivar.value().get_tensor().set_recursive_sequence_lengths(
v.value().get_tensor().set_recursive_sequence_lengths(
lod_temp)
else:
v = block.create_var(
......@@ -840,8 +840,8 @@ class OpTest(unittest.TestCase):
if check_dygraph:
imperative_actual = find_imperative_actual(
sub_out_name, dygraph_outs, place)
imperative_actual_t = np.array(
imperative_actual._ivar.value().get_tensor())
imperative_actual_t = np.array(imperative_actual.value()
.get_tensor())
idx = find_actual(sub_out_name, fetch_list)
actual = outs[idx]
actual_t = np.array(actual)
......@@ -868,7 +868,7 @@ class OpTest(unittest.TestCase):
") has different lod at " + str(place))
if check_dygraph:
self.assertListEqual(
imperative_actual._ivar.value().get_tensor()
imperative_actual.value().get_tensor()
.recursive_sequence_lengths(), expect[1],
"Output (" + out_name +
") has different lod at " + str(place) +
......@@ -877,8 +877,8 @@ class OpTest(unittest.TestCase):
if check_dygraph:
imperative_actual = find_imperative_actual(
out_name, dygraph_outs, place)
imperative_actual_t = np.array(
imperative_actual._ivar.value().get_tensor())
imperative_actual_t = np.array(imperative_actual.value()
.get_tensor())
idx = find_actual(out_name, fetch_list)
actual = outs[idx]
actual_t = np.array(actual)
......@@ -913,7 +913,7 @@ class OpTest(unittest.TestCase):
") has different lod at " + str(place))
if check_dygraph:
self.assertListEqual(
imperative_actual._ivar.value().get_tensor()
imperative_actual.value().get_tensor()
.recursive_sequence_lengths(), expect[1],
"Output (" + out_name + ") has different lod at " +
str(place) + " in dygraph mode")
......
......@@ -153,8 +153,8 @@ class TestImperativeAutoPrune(unittest.TestCase):
v2 = fluid.dygraph.to_variable(value2)
loss = case1(v1, v2)
loss.backward()
self.assertTrue(case1.fc2._w._ivar._grad_ivar() is not None)
self.assertTrue(case1.fc1._w._ivar._grad_ivar() is not None)
self.assertTrue(case1.fc2._w._grad_ivar() is not None)
self.assertTrue(case1.fc1._w._grad_ivar() is not None)
def test_auto_prune2(self):
with fluid.dygraph.guard():
......@@ -166,8 +166,8 @@ class TestImperativeAutoPrune(unittest.TestCase):
loss = case2(v1, v2)
loss.backward()
self.assertTrue(case2.fc2._w._ivar._grad_ivar() is None)
self.assertTrue(case2.fc1._w._ivar._grad_ivar() is not None)
self.assertTrue(case2.fc2._w._grad_ivar() is None)
self.assertTrue(case2.fc1._w._grad_ivar() is not None)
def test_auto_prune3(self):
with fluid.dygraph.guard():
......@@ -178,7 +178,7 @@ class TestImperativeAutoPrune(unittest.TestCase):
v2 = fluid.dygraph.to_variable(value2)
loss, part2 = case3(v1, v2, 1)
loss.backward()
self.assertTrue(case3.fc._w._ivar._grad_ivar() is not None)
self.assertTrue(case3.fc._w._grad_ivar() is not None)
self.assertTrue((part2.gradient() == 0).all())
def test_auto_prune4(self):
......@@ -190,7 +190,7 @@ class TestImperativeAutoPrune(unittest.TestCase):
v2 = fluid.dygraph.to_variable(value2)
loss, part2 = case4(v1, v2, 1)
part2.backward()
self.assertTrue(case4.fc._w._ivar._grad_ivar() is not None)
self.assertTrue(case4.fc._w._grad_ivar() is not None)
self.assertTrue((part2.gradient() == 1).all())
def test_auto_prune5(self):
......@@ -202,7 +202,7 @@ class TestImperativeAutoPrune(unittest.TestCase):
v2 = fluid.dygraph.to_variable(value2)
loss, part1, part2 = case4(v1, v2, 2)
part1.backward()
self.assertTrue(case4.fc._w._ivar._grad_ivar() is not None)
self.assertTrue(case4.fc._w._grad_ivar() is not None)
self.assertTrue((part2.gradient() == 0).all())
def test_auto_prune6(self):
......@@ -333,8 +333,8 @@ class TestImperativeAutoPrune(unittest.TestCase):
for items in params_grads:
assert items[0].name is not model.embed1._w.name
assert items[0].name is not model.fc1._w.name
assert model.embed1._w._ivar._grad_ivar() is None
assert model.fc1._w._ivar._grad_ivar() is None
assert model.embed1._w._grad_ivar() is None
assert model.fc1._w._grad_ivar() is None
with fluid.dygraph.guard(place):
model = MyLayer2("mylayer", vocab_size, size)
......@@ -351,8 +351,8 @@ class TestImperativeAutoPrune(unittest.TestCase):
for items in params_grads:
assert items[0].name is not model.embed1._w.name
assert items[0].name is not model.fc1._w.name
assert model.embed1._w._ivar._grad_ivar() is None
assert model.fc1._w._ivar._grad_ivar() is None
assert model.embed1._w._grad_ivar() is None
assert model.fc1._w._grad_ivar() is None
def test_case2_prune_no_grad_branch(self):
with fluid.dygraph.guard():
......@@ -363,8 +363,8 @@ class TestImperativeAutoPrune(unittest.TestCase):
case3 = AutoPruneLayer2("l2")
loss = case3(v1, v2)
loss.backward()
self.assertTrue(case3.fc2._w._ivar._grad_ivar() is None)
self.assertTrue(case3.fc._w._ivar._grad_ivar() is not None)
self.assertTrue(case3.fc2._w._grad_ivar() is None)
self.assertTrue(case3.fc._w._grad_ivar() is not None)
def test_case2_prune_no_grad_branch(self):
with fluid.dygraph.guard():
......@@ -375,8 +375,8 @@ class TestImperativeAutoPrune(unittest.TestCase):
case3 = AutoPruneLayer2("l2")
loss = case3(v1, v2)
loss.backward()
self.assertTrue(case3.fc2._w._ivar._grad_ivar() is None)
self.assertTrue(case3.fc._w._ivar._grad_ivar() is not None)
self.assertTrue(case3.fc2._w._grad_ivar() is None)
self.assertTrue(case3.fc._w._grad_ivar() is not None)
def test_case3_prune_no_grad_branch2(self):
with fluid.dygraph.guard():
......@@ -389,14 +389,14 @@ class TestImperativeAutoPrune(unittest.TestCase):
out = fluid.layers.one_hot(input=label, depth=100)
loss = fluid.layers.mean(out)
loss.backward()
self.assertTrue(fc._w._ivar._grad_ivar() is None)
self.assertTrue(fc._w._grad_ivar() is None)
def test_case4_with_no_grad_op_maker(self):
with fluid.dygraph.guard():
out = fluid.layers.gaussian_random(shape=[20, 30])
loss = fluid.layers.mean(out)
loss.backward()
self.assertTrue(out._ivar._grad_ivar() is None)
self.assertTrue(out._grad_ivar() is None)
if __name__ == '__main__':
......
......@@ -177,6 +177,30 @@ class SimpleRNN(fluid.Layer):
class TestImperative(unittest.TestCase):
def test_isinstance(self):
var = fluid.layers.data(shape=[1], name='x', dtype='float32')
self.assertTrue(isinstance(var, fluid.Variable))
with fluid.dygraph.guard():
var_base = fluid.dygraph.base.to_variable(np.array([3, 4, 5]))
self.assertTrue(isinstance(var_base, core.VarBase))
self.assertTrue(isinstance(var_base, fluid.Variable))
def test_create_VarBase(self):
x = np.ones([2, 2], np.float32)
y = np.zeros([3, 3], np.float32)
with fluid.dygraph.guard():
tmp = fluid.core.VarBase(value=x, place=fluid.core.CPUPlace())
tmp2 = fluid.core.VarBase(y, fluid.core.CPUPlace())
tmp3 = fluid.dygraph.base.to_variable(x)
tmp4 = fluid.core.VarBase(y)
tmp5 = fluid.core.VarBase(value=x)
self.assertTrue(np.array_equal(x, tmp.numpy()))
self.assertTrue(np.array_equal(y, tmp2.numpy()))
self.assertTrue(np.array_equal(x, tmp3.numpy()))
self.assertTrue(np.array_equal(y, tmp4.numpy()))
self.assertTrue(np.array_equal(x, tmp5.numpy()))
def test_sum_op(self):
x = np.ones([2, 2], np.float32)
with fluid.dygraph.guard():
......@@ -215,17 +239,17 @@ class TestImperative(unittest.TestCase):
try:
new_variable.numpy()
except Exception as e:
assert type(e) == ValueError
assert type(e) == core.EnforceNotMet
try:
new_variable.backward()
except Exception as e:
assert type(e) == ValueError
assert type(e) == core.EnforceNotMet
try:
new_variable.clear_gradient()
except Exception as e:
assert type(e) == ValueError
assert type(e) == core.EnforceNotMet
def test_empty_grad(self):
with fluid.dygraph.guard():
......@@ -239,7 +263,7 @@ class TestImperative(unittest.TestCase):
try:
new_var.clear_gradient()
except Exception as e:
assert type(e) == ValueError
assert type(e) == core.EnforceNotMet
with fluid.dygraph.guard():
cur_program = fluid.Program()
......@@ -257,7 +281,7 @@ class TestImperative(unittest.TestCase):
new_var = fluid.dygraph.base.to_variable(x)
self.assertFalse(new_var.persistable)
new_var.persistable = True
self.assertFalse(new_var.persistable)
self.assertTrue(new_var.persistable)
def test_layer(self):
with fluid.dygraph.guard():
......
......@@ -70,7 +70,6 @@ class SimpleNet(fluid.Layer):
loss = fluid.layers.reshape(loss, shape=[-1, self.num_steps])
loss = fluid.layers.reduce_mean(loss, dim=[0])
loss = fluid.layers.reduce_sum(loss)
loss.permissions = True
return loss
......
......@@ -459,8 +459,7 @@ class TestDygraphOCRAttention(unittest.TestCase):
for batch_id in range(batch_num):
label_in = to_variable(label_in_np)
label_out = to_variable(label_out_np)
label_out._stop_gradient = True
label_out.trainable = False
label_out.stop_gradient = True
img = to_variable(image_np)
dy_prediction = ocr_attention(img, label_in)
label_out = fluid.layers.reshape(
......@@ -481,7 +480,7 @@ class TestDygraphOCRAttention(unittest.TestCase):
dy_grad_value = {}
for param in ocr_attention.parameters():
if param.trainable:
np_array = np.array(param._ivar._grad_ivar().value()
np_array = np.array(param._grad_ivar().value()
.get_tensor())
dy_grad_value[param.name + core.grad_var_suffix(
)] = np_array
......@@ -514,7 +513,7 @@ class TestDygraphOCRAttention(unittest.TestCase):
name='label_in', shape=[1], dtype='int64', lod_level=0)
static_label_out = fluid.layers.data(
name='label_out', shape=[1], dtype='int64', lod_level=0)
static_label_out._stop_gradient = True
static_label_out.stop_gradient = True
static_label_out.trainable = False
static_prediction = ocr_attention(images, static_label_in)
......
......@@ -83,7 +83,7 @@ class TestImperativeOptimizerBase(unittest.TestCase):
img = data[0]
label = data[1]
label._stop_gradient = True
label.stop_gradient = True
cost = mlp(img)
avg_loss = fluid.layers.reduce_mean(cost)
......
......@@ -33,10 +33,10 @@ class TestImperativePartitialBackward(unittest.TestCase):
loss.backward()
for param in fc1.parameters():
self.assertIsNotNone(param._ivar._grad_ivar())
self.assertIsNotNone(param._grad_ivar())
for param in fc2.parameters():
self.assertIsNone(param._ivar._grad_ivar())
self.assertIsNone(param._grad_ivar())
optimizer = fluid.optimizer.AdamOptimizer()
_, params_grads = optimizer.minimize(loss)
......
......@@ -207,7 +207,6 @@ class PtbModel(fluid.Layer):
loss = fluid.layers.reshape(loss, shape=[-1, self.num_steps])
loss = fluid.layers.reduce_mean(loss, dim=[0])
loss = fluid.layers.reduce_sum(loss)
loss.permissions = True
return loss, last_hidden, last_cell
......
......@@ -302,7 +302,7 @@ class TestDygraphResnet(unittest.TestCase):
dy_grad_value = {}
for param in resnet.parameters():
if param.trainable:
np_array = np.array(param._ivar._grad_ivar().value()
np_array = np.array(param._grad_ivar().value()
.get_tensor())
dy_grad_value[param.name + core.grad_var_suffix(
)] = np_array
......
......@@ -119,7 +119,7 @@ class TestDygraphResnetSortGradient(unittest.TestCase):
dy_grad_value = {}
for param in resnet.parameters():
if param.trainable:
np_array = np.array(param._ivar._grad_ivar().value()
np_array = np.array(param._grad_ivar().value()
.get_tensor())
dy_grad_value[param.name + core.grad_var_suffix(
)] = np_array
......
......@@ -197,7 +197,6 @@ class PtbModel(fluid.Layer):
loss = fluid.layers.reshape(loss, shape=[-1, self.num_steps])
loss = fluid.layers.reduce_mean(loss, dim=[0])
loss = fluid.layers.reduce_sum(loss)
loss.permissions = True
return loss, last_hidden, last_cell
......@@ -353,7 +352,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
# set to zero
for k, v in opti_dict.items():
np_t = v.numpy()
var = v._ivar.value().get_tensor()
var = v.value().get_tensor()
var.set(np.zeros_like(np_t), place)
self.assertTrue(np.sum(np.abs(v.numpy())) == 0)
......@@ -373,7 +372,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
state_dict = ptb_model.state_dict()
for k, v in state_dict.items():
np_t = v.numpy()
var = v._ivar.value().get_tensor()
var = v.value().get_tensor()
var.set(np.zeros_like(np_t), place)
......@@ -457,7 +456,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
# set to zero
for k, v in opti_dict.items():
np_t = v.numpy()
var = v._ivar.value().get_tensor()
var = v.value().get_tensor()
var.set(np.zeros_like(np_t), place)
self.assertTrue(np.sum(np.abs(v.numpy())) == 0)
......@@ -476,7 +475,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
state_dict = ptb_model.state_dict()
for k, v in state_dict.items():
np_t = v.numpy()
var = v._ivar.value().get_tensor()
var = v.value().get_tensor()
var.set(np.zeros_like(np_t), place)
......@@ -562,7 +561,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
for k, v in opti_dict.items():
np_t = v.numpy()
np_opti_dict[v.name] = np_t
var = v._ivar.value().get_tensor()
var = v.value().get_tensor()
var.set(np.zeros_like(np_t), place)
self.assertTrue(np.sum(np.abs(v.numpy())) == 0)
......@@ -583,7 +582,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
for k, v in state_dict.items():
np_t = v.numpy()
np_state_dict[v.name] = np_t
var = v._ivar.value().get_tensor()
var = v.value().get_tensor()
var.set(np.zeros_like(np_t), place)
......
......@@ -361,7 +361,7 @@ class TestImperativeResneXt(unittest.TestCase):
#dy_grad_value = {}
#for param in se_resnext.parameters():
# if param.trainable:
# np_array = np.array(param._ivar._grad_ivar().value()
# np_array = np.array(param._grad_ivar().value()
# .get_tensor())
# dy_grad_value[param.name + core.grad_var_suffix()] = np_array
......
......@@ -78,7 +78,6 @@ class SimpleNet(fluid.Layer):
loss = fluid.layers.reshape(loss, shape=[-1, self.num_steps])
loss = fluid.layers.reduce_mean(loss, dim=[0])
loss = fluid.layers.reduce_sum(loss)
loss.permissions = True
return loss
......
......@@ -25,8 +25,8 @@ import numpy as np
main_program = default_main_program()
class TestParameter(unittest.TestCase):
def test_param(self):
class ParameterChecks(unittest.TestCase):
def check_param(self):
shape = [784, 100]
val = 1.0625
b = main_program.global_block()
......@@ -46,7 +46,7 @@ class TestParameter(unittest.TestCase):
p = io.get_parameter_value_by_name('fc.w', exe, main_program)
self.assertTrue(np.allclose(np.array(p), np.ones(shape) * val))
def test_exceptions(self):
def check_exceptions(self):
b = main_program.global_block()
with self.assertRaises(ValueError):
b.create_parameter(
......@@ -62,5 +62,13 @@ class TestParameter(unittest.TestCase):
name='test', shape=[-1], dtype='float32', initializer=None)
class TestParameter(ParameterChecks):
def test_param(self):
self.check_param()
def test_exceptions(self):
self.check_exceptions()
if __name__ == '__main__':
unittest.main()
......@@ -208,7 +208,6 @@ class PtbModel(fluid.Layer):
loss = fluid.layers.reshape(loss, shape=[-1, self.num_steps])
loss = fluid.layers.reduce_mean(loss, dim=[0])
loss = fluid.layers.reduce_sum(loss)
loss.permissions = True
return loss, last_hidden, last_cell
......
......@@ -184,6 +184,28 @@ class TestVariable(unittest.TestCase):
with fluid.program_guard(default_main_program()):
self._tostring()
# NOTE(zhiqiu): for coverage CI
# TODO(zhiqiu): code clean for dygraph
def test_dygraph_deprecated_api(self):
b = default_main_program().current_block()
var = b.create_var(dtype="float64", lod_level=0)
with fluid.dygraph.guard():
self.assertIsNone(var.detach())
self.assertIsNone(var.numpy())
self.assertIsNone(var.set_value(None))
self.assertIsNone(var.backward())
self.assertIsNone(var.gradient())
self.assertIsNone(var.clear_gradient())
self.assertIsNone(var.to_string(True))
self.assertIsNone(var.persistable)
var.stop_gradient = True
self.assertIsNone(var.stop_gradient)
var.stop_gradient = 'tmp'
self.assertIsNone(var.name)
self.assertIsNone(var.shape)
self.assertIsNone(var.dtype)
self.assertIsNone(var.type)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册