提交 b123ce88 编写于 作者: X Xin Pan

Add enable/disable for delayed ops

上级 be1373dc
......@@ -23,14 +23,15 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
size_t num_threads, bool use_event,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
std::unique_ptr<SSAGraph> &&graph)
std::unique_ptr<SSAGraph> &&graph, bool allow_op_delay)
: SSAGraphExecutor(std::move(graph)),
pool_(num_threads >= 2 ? new ::ThreadPool(num_threads) : nullptr),
local_scopes_(local_scopes),
places_(places),
fetch_ctxs_(places),
use_event_(use_event),
running_ops_(0) {}
running_ops_(0),
allow_op_delay_(allow_op_delay) {}
void ThreadedSSAGraphExecutor::RunDelayedOps(
const std::unordered_set<OpHandleBase *> &delayed_ops) {
......@@ -119,7 +120,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
auto run_all_ready_ops = [&] {
for (auto *op : ready_ops) {
if (op->IsMultiDeviceTransfer()) {
if (op->IsMultiDeviceTransfer() && allow_op_delay_) {
delayed_ops.insert(op);
delayed_vars.insert(op->outputs_.begin(), op->outputs_.end());
ready_vars.Extend(op->outputs_);
......@@ -138,7 +139,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
}
// Step 3. Execution
while (!pending_vars.empty()) {
while (!pending_vars.empty() || !ready_ops.empty() || !delayed_ops.empty()) {
// 1. Run All Ready ops
run_all_ready_ops();
......@@ -181,6 +182,9 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
}
// Keep loop until all vars are ready.
}
PADDLE_ENFORCE(ready_ops.empty());
PADDLE_ENFORCE(delayed_ops.empty());
PADDLE_ENFORCE(blocked_by_delayed_ops.empty());
++computation_count_;
auto sync_computation = [&] {
......
......@@ -75,7 +75,8 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
ThreadedSSAGraphExecutor(size_t num_threads, bool use_event,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
std::unique_ptr<SSAGraph> &&graph);
std::unique_ptr<SSAGraph> &&graph,
bool allow_op_delay);
// Run a SSAGraph by a thread pool
// Use topological sort algorithm
......@@ -97,6 +98,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
const bool use_event_;
std::unique_ptr<platform::EnforceNotMet> exception_;
std::atomic<int> running_ops_;
bool allow_op_delay_;
size_t computation_count_{0};
size_t max_async_computation{100};
......
......@@ -48,7 +48,7 @@ ParallelExecutor::ParallelExecutor(
const std::vector<platform::Place> &places,
const std::unordered_set<std::string> &params,
const ProgramDesc &startup_program, const ProgramDesc &main_program,
const std::string &loss_var_name, Scope *scope)
const std::string &loss_var_name, Scope *scope, bool allow_op_delay)
: member_(new ParallelExecutorPrivate(places)) {
member_->global_scope_ = scope;
......@@ -83,8 +83,8 @@ ParallelExecutor::ParallelExecutor(
auto graph = builder.Build(main_program);
member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
num_threads, use_event, member_->local_scopes_, places,
std::move(graph)));
num_threads, use_event, member_->local_scopes_, places, std::move(graph),
allow_op_delay));
// Step 3. Create vars in each scope;
for (auto *scope : member_->local_scopes_) {
......
......@@ -14,8 +14,9 @@ limitations under the License. */
#pragma once
#include <future>
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/program_desc.h"
......@@ -37,7 +38,8 @@ class ParallelExecutor {
const std::unordered_set<std::string>& params,
const ProgramDesc& startup_program,
const ProgramDesc& main_program,
const std::string& loss_var_name, Scope* scope);
const std::string& loss_var_name, Scope* scope,
bool allow_op_delay);
void Run(const std::vector<std::string>& fetch_tensors,
const std::string& fetched_var_name = "fetched_var");
......
......@@ -504,10 +504,10 @@ All parameter, weight, gradient are variables in Paddle.
const std::unordered_set<std::string> &params,
const ProgramDesc &startup_program,
const ProgramDesc &main_program, const std::string &loss_var_name,
Scope *scope) {
Scope *scope, bool allow_op_delay) {
new (&self) ParallelExecutor(num_threads, use_event, places,
params, startup_program, main_program,
loss_var_name, scope);
loss_var_name, scope, allow_op_delay);
})
.def("run", &ParallelExecutor::Run);
......
......@@ -21,7 +21,11 @@ __all__ = ['ParallelExecutor']
class ParallelExecutor(object):
def __init__(self, loss_name, use_cuda, num_threads=None):
def __init__(self,
loss_name,
use_cuda,
num_threads=None,
allow_op_delay=False):
places = []
if use_cuda:
for i in xrange(core.get_cuda_device_count()):
......@@ -57,7 +61,8 @@ class ParallelExecutor(object):
startup.desc,
main.desc,
loss_name,
scope)
scope,
allow_op_delay)
self.scope = scope
def run(self, fetch_list):
......
......@@ -184,7 +184,8 @@ class TestParallelExecutorBase(unittest.TestCase):
method,
memory_opt=True,
iter=10,
batch_size=None):
batch_size=None,
allow_op_delay=False):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
......@@ -194,7 +195,10 @@ class TestParallelExecutorBase(unittest.TestCase):
if memory_opt:
fluid.memory_optimize(main)
exe = fluid.ParallelExecutor(loss_name=loss.name, use_cuda=True)
exe = fluid.ParallelExecutor(
loss_name=loss.name,
use_cuda=True,
allow_op_delay=allow_op_delay)
if batch_size is not None:
batch_size *= fluid.core.get_cuda_device_count()
begin = time.time()
......@@ -236,9 +240,11 @@ class TestMNIST(TestParallelExecutorBase):
def test_simple_fc(self):
self.check_network_convergence(simple_fc_net)
self.check_network_convergence(simple_fc_net, allow_op_delay=True)
def test_batchnorm_fc(self):
self.check_network_convergence(fc_with_batchnorm)
self.check_network_convergence(fc_with_batchnorm, allow_op_delay=True)
class TestResnet(TestParallelExecutorBase):
......@@ -268,6 +274,12 @@ class TestResnet(TestParallelExecutorBase):
SE_ResNeXt152, batch_size=batch_size),
iter=20,
batch_size=batch_size)
self.check_network_convergence(
functools.partial(
SE_ResNeXt152, batch_size=batch_size),
iter=20,
batch_size=batch_size,
allow_op_delay=True)
class ModelHyperParams(object):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册