async_executor.py 6.1 KB
Newer Older
W
wangguibao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import numpy as np
import contextlib
import six
from .framework import Program, default_main_program, Variable
from . import core
W
Fix bug  
wangguibao 已提交
22
from .executor import global_scope, Executor
23 24
from paddle.fluid.proto import data_feed_pb2
from google.protobuf import text_format
W
wangguibao 已提交
25

B
barrierye 已提交
26
__all__ = ['MultiSlotDesc', 'AsyncExecutor']
W
wangguibao 已提交
27 28 29

g_scope = core.Scope()

W
wangguibao 已提交
30
class DataFeedDesc(object):
31
    def __init__(self, proto_file):
B
barrierye 已提交
32
        self._proto_desc = data_feed_pb2.DataFeedDesc()
33
        with open(proto_file, 'r') as f:
B
barrierye 已提交
34
            text_format.Parse(f.read(), self._proto_desc)
35

36
    def set_batch_size(self, batch_size):
B
barrierye 已提交
37 38 39 40
        self._proto_desc.batch = batch_size

    def desc(self):
        return text_format.MessageToString(self._proto_desc)
41

B
barrierye 已提交
42 43 44 45 46 47 48
class MultiSlotDesc(DataFeedDesc):
    def __init__(self, proto_file):
        super(MultiSlotDesc, self).__init__(proto_file)
        if self._proto_desc.name != "MultiSlotDataFeed":
            raise ValueError("The DataFeed name in proto is %s, not MultiSlotDataFeed" % self._proto_desc.name)
        self.__name_to_index = {slot.name: i for i, slot in enumerate(self._proto_desc.multi_slot_desc.slots)}
    
49 50
    def set_dense_slots(self, dense_slots_name):
        for name in dense_slots_name:
B
barrierye 已提交
51
            self._proto_desc.multi_slot_desc.slots[self.__name_to_index[name]].dense = True
52

53 54
    def set_use_slots(self, use_slots_name):
        for name in use_slots_name:
B
barrierye 已提交
55
            self._proto_desc.multi_slot_desc.slots[self.__name_to_index[name]].use = True
56

W
wangguibao 已提交
57 58 59 60 61 62 63 64 65 66 67 68

class AsyncExecutor(object):
    """
    An asynchronous Executor in Python

    Args:
        place(core.CPUPlace|core.CUDAPlace(n)): indicate the executor run on which device

    Note: For debugging complicated network in parallel-GPUs, you can test it on the executor.
    They has the exactly same arguments, and expected the same results.
    """

W
wangguibao 已提交
69
    def __init__(self, place=None):
70 71 72 73 74 75 76
        if place is None:
            place = core.CPUPlace()
        if not isinstance(place, core.CPUPlace):
            raise ValueError("AsyncExecutor only supports CPU device")

        p = core.Place()
        p.set_place(place)
W
wangguibao 已提交
77

W
wangguibao 已提交
78 79
        scope = global_scope()
        self.executor = core.AsyncExecutor(scope, p)
W
wangguibao 已提交
80

W
Fix bug  
wangguibao 已提交
81 82 83 84 85 86 87 88 89 90 91 92 93
    def run_startup_program(self, program=None, place=None):
        if program is None:
            program = fluid.default_startup_program()

        if place is None:
            place = core.CPUPlace()

        if not isinstance(place, core.CPUPlace):
            raise ValueError("AsyncExecutor only supports CPU device")

        executor = Executor(place)
        executor.run(program)

W
wangguibao 已提交
94
    def run(self, program, data_feed, filelist, thread_num, fetch):
W
wangguibao 已提交
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
        """
        Run program by this Executor. Feed data by feed map, fetch result by fetch_list.
        Python executor takes a program, add feed operators and fetch operators to this program according
        to feed map and fetch_list. Feed map provides input data for the program. fetch_list provides
        the variables(or names) that user want to get after program run.

        Note: the executor will run all
        operators in the program but not only the operators dependent by the fetch_list

        Args:
            program(Program): the program that need to run, if not provied, then default_main_program will be used.
            feed(dict): feed variable map, e.g. {"image": ImageData, "label": LableData}
            fetch_list(list): a list of variable or variable names that user want to get, run will return them according to this list.
            feed_var_name(str): the name for the input variable of feed Operator.
            fetch_var_name(str): the name for the output variable of fetch Operator.
            scope(Scope): the scope used to run this program, you can switch it to different scope. default is global_scope
            return_numpy(bool): if convert the fetched tensor to numpy
            use_program_cache(bool): set use_program_cache to true if program not changed compare to the last step.

        Returns:

            list(numpy.array): fetch result according to fetch_list.


        Examples:

            >>> data = layers.data(name='X', shape=[1], dtype='float32')
            >>> hidden = layers.fc(input=data, size=10)
            >>> layers.assign(hidden, out)
            >>> loss = layers.mean(out)
            >>> adam = fluid.optimizer.Adam()
            >>> adam.minimize(loss)

            >>> cpu = core.CPUPlace()
            >>> exe = Executor(cpu)
            >>> exe.run(default_startup_program())

            >>> x = numpy.random.random(size=(10, 1)).astype('float32')
            >>> outs = exe.run(
            >>>     feed={'X': x},
            >>>     fetch_list=[loss.name])
        """
W
wangguibao 已提交
137 138 139 140 141 142 143 144 145 146 147 148
        if program is None:
            program = default_main_program()
        program_desc = program.desc

        if data_feed is None:
            raise ValueError('ValueError: data_feed should be provided')

        if filelist is None:
            raise ValueError('ValueError: filelist should be provided')

        if isinstance(filelist, str):
            filelist = [filelist]
W
wangguibao 已提交
149

W
wangguibao 已提交
150 151
        if not isinstance(thread_num, int):
            raise TypeError('TypeError: thread_num should be a positive number')
W
wangguibao 已提交
152

W
wangguibao 已提交
153 154 155 156
        if fetch is not None:
            if isinstance(fetch, Variable):
                fetch = [fetch]
            fetch_var_names = [var.name for var in fetch]
157

158
        evaluation = self.executor.run_from_files(program_desc, data_feed.desc(), filelist, thread_num, fetch_var_names)
159 160
        return evaluation