未验证 提交 9c35b0dc 编写于 作者: 武毅 提交者: GitHub

Merge pull request #9287 from typhoonzero/pserver_prepare_before_run

Pserver prepare before run
...@@ -21,14 +21,11 @@ limitations under the License. */ ...@@ -21,14 +21,11 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h" #include "paddle/fluid/operators/detail/grpc_service.h"
#include "paddle/fluid/operators/detail/simple_block_queue.h" #include "paddle/fluid/operators/detail/grpc_service.h"
#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h" #include "paddle/fluid/operators/detail/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h" #include "paddle/fluid/operators/detail/send_recv.pb.h"
#include "paddle/fluid/operators/detail/grpc_service.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace detail {
......
...@@ -93,6 +93,12 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -93,6 +93,12 @@ class ListenAndServOp : public framework::OperatorBase {
"server program should have at least 2 blocks"); "server program should have at least 2 blocks");
framework::Executor executor(dev_place); framework::Executor executor(dev_place);
std::vector<framework::ExecutorPrepareContext *> blk_ctx_list;
blk_ctx_list.push_back(nullptr); // block0 is not used.
for (int blkid = 1; blkid < num_blocks; ++blkid) {
auto *exe_ctx = executor.Prepare(*program, blkid);
blk_ctx_list.push_back(exe_ctx);
}
// TODO(typhoonzero): change this to a while_op for every cluster-batch. // TODO(typhoonzero): change this to a while_op for every cluster-batch.
bool exit_flag = false; bool exit_flag = false;
...@@ -139,26 +145,27 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -139,26 +145,27 @@ class ListenAndServOp : public framework::OperatorBase {
// should be global ops. // should be global ops.
// NOTE: if is_gpu_place, CUDA kernels are laugched by multiple threads // NOTE: if is_gpu_place, CUDA kernels are laugched by multiple threads
// and this will still work. // and this will still work.
std::vector<std::future<void>> fs; std::vector<std::future<void>> fs;
// block0 contains only listen_and_serv op, start run from block1. // block0 contains only listen_and_serv op, start run from block1.
for (int blkid = 1; blkid < num_blocks - 1; ++blkid) { for (int blkid = 1; blkid < num_blocks - 1; ++blkid) {
fs.push_back(framework::Async([&executor, &program, &recv_scope, fs.push_back(framework::Async(
blkid]() { [&executor, &program, &recv_scope, &blk_ctx_list, blkid]() {
int run_block = blkid; // thread local int run_block = blkid; // thread local
try { try {
executor.Run(*program, &recv_scope, run_block, executor.RunPreparedContext(blk_ctx_list[run_block],
false /*create_local_scope*/, false /*create_vars*/); &recv_scope, false, false);
} catch (std::exception &e) { } catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what(); LOG(ERROR) << "run sub program error " << e.what();
} }
})); }));
} }
for (int i = 0; i < num_blocks - 2; ++i) fs[i].wait(); for (int i = 0; i < num_blocks - 2; ++i) fs[i].wait();
// Run global block at final step, or block1 if there are only 2 blocks // Run global block at final step, or block1 if there are only 2 blocks
if (num_blocks >= 2) { if (num_blocks >= 2) {
try { try {
executor.Run(*program, &recv_scope, num_blocks - 1, executor.RunPreparedContext(blk_ctx_list[num_blocks - 1], &recv_scope,
false /*create_local_scope*/, false /*create_vars*/); false, false);
} catch (std::exception &e) { } catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what(); LOG(ERROR) << "run sub program error " << e.what();
} }
...@@ -177,6 +184,10 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -177,6 +184,10 @@ class ListenAndServOp : public framework::OperatorBase {
rpc_service_->WaitClientGet(fan_in); rpc_service_->WaitClientGet(fan_in);
sparse_vars.clear(); sparse_vars.clear();
} // while(true) } // while(true)
for (int i = 0; i < num_blocks; ++i) {
delete blk_ctx_list[i];
}
} }
protected: protected:
......
...@@ -68,7 +68,7 @@ class SendOp : public framework::OperatorBase { ...@@ -68,7 +68,7 @@ class SendOp : public framework::OperatorBase {
for (size_t i = 0; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) { if (NeedSend(scope, ins[i])) {
VLOG(3) << "sending " << ins[i] << " to " << epmap[i]; VLOG(2) << "sending " << ins[i] << " to " << epmap[i];
rpc_client->AsyncSendVariable(epmap[i], ctx, scope, ins[i]); rpc_client->AsyncSendVariable(epmap[i], ctx, scope, ins[i]);
} else { } else {
VLOG(3) << "don't send no-initialied variable: " << ins[i]; VLOG(3) << "don't send no-initialied variable: " << ins[i];
...@@ -77,20 +77,20 @@ class SendOp : public framework::OperatorBase { ...@@ -77,20 +77,20 @@ class SendOp : public framework::OperatorBase {
PADDLE_ENFORCE(rpc_client->Wait()); PADDLE_ENFORCE(rpc_client->Wait());
for (auto& ep : endpoints) { for (auto& ep : endpoints) {
VLOG(3) << "batch barrier, ep: " << ep; VLOG(2) << "batch barrier, ep: " << ep;
rpc_client->AsyncSendBatchBarrier(ep); rpc_client->AsyncSendBatchBarrier(ep);
} }
PADDLE_ENFORCE(rpc_client->Wait()); PADDLE_ENFORCE(rpc_client->Wait());
if (outs.size() > 0) { if (outs.size() > 0) {
for (size_t i = 0; i < outs.size(); i++) { for (size_t i = 0; i < outs.size(); i++) {
VLOG(3) << "getting " << outs[i] << " from " << epmap[i]; VLOG(2) << "getting " << outs[i] << " from " << epmap[i];
rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]); rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
} }
PADDLE_ENFORCE(rpc_client->Wait()); PADDLE_ENFORCE(rpc_client->Wait());
// tell pservers that current trainer have called fetch // tell pservers that current trainer have called fetch
for (auto& ep : endpoints) { for (auto& ep : endpoints) {
VLOG(3) << "send fetch barrier, ep: " << ep; VLOG(2) << "send fetch barrier, ep: " << ep;
rpc_client->AsyncSendFetchBarrier(ep); rpc_client->AsyncSendFetchBarrier(ep);
} }
PADDLE_ENFORCE(rpc_client->Wait()); PADDLE_ENFORCE(rpc_client->Wait());
......
...@@ -565,6 +565,8 @@ class DistributeTranspiler: ...@@ -565,6 +565,8 @@ class DistributeTranspiler:
orig_var_name = "" orig_var_name = ""
if suff_idx >= 0: if suff_idx >= 0:
orig_var_name = varname[:suff_idx] orig_var_name = varname[:suff_idx]
else:
orig_var_name = varname
return orig_var_name return orig_var_name
def _append_pserver_ops(self, optimize_block, opt_op, endpoint, def _append_pserver_ops(self, optimize_block, opt_op, endpoint,
...@@ -579,7 +581,8 @@ class DistributeTranspiler: ...@@ -579,7 +581,8 @@ class DistributeTranspiler:
grad_block = None grad_block = None
for g in self.param_grad_ep_mapping[endpoint]["grads"]: for g in self.param_grad_ep_mapping[endpoint]["grads"]:
if same_or_split_var( if same_or_split_var(
self._orig_varname(g.name), opt_op.input(key)[0]): self._orig_varname(g.name),
self._orig_varname(opt_op.input(key)[0])):
grad_block = g grad_block = g
break break
if not grad_block: if not grad_block:
...@@ -750,7 +753,7 @@ class DistributeTranspiler: ...@@ -750,7 +753,7 @@ class DistributeTranspiler:
param_names = [ param_names = [
p.name for p in self.param_grad_ep_mapping[endpoint]["params"] p.name for p in self.param_grad_ep_mapping[endpoint]["params"]
] ]
if op.input("Param") in param_names: if op.input("Param")[0] in param_names:
return True return True
else: else:
for n in param_names: for n in param_names:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册