executor.py 2.1 KB
Newer Older
Y
Yu Yang 已提交
1
import paddle.v2.framework.core as core
Y
Yu Yang 已提交
2
from paddle.v2.framework.framework import Block, Program, g_main_program
Y
Yu Yang 已提交
3

Y
Yu Yang 已提交
4 5
g_scope = core.Scope()

Y
Yu Yang 已提交
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20

class Executor(object):
    def __init__(self, places):
        if not isinstance(places, list) and not isinstance(places, tuple):
            places = [places]

        act_places = []
        for each in places:
            p = core.Place()
            p.set_place(each)
            act_places.append(p)

        self.executor = core.Executor(act_places)

    def run(self,
Y
Yu Yang 已提交
21
            program=None,
22 23
            feed=None,
            fetch_list=None,
Y
Yu Yang 已提交
24
            feed_var_name='feed',
Y
Yu Yang 已提交
25 26
            fetch_var_name='fetch',
            scope=None):
27 28 29 30 31
        if feed is None:
            feed = {}
        if fetch_list is None:
            fetch_list = []

Y
Yu Yang 已提交
32 33 34
        if program is None:
            program = g_main_program

Y
Yu Yang 已提交
35 36 37
        if not isinstance(program, Program):
            raise TypeError()

Y
Yu Yang 已提交
38 39 40
        if scope is None:
            scope = g_scope

Y
Yu Yang 已提交
41 42 43 44 45 46 47 48 49 50 51 52 53 54
        program = program.clone()
        global_block = program.global_block()
        feed_var = global_block.create_var(
            name=feed_var_name,
            type=core.VarDesc.VarType.FEED_MINIBATCH,
            persistable=True)

        for i, name in enumerate(feed):
            out = global_block.var(name)
            global_block.prepend_op(
                'feed',
                inputs={'X': [feed_var]},
                outputs={'Out': [out]},
                attrs={'col': i})
Y
Yu Yang 已提交
55
            core.set_feed_variable(scope, feed[name], feed_var.name, i)
Y
Yu Yang 已提交
56 57 58 59 60 61 62 63 64 65 66 67

        fetch_var = global_block.create_var(
            name=fetch_var_name,
            type=core.VarDesc.VarType.FETCH_LIST,
            persistable=True)
        for i, var in enumerate(fetch_list):
            global_block.append_op(
                type='fetch',
                inputs={'X': [var]},
                outputs={'Out': [fetch_var]},
                attrs={'col': i})

Y
Yu Yang 已提交
68
        self.executor.run(program.desc, scope, 0, True)
Y
Yu Yang 已提交
69
        return [
Y
Yu Yang 已提交
70
            core.get_fetch_variable(scope, fetch_var_name, i)
Y
Yu Yang 已提交
71 72
            for i in xrange(len(fetch_list))
        ]