提交 beaae61a 编写于 作者: X Xin Pan

polish

test=develop
上级 5e928e57
......@@ -15,6 +15,7 @@
import multiprocessing
import os
import six
import sys
from .. import compat as cpt
from . import core
......@@ -29,27 +30,50 @@ def _place_obj(place):
return p
class _ProgramCompiler(object):
class CompiledProgram(object):
def __init__(self, program):
self._program = program
self._scope = None
self._place = None
self._executor = None
self._compiled = False
self._is_data_parallel = False
def _with_data_parallel(self,
loss_name=None,
build_strategy=None,
exec_strategy=None):
exec_strategy=None,
share_vars_from=None):
assert not self._is_data_parallel, "Already compiled with parallel."
self._is_data_parallel = True
self._build_strategy = build_strategy
self._exec_strategy = exec_strategy
self._loss_name = loss_name
self._share_vars_from = share_vars_from
return self
def _with_distributed(self):
raise NotImplementedError()
def _with_inference_optimize(self):
raise NotImplementedError()
def _compile_data_parallel(self):
self._places = []
if self._share_vars_from:
if self._scope:
sys.stderr.write("share_vars_from is set, scope is ignored.\n")
if not self._share_vars_from._is_data_parallel:
raise ValueError("share_vars_from is not data parallel. Cannot "
"share vars from it.")
if self._share_vars_from._executor is None:
raise ValueError(
"share_vars_from is not compiled and run, so there is no "
"var to share.")
self._local_scopes = self._share_vars_from._executor.local_scopes()
else:
self._local_scopes = []
self._places = []
if self._exec_strategy is None:
self._exec_strategy = ExecutionStrategy()
if self._build_strategy is None:
......@@ -104,12 +128,14 @@ class _ProgramCompiler(object):
def _compile(self, scope, place):
if self._compiled:
if scope and self._scope != scope:
raise ValueError("Cannot compile with different scope")
if place and self._place != place:
raise ValueError("Cannot compile with different place")
return self
self._compiled = True
self._scope = scope
self._place = place
if self._is_data_parallel:
self._executor = self._compile_data_parallel()
else:
......
......@@ -481,8 +481,10 @@ class Executor(object):
if scope is None:
scope = global_scope()
compiled = isinstance(program, compiler._ProgramCompiler)
compiled = isinstance(program, compiler.CompiledProgram)
# For backward compatibility, run directly.
if not compiled:
if not self.executor:
p = core.Place()
p.set_place(self.place)
self.executor = core.Executor(p)
......
......@@ -81,12 +81,12 @@ class TestParallelExecutorBase(unittest.TestCase):
if use_cuda and core.is_compiled_with_cuda():
build_strategy.remove_unnecessary_lock = True
if use_parallel_executor:
binary = compiler._ProgramCompiler(main)._with_data_parallel(
binary = compiler.CompiledProgram(main)._with_data_parallel(
loss_name=loss.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
else:
binary = compiler._ProgramCompiler(main)
binary = compiler.CompiledProgram(main)
if batch_size is not None:
batch_size *= fluid.core.get_cuda_device_count(
......
......@@ -132,7 +132,7 @@ class TestDistRunnerBase(object):
build_stra.num_trainers = 1
build_stra.trainer_id = 0
binary = compiler._ProgramCompiler(trainer_prog)._with_data_parallel(
binary = compiler.CompiledProgram(trainer_prog)._with_data_parallel(
loss_name=avg_cost.name,
build_strategy=build_stra,
exec_strategy=strategy)
......
......@@ -15,6 +15,7 @@
from __future__ import print_function
import paddle.fluid as fluid
from paddle.fluid import compiler
import paddle.fluid.core as core
import numpy as np
import unittest
......@@ -61,22 +62,22 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase):
exe.run(startup)
feed_dict = {'image': image, 'label': label}
train_exe = fluid.ParallelExecutor(
use_cuda=use_cuda,
train_cp = compiler.CompiledProgram(main)._with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy)
test_cp = compiler.CompiledProgram(
test_program)._with_data_parallel(
loss_name=loss.name,
main_program=main,
build_strategy=build_strategy)
test_exe = fluid.ParallelExecutor(
use_cuda=use_cuda,
main_program=test_program,
share_vars_from=train_exe,
build_strategy=build_strategy)
build_strategy=build_strategy,
share_vars_from=train_cp)
for i in range(5):
test_loss, = test_exe.run([loss.name], feed=feed_dict)
train_loss, = train_exe.run([loss.name], feed=feed_dict)
exe.run(train_cp, feed=feed_dict, fetch_list=[loss.name])
test_loss, = exe.run(test_cp,
feed=feed_dict,
fetch_list=[loss.name])
train_loss, = exe.run(train_cp,
feed=feed_dict,
fetch_list=[loss.name])
avg_test_loss_val = np.array(test_loss).mean()
if math.isnan(float(avg_test_loss_val)):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册