diff --git a/python/oneflow/nn/graph/graph.py b/python/oneflow/nn/graph/graph.py index 23a2a21846b2620c051a4392a607e71f9da77068..6ccfde0e26fd77ed2ec43b7666735935e240ce3d 100644 --- a/python/oneflow/nn/graph/graph.py +++ b/python/oneflow/nn/graph/graph.py @@ -941,8 +941,15 @@ class Graph(object): self._add_block(name, value) elif isinstance(value, Optimizer): raise AttributeError( - "'{}' object are not allowed to set Optimizer attribute named '{}', " - "please use add_optimizer(...) instead.".format( + "'{}' nn.Graph is not allowed to set Optimizer attribute named '{}'. " + "Please use add_optimizer(...) instead.".format( + type(self).__name__, name + ) + ) + elif isinstance(value, Tensor): + raise AttributeError( + "'{}' nn.Graph is not allowed to set Tensor attribute named '{}'. " + "Please use nn.Module to hold the tensor, then add the nn.Module to nn.Graph.".format( type(self).__name__, name ) ) diff --git a/python/oneflow/nn/optimizer/optimizer.py b/python/oneflow/nn/optimizer/optimizer.py index 1627c578ccb7261b6273039c911ceee0572ca6d8..afe62aa461842e4c7c23f812796b8d532d0a098c 100644 --- a/python/oneflow/nn/optimizer/optimizer.py +++ b/python/oneflow/nn/optimizer/optimizer.py @@ -20,6 +20,7 @@ from itertools import chain from typing import Any, Callable, Dict, Union from oneflow.framework.tensor import Tensor +from oneflow.nn.graph.block import TensorBlock from oneflow.nn.parameter import Parameter from oneflow.nn.utils.clip_grad import clip_grad_norm_ @@ -28,10 +29,21 @@ class ParamGroup(object): def __init__( self, parameters: Dict[str, Any], default_options: Dict, ): - # ParamGroup must be constructed by Dict["params": parameters: List[Parameter or Tensor], "...": ...] + # ParamGroup must be constructed by Dict["params": parameters: List[Parameter, Tensor or TensorBlock], "...": ...] assert isinstance(parameters, dict) and "params" in parameters assert not isinstance(parameters["params"], (Parameter, Tensor)) - self._parameters = list(parameters["params"]) + self._parameters = list() + for p in parameters["params"]: + if isinstance(p, (Parameter, Tensor)): + self._parameters.append(p) + elif isinstance(p, TensorBlock): + # Add parameter from nn.Graph + self._parameters.append(p.origin) + else: + raise ValueError( + "parameters in ParamGroup must be Tensor or TensorBlock." + ) + self._options = deepcopy(default_options) for key in self._options: if key in parameters: diff --git a/python/oneflow/test/graph/test_graph.py b/python/oneflow/test/graph/test_graph.py index 536bd170c5de465a47064f8641f7aad2a03468c0..d66756c03f75c088784139fb6cf446c916146701 100644 --- a/python/oneflow/test/graph/test_graph.py +++ b/python/oneflow/test/graph/test_graph.py @@ -250,6 +250,31 @@ class TestGraph(flow.unittest.TestCase): y = flow.tensor(y, dtype=flow.float32) g._compile(x, y) + def test_create_optimizer_in_graph(test_case): + device = "cuda" + linear = flow.nn.Linear(3, 8) + linear = linear.to(device) + flow.nn.init.constant_(linear.weight, 2.068758) + flow.nn.init.constant_(linear.bias, 0.23) + + class OptCreatedInGraph(flow.nn.Graph): + def __init__(self): + super().__init__() + self.linear = linear + # creat optimizer in nn.Graph and add parameter from ModuleBlock + self.add_optimizer( + flow.optim.SGD(self.linear.parameters(), lr=0.001, momentum=0.9) + ) + + def build(self, x): + out = self.linear(x) + out = out.sum() + out.backward() + return out + + g = OptCreatedInGraph() + print(g) + if __name__ == "__main__": unittest.main()