未验证 提交 9ed5706f 编写于 作者: X Xiaoyu Xu 提交者: GitHub

support create opt in graph (#6598)

* support create opt in graph

* add comment
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 a07c5ade
......@@ -941,8 +941,15 @@ class Graph(object):
self._add_block(name, value)
elif isinstance(value, Optimizer):
raise AttributeError(
"'{}' object are not allowed to set Optimizer attribute named '{}', "
"please use add_optimizer(...) instead.".format(
"'{}' nn.Graph is not allowed to set Optimizer attribute named '{}'. "
"Please use add_optimizer(...) instead.".format(
type(self).__name__, name
)
)
elif isinstance(value, Tensor):
raise AttributeError(
"'{}' nn.Graph is not allowed to set Tensor attribute named '{}'. "
"Please use nn.Module to hold the tensor, then add the nn.Module to nn.Graph.".format(
type(self).__name__, name
)
)
......
......@@ -20,6 +20,7 @@ from itertools import chain
from typing import Any, Callable, Dict, Union
from oneflow.framework.tensor import Tensor
from oneflow.nn.graph.block import TensorBlock
from oneflow.nn.parameter import Parameter
from oneflow.nn.utils.clip_grad import clip_grad_norm_
......@@ -28,10 +29,21 @@ class ParamGroup(object):
def __init__(
self, parameters: Dict[str, Any], default_options: Dict,
):
# ParamGroup must be constructed by Dict["params": parameters: List[Parameter or Tensor], "...": ...]
# ParamGroup must be constructed by Dict["params": parameters: List[Parameter, Tensor or TensorBlock], "...": ...]
assert isinstance(parameters, dict) and "params" in parameters
assert not isinstance(parameters["params"], (Parameter, Tensor))
self._parameters = list(parameters["params"])
self._parameters = list()
for p in parameters["params"]:
if isinstance(p, (Parameter, Tensor)):
self._parameters.append(p)
elif isinstance(p, TensorBlock):
# Add parameter from nn.Graph
self._parameters.append(p.origin)
else:
raise ValueError(
"parameters in ParamGroup must be Tensor or TensorBlock."
)
self._options = deepcopy(default_options)
for key in self._options:
if key in parameters:
......
......@@ -250,6 +250,31 @@ class TestGraph(flow.unittest.TestCase):
y = flow.tensor(y, dtype=flow.float32)
g._compile(x, y)
def test_create_optimizer_in_graph(test_case):
device = "cuda"
linear = flow.nn.Linear(3, 8)
linear = linear.to(device)
flow.nn.init.constant_(linear.weight, 2.068758)
flow.nn.init.constant_(linear.bias, 0.23)
class OptCreatedInGraph(flow.nn.Graph):
def __init__(self):
super().__init__()
self.linear = linear
# creat optimizer in nn.Graph and add parameter from ModuleBlock
self.add_optimizer(
flow.optim.SGD(self.linear.parameters(), lr=0.001, momentum=0.9)
)
def build(self, x):
out = self.linear(x)
out = out.sum()
out.backward()
return out
g = OptCreatedInGraph()
print(g)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册