提交 beaae61a 编写于 作者: X Xin Pan

polish

test=develop
上级 5e928e57
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import multiprocessing import multiprocessing
import os import os
import six import six
import sys
from .. import compat as cpt from .. import compat as cpt
from . import core from . import core
...@@ -29,27 +30,50 @@ def _place_obj(place): ...@@ -29,27 +30,50 @@ def _place_obj(place):
return p return p
class _ProgramCompiler(object): class CompiledProgram(object):
def __init__(self, program): def __init__(self, program):
self._program = program self._program = program
self._scope = None
self._place = None
self._executor = None
self._compiled = False self._compiled = False
self._is_data_parallel = False self._is_data_parallel = False
def _with_data_parallel(self, def _with_data_parallel(self,
loss_name=None, loss_name=None,
build_strategy=None, build_strategy=None,
exec_strategy=None): exec_strategy=None,
share_vars_from=None):
assert not self._is_data_parallel, "Already compiled with parallel." assert not self._is_data_parallel, "Already compiled with parallel."
self._is_data_parallel = True self._is_data_parallel = True
self._build_strategy = build_strategy self._build_strategy = build_strategy
self._exec_strategy = exec_strategy self._exec_strategy = exec_strategy
self._loss_name = loss_name self._loss_name = loss_name
self._share_vars_from = share_vars_from
return self return self
def _with_distributed(self):
raise NotImplementedError()
def _with_inference_optimize(self):
raise NotImplementedError()
def _compile_data_parallel(self): def _compile_data_parallel(self):
self._places = [] if self._share_vars_from:
if self._scope:
sys.stderr.write("share_vars_from is set, scope is ignored.\n")
if not self._share_vars_from._is_data_parallel:
raise ValueError("share_vars_from is not data parallel. Cannot "
"share vars from it.")
if self._share_vars_from._executor is None:
raise ValueError(
"share_vars_from is not compiled and run, so there is no "
"var to share.")
self._local_scopes = self._share_vars_from._executor.local_scopes()
else:
self._local_scopes = [] self._local_scopes = []
self._places = []
if self._exec_strategy is None: if self._exec_strategy is None:
self._exec_strategy = ExecutionStrategy() self._exec_strategy = ExecutionStrategy()
if self._build_strategy is None: if self._build_strategy is None:
...@@ -104,12 +128,14 @@ class _ProgramCompiler(object): ...@@ -104,12 +128,14 @@ class _ProgramCompiler(object):
def _compile(self, scope, place): def _compile(self, scope, place):
if self._compiled: if self._compiled:
if scope and self._scope != scope:
raise ValueError("Cannot compile with different scope")
if place and self._place != place:
raise ValueError("Cannot compile with different place")
return self return self
self._compiled = True
self._scope = scope self._scope = scope
self._place = place self._place = place
if self._is_data_parallel: if self._is_data_parallel:
self._executor = self._compile_data_parallel() self._executor = self._compile_data_parallel()
else: else:
......
...@@ -481,8 +481,10 @@ class Executor(object): ...@@ -481,8 +481,10 @@ class Executor(object):
if scope is None: if scope is None:
scope = global_scope() scope = global_scope()
compiled = isinstance(program, compiler._ProgramCompiler) compiled = isinstance(program, compiler.CompiledProgram)
# For backward compatibility, run directly.
if not compiled: if not compiled:
if not self.executor:
p = core.Place() p = core.Place()
p.set_place(self.place) p.set_place(self.place)
self.executor = core.Executor(p) self.executor = core.Executor(p)
......
...@@ -81,12 +81,12 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -81,12 +81,12 @@ class TestParallelExecutorBase(unittest.TestCase):
if use_cuda and core.is_compiled_with_cuda(): if use_cuda and core.is_compiled_with_cuda():
build_strategy.remove_unnecessary_lock = True build_strategy.remove_unnecessary_lock = True
if use_parallel_executor: if use_parallel_executor:
binary = compiler._ProgramCompiler(main)._with_data_parallel( binary = compiler.CompiledProgram(main)._with_data_parallel(
loss_name=loss.name, loss_name=loss.name,
build_strategy=build_strategy, build_strategy=build_strategy,
exec_strategy=exec_strategy) exec_strategy=exec_strategy)
else: else:
binary = compiler._ProgramCompiler(main) binary = compiler.CompiledProgram(main)
if batch_size is not None: if batch_size is not None:
batch_size *= fluid.core.get_cuda_device_count( batch_size *= fluid.core.get_cuda_device_count(
......
...@@ -132,7 +132,7 @@ class TestDistRunnerBase(object): ...@@ -132,7 +132,7 @@ class TestDistRunnerBase(object):
build_stra.num_trainers = 1 build_stra.num_trainers = 1
build_stra.trainer_id = 0 build_stra.trainer_id = 0
binary = compiler._ProgramCompiler(trainer_prog)._with_data_parallel( binary = compiler.CompiledProgram(trainer_prog)._with_data_parallel(
loss_name=avg_cost.name, loss_name=avg_cost.name,
build_strategy=build_stra, build_strategy=build_stra,
exec_strategy=strategy) exec_strategy=strategy)
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from __future__ import print_function from __future__ import print_function
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import compiler
import paddle.fluid.core as core import paddle.fluid.core as core
import numpy as np import numpy as np
import unittest import unittest
...@@ -61,22 +62,22 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase): ...@@ -61,22 +62,22 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase):
exe.run(startup) exe.run(startup)
feed_dict = {'image': image, 'label': label} feed_dict = {'image': image, 'label': label}
train_exe = fluid.ParallelExecutor( train_cp = compiler.CompiledProgram(main)._with_data_parallel(
use_cuda=use_cuda, loss_name=loss.name, build_strategy=build_strategy)
test_cp = compiler.CompiledProgram(
test_program)._with_data_parallel(
loss_name=loss.name, loss_name=loss.name,
main_program=main, build_strategy=build_strategy,
build_strategy=build_strategy) share_vars_from=train_cp)
test_exe = fluid.ParallelExecutor(
use_cuda=use_cuda,
main_program=test_program,
share_vars_from=train_exe,
build_strategy=build_strategy)
for i in range(5): for i in range(5):
test_loss, = test_exe.run([loss.name], feed=feed_dict) exe.run(train_cp, feed=feed_dict, fetch_list=[loss.name])
test_loss, = exe.run(test_cp,
train_loss, = train_exe.run([loss.name], feed=feed_dict) feed=feed_dict,
fetch_list=[loss.name])
train_loss, = exe.run(train_cp,
feed=feed_dict,
fetch_list=[loss.name])
avg_test_loss_val = np.array(test_loss).mean() avg_test_loss_val = np.array(test_loss).mean()
if math.isnan(float(avg_test_loss_val)): if math.isnan(float(avg_test_loss_val)):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册