提交 1a438287 编写于 作者: Q qiaolongfei

implement main logic

上级 79a1a7cd
......@@ -19,6 +19,8 @@ limitations under the License. */
#include "paddle/fluid/operators/async_listen_and_serv_op.h"
#include "paddle/utils/StringUtil.h"
namespace paddle {
namespace operators {
......@@ -27,46 +29,18 @@ void RunServer(std::shared_ptr<detail::SyncGRPCServer> service) {
VLOG(4) << "RunServer thread end";
}
static void CreateTensorFromMessageType(framework::Variable *var,
sendrecv::VarType var_type) {
if (var_type == sendrecv::VarType::LOD_TENSOR) {
var->GetMutable<framework::LoDTensor>();
} else if (var_type == sendrecv::VarType::SELECTED_ROWS) {
var->GetMutable<framework::SelectedRows>();
} else {
PADDLE_THROW(
"VariableMessage type %d is not in "
"[LoDTensor, SelectedRows]",
var_type);
}
}
static void ParallelExecuteBlocks(
const std::vector<size_t> &parallel_blkids, framework::Executor *executor,
const std::vector<std::shared_ptr<framework::ExecutorPrepareContext>>
&prepared,
framework::ProgramDesc *program, framework::Scope *scope) {
std::vector<std::future<void>> fs;
for (size_t idx : parallel_blkids) {
fs.push_back(
framework::Async([&executor, &prepared, &program, &scope, idx]() {
int run_block = idx; // thread local
try {
executor->RunPreparedContext(prepared[run_block].get(), scope,
false, false);
} catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}
}));
}
for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
static void AsyncExecuteBlock(framework::Executor *executor,
framework::ExecutorPrepareContext *prepared,
framework::Scope *scope) {
framework::Async([&executor, &prepared, &scope]() {
try {
executor->RunPreparedContext(prepared, scope, false, false);
} catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}
});
}
static void AsyncExecuteBlock(
size_t block_id, framework::Executor *executor,
std::shared_ptr<framework::ExecutorPrepareContext> ctx,
framework::ProgramDesc *program, framework::Scope *scope) {}
AsyncListenAndServOp::AsyncListenAndServOp(
const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
......@@ -93,6 +67,21 @@ void AsyncListenAndServOp::RunImpl(const framework::Scope &scope,
rpc_service_.reset(new detail::SyncGRPCServer(endpoint));
}
// grad name to block id
std::unordered_map<std::string, int32_t> grad_to_id;
std::unordered_map<int32_t, std::string> id_to_grad;
auto grad_map_str = Attr<std::vector<std::string>>("grad_map");
for (auto &grad_and_id : grad_map_str) {
std::vector<std::string> pieces;
paddle::str::split(grad_and_id, ' ', &pieces);
PADDLE_ENFORCE_EQ(pieces.size(), 2);
PADDLE_ENFORCE_EQ(grad_to_id.count(pieces[0]), 0);
int block_id = std::stoi(pieces[1]);
grad_to_id[pieces[0]] = block_id;
id_to_grad[block_id] = pieces[0];
}
auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock);
auto *prefetch_block = Attr<framework::BlockDesc *>(kPrefetchBlock);
auto *program = optimize_block->Program();
......@@ -108,10 +97,13 @@ void AsyncListenAndServOp::RunImpl(const framework::Scope &scope,
}
}
auto optimize_prepared = executor.Prepare(*program, block_list);
// Insert placeholder for block0 which holds current op itself.
optimize_prepared.insert(
optimize_prepared.begin(),
std::shared_ptr<framework::ExecutorPrepareContext>(nullptr));
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>
grad_to_prepared;
for (size_t i = 0; i < block_list.size(); ++i) {
grad_to_prepared[id_to_grad[block_list[i]]] = optimize_prepared[i];
}
rpc_service_->SetScope(&recv_scope);
rpc_service_->SetDevCtx(&dev_ctx);
......@@ -122,6 +114,7 @@ void AsyncListenAndServOp::RunImpl(const framework::Scope &scope,
rpc_service_->SetPrefetchPreparedCtx(prefetch_prepared.get());
prefetch_prepared.release();
rpc_service_->SetProgram(program);
// start the server listening after all member initialized.
server_thread_.reset(new std::thread(RunServer, rpc_service_));
VLOG(3) << "wait server thread to become ready...";
......@@ -133,9 +126,6 @@ void AsyncListenAndServOp::RunImpl(const framework::Scope &scope,
port_file.close();
bool exit_flag = false;
// Record received sparse variables, so that
// we could reset those after execute optimize program
std::vector<framework::Variable *> sparse_vars;
while (!exit_flag) {
const detail::ReceivedMessage v = rpc_service_->Get();
auto recv_var_name = v.first;
......@@ -150,36 +140,17 @@ void AsyncListenAndServOp::RunImpl(const framework::Scope &scope,
LOG(ERROR) << "Can not find server side var: " << recv_var_name;
PADDLE_THROW("Can not find server side var");
}
AsyncExecuteBlock(&executor, grad_to_prepared[recv_var_name].get(),
&recv_scope);
if (var->IsType<framework::SelectedRows>()) {
sparse_vars.push_back(var);
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
}
AsyncExecuteBlock();
}
if (exit_flag) {
rpc_service_->ShutDown();
break;
}
// NOTE: if is_gpu_place, CUDA kernels are launched by multiple threads
// and this will still work.
// The optimize blocks which have the same parent ID would run parallel
// TODO(Yancey1989): need to use ParallelExecutor for future
int32_t last_parent_blkid = program->Block(1).Parent();
VLOG(2) << "run all blocks spent " << detail::GetTimestamp() - ts << "(ms)";
// Reset the received sparse variables, the sum operator would not
// sum the input sparse variables which rows is empty at the next
// mini-batch.
// TODO(Yancey1989): move the reset action into an operator, we couldn't
// have any hide logic in the operator.
for (auto &var : sparse_vars) {
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
}
// FIXME(typhoonzero): use another condition to sync wait clients get.
sparse_vars.clear();
} // while(true)
}
......@@ -199,6 +170,10 @@ from send_op and send back variables to recv_op.
"IP address to listen on.")
.SetDefault("127.0.0.1:6164")
.AddCustomChecker([](const std::string &ip) { return !ip.empty(); });
AddAttr<std::vector<std::string>>(
"grad_map(['param1@GRAD.block0:1', 'param2@GRAD.blockn:2'])",
"a map from grad name to it's optimize block id")
.SetDefault({});
AddAttr<framework::BlockDesc *>(kOptimizeBlock,
"BlockID to run on server side.");
AddAttr<framework::BlockDesc *>(kPrefetchBlock,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册