graph.py 35.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from collections import OrderedDict
from functools import partial
X
Xiaoyu Xu 已提交
18
from typing import Dict, Optional, Union, List
19

20
import oneflow
21 22 23 24
import oneflow._oneflow_internal
import oneflow.framework.c_api_util as c_api_util
import oneflow.framework.graph_build_util as graph_build_util
import oneflow.framework.session_context as session_ctx
25
from oneflow.amp import GradScaler, StaticGradScaler
26
from oneflow.env import get_rank
27
from oneflow.framework.multi_client_session import MultiClientSession
28
from oneflow.framework.tensor import Tensor, TensorTuple
29
from oneflow.framework.tensor_tuple_util import convert_to_tensor_tuple
X
Xiaoyu Xu 已提交
30 31
from oneflow.nn.graph.block import Block, BlockType, get_block_cls
from oneflow.nn.graph.graph_config import GraphConfig
32
from oneflow.nn.graph.optimizer import OptDict, VariableConfig
33
from oneflow.nn.graph.util import add_indent, seq_to_func_return, sys_exc_error_msg
34
from oneflow.nn.module import Module
35
from oneflow.nn.optimizer.lr_scheduler import LrScheduler
36
from oneflow.nn.optimizer.optimizer import Optimizer
37 38 39


class Graph(object):
X
Xiaoyu Xu 已提交
40 41 42 43 44 45 46 47 48 49 50
    r"""Base class for training or evaluating a neural network in graph mode.

    To use graph mode for model training or evaluation in OneFlow, you should:

    1. Define your customized graph as a subclass of ``nn.Graph``.
    2. Add ``super().__init__()`` in your subclass's ``__init__()``.
    3. Add modules to your graph as regular attributes.
    4. Define computation logical in ``build()`` method.
    5. Instantiate your graph then call it.

    .. code-block:: python
51

X
Xiaoyu Xu 已提交
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
        >>> import oneflow as flow

        >>> class LinearGraph(flow.nn.Graph):
        ...    def __init__(self):
        ...        super().__init__()
        ...        # Add a module to the graph.
        ...        self.linear = flow.nn.Linear(3, 8, False)
        ...    def build(self, x):
        ...        # Use the module to build the computation logic of the graph.
        ...        return self.linear(x)

        # Instantiate the graph
        >>> linear_graph = LinearGraph()
        >>> x = flow.randn(4, 3)

        # First call on graph will run graph's build() method to
        # trace a computatioin graph. Then the computation graph will be
        # optimized and executed for the first time.
        >>> linear_graph(x).shape
71
        oneflow.Size([4, 8])
X
Xiaoyu Xu 已提交
72 73 74

        # Later call on graph will execute the computation graph directly.
        >>> linear_graph(x).shape
75
        oneflow.Size([4, 8])
X
Xiaoyu Xu 已提交
76 77 78

    Note that Graph cannot be nested at the moment.
    """
79 80 81
    _child_init_cnt = dict()

    def __init__(self):
X
Xiaoyu Xu 已提交
82 83 84 85
        """
        Initializes internal Graph states. It MUST be called in ``__init__`` method of subclass.

        .. code-block:: python
86

X
Xiaoyu Xu 已提交
87 88 89 90 91 92
            >>> import oneflow as flow
            >>> class SubclassGraph(flow.nn.Graph):
            ...     def __init__(self):
            ...         super().__init__() # MUST be called
            ...         # Then define the graph attributes
            ...     def build(self):
93 94
            ...         pass

X
Xiaoyu Xu 已提交
95
        """
96
        self._generate_name()
97
        self.config = GraphConfig()
98
        self._blocks = OrderedDict()
99
        self._opts = []
100
        self._grad_scaler = None
101
        self._variables_conf = OrderedDict()
102
        self._is_compiled = False
103 104 105 106
        # forward graph job proto
        self._forward_job_proto = None
        # forward, backward and optimized graph job proto
        self._full_job_proto = None
107 108 109
        self._args_repr = []
        self._outs_repr = []
        self._debug = False
X
Xiaoyu Xu 已提交
110 111 112
        self._debug_min_s_level = 2
        self._debug_max_v_level = 0
        self._outputs_buffer_size = 2
113 114
        self._cur_index_of_ouputs_buffer = 0

115 116 117 118 119
        self._c_nn_graph = oneflow._oneflow_internal.nn.graph.CNNGraph(self._name)
        session = session_ctx.GetDefaultSession()
        assert type(session) is MultiClientSession
        session.TryInit()
        session.AddCGraph(self._c_nn_graph)
120 121

    def build(self, *args):
X
Xiaoyu Xu 已提交
122 123 124 125 126 127 128 129 130 131
        r"""The ``build()`` method must be overridden to define neural network
        computaion logic.

        The ``build()`` method of nn.Graph is very similar to the ``forward()``
        method of nn.Module. It is used to describe the computatioin logical of
        a neural network.

        When a graph object being called for the first time, the ``build()``
        method will be called implicitly to build the computatioin graph.

132 133 134 135
        Make sure to call modules's ``train()`` or ``eval()`` method before the
        first call of your graph to make the module executing the right
        training or evaluation logic if needed.

X
Xiaoyu Xu 已提交
136
        .. code-block:: python
137

X
Xiaoyu Xu 已提交
138 139 140 141 142 143 144 145 146 147 148 149
            >>> import oneflow as flow
            >>> class MyGraph(flow.nn.Graph):
            ...     def __init__(self):
            ...         super().__init__()
            ...         self.linear = flow.nn.Linear(3, 8, False)
            ...     def build(self, x):
            ...         return self.linear(x)

            >>> linear_graph = MyGraph()
            >>> x = flow.randn(4, 3)
            >>> y = linear_graph(x) # The build() method is called implicitly

150 151 152 153 154 155 156
        Note that ``build()`` method's inputs and outputs only accept positional
        arguements at the moment, each argument must be one of these types:

        * ``Tensor``
        * ``list`` of ``Tensor``
        * ``None``

X
Xiaoyu Xu 已提交
157
        """
158 159 160
        raise NotImplementedError()

    def add_optimizer(
161
        self, optim: Optimizer, *, lr_sch: LrScheduler = None,
162
    ):
X
Xiaoyu Xu 已提交
163 164
        r"""Add an optimizer, an learning rate scheduler to the graph.

165 166
        To do training with nn.Graph, you should do 2 more things:

X
Xiaoyu Xu 已提交
167
        1. Add at least one optimizer(learning rate schedulers are optional) with ``add_optimizer()`` method.
X
Xiaoyu Xu 已提交
168
        2. Call loss tensor's ``backward()`` method in ``build()`` method.
169

X
Xiaoyu Xu 已提交
170
        Note that the computaion graph will automatically execute these methods:
171

X
Xiaoyu Xu 已提交
172 173 174 175 176 177 178 179 180 181
        * optimizer's ``clip_grad()`` if a optimizer is set to do grad cliping.
        * optimizer's ``step()``.
        * optimizer's ``zero_grad()``.
        * learn rate scheduler's ``step()``.

        Also note that only scalar tensor are allowed to call ``backward()``
        in ``nn.Graph.build()`` for the moment. So you may call ``Tensor.sum()``
        or ``Tensor.mean()`` to make the loss tensor a scalar tensor.

        .. code-block:: python
182

X
Xiaoyu Xu 已提交
183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210
            >>> import oneflow as flow
            >>> loss_fn = flow.nn.MSELoss(reduction="sum")
            >>> model = flow.nn.Sequential(flow.nn.Linear(3, 1), flow.nn.Flatten(0, 1))
            >>> optimizer = flow.optim.SGD(model.parameters(), lr=1e-6)
            >>> class LinearTrainGraph(flow.nn.Graph):
            ...     def __init__(self):
            ...         super().__init__()
            ...         self.model = model
            ...         self.loss_fn = loss_fn
            ...         # Add an optimizer
            ...         self.add_optimizer(optimizer)
            ...     def build(self, x, y):
            ...         y_pred = self.model(x)
            ...         loss = self.loss_fn(y_pred, y)
            ...         # Call loss tensor's backward(), loss tensor must be a scalar tensor
            ...         loss.backward()
            ...         return loss

            >>> linear_graph = LinearTrainGraph()
            >>> x = flow.randn(10, 3)
            >>> y = flow.randn(10)
            >>> for t in range(3):
            ...     loss = linear_graph(x, y)

        Args:
            optim (oneflow.optim.Optimizer): The optimizer.
            lr_sch : The learning rate scheduler, see oneflow.optim.lr_scheduler.
        """
211 212
        opt_dict = dict()
        assert optim is not None, "optimizer cannot be None"
213
        assert isinstance(
214
            optim, Optimizer
215
        ), "optimizer must be an instance of Optimizer"
216 217 218 219 220 221 222 223
        opt_dict["optim"] = optim
        if lr_sch is not None:
            assert isinstance(lr_sch, LrScheduler)
            assert (
                lr_sch._optimizer is optim
            ), "lr_scheduler's optimizer must be the same optimizer in add_optimizer."
            opt_dict["lr_sch"] = lr_sch
        self._opts.append(opt_dict)
X
Xiaoyu Xu 已提交
224 225 226
        # Set the training config if there is an optimizer add in graph.
        if len(self._opts) == 1:
            self.config._train(True)
227

228
    def set_grad_scaler(self, grad_scaler: GradScaler = None):
X
Xiaoyu Xu 已提交
229 230
        r"""Set the GradScaler for gradient and loss scaling.
        """
231
        assert isinstance(grad_scaler, (GradScaler, StaticGradScaler))
232 233
        self._grad_scaler = grad_scaler

X
Xiaoyu Xu 已提交
234 235 236 237 238 239 240 241 242
    def __call__(self, *args):
        r"""Call nn.Graph subclass instance to run your customized graph.

        Call your customized graph after the instantiation:

        .. code-block:: python

            g = CustomGraph()
            out_tensors = g(input_tensors)
243

244 245 246
        The inputs of ``__call__`` method must match the inputs of ``build()``
        method. And the ``__call__`` method will return outputs matching the
        outputs of ``build()`` method.
X
Xiaoyu Xu 已提交
247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269

        Note that the first call takes longer than later calls, because nn.Graph
        will do the computaion graph generation and optimization at the first call.

        Donot override this function.
        """
        if not self._is_compiled:
            self._compile(*args)

        return self._run(*args)

    @property
    def name(self):
        r"""Name auto-generated for this graph.
        """
        return self._name

    @property
    def training(self):
        r"""In traninig mode if the graph has an optimizer.
        """
        return self.config.training

X
Xiaoyu Xu 已提交
270
    def debug(
X
Xiaoyu Xu 已提交
271 272 273
        self,
        v_level: int = 0,
        ranks: Optional[Union[int, List[int]]] = None,
X
Xiaoyu Xu 已提交
274
        mode: bool = True,
X
Xiaoyu Xu 已提交
275
    ) -> None:
X
Xiaoyu Xu 已提交
276 277
        r"""Open or close debug mode of the graph.

X
Xiaoyu Xu 已提交
278 279
        If in debug mode, logs of computation graph building infos or warnings will be
        printed. Otherwise, only errors will be printed.
X
Xiaoyu Xu 已提交
280

X
Xiaoyu Xu 已提交
281
        Use ``v_level`` to choose verbose debug info level, default level is 0, max level is 1.
X
Xiaoyu Xu 已提交
282 283
        ``v_level`` 0 will print warning and graph creating stages. ``v_level`` 1 will additionally
        print graph build info of each module.
X
Xiaoyu Xu 已提交
284
        
X
Xiaoyu Xu 已提交
285
        Use ``ranks`` to choose which rank to print the debug information.
X
Xiaoyu Xu 已提交
286 287 288 289

        .. code-block:: python

            g = CustomGraph()
X
Xiaoyu Xu 已提交
290
            g.debug()  # Open debug mode
X
Xiaoyu Xu 已提交
291 292
            out_tensors = g(input_tensors)  # Will print log for debug at the first call

X
Xiaoyu Xu 已提交
293
        Args:
X
Xiaoyu Xu 已提交
294
            v_level (int): choose verbose debug info level, default v_level is 0, max v_level is 1.
X
Xiaoyu Xu 已提交
295 296
            ranks (int or list(int)): choose ranks to print the debug information. Default rank ``0``.
                You can choose any valid rank. Ranks equals ``-1`` means debug on all ranks.
X
Xiaoyu Xu 已提交
297
            mode (bool): whether to set debug mode (``True``) or not (``False``). Default: ``True``.
X
Xiaoyu Xu 已提交
298
        """
X
Xiaoyu Xu 已提交
299
        assert isinstance(v_level, int)
X
Xiaoyu Xu 已提交
300 301 302 303 304 305 306 307 308 309 310 311 312 313
        assert isinstance(mode, bool)

        if ranks is None:
            rank_list = [0]
        elif isinstance(ranks, int):
            rank_list = [ranks]
        elif isinstance(ranks, list):
            rank_list = ranks
        else:
            raise ValueError("ranks must be int or List[int].")

        my_rank = get_rank()
        if -1 in rank_list or my_rank in rank_list:
            self._debug = mode
X
Xiaoyu Xu 已提交
314 315 316
            if self._debug:
                self._debug_min_s_level = 0
                self._debug_max_v_level = v_level
X
Xiaoyu Xu 已提交
317 318
            for name, block in self._blocks.items():
                assert block.type == BlockType.MODULE
X
Xiaoyu Xu 已提交
319
                block.debug(v_level, ranks, mode)
X
Xiaoyu Xu 已提交
320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338

    def __repr__(self):
        r"""For printing the graph structure.

        The graph structure can be printed after graph instantiation.

        After the first call of graph, inputs and outputs will be added to
        the graph structure.

        .. code-block:: python

            g = CustomGraph()
            print(g)

            out_tensors = g(input_tensors)
            print(g) # Inputs and Outputs infos are added

        """
        child_lines = []
X
Xiaoyu Xu 已提交
339
        child_lines.append(add_indent(repr(self.config), 2))
X
Xiaoyu Xu 已提交
340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365
        if len(self._args_repr) > 0:
            for in_str in self._args_repr:
                input_str = add_indent(in_str, 2)
                child_lines.append(input_str)

        if len(self._blocks) > 0:
            for n, m in self._blocks.items():
                mod_str = repr(m)
                mod_str = add_indent(mod_str, 2)
                child_lines.append(mod_str)

        if len(self._outs_repr) > 0:
            for out_str in self._outs_repr:
                output_str = add_indent(out_str, 2)
                child_lines.append(output_str)

        main_str = self._shallow_repr() + ": ("
        if len(child_lines) > 0:
            main_str += "\n  " + "\n  ".join(child_lines) + "\n"
        main_str += ")"
        return main_str

    def _shallow_repr(self):
        shallow_repr = "(GRAPH:" + self._name + ":" + self.__class__.__name__ + ")"
        return shallow_repr

X
Xiaoyu Xu 已提交
366 367 368 369 370 371 372 373 374
    def _print(self, s_level=2, v_level=0, msg: str = ""):
        r"""Do print according to info level.
        """
        assert isinstance(s_level, int)
        assert isinstance(v_level, int)
        assert isinstance(msg, str)
        if s_level >= self._debug_min_s_level:
            if (s_level > 0) or (s_level == 0 and v_level <= self._debug_max_v_level):
                print(msg)
375

X
Xiaoyu Xu 已提交
376 377 378 379 380 381 382 383 384 385 386 387
    @property
    def _config_proto(self):
        return self.config.proto

    @property
    def _optimization_conf_proto(self):
        session = session_ctx.GetDefaultSession()
        assert type(session) is MultiClientSession
        return session.resource

    @property
    def _graph_proto(self):
388
        if not self._is_compiled:
X
Xiaoyu Xu 已提交
389 390 391
            self._print(
                2,
                0,
392
                f"[ERROR]{self._shallow_repr()} has not been compiled, so it's graph proto is None."
X
Xiaoyu Xu 已提交
393
                " You can call the graph to trigger it's compilation.",
394 395 396 397 398 399
            )
        return self._forward_job_proto

    @property
    def _full_graph_proto(self):
        if not self._is_compiled:
X
Xiaoyu Xu 已提交
400 401 402
            self._print(
                2,
                0,
403
                f"[ERROR]{self._shallow_repr()} has not been compiled, so it's full graph proto is None."
X
Xiaoyu Xu 已提交
404
                " You can call the graph to trigger it's compilation.",
405 406
            )
        return self._full_job_proto
X
Xiaoyu Xu 已提交
407

408 409 410 411 412 413 414 415
    def _generate_name(self):
        child_name = self.__class__.__name__
        if Graph._child_init_cnt.get(child_name) is None:
            Graph._child_init_cnt[child_name] = 0
        self._name = child_name + "_" + str(Graph._child_init_cnt[child_name])
        Graph._child_init_cnt[child_name] += 1

    def _state(self):
416
        for _, b in self._blocks.items():
417 418 419 420 421 422 423
            pa_gen = b.parameters(recurse=True)
            for pa in pa_gen:
                yield pa
            bu_gen = b.buffers(recurse=True)
            for bu in bu_gen:
                yield bu

424 425 426 427
    def _generate_config_proto(self):
        self.config.proto.set_job_name(self._name)

        if self._grad_scaler is not None:
X
Xiaoyu Xu 已提交
428
            self._grad_scaler._generate_conf_for_graph(
429 430 431
                self.config.proto.mutable_train_conf()
            )

432 433
        for state_block in self._state():
            if state_block.type == BlockType.PARAMETER:
434
                self._variables_conf[state_block.origin] = VariableConfig(
435 436
                    state_block.name_prefix + state_block.name
                )
437 438
        for opt in self._opts:
            opt_dict = OptDict(opt)
439
            self.config._generate_optimizer_and_variable_configs(
440
                opt_dict, self._variables_conf
441
            )
442 443

    def _compile(self, *args):
444
        # Build graph
X
Xiaoyu Xu 已提交
445
        try:
X
Xiaoyu Xu 已提交
446
            self._print(0, 0, self._shallow_repr() + " start building graph.")
X
Xiaoyu Xu 已提交
447 448 449 450
            assert not self._is_compiled, (
                "nn.Graph " + self._name + " has already been compiled."
            )

451
            eager_outputs = self._build_graph(*args)
X
Xiaoyu Xu 已提交
452

X
Xiaoyu Xu 已提交
453
            self._print(0, 0, self._shallow_repr() + " end building graph.")
X
Xiaoyu Xu 已提交
454
        except:
X
Xiaoyu Xu 已提交
455 456 457
            self._print(
                2,
                0,
X
Xiaoyu Xu 已提交
458 459
                "[ERROR]"
                + self._shallow_repr()
460
                + " build graph got error: "
X
Xiaoyu Xu 已提交
461
                + sys_exc_error_msg(),
X
Xiaoyu Xu 已提交
462 463 464
            )
            raise

465
        # Complie graph to execution plan and init Runtime
X
Xiaoyu Xu 已提交
466
        try:
X
Xiaoyu Xu 已提交
467 468 469 470 471
            self._print(
                0,
                0,
                self._shallow_repr() + " start compiling plan and init graph runtime.",
            )
X
Xiaoyu Xu 已提交
472 473 474

            self._c_nn_graph.complie_and_init_runtime()

X
Xiaoyu Xu 已提交
475 476 477 478 479
            self._print(
                0,
                0,
                self._shallow_repr() + " end compiling plan and init graph rumtime.",
            )
X
Xiaoyu Xu 已提交
480
        except:
X
Xiaoyu Xu 已提交
481 482 483
            self._print(
                2,
                0,
X
Xiaoyu Xu 已提交
484 485
                "[ERROR]"
                + self._shallow_repr()
486
                + " compiling plan or initialing graph runtime got error : ",
X
Xiaoyu Xu 已提交
487 488 489 490 491 492
                sys_exc_error_msg(),
            )
            raise

        self._is_compiled = True
        return eager_outputs
493

494
    def _build_graph(self, *args):
495 496
        session = session_ctx.GetDefaultSession()
        assert type(session) is MultiClientSession
497 498 499

        # Get config form GraphConfig
        self._outputs_buffer_size = self.config._outputs_buffer_size
500
        self._generate_config_proto()
501

502
        with graph_build_util.graph_build_context(self.config.proto, session):
X
Xiaoyu Xu 已提交
503
            # Deal with inputs
504
            arg_op_names, lazy_args, self._args_repr, _ = self._build_io(
X
Xiaoyu Xu 已提交
505 506
                "input", graph_build_util.build_graph_input_arg, *args
            )
507 508

            # Deal with parameter and buffer
X
Xiaoyu Xu 已提交
509
            state_op_names, self._states_tensor_tuple = self._build_states()
510 511

            # Deal with module in self.build(*args)
512
            outputs = self.build(*lazy_args)
513

X
Xiaoyu Xu 已提交
514 515 516 517 518 519
            # Deal with outputs
            if not (type(outputs) is tuple or type(outputs) is list):
                if outputs is None:
                    outputs = ()
                else:
                    outputs = (outputs,)
520 521 522 523 524 525 526 527

            (
                output_op_names,
                self._eager_outputs,
                self._outs_repr,
                out2name,
            ) = self._build_io("output", graph_build_util.build_graph_output, *outputs)

528 529 530
            # Save forward graph job proto
            self._forward_job_proto = c_api_util.GetCurrentJob()
            # Complete the graph job proto
531
            oneflow._oneflow_internal.CurJobBuildAndInferCtx_Complete()
532 533
            # Save full graph job proto after job Complete for find real output blob shape and build it.
            self._full_job_proto = c_api_util.GetCurrentJob()
534

535 536
            # Re-build outputs accoring to full graph and outputs buffer config.
            self._rebuild_outputs(out2name)
537

538
            # Register input/output/variable/buffer to _c_nn_graph
539 540 541 542 543 544
            self._c_nn_graph.register_input_op_names_and_tensors(
                arg_op_names, convert_to_tensor_tuple(self._flatten_io("input", *args))
            )
            self._c_nn_graph.register_output_op_names_and_tensors(
                output_op_names, self._outputs_tensor_tuple
            )
545
            self._c_nn_graph.register_variable_op_names_and_tensors(
X
Xiaoyu Xu 已提交
546
                state_op_names, self._states_tensor_tuple
547
            )
548

549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579
        return seq_to_func_return(self._eager_outputs_buffer[0])

    def _rebuild_outputs(self, out2name=None):
        # NOTE(chengcheng):
        #   Lazy build output eager tensors.
        #
        #   After JobBuildAndInferCtxt.Complete, the output tensor shape
        #   could be changed by JobPass, such as GradientAccumulationRewritePass.
        def build_real_output(fake_eager_out):
            lbn = out2name[fake_eager_out] + "/out"
            assert lbn in self._full_job_proto.helper.lbn2logical_blob_desc
            blob_conf = self._full_job_proto.helper.lbn2logical_blob_desc[lbn]

            shape = tuple(blob_conf.shape.dim)
            dtype = fake_eager_out.dtype

            with oneflow._oneflow_internal.lazy_mode.guard(False):
                if fake_eager_out.is_consistent:
                    eager_out = oneflow.empty(
                        shape,
                        dtype=dtype,
                        placement=fake_eager_out.placement,
                        sbp=fake_eager_out.sbp,
                    )
                else:
                    eager_out = oneflow.empty(
                        shape, dtype=dtype, device=fake_eager_out.device
                    )

            return eager_out

L
Li Xinqi 已提交
580 581 582 583 584 585 586 587
        def convert_to_synced_tensor_tuple(*args):
            tensor_tuple = convert_to_tensor_tuple(*args)
            # tensors acting as buffer should be synced once upon created.
            oneflow._oneflow_internal.nn.graph.SoftSyncNNGraphBuffers(
                tensor_tuple, self._c_nn_graph
            )
            return tensor_tuple

588 589 590 591
        self._eager_outputs = self._mapping_io(
            "output", build_real_output, *self._eager_outputs
        )

L
Li Xinqi 已提交
592
        self._outputs_tensor_tuple = convert_to_synced_tensor_tuple(
593 594 595 596 597 598 599 600 601 602
            self._flatten_io("output", *self._eager_outputs)
        )
        self._eager_outputs_buffer = [
            self._eager_outputs,
        ]
        self._outputs_tensor_tuple_buffer = [
            self._outputs_tensor_tuple,
        ]

        # Make outputs buffer
L
Li Xinqi 已提交
603 604 605
        for i in range(self._outputs_buffer_size - 1):
            outputs_buffer_item = self._empty_like_io("output", *self._eager_outputs)
            self._eager_outputs_buffer.append(outputs_buffer_item)
L
Li Xinqi 已提交
606
            outputs_tensor_tuple_buffer_item = convert_to_synced_tensor_tuple(
L
Li Xinqi 已提交
607 608 609
                self._flatten_io("output", *outputs_buffer_item)
            )
            self._outputs_tensor_tuple_buffer.append(outputs_tensor_tuple_buffer_item)
610
        self._check_outputs_buffer()
611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632

    def _check_outputs_buffer(self):
        has_len = len(self._outputs_tensor_tuple_buffer)
        assert (
            has_len == self._outputs_buffer_size
        ), f"nn.Graph's outputs buffer size {has_len} donot match the set value {self._outputs_buffer_size}."
        # Check there is not duplicated outputs buffer tensor.
        out_id_dic = dict()

        def check_id_and_add(t, name):
            if t is not None:
                tid = id(t)
                assert (
                    tid not in out_id_dic
                ), f"nn.Graph's outputs buffer add buffer tensor tid {tid} has conflict, new item name {name}, old item name {out_id_dic[tid]}."
                out_id_dic[tid] = name

        for b_idx, buffer in enumerate(self._outputs_tensor_tuple_buffer):
            for i_idx, item in enumerate(buffer):
                check_id_and_add(
                    item, "graph_ouputs_buffer_" + str(b_idx) + "_" + str(i_idx)
                )
633

X
Xiaoyu Xu 已提交
634 635 636
    def _run(self, *args):
        try:
            flattened_eager_args = self._flatten_io("input", *args)
637 638 639 640
            outputs_tensor_tuple = self._outputs_tensor_tuple_buffer[
                self._cur_index_of_ouputs_buffer
            ]
            eager_outputs = self._eager_outputs_buffer[self._cur_index_of_ouputs_buffer]
L
Li Xinqi 已提交
641

X
Xiaoyu Xu 已提交
642 643 644
            # oneflow._oneflow_internal.eager.multi_client.Sync() NOTE(chengcheng): Need Sync?
            oneflow._oneflow_internal.nn.graph.RunLazyNNGraph(
                convert_to_tensor_tuple(flattened_eager_args),
645
                outputs_tensor_tuple,
X
Xiaoyu Xu 已提交
646 647 648
                self._states_tensor_tuple,
                self._c_nn_graph,
            )
649 650 651 652
            # Update outputs buffer reading index
            self._cur_index_of_ouputs_buffer += 1
            if self._cur_index_of_ouputs_buffer >= self._outputs_buffer_size:
                self._cur_index_of_ouputs_buffer = 0
X
Xiaoyu Xu 已提交
653
        except:
X
Xiaoyu Xu 已提交
654 655 656
            self._print(
                2,
                0,
X
Xiaoyu Xu 已提交
657 658 659
                "[ERROR]"
                + self._shallow_repr()
                + " run got error : "
X
Xiaoyu Xu 已提交
660
                + sys_exc_error_msg(),
X
Xiaoyu Xu 已提交
661 662
            )
            raise
663 664 665

        # Copy outputs from buffer
        eager_outputs = self._copy_io("output", *eager_outputs)
L
Li Xinqi 已提交
666 667 668 669 670 671 672 673

        # Make sure that last used devices of tensors in `outputs_tensor_tuple` are
        # "critical_section".
        # NNGraph's execution flow will be broken if `last_used_device` of `outputs_tensor_tuple`
        # are not "critical_section".
        oneflow._oneflow_internal.nn.graph.SoftSyncNNGraphBuffers(
            outputs_tensor_tuple, self._c_nn_graph
        )
674
        return seq_to_func_return(eager_outputs)
675

X
Xiaoyu Xu 已提交
676 677 678 679 680 681
    def _build_io(self, io_type, build_func, *args):
        assert io_type in ("input", "output")
        io_type_upper = io_type.upper()
        build_args = []
        op_names = []
        args_repr = []
682
        tensor2op_name = {}
X
Xiaoyu Xu 已提交
683 684 685 686 687 688

        def build_tensor_or_none(tensor, name, repr_str):
            assert tensor is None or (isinstance(tensor, Tensor))
            if isinstance(tensor, Tensor):
                build_arg = build_func(name, tensor)
                op_names.append(name)
689
                tensor2op_name[build_arg] = name
X
Xiaoyu Xu 已提交
690 691 692 693
            else:
                build_arg = None

            args_repr.append(repr_str)
X
Xiaoyu Xu 已提交
694
            self._print(0, 1, repr_str)
X
Xiaoyu Xu 已提交
695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721
            return build_arg

        for idx, arg in enumerate(args):
            if isinstance(arg, Tensor) or arg is None:
                if arg is None:
                    name, repr_str = self._io_item_check_and_gen(
                        arg, None, io_type, idx
                    )
                else:
                    name, repr_str = self._io_item_check_and_gen(
                        arg, Tensor, io_type, idx
                    )
                build_args.append(build_tensor_or_none(arg, name, repr_str))
            elif isinstance(arg, (TensorTuple, list)):
                if isinstance(arg, TensorTuple):
                    seq_args = TensorTuple()
                else:
                    seq_args = list()
                for i in range(len(arg)):
                    name, repr_str = self._io_item_check_and_gen(
                        arg[i], Tensor, io_type, idx, i
                    )
                    seq_args.append(build_tensor_or_none(arg[i], name, repr_str))
                build_args.append(seq_args)
            else:
                self._io_item_check_and_gen(arg, Tensor, io_type, idx)

722
        return op_names, build_args, args_repr, tensor2op_name
X
Xiaoyu Xu 已提交
723

724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752
    def _mapping_io(self, io_type, func, *args):
        assert io_type in ("input", "output")
        io_type_upper = io_type.upper()
        mapped_args = []

        def mapping_tensor_or_none(tensor):
            assert tensor is None or (isinstance(tensor, Tensor))
            if isinstance(tensor, Tensor):
                mapped_arg = func(tensor)
            else:
                mapped_arg = None
            return mapped_arg

        for idx, arg in enumerate(args):
            if isinstance(arg, Tensor) or arg is None:
                mapped_args.append(mapping_tensor_or_none(arg))
            elif isinstance(arg, (TensorTuple, list)):
                if isinstance(arg, TensorTuple):
                    seq_args = TensorTuple()
                else:
                    seq_args = list()
                for i in range(len(arg)):
                    seq_args.append(mapping_tensor_or_none(arg[i]))
                mapped_args.append(seq_args)
            else:
                self._io_item_check(arg, None, io_type, idx)

        return mapped_args

X
Xiaoyu Xu 已提交
753 754 755 756 757
    def _empty_like_io(self, io_type, *args):
        def func(t):
            shape = t.shape
            dtype = t.dtype

758
            with oneflow._oneflow_internal.lazy_mode.guard(False):
X
Xiaoyu Xu 已提交
759 760 761 762 763 764 765 766
                if t.is_consistent:
                    eager_out = oneflow.empty(
                        shape, dtype=dtype, placement=t.placement, sbp=t.sbp,
                    )
                else:
                    eager_out = oneflow.empty(shape, dtype=dtype, device=t.device)

            return eager_out
767 768 769 770 771 772 773 774 775 776 777

        return self._mapping_io(io_type, func, *args)

    def _copy_io(self, io_type, *args):
        def func(tensor):
            with oneflow._oneflow_internal.lazy_mode.guard(False):
                build_arg = tensor.to(copy=True)
                return build_arg

        return self._mapping_io(io_type, func, *args)

X
Xiaoyu Xu 已提交
778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810
    def _flatten_io(self, io_type, *args):
        assert isinstance(args, tuple)
        flattened_args = []
        for idx, arg in enumerate(args):
            if isinstance(arg, Tensor):
                flattened_args.append(arg)
            elif isinstance(arg, (TensorTuple, list)):
                for i in range(len(arg)):
                    self._io_item_check(arg[i], Tensor, io_type, idx, i)
                    flattened_args.append(arg[i])
            else:
                self._io_item_check(arg, None, io_type, idx)
        return flattened_args

    def _io_item_check(self, item, expect_type, io_type, idx, second_idx=None):
        if expect_type is None and item is None:
            return
        elif expect_type is not None and isinstance(item, expect_type):
            return
        else:
            assert io_type in ("input", "output")
            name = (
                "_"
                + self.name
                + "-"
                + io_type
                + "_"
                + str(idx)
                + ("" if second_idx is None else "_" + str(second_idx))
            )
            repr_str = (
                "[ERROR](" + io_type.upper() + ":" + name + ":" + str(type(item)) + ")"
            )
X
Xiaoyu Xu 已提交
811
            self._print(2, 0, repr_str)
X
Xiaoyu Xu 已提交
812
            raise NotImplementedError(
813
                "nn.Graph.build()'s input/output only support types: Tensor/list(Tensor)/None."
X
Xiaoyu Xu 已提交
814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857
            )

    def _io_item_check_and_gen(self, item, expect_type, io_type, idx, second_idx=None):
        assert io_type in ("input", "output")
        name = (
            "_"
            + self.name
            + "-"
            + io_type
            + "_"
            + str(idx)
            + ("" if second_idx is None else "_" + str(second_idx))
        )
        if expect_type is None and item is None:
            repr_str = (
                "[WARNING]("
                + io_type.upper()
                + ":"
                + name
                + ":"
                + str(type(item))
                + ")"
            )
            return name, repr_str
        elif expect_type is not None and isinstance(item, expect_type):
            if isinstance(item, Tensor):
                repr_str = (
                    "(" + io_type.upper() + ":" + name + ":" + item._meta_repr() + ")"
                )
            else:
                repr_str = (
                    "[WARNING]("
                    + io_type.upper()
                    + ":"
                    + name
                    + ":"
                    + str(type(item))
                    + ")"
                )
            return name, repr_str
        else:
            repr_str = (
                "[ERROR](" + io_type.upper() + ":" + name + ":" + str(type(item)) + ")"
            )
X
Xiaoyu Xu 已提交
858
            self._print(2, 0, repr_str)
X
Xiaoyu Xu 已提交
859
            raise NotImplementedError(
860
                "nn.Graph.build()'s input/output only support types: Tensor/list(Tensor)/None."
X
Xiaoyu Xu 已提交
861 862 863 864 865 866 867 868 869 870
            )

    def _build_states(self):
        state_op_names = []
        state_tensors = []
        for state_block in self._state():
            op_name = state_block.name_prefix + state_block.name
            state_tensor = state_block.origin
            state_op_names.append(op_name)
            state_tensors.append(state_tensor)
X
Xiaoyu Xu 已提交
871 872 873 874
            if (
                state_block.type == BlockType.PARAMETER
                and state_block.origin in self._variables_conf
            ):
X
Xiaoyu Xu 已提交
875 876 877 878 879 880 881 882 883 884 885 886 887
                state_config = self._variables_conf[state_block.origin]
            else:
                state_config = None
            state_block.set_lazy_origin_builder(
                partial(
                    graph_build_util.build_graph_state,
                    op_name,
                    state_tensor,
                    state_config,
                )
            )
        state_tensor_tuple = convert_to_tensor_tuple(state_tensors)
        return state_op_names, state_tensor_tuple
888 889

    def _add_block(self, name: str, module: Module = None) -> None:
X
Xiaoyu Xu 已提交
890 891
        r"""Adds module to the graph as a block so that the module will
        be called in nn.Graph.build.
892 893

        Args:
X
Xiaoyu Xu 已提交
894
            name (str): name of the child block. The child block can be accessed from this graph using the given name.
895
            module (Module): child module to be added to the graph.
896

X
Xiaoyu Xu 已提交
897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914
        Just assign nn.Module in nn.Graph, _add_block will be called to add the
        module as a Block:

        .. code-block:: python

            >>> import oneflow as flow
            >>> class LinearGraph(flow.nn.Graph):
            ...     def __init__(self):
            ...         super().__init__()
            ...         # add a nn.Module as a block to graph.
            ...         self.linear = flow.nn.Linear(3, 8, False)
            ...     def build(self, x):
            ...         # call the nn.Module block.
            ...         return self.linear(x)


        The block can be accessed as an attribute using the given name.
            >>> g = LinearGraph()
915
            >>> print(repr(g.linear))
X
Xiaoyu Xu 已提交
916 917 918
            (MODULE:linear:Linear(in_features=3, out_features=8, bias=False)): (
              (PARAMETER:linear.weight:tensor(..., size=(8, 3), dtype=oneflow.float32, requires_grad=True)): ()
            )
919
        """
X
Xiaoyu Xu 已提交
920 921 922 923 924 925
        if "_name" not in self.__dict__:
            raise AttributeError(
                "Base class nn.Graph has not been initialized, "
                "please call super().__init__() in subclass of nn.Graph "
                "before assigning any attribute."
            )
926 927 928 929 930 931 932 933 934 935
        if not isinstance(module, Module) and module is not None:
            raise TypeError("{} is not a Module subclass".format(type(module)))
        elif not isinstance(name, str):
            raise TypeError("module name should be a string. Got {}".format(type(name)))
        elif hasattr(self, name) and name not in self._blocks:
            raise KeyError("attribute '{}' already exists".format(name))
        elif "." in name:
            raise KeyError('module name can\'t contain ".", got: {}'.format(name))
        elif name == "":
            raise KeyError('module name can\'t be empty string ""')
X
Xiaoyu Xu 已提交
936 937

        self._blocks[name] = get_block_cls(module)("", name, module)
938 939 940 941 942 943

    def __setattr__(self, name: str, value=None):
        if isinstance(value, Module):
            self._add_block(name, value)
        elif isinstance(value, Optimizer):
            raise AttributeError(
X
Xiaoyu Xu 已提交
944 945 946 947 948 949 950 951 952
                "'{}' nn.Graph is not allowed to set Optimizer attribute named '{}'. "
                "Please use add_optimizer(...) instead.".format(
                    type(self).__name__, name
                )
            )
        elif isinstance(value, Tensor):
            raise AttributeError(
                "'{}' nn.Graph is not allowed to set Tensor attribute named '{}'. "
                "Please use nn.Module to hold the tensor, then add the nn.Module to nn.Graph.".format(
953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968
                    type(self).__name__, name
                )
            )
        else:
            object.__setattr__(self, name, value)

    def __getattr__(self, name: str):
        if "_blocks" in self.__dict__:
            if name in self._blocks:
                return self._blocks[name]
        if name in self.__dict__:
            return self.__dict__[name]
        raise AttributeError(
            "'{}' object has no attribute '{}'".format(type(self).__name__, name)
        )

969

X
Xiaoyu Xu 已提交
970 971
if __name__ == "__main__":
    import doctest
972

X
Xiaoyu Xu 已提交
973
    doctest.testmod(raise_on_error=True)