parallel_executor.py 7.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import core
import multiprocessing
import framework
import executor

__all__ = ['ParallelExecutor']


class ParallelExecutor(object):
X
Xin Pan 已提交
24 25
    def __init__(self,
                 use_cuda,
26 27
                 loss_name=None,
                 main_program=None,
X
Xin Pan 已提交
28
                 num_threads=None,
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
                 allow_op_delay=False,
                 share_vars_from=None):
        """
        ParallelExecutor can run program in parallel.

        Args:
            use_cuda(bool): Whether to use CUDA or not.
            loss_name(str, default None): The loss name must set in training.
            main_program(Program, default None): The program that need to run,
                if not provided, then default_main_program will be used.
            num_threads(int, default None): How many threads are used for
                training.
            allow_op_delay(bool, default False): Whether to delay and buffer
                some operators together for scheduling or not, which may
                improve performance in some cases, defalut False.
            share_vars_from(ParallelExecutor, default None): If provied,
                it will share variables from the specified ParallelExecutor.

        Returns:
            A ParallelExecutor object.

        Raises:
            TypeError: If share_vars_from is provided, but not ParallelExecutor
                object.

        Examples:
            .. code-block:: python

              train_exe = fluid.ParallelExecutor(
                  use_cuda=True, loss_name=loss.name)
              test_exe = fluid.ParallelExecutor(
                  use_cuda=True,
                  main_program=test_program,
                  share_vars_from=train_exe)

              train_loss, = train_exe.run([loss.name], feed_dict=feed_dict)
              test_loss, = test_exe.run([loss.name], feed_dict=feed_dict)
        """

X
Xin Pan 已提交
68 69
        self._places = []
        self._act_places = []
70 71 72
        if use_cuda:
            for i in xrange(core.get_cuda_device_count()):
                p = core.Place()
X
Xin Pan 已提交
73 74 75
                self._act_places.append(core.CUDAPlace(i))
                p.set_place(self._act_places[-1])
                self._places.append(p)
76 77 78
        else:
            for i in xrange(multiprocessing.cpu_count()):
                p = core.Place()
X
Xin Pan 已提交
79 80 81 82
                self._act_places.append(core.CPUPlace(i))
                p.set_place(self._act_places[-1])
                self._places.append(p)
        assert self._places, "no place for execution"
83 84

        if num_threads is None:
X
Xin Pan 已提交
85 86 87
            if use_cuda:
                # Experiments on se-resnext shows that too many threads hurt
                # performance. Worth tunning for other models in the future.
X
Xin Pan 已提交
88
                num_threads = len(self._places)
X
Xin Pan 已提交
89
            else:
90 91
                num_threads = min(
                    len(self._places) * 2, multiprocessing.cpu_count())
92

93 94
        main = main_program
        main = main if main else framework.default_main_program()
95 96
        scope = executor.global_scope()

97 98 99 100 101 102
        if share_vars_from and not isinstance(share_vars_from,
                                              ParallelExecutor):
            raise TypeError("share_vars_from must be ParallelExecutor.")
        local_scopes = share_vars_from.executor.local_scopes(
        ) if share_vars_from else []

T
typhoonzero 已提交
103
        self.persistable_vars = [
104
            v.name
T
typhoonzero 已提交
105 106 107
            for v in filter(lambda var: \
                var.persistable and var.type != core.VarDesc.VarType.RAW,
                main.list_vars())
108 109
        ]

110 111 112
        self.executor = core.ParallelExecutor(
            num_threads,
            True if use_cuda else False,  # use_event
X
Xin Pan 已提交
113
            self._places,
114 115 116 117
            set([
                p.name for p in main.global_block().iter_parameters()
                if not p.stop_gradient
            ]),
T
typhoonzero 已提交
118
            set(self.persistable_vars),
119
            main.desc,
120
            loss_name if loss_name else '',
X
Xin Pan 已提交
121
            scope,
122
            local_scopes,
X
Xin Pan 已提交
123
            allow_op_delay)
124 125
        self.scope = scope

Y
Yu Yang 已提交
126
    def run(self, fetch_list, feed=None, feed_dict=None):
X
Xin Pan 已提交
127
        """
X
Xin Pan 已提交
128

Y
Yu Yang 已提交
129 130
        Args:
            fetch_list(list): The fetched variable names
Y
Yu Yang 已提交
131 132 133 134
            feed(list|dict|None): The feed variables. If the feed is a dict,
                tensors in that dict will be splitted into each devices. If
                the feed is a list, each element of the list will be copied
                to each device.
Y
Yu Yang 已提交
135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
            feed_dict: Alias for feed parameter, for backward compatibility.

        Returns: fetched result list.

        """
        if feed is None:
            feed = feed_dict

        if isinstance(feed, dict):
            feed_tensor_dict = dict()
            for feed_name in feed:
                feed_tensor = feed[feed_name]
                if not isinstance(feed_tensor, core.LoDTensor):
                    feed_tensor = core.LoDTensor()
                    # always set to CPU place, since the tensor need to be splitted
                    # it is fast in CPU
                    feed_tensor.set(feed[feed_name], core.CPUPlace())
                feed_tensor_dict[feed_name] = feed_tensor

            self.executor.feed_and_split_tensor_into_local_scopes(
                feed_tensor_dict)
        elif isinstance(feed, list) or isinstance(feed, tuple):
            if len(feed) != len(self._act_places):
                raise ValueError(
                    "Feed a list of tensor, the list should be the same size as places"
                )

            res = list()

            for i, each in enumerate(feed):
                if not isinstance(each, dict):
                    raise TypeError(
                        "Each element of feed list should be a dict")
                res_dict = dict()
                for feed_name in each:
                    tensor = each[feed_name]
                    if not isinstance(tensor, core.LoDTensor):
                        tmp = core.LoDTensor()
                        tmp.set(tensor, self._act_places[i])
                        tensor = tmp
                    res_dict[feed_name] = tensor
                res.append(res_dict)
            self.executor.feed_tensors_into_local_scopes(res)
X
Xin Pan 已提交
178

179
        fetch_var_name = '@FETCHED_VAR_NAME@'
Y
Yu Yang 已提交
180
        self.executor.run(fetch_list, fetch_var_name)
181 182
        arr = self.scope.find_var(fetch_var_name).get_lod_tensor_array()
        return [arr[i] for i in range(len(arr))]
T
typhoonzero 已提交
183 184 185

    def bcast_params(self):
        self.executor.bcast_params(set(self.persistable_vars))
Y
Yu Yang 已提交
186 187 188 189

    @property
    def device_count(self):
        return len(self._act_places)