提交 54e07994 编写于 作者: Y Youwei Song 提交者: Tao Luo

Dygraph Layer kwargs & param getter setter (#19901)

* opt FC

* opt rest of dygraph.nn

* new param shape check and unittest

* add kwargs for Layer

* add new set_value api

* use property decorator

* update API.spec, test=develop

* use UserList, separate gettersetters, test=develop

* update test_custom_layer_with_kwargs, test=develop

* fix UserList compatibility, test=develop

* fix UserList compatibility, test=develop

* keep FC._w, test=develop

* add unittests, Conv3D bug fix, test=develop

* clean code, test=develop

* fix dygraph guard in unittest, test=develop

* add property setters, remove unused param in tracer, test=develop

* tracer none check, test=develop

* merge, test=develop

* refine, test=develop

* bug fix in  prelu and conv3d_transpose, test=develop

* rm __set__, test=develop

* set tensor value instead of assign op

* fix property setter call, test=develop

* fix api.spec, test=develop

* fix doc sample, test=develop
上级 9de67725
此差异已折叠。
...@@ -150,20 +150,20 @@ class Layer(core.Layer): ...@@ -150,20 +150,20 @@ class Layer(core.Layer):
if p.trainable: if p.trainable:
p.clear_gradient() p.clear_gradient()
def _build_once(self, *args): def _build_once(self, *args, **kwargs):
pass pass
def __call__(self, *inputs): def __call__(self, *inputs, **kwargs):
if not self._built: if not self._built:
self._build_once(*inputs) self._build_once(*inputs, **kwargs)
if parallel_helper._is_data_parallel_mode(): if parallel_helper._is_data_parallel_mode():
parallel_helper._broadcast_parameters(self._parameters.values()) parallel_helper._broadcast_parameters(self._parameters.values())
outputs = self.forward(*inputs) outputs = self.forward(*inputs, **kwargs)
self._built = True self._built = True
return outputs return outputs
def forward(self, *inputs): def forward(self, *inputs, **kwargs):
raise NotImplementedError raise NotImplementedError
def backward(self, *inputs): def backward(self, *inputs):
...@@ -216,6 +216,8 @@ class Layer(core.Layer): ...@@ -216,6 +216,8 @@ class Layer(core.Layer):
return object.__getattribute__(self, name) return object.__getattribute__(self, name)
def __setattr__(self, name, value): def __setattr__(self, name, value):
if isinstance(getattr(type(self), name, None), property):
object.__setattr__(self, name, value)
if isinstance(value, framework.Parameter): if isinstance(value, framework.Parameter):
params = self.__dict__.get('_parameters', None) params = self.__dict__.get('_parameters', None)
if params is None: if params is None:
...@@ -226,6 +228,11 @@ class Layer(core.Layer): ...@@ -226,6 +228,11 @@ class Layer(core.Layer):
tensor = var.get_tensor() tensor = var.get_tensor()
tensor.set(self._loaddict_holder[value.name].numpy(), tensor.set(self._loaddict_holder[value.name].numpy(),
framework._current_expected_place()) framework._current_expected_place())
if name in params:
# remove unused param in tracer
if framework._dygraph_tracer_ is not None:
framework._dygraph_tracer_._vars.pop(params[name].name,
None)
params[name] = value params[name] = value
elif isinstance(value, core.Layer): elif isinstance(value, core.Layer):
layers = self.__dict__.get('_sub_layers', None) layers = self.__dict__.get('_sub_layers', None)
......
此差异已折叠。
...@@ -638,6 +638,45 @@ class Variable(object): ...@@ -638,6 +638,45 @@ class Variable(object):
new_ivar = self._ivar._copy_to(core.CPUPlace(), True) new_ivar = self._ivar._copy_to(core.CPUPlace(), True)
return np.array(new_ivar.value().get_tensor()) return np.array(new_ivar.value().get_tensor())
@dygraph_only
def set_value(self, value):
"""
Set a new value for this Variable.
Args:
value (Variable|np.ndarray): the new value.
Returns:
None.
Examples:
.. code-block:: python
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph import FC
import numpy as np
data = np.ones([3, 32, 32], dtype='float32')
with fluid.dygraph.guard():
fc = fluid.dygraph.FC("fc", 4)
t = to_variable(data)
fc(t) # call with default weight
custom_weight = np.random.randn(1024, 4).astype("float32")
fc.weight.set_value(custom_weight) # change existing weight
out = fc(t) # call with different weight
"""
assert isinstance(value, (Variable, np.ndarray))
if list(value.shape) != list(self.shape):
raise ValueError(
"The shape of the new value must be the same as that of the original Variable."
)
self_tensor = self._ivar.value().get_tensor()
if isinstance(value, Variable):
value = value._ivar.value().get_tensor().__array__()
self_tensor.set(value, _current_expected_place())
@dygraph_only @dygraph_only
def backward(self, backward_strategy=None): def backward(self, backward_strategy=None):
""" """
...@@ -1042,7 +1081,7 @@ class Variable(object): ...@@ -1042,7 +1081,7 @@ class Variable(object):
if self.shape[axis] < 0: if self.shape[axis] < 0:
return self._cloneVar(True) return self._cloneVar(True)
index = int(item) index = int(item)
if (index > 0 and index >= self.shape[axis])\ if (index > 0 and index >= self.shape[axis]) \
or (index < 0 and (index + self.shape[axis]) < 0): or (index < 0 and (index + self.shape[axis]) < 0):
raise IndexError("invalid index") raise IndexError("invalid index")
return self._sliceVar([axis], [index], [index + 1]) return self._sliceVar([axis], [index], [index + 1])
...@@ -2662,10 +2701,10 @@ class IrOpNode(IrNode): ...@@ -2662,10 +2701,10 @@ class IrOpNode(IrNode):
if isinstance(val, Block): if isinstance(val, Block):
desc.set_block_attr(name, val.desc) desc.set_block_attr(name, val.desc)
elif isinstance(val, list) and val and \ elif isinstance(val, list) and val and \
all(isinstance(v, Block) for v in val): all(isinstance(v, Block) for v in val):
desc.set_blocks_attr(name, [v.desc for v in val]) desc.set_blocks_attr(name, [v.desc for v in val])
elif isinstance(val, core.BlockDesc) or \ elif isinstance(val, core.BlockDesc) or \
isinstance(val, core.ProgramDesc): isinstance(val, core.ProgramDesc):
desc.set_serialized_attr(name, val.serialize_to_string()) desc.set_serialized_attr(name, val.serialize_to_string())
else: else:
desc._set_attr(name, val) desc._set_attr(name, val)
...@@ -2888,8 +2927,8 @@ class IrGraph(object): ...@@ -2888,8 +2927,8 @@ class IrGraph(object):
op_node(IrOpNode): the operator node that is needed to update input's link. op_node(IrOpNode): the operator node that is needed to update input's link.
""" """
assert old_input_node.node in self.graph.nodes() and new_input_node.node in \ assert old_input_node.node in self.graph.nodes() and new_input_node.node in \
self.graph.nodes() and op_node.node in self.graph.nodes(), \ self.graph.nodes() and op_node.node in self.graph.nodes(), \
'The three arguments(old_input_node&new_input_node&op_node) must be in the graph nodes.' 'The three arguments(old_input_node&new_input_node&op_node) must be in the graph nodes.'
old_input_node.remove_output(op_node) old_input_node.remove_output(op_node)
op_node.remove_input(old_input_node) op_node.remove_input(old_input_node)
new_input_node.append_output(op_node) new_input_node.append_output(op_node)
...@@ -3024,7 +3063,7 @@ class IrGraph(object): ...@@ -3024,7 +3063,7 @@ class IrGraph(object):
def _convert_to_pdf(dot_file_path): def _convert_to_pdf(dot_file_path):
pdf_save_path = os.path.splitext(dot_file_path)[0] + '.pdf' pdf_save_path = os.path.splitext(dot_file_path)[0] + '.pdf'
exited_code = subprocess.call('dot -Tpdf ' + dot_file_path \ exited_code = subprocess.call('dot -Tpdf ' + dot_file_path \
+ ' -o ' + pdf_save_path, shell=True) + ' -o ' + pdf_save_path, shell=True)
if exited_code != 0: if exited_code != 0:
print('The dot command is needed for creating pdf files.') print('The dot command is needed for creating pdf files.')
print('The {} is saved as the dot filetype.'.format( print('The {} is saved as the dot filetype.'.format(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册