提交 54e07994 编写于 作者: Y Youwei Song 提交者: Tao Luo

Dygraph Layer kwargs & param getter setter (#19901)

* opt FC

* opt rest of dygraph.nn

* new param shape check and unittest

* add kwargs for Layer

* add new set_value api

* use property decorator

* update API.spec, test=develop

* use UserList, separate gettersetters, test=develop

* update test_custom_layer_with_kwargs, test=develop

* fix UserList compatibility, test=develop

* fix UserList compatibility, test=develop

* keep FC._w, test=develop

* add unittests, Conv3D bug fix, test=develop

* clean code, test=develop

* fix dygraph guard in unittest, test=develop

* add property setters, remove unused param in tracer, test=develop

* tracer none check, test=develop

* merge, test=develop

* refine, test=develop

* bug fix in  prelu and conv3d_transpose, test=develop

* rm __set__, test=develop

* set tensor value instead of assign op

* fix property setter call, test=develop

* fix api.spec, test=develop

* fix doc sample, test=develop
上级 9de67725
此差异已折叠。
......@@ -150,20 +150,20 @@ class Layer(core.Layer):
if p.trainable:
p.clear_gradient()
def _build_once(self, *args):
def _build_once(self, *args, **kwargs):
pass
def __call__(self, *inputs):
def __call__(self, *inputs, **kwargs):
if not self._built:
self._build_once(*inputs)
self._build_once(*inputs, **kwargs)
if parallel_helper._is_data_parallel_mode():
parallel_helper._broadcast_parameters(self._parameters.values())
outputs = self.forward(*inputs)
outputs = self.forward(*inputs, **kwargs)
self._built = True
return outputs
def forward(self, *inputs):
def forward(self, *inputs, **kwargs):
raise NotImplementedError
def backward(self, *inputs):
......@@ -216,6 +216,8 @@ class Layer(core.Layer):
return object.__getattribute__(self, name)
def __setattr__(self, name, value):
if isinstance(getattr(type(self), name, None), property):
object.__setattr__(self, name, value)
if isinstance(value, framework.Parameter):
params = self.__dict__.get('_parameters', None)
if params is None:
......@@ -226,6 +228,11 @@ class Layer(core.Layer):
tensor = var.get_tensor()
tensor.set(self._loaddict_holder[value.name].numpy(),
framework._current_expected_place())
if name in params:
# remove unused param in tracer
if framework._dygraph_tracer_ is not None:
framework._dygraph_tracer_._vars.pop(params[name].name,
None)
params[name] = value
elif isinstance(value, core.Layer):
layers = self.__dict__.get('_sub_layers', None)
......
此差异已折叠。
......@@ -638,6 +638,45 @@ class Variable(object):
new_ivar = self._ivar._copy_to(core.CPUPlace(), True)
return np.array(new_ivar.value().get_tensor())
@dygraph_only
def set_value(self, value):
"""
Set a new value for this Variable.
Args:
value (Variable|np.ndarray): the new value.
Returns:
None.
Examples:
.. code-block:: python
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph import FC
import numpy as np
data = np.ones([3, 32, 32], dtype='float32')
with fluid.dygraph.guard():
fc = fluid.dygraph.FC("fc", 4)
t = to_variable(data)
fc(t) # call with default weight
custom_weight = np.random.randn(1024, 4).astype("float32")
fc.weight.set_value(custom_weight) # change existing weight
out = fc(t) # call with different weight
"""
assert isinstance(value, (Variable, np.ndarray))
if list(value.shape) != list(self.shape):
raise ValueError(
"The shape of the new value must be the same as that of the original Variable."
)
self_tensor = self._ivar.value().get_tensor()
if isinstance(value, Variable):
value = value._ivar.value().get_tensor().__array__()
self_tensor.set(value, _current_expected_place())
@dygraph_only
def backward(self, backward_strategy=None):
"""
......@@ -1042,7 +1081,7 @@ class Variable(object):
if self.shape[axis] < 0:
return self._cloneVar(True)
index = int(item)
if (index > 0 and index >= self.shape[axis])\
if (index > 0 and index >= self.shape[axis]) \
or (index < 0 and (index + self.shape[axis]) < 0):
raise IndexError("invalid index")
return self._sliceVar([axis], [index], [index + 1])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册