tracer.py 5.3 KB
Newer Older
M
minqiyang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import six

from collections import defaultdict
from paddle.fluid import core
from paddle.fluid import framework
22
from paddle import _C_ops
M
minqiyang 已提交
23 24 25 26


class Tracer(core.Tracer):
    """
27 28
    :api_attr: imperative
    
29 30 31 32 33 34 35
    Tracer is used to execute and record the operators executed, to construct the 
    computation graph in dygraph model. Tracer has two mode, :code:`train_mode`
    and :code:`eval_mode`. In :code:`train_mode`, Tracer would add backward network 
    automatically and perform AutoGrad by method :code:`loss.backward()`. 
    In :code:`eval_mode`, Tracer would not add backward network.

    This is a low level API, users don't need to use it directly.
M
minqiyang 已提交
36 37
    """

J
Jiabin Yang 已提交
38 39
    def __init__(self):
        super(Tracer, self).__init__()
M
minqiyang 已提交
40

M
minqiyang 已提交
41
        self._train_mode = True
M
minqiyang 已提交
42

Z
zyfncg 已提交
43 44 45 46 47 48 49
    def trace_op(self,
                 type,
                 inputs,
                 outputs,
                 attrs,
                 stop_gradient=False,
                 inplace_map=None):
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
        if framework._in_eager_mode():
            # inputs : {"sum": [tensor], ...}
            # outputs : {"sum": [tensor], ...}

            function_ptr = _C_ops.__dict__[type]

            core_ops_args_info = _C_ops.get_core_ops_args_info()
            core_ops_args_type_info = _C_ops.get_core_ops_args_type_info()
            core_ops_returns_info = _C_ops.get_core_ops_returns_info()

            op_args = core_ops_args_info[type]
            op_args_type = core_ops_args_type_info[type]
            op_returns = core_ops_returns_info[type]

            arg_list = []
            for i in range(len(op_args)):
                arg_name = op_args[i]
                arg_type = op_args_type[i]
                if arg_name in inputs.keys():
                    arg_to_append = inputs[arg_name]
                elif arg_name in outputs.keys():
                    arg_to_append = outputs[arg_name]
                else:
                    if "Num" in arg_name:
                        # Remove "Num" suffix to get out_name
                        out_name = arg_name[:-3]
                        assert out_name in outputs.keys()
                        num_outs = len(outputs[out_name])
                        arg_to_append = num_outs
                    else:
                        arg_to_append = None

                if arg_to_append is None:
                    arg_list.append(arg_to_append)
                elif arg_type == "tensor":
                    if isinstance(arg_to_append, list):
                        arg_list.append(arg_to_append[0])
                    else:
                        arg_list.append(arg_to_append)
                elif arg_type == "list":
                    assert isinstance(arg_to_append, list)
                    arg_list.append(arg_to_append)
                else:
                    assert arg_type == "int"
                    assert isinstance(arg_to_append, int)
                    arg_list.append(arg_to_append)

            attrs_list = []
            for k, v in attrs.items():
                attrs_list.append(k)
                attrs_list.append(v)
            returns = function_ptr(*arg_list, *attrs_list)

            if isinstance(returns, tuple):
                for i in range(len(op_returns)):
                    retname = op_returns[i]
                    if retname in outputs.keys():
                        # Replaced outputs by function returns
                        if isinstance(returns[i], list):
                            for j in range(len(returns[i])):
                                outputs[retname][j].reconstruct_from_(
                                    returns[i][j], False)
                        else:
                            outputs[retname][0].reconstruct_from_(returns[i],
                                                                  False)
            elif isinstance(returns, list):
                assert len(outputs.keys()) == 1
                key = list(outputs.keys())[0]
                for j in range(len(returns)):
                    outputs[key][j].reconstruct_from_(returns[j], False)
            else:
                assert len(outputs.keys()) == 1
                key = list(outputs.keys())[0]
                if isinstance(outputs[key], list):
                    outputs[key][0].reconstruct_from_(returns, False)
                else:
                    outputs[key].reconstruct_from_(returns, False)
        else:
            self.trace(type, inputs, outputs, attrs,
                       framework._current_expected_place(), self._has_grad and
                       not stop_gradient, inplace_map if inplace_map else {})
M
minqiyang 已提交
131

M
minqiyang 已提交
132
    def train_mode(self):
M
minqiyang 已提交
133 134
        self._train_mode = True

M
minqiyang 已提交
135
    def eval_mode(self):
M
minqiyang 已提交
136
        self._train_mode = False