executor.py 1.9 KB
Newer Older
Y
Yu Yang 已提交
1 2 3
import paddle.v2.framework.core as core
from paddle.v2.framework.framework import Block, Program

Y
Yu Yang 已提交
4 5
g_scope = core.Scope()

Y
Yu Yang 已提交
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24

class Executor(object):
    def __init__(self, places):
        if not isinstance(places, list) and not isinstance(places, tuple):
            places = [places]

        act_places = []
        for each in places:
            p = core.Place()
            p.set_place(each)
            act_places.append(p)

        self.executor = core.Executor(act_places)

    def run(self,
            program,
            feed,
            fetch_list,
            feed_var_name='feed',
Y
Yu Yang 已提交
25 26
            fetch_var_name='fetch',
            scope=None):
Y
Yu Yang 已提交
27 28 29
        if not isinstance(program, Program):
            raise TypeError()

Y
Yu Yang 已提交
30 31 32
        if scope is None:
            scope = g_scope

Y
Yu Yang 已提交
33 34 35 36 37 38 39 40 41 42 43 44 45 46
        program = program.clone()
        global_block = program.global_block()
        feed_var = global_block.create_var(
            name=feed_var_name,
            type=core.VarDesc.VarType.FEED_MINIBATCH,
            persistable=True)

        for i, name in enumerate(feed):
            out = global_block.var(name)
            global_block.prepend_op(
                'feed',
                inputs={'X': [feed_var]},
                outputs={'Out': [out]},
                attrs={'col': i})
Y
Yu Yang 已提交
47
            core.set_feed_variable(scope, feed[name], feed_var.name, i)
Y
Yu Yang 已提交
48 49 50 51 52 53 54 55 56 57 58 59

        fetch_var = global_block.create_var(
            name=fetch_var_name,
            type=core.VarDesc.VarType.FETCH_LIST,
            persistable=True)
        for i, var in enumerate(fetch_list):
            global_block.append_op(
                type='fetch',
                inputs={'X': [var]},
                outputs={'Out': [fetch_var]},
                attrs={'col': i})

Y
Yu Yang 已提交
60
        self.executor.run(program.desc, scope, 0)
Y
Yu Yang 已提交
61
        return [
Y
Yu Yang 已提交
62
            core.get_fetch_variable(scope, fetch_var_name, i)
Y
Yu Yang 已提交
63 64
            for i in xrange(len(fetch_list))
        ]