tracer.py 5.2 KB
Newer Older
M
minqiyang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import six

from collections import defaultdict
from paddle.fluid import core
from paddle.fluid import framework
22
from paddle import _C_ops
M
minqiyang 已提交
23 24 25 26


class Tracer(core.Tracer):
    """
27 28
    :api_attr: imperative
    
29 30 31 32 33 34 35
    Tracer is used to execute and record the operators executed, to construct the 
    computation graph in dygraph model. Tracer has two mode, :code:`train_mode`
    and :code:`eval_mode`. In :code:`train_mode`, Tracer would add backward network 
    automatically and perform AutoGrad by method :code:`loss.backward()`. 
    In :code:`eval_mode`, Tracer would not add backward network.

    This is a low level API, users don't need to use it directly.
M
minqiyang 已提交
36 37
    """

J
Jiabin Yang 已提交
38 39
    def __init__(self):
        super(Tracer, self).__init__()
M
minqiyang 已提交
40

M
minqiyang 已提交
41
        self._train_mode = True
M
minqiyang 已提交
42

Z
zyfncg 已提交
43 44 45 46 47 48 49
    def trace_op(self,
                 type,
                 inputs,
                 outputs,
                 attrs,
                 stop_gradient=False,
                 inplace_map=None):
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
        if framework._in_eager_mode():
            # inputs : {"sum": [tensor], ...}
            # outputs : {"sum": [tensor], ...}

            function_ptr = _C_ops.__dict__[type]

            core_ops_args_info = _C_ops.get_core_ops_args_info()
            core_ops_args_type_info = _C_ops.get_core_ops_args_type_info()
            core_ops_returns_info = _C_ops.get_core_ops_returns_info()

            op_args = core_ops_args_info[type]
            op_args_type = core_ops_args_type_info[type]
            op_returns = core_ops_returns_info[type]

            arg_list = []
            for i in range(len(op_args)):
                arg_name = op_args[i]
                arg_type = op_args_type[i]
                if arg_name in inputs.keys():
                    arg_to_append = inputs[arg_name]
                elif arg_name in outputs.keys():
                    arg_to_append = outputs[arg_name]
                else:
                    if "Num" in arg_name:
                        # Remove "Num" suffix to get out_name
                        out_name = arg_name[:-3]
                        assert out_name in outputs.keys()
                        num_outs = len(outputs[out_name])
                        arg_to_append = num_outs
                    else:
                        arg_to_append = None

                if arg_to_append is None:
                    arg_list.append(arg_to_append)
                elif arg_type == "tensor":
                    if isinstance(arg_to_append, list):
                        arg_list.append(arg_to_append[0])
                    else:
                        arg_list.append(arg_to_append)
                elif arg_type == "list":
                    assert isinstance(arg_to_append, list)
                    arg_list.append(arg_to_append)
                else:
                    assert arg_type == "int"
                    assert isinstance(arg_to_append, int)
                    arg_list.append(arg_to_append)

            attrs_list = []
            for k, v in attrs.items():
                attrs_list.append(k)
                attrs_list.append(v)
            returns = function_ptr(*arg_list, *attrs_list)

            if isinstance(returns, tuple):
                for i in range(len(op_returns)):
                    retname = op_returns[i]
                    if retname in outputs.keys():
                        # Replaced outputs by function returns
                        if isinstance(returns[i], list):
                            for j in range(len(returns[i])):
110 111
                                outputs[retname][j].reconstruct_from_(returns[i]
                                                                      [j])
112
                        else:
113
                            outputs[retname][0].reconstruct_from_(returns[i])
114 115 116 117
            elif isinstance(returns, list):
                assert len(outputs.keys()) == 1
                key = list(outputs.keys())[0]
                for j in range(len(returns)):
118
                    outputs[key][j].reconstruct_from_(returns[j])
119 120 121 122
            else:
                assert len(outputs.keys()) == 1
                key = list(outputs.keys())[0]
                if isinstance(outputs[key], list):
123
                    outputs[key][0].reconstruct_from_(returns)
124
                else:
125
                    outputs[key].reconstruct_from_(returns)
126 127 128 129
        else:
            self.trace(type, inputs, outputs, attrs,
                       framework._current_expected_place(), self._has_grad and
                       not stop_gradient, inplace_map if inplace_map else {})
M
minqiyang 已提交
130

M
minqiyang 已提交
131
    def train_mode(self):
M
minqiyang 已提交
132 133
        self._train_mode = True

M
minqiyang 已提交
134
    def eval_mode(self):
M
minqiyang 已提交
135
        self._train_mode = False