executor.py 2.2 KB
Newer Older
W
whs 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
from ....compiler import CompiledProgram
from ....data_feeder import DataFeeder
W
whs 已提交
17
from .... import executor
18
from .graph_wrapper import GraphWrapper
W
whs 已提交
19

20
__all__ = ['SlimGraphExecutor']
W
whs 已提交
21 22


23 24 25 26
class SlimGraphExecutor(object):
    """
    Wrapper of executor used to run GraphWrapper.
    """
W
whs 已提交
27 28 29

    def __init__(self, place):
        self.exe = executor.Executor(place)
30
        self.place = place
W
whs 已提交
31

32 33 34 35 36 37 38 39 40 41 42 43
    def run(self, graph, scope, data=None):
        """
        Runing a graph with a batch of data.
        Args:
            graph(GraphWrapper): The graph to be executed.
            scope(fluid.core.Scope): The scope to be used.
            data(list<tuple>): A batch of data. Each tuple in this list is a sample.
                               It will feed the items of tuple to the in_nodes of graph.
        Returns:
            results(list): A list of result with the same order indicated by graph.out_nodes.
        """
        assert isinstance(graph, GraphWrapper)
W
whs 已提交
44
        feed = None
45 46 47 48
        if data is not None and isinstance(data[0], dict):
            # return list = False
            feed = data
        elif data is not None:
49
            feeder = DataFeeder(
50
                feed_list=list(graph.in_nodes.values()),
51 52 53 54
                place=self.place,
                program=graph.program)
            feed = feeder.feed(data)

55
        fetch_list = list(graph.out_nodes.values())
56 57
        program = graph.compiled_graph if graph.compiled_graph else graph.program
        results = self.exe.run(program,
W
whs 已提交
58 59 60 61
                               scope=scope,
                               fetch_list=fetch_list,
                               feed=feed)
        return results