未验证 提交 a8f85f2c 编写于 作者: T tangwei12 提交者: GitHub

fix bug with compiledProgram (#22495) (#22566)

* add thread barrier for the compiled program
上级 e78858f1
...@@ -168,6 +168,12 @@ FeedFetchList AsyncSSAGraphExecutor::Run( ...@@ -168,6 +168,12 @@ FeedFetchList AsyncSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) { const std::vector<std::string> &fetch_tensors) {
// init once // init once
if (run_futures_.size() == 0 && places_.size() > 1) { if (run_futures_.size() == 0 && places_.size() > 1) {
if (strategy_.thread_barrier_) {
#ifdef PADDLE_WITH_DISTRIBUTE
operators::distributed::Communicator::GetInstance()->BarrierTriggerReset(
places_.size());
#endif
}
exception_holder_.Clear(); exception_holder_.Clear();
StartOffPythonTrainLoop(); StartOffPythonTrainLoop();
} }
......
...@@ -36,6 +36,7 @@ struct ExecutionStrategy { ...@@ -36,6 +36,7 @@ struct ExecutionStrategy {
ExecutorType type_{kExperimental}; ExecutorType type_{kExperimental};
// This debug option. // This debug option.
bool dry_run_{false}; bool dry_run_{false};
bool thread_barrier_{false};
// only use with async_ssa_graph_executor // only use with async_ssa_graph_executor
// and pyreader with data queue // and pyreader with data queue
......
...@@ -16,6 +16,10 @@ ...@@ -16,6 +16,10 @@
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/fluid/operators/distributed/communicator.h"
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
...@@ -332,8 +336,16 @@ bool ThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) { ...@@ -332,8 +336,16 @@ bool ThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) {
void ThreadedSSAGraphExecutor::ExecutionFinal( void ThreadedSSAGraphExecutor::ExecutionFinal(
std::vector<OpHandleBase *> *fetch_ops) { std::vector<OpHandleBase *> *fetch_ops) {
#ifdef PADDLE_WITH_DISTRIBUTE
if (strategy_.thread_barrier_) {
operators::distributed::Communicator::GetInstance()
->BarrierTriggerDecrement();
}
#endif
VLOG(3) << "caught exception " << exception_holder_.Type() << ", rethrow it"; VLOG(3) << "caught exception " << exception_holder_.Type() << ", rethrow it";
ClearFetchOp(graph_, fetch_ops); ClearFetchOp(graph_, fetch_ops);
exception_holder_.ReThrow(); exception_holder_.ReThrow();
} }
......
...@@ -1732,6 +1732,14 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1732,6 +1732,14 @@ All parameter, weight, gradient are variables in Paddle.
R"DOC(This config that how many iteration the executor will run when R"DOC(This config that how many iteration the executor will run when
user call exe.run() in python user call exe.run() in python
)DOC") )DOC")
.def_property(
"use_thread_barrier",
[](const ExecutionStrategy &self) { return self.thread_barrier_; },
[](ExecutionStrategy &self, bool use_thread_barrier) {
self.thread_barrier_ = use_thread_barrier;
},
R"DOC(This config that the this is distributed training with parameter server
)DOC")
.def_property("_dry_run", .def_property("_dry_run",
[](const ExecutionStrategy &self) { return self.dry_run_; }, [](const ExecutionStrategy &self) { return self.dry_run_; },
[](ExecutionStrategy &self, bool dry_run) { [](ExecutionStrategy &self, bool dry_run) {
......
...@@ -196,8 +196,9 @@ class HalfAsyncStrategy(DistributedStrategy): ...@@ -196,8 +196,9 @@ class HalfAsyncStrategy(DistributedStrategy):
super(HalfAsyncStrategy, self).__init__() super(HalfAsyncStrategy, self).__init__()
self._program_config.sync_mode = False self._program_config.sync_mode = False
self._program_config.runtime_split_send_recv = True self._program_config.runtime_split_send_recv = True
self._build_strategy.async_mode = True
self._program_config.half_async = True self._program_config.half_async = True
self._build_strategy.async_mode = True
self._execute_strategy.use_thread_barrier = True
class GeoStrategy(DistributedStrategy): class GeoStrategy(DistributedStrategy):
......
...@@ -39,7 +39,7 @@ class TestDistCTR2x2(FleetDistRunnerBase): ...@@ -39,7 +39,7 @@ class TestDistCTR2x2(FleetDistRunnerBase):
For test CTR model, using Fleet api For test CTR model, using Fleet api
""" """
def net(self, batch_size=4, lr=0.01): def net(self, args, batch_size=4, lr=0.01):
""" """
network definition network definition
...@@ -72,6 +72,13 @@ class TestDistCTR2x2(FleetDistRunnerBase): ...@@ -72,6 +72,13 @@ class TestDistCTR2x2(FleetDistRunnerBase):
datas = [dnn_data, lr_data, label] datas = [dnn_data, lr_data, label]
if args.reader == "pyreader":
self.reader = fluid.io.PyReader(
feed_list=datas,
capacity=64,
iterable=False,
use_double_buffer=False)
# build dnn model # build dnn model
dnn_layer_dims = [128, 128, 64, 32, 1] dnn_layer_dims = [128, 128, 64, 32, 1]
dnn_embedding = fluid.layers.embedding( dnn_embedding = fluid.layers.embedding(
......
...@@ -102,7 +102,7 @@ class FleetDistRunnerBase(object): ...@@ -102,7 +102,7 @@ class FleetDistRunnerBase(object):
def run_pserver(self, args): def run_pserver(self, args):
fleet.init(self.build_role(args)) fleet.init(self.build_role(args))
strategy = self.build_strategy(args) strategy = self.build_strategy(args)
avg_cost = self.net() avg_cost = self.net(args)
self.build_optimizer(avg_cost, strategy) self.build_optimizer(avg_cost, strategy)
fleet.init_server() fleet.init_server()
...@@ -111,24 +111,18 @@ class FleetDistRunnerBase(object): ...@@ -111,24 +111,18 @@ class FleetDistRunnerBase(object):
def run_dataset_trainer(self, args): def run_dataset_trainer(self, args):
fleet.init(self.build_role(args)) fleet.init(self.build_role(args))
strategy = self.build_strategy(args) strategy = self.build_strategy(args)
avg_cost = self.net() avg_cost = self.net(args)
self.build_optimizer(avg_cost, strategy) self.build_optimizer(avg_cost, strategy)
out = self.do_dataset_training(fleet) out = self.do_dataset_training(fleet)
def run_pyreader_trainer(self, args): def run_pyreader_trainer(self, args):
fleet.init(self.build_role(args)) fleet.init(self.build_role(args))
strategy = self.build_strategy(args) strategy = self.build_strategy(args)
avg_cost = self.net() avg_cost = self.net(args)
self.reader = fluid.io.PyReader(
feed_list=self.feeds,
capacity=64,
iterable=False,
use_double_buffer=False)
self.build_optimizer(avg_cost, strategy) self.build_optimizer(avg_cost, strategy)
out = self.do_pyreader_training(fleet) out = self.do_pyreader_training(fleet)
def net(self, batch_size=4, lr=0.01): def net(self, args, batch_size=4, lr=0.01):
raise NotImplementedError( raise NotImplementedError(
"get_model should be implemented by child classes.") "get_model should be implemented by child classes.")
......
...@@ -34,7 +34,8 @@ class TestDistMnistSync2x2(TestFleetBase): ...@@ -34,7 +34,8 @@ class TestDistMnistSync2x2(TestFleetBase):
"PYTHONPATH": os.getenv("PYTHONPATH", ""), "PYTHONPATH": os.getenv("PYTHONPATH", ""),
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""), "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"FLAGS_rpc_deadline": "5000", # 5sec to fail fast "FLAGS_rpc_deadline": "5000", # 5sec to fail fast
"http_proxy": "" "http_proxy": "",
"CPU_NUM": "2"
} }
required_envs.update(need_envs) required_envs.update(need_envs)
...@@ -65,7 +66,8 @@ class TestDistMnistAsync2x2(TestFleetBase): ...@@ -65,7 +66,8 @@ class TestDistMnistAsync2x2(TestFleetBase):
"PYTHONPATH": os.getenv("PYTHONPATH", ""), "PYTHONPATH": os.getenv("PYTHONPATH", ""),
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""), "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"FLAGS_rpc_deadline": "5000", # 5sec to fail fast "FLAGS_rpc_deadline": "5000", # 5sec to fail fast
"http_proxy": "" "http_proxy": "",
"CPU_NUM": "2"
} }
required_envs.update(need_envs) required_envs.update(need_envs)
...@@ -129,9 +131,9 @@ class TestDistCtrHalfAsync2x2(TestFleetBase): ...@@ -129,9 +131,9 @@ class TestDistCtrHalfAsync2x2(TestFleetBase):
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""), "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"FLAGS_rpc_deadline": "30000", # 5sec to fail fast "FLAGS_rpc_deadline": "30000", # 5sec to fail fast
"http_proxy": "", "http_proxy": "",
"FLAGS_communicator_send_queue_size": "1", "FLAGS_communicator_send_queue_size": "2",
"FLAGS_communicator_max_merge_var_num": "1", "FLAGS_communicator_max_merge_var_num": "2",
"CPU_NUM": "1", "CPU_NUM": "2",
"SAVE_MODEL": "0" "SAVE_MODEL": "0"
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册