未验证 提交 a8f85f2c 编写于 作者: T tangwei12 提交者: GitHub

fix bug with compiledProgram (#22495) (#22566)

* add thread barrier for the compiled program
上级 e78858f1
......@@ -168,6 +168,12 @@ FeedFetchList AsyncSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) {
// init once
if (run_futures_.size() == 0 && places_.size() > 1) {
if (strategy_.thread_barrier_) {
#ifdef PADDLE_WITH_DISTRIBUTE
operators::distributed::Communicator::GetInstance()->BarrierTriggerReset(
places_.size());
#endif
}
exception_holder_.Clear();
StartOffPythonTrainLoop();
}
......
......@@ -36,6 +36,7 @@ struct ExecutionStrategy {
ExecutorType type_{kExperimental};
// This debug option.
bool dry_run_{false};
bool thread_barrier_{false};
// only use with async_ssa_graph_executor
// and pyreader with data queue
......
......@@ -16,6 +16,10 @@
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/profiler.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/fluid/operators/distributed/communicator.h"
#endif
namespace paddle {
namespace framework {
namespace details {
......@@ -332,8 +336,16 @@ bool ThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) {
void ThreadedSSAGraphExecutor::ExecutionFinal(
std::vector<OpHandleBase *> *fetch_ops) {
#ifdef PADDLE_WITH_DISTRIBUTE
if (strategy_.thread_barrier_) {
operators::distributed::Communicator::GetInstance()
->BarrierTriggerDecrement();
}
#endif
VLOG(3) << "caught exception " << exception_holder_.Type() << ", rethrow it";
ClearFetchOp(graph_, fetch_ops);
exception_holder_.ReThrow();
}
......
......@@ -1732,6 +1732,14 @@ All parameter, weight, gradient are variables in Paddle.
R"DOC(This config that how many iteration the executor will run when
user call exe.run() in python
)DOC")
.def_property(
"use_thread_barrier",
[](const ExecutionStrategy &self) { return self.thread_barrier_; },
[](ExecutionStrategy &self, bool use_thread_barrier) {
self.thread_barrier_ = use_thread_barrier;
},
R"DOC(This config that the this is distributed training with parameter server
)DOC")
.def_property("_dry_run",
[](const ExecutionStrategy &self) { return self.dry_run_; },
[](ExecutionStrategy &self, bool dry_run) {
......
......@@ -196,8 +196,9 @@ class HalfAsyncStrategy(DistributedStrategy):
super(HalfAsyncStrategy, self).__init__()
self._program_config.sync_mode = False
self._program_config.runtime_split_send_recv = True
self._build_strategy.async_mode = True
self._program_config.half_async = True
self._build_strategy.async_mode = True
self._execute_strategy.use_thread_barrier = True
class GeoStrategy(DistributedStrategy):
......
......@@ -39,7 +39,7 @@ class TestDistCTR2x2(FleetDistRunnerBase):
For test CTR model, using Fleet api
"""
def net(self, batch_size=4, lr=0.01):
def net(self, args, batch_size=4, lr=0.01):
"""
network definition
......@@ -72,6 +72,13 @@ class TestDistCTR2x2(FleetDistRunnerBase):
datas = [dnn_data, lr_data, label]
if args.reader == "pyreader":
self.reader = fluid.io.PyReader(
feed_list=datas,
capacity=64,
iterable=False,
use_double_buffer=False)
# build dnn model
dnn_layer_dims = [128, 128, 64, 32, 1]
dnn_embedding = fluid.layers.embedding(
......
......@@ -102,7 +102,7 @@ class FleetDistRunnerBase(object):
def run_pserver(self, args):
fleet.init(self.build_role(args))
strategy = self.build_strategy(args)
avg_cost = self.net()
avg_cost = self.net(args)
self.build_optimizer(avg_cost, strategy)
fleet.init_server()
......@@ -111,24 +111,18 @@ class FleetDistRunnerBase(object):
def run_dataset_trainer(self, args):
fleet.init(self.build_role(args))
strategy = self.build_strategy(args)
avg_cost = self.net()
avg_cost = self.net(args)
self.build_optimizer(avg_cost, strategy)
out = self.do_dataset_training(fleet)
def run_pyreader_trainer(self, args):
fleet.init(self.build_role(args))
strategy = self.build_strategy(args)
avg_cost = self.net()
self.reader = fluid.io.PyReader(
feed_list=self.feeds,
capacity=64,
iterable=False,
use_double_buffer=False)
avg_cost = self.net(args)
self.build_optimizer(avg_cost, strategy)
out = self.do_pyreader_training(fleet)
def net(self, batch_size=4, lr=0.01):
def net(self, args, batch_size=4, lr=0.01):
raise NotImplementedError(
"get_model should be implemented by child classes.")
......
......@@ -34,7 +34,8 @@ class TestDistMnistSync2x2(TestFleetBase):
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"FLAGS_rpc_deadline": "5000", # 5sec to fail fast
"http_proxy": ""
"http_proxy": "",
"CPU_NUM": "2"
}
required_envs.update(need_envs)
......@@ -65,7 +66,8 @@ class TestDistMnistAsync2x2(TestFleetBase):
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"FLAGS_rpc_deadline": "5000", # 5sec to fail fast
"http_proxy": ""
"http_proxy": "",
"CPU_NUM": "2"
}
required_envs.update(need_envs)
......@@ -129,9 +131,9 @@ class TestDistCtrHalfAsync2x2(TestFleetBase):
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"FLAGS_rpc_deadline": "30000", # 5sec to fail fast
"http_proxy": "",
"FLAGS_communicator_send_queue_size": "1",
"FLAGS_communicator_max_merge_var_num": "1",
"CPU_NUM": "1",
"FLAGS_communicator_send_queue_size": "2",
"FLAGS_communicator_max_merge_var_num": "2",
"CPU_NUM": "2",
"SAVE_MODEL": "0"
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册