提交 de65398c 编写于 作者: Q Qiao Longfei

update transpiler and listen and serv op

上级 25e2b417
...@@ -51,7 +51,8 @@ endif() ...@@ -51,7 +51,8 @@ endif()
cc_test(rpc_server_test SRCS rpc_server_test.cc cc_test(rpc_server_test SRCS rpc_server_test.cc
DEPS ${RPC_DEPS} executor proto_desc lookup_sparse_table_op SERIAL) DEPS ${RPC_DEPS} executor proto_desc lookup_sparse_table_op SERIAL)
cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler scope) cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler scope)
cc_test(async_sparse_param_update_recorder_test SRCS async_sparse_param_update_recorder_test.cc DEPS enforce simple_threadpool) cc_library(async_sparse_param_update_recorder SRCS async_sparse_param_update_recorder.cc DEPS enforce simple_threadpool)
cc_test(async_sparse_param_update_recorder_test SRCS async_sparse_param_update_recorder_test.cc DEPS async_sparse_param_update_recorder)
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory) cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory)
cc_library(parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory) cc_library(parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory)
cc_library(parameter_recv SRCS parameter_recv.cc DEPS sendrecvop_rpc memory) cc_library(parameter_recv SRCS parameter_recv.cc DEPS sendrecvop_rpc memory)
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h"
namespace paddle {
namespace operators {
namespace distributed {
std::once_flag AsyncSparseParamUpdateRecorder::init_flag_;
std::unique_ptr<AsyncSparseParamUpdateRecorder>
AsyncSparseParamUpdateRecorder::recorder_(nullptr);
} // namespace distributed
} // namespace operators
} // namespace paddle
...@@ -67,10 +67,9 @@ class AsyncSparseParamUpdateRecorder { ...@@ -67,10 +67,9 @@ class AsyncSparseParamUpdateRecorder {
int trainer_num, int trainer_num,
const std::unordered_map<std::string, std::string>& grad_to_param) const std::unordered_map<std::string, std::string>& grad_to_param)
: trainer_num_(trainer_num), grad_to_param_(grad_to_param) { : trainer_num_(trainer_num), grad_to_param_(grad_to_param) {
for (auto iter = grad_to_param.begin(); iter != grad_to_param.end(); for (auto& iter : grad_to_param) {
iter++) { param_to_grad_[iter.second] = iter.first;
param_to_grad_[iter->second] = iter->first; auto& param_name = iter.second;
auto& param_name = iter->second;
param_to_updated_rows_[param_name] = TrainerToRows(); param_to_updated_rows_[param_name] = TrainerToRows();
auto& trainer_to_rows = param_to_updated_rows_[param_name]; auto& trainer_to_rows = param_to_updated_rows_[param_name];
for (auto i = 0; i < trainer_num; ++i) { for (auto i = 0; i < trainer_num; ++i) {
...@@ -104,11 +103,41 @@ class AsyncSparseParamUpdateRecorder { ...@@ -104,11 +103,41 @@ class AsyncSparseParamUpdateRecorder {
return param_to_grad_.find(param_name) != param_to_grad_.end(); return param_to_grad_.find(param_name) != param_to_grad_.end();
} }
bool HasGrad(const std::string& grad_name) {
return grad_to_param_.find(grad_name) != grad_to_param_.end();
}
private: private:
const int trainer_num_; const int trainer_num_;
std::unordered_map<std::string, std::string> grad_to_param_; std::unordered_map<std::string, std::string> grad_to_param_;
std::unordered_map<std::string, std::string> param_to_grad_; std::unordered_map<std::string, std::string> param_to_grad_;
std::unordered_map<std::string, TrainerToRows> param_to_updated_rows_; std::unordered_map<std::string, TrainerToRows> param_to_updated_rows_;
// init recorder
public:
static void Init(
int trainer_num,
const std::unordered_map<std::string, std::string>& grad_to_param) {
InitImpl(trainer_num, grad_to_param);
}
static AsyncSparseParamUpdateRecorder* GetInstance() {
return recorder_.get();
}
private:
// Init is called by GetInstance.
static void InitImpl(
int trainer_num,
const std::unordered_map<std::string, std::string>& grad_to_param) {
if (recorder_ == nullptr) {
recorder_.reset(
new AsyncSparseParamUpdateRecorder(trainer_num, grad_to_param));
}
}
static std::once_flag init_flag_;
static std::unique_ptr<AsyncSparseParamUpdateRecorder> recorder_;
}; };
} // namespace distributed } // namespace distributed
......
...@@ -73,6 +73,10 @@ TEST(AsyncSparseParamUpdateRecorder, All) { ...@@ -73,6 +73,10 @@ TEST(AsyncSparseParamUpdateRecorder, All) {
EXPECT_TRUE(recorder.HasParam("param2")); EXPECT_TRUE(recorder.HasParam("param2"));
EXPECT_FALSE(recorder.HasParam("param3")); EXPECT_FALSE(recorder.HasParam("param3"));
EXPECT_TRUE(recorder.HasGrad("grad1"));
EXPECT_TRUE(recorder.HasGrad("grad2"));
EXPECT_FALSE(recorder.HasGrad("grad3"));
std::vector<int64_t> ret; std::vector<int64_t> ret;
EXPECT_ANY_THROW(recorder.GetAndClear("param1", trainer_num, &ret)); EXPECT_ANY_THROW(recorder.GetAndClear("param1", trainer_num, &ret));
......
...@@ -180,6 +180,10 @@ class RequestHandler { ...@@ -180,6 +180,10 @@ class RequestHandler {
grad_to_prepared_ctx_ = g; grad_to_prepared_ctx_ = g;
} }
void SetSparseGradToParam(std::unordered_map<std::string, std::string>* g) {
sparse_grad_to_param_ = g;
}
void SetRPCServer(RPCServer* rpc_server) { rpc_server_ = rpc_server; } void SetRPCServer(RPCServer* rpc_server) { rpc_server_ = rpc_server; }
// Get attributes. // Get attributes.
...@@ -228,6 +232,7 @@ class RequestHandler { ...@@ -228,6 +232,7 @@ class RequestHandler {
std::unordered_map<std::string, std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>* std::shared_ptr<framework::ExecutorPrepareContext>>*
grad_to_prepared_ctx_; grad_to_prepared_ctx_;
std::unordered_map<std::string, std::string>* sparse_grad_to_param_;
RPCServer* rpc_server_; RPCServer* rpc_server_;
}; };
......
...@@ -2,9 +2,9 @@ include(operators) ...@@ -2,9 +2,9 @@ include(operators)
set(DISTRIBUTE_DEPS "") set(DISTRIBUTE_DEPS "")
if(WITH_GRPC) if(WITH_GRPC)
set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node) set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator async_sparse_param_update_recorder grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node)
else() else()
set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator brpc leveldb snappystream snappy protobuf ssl crypto zlib node) set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator async_sparse_param_update_recorder brpc leveldb snappystream snappy protobuf ssl crypto zlib node)
if(WITH_BRPC_RDMA) if(WITH_BRPC_RDMA)
find_library(IBVERBS_LIBRARY NAMES ibverbs) find_library(IBVERBS_LIBRARY NAMES ibverbs)
ADD_LIBRARY(ibverbs SHARED IMPORTED GLOBAL) ADD_LIBRARY(ibverbs SHARED IMPORTED GLOBAL)
......
...@@ -24,8 +24,10 @@ limitations under the License. */ ...@@ -24,8 +24,10 @@ limitations under the License. */
#include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h" #include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
DEFINE_int32(rpc_send_thread_num, 12, "number of threads for rpc send"); DEFINE_int32(rpc_send_thread_num, 12, "number of threads for rpc send");
...@@ -292,6 +294,8 @@ static void FillRequestCtx( ...@@ -292,6 +294,8 @@ static void FillRequestCtx(
std::unordered_map<std::string, std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>> std::shared_ptr<framework::ExecutorPrepareContext>>
*prefetch_ctx, *prefetch_ctx,
std::unordered_map<std::string, std::string>
*sparse_grad_name_to_param_name,
std::shared_ptr<framework::ExecutorPrepareContext> checkpoint_ctx, std::shared_ptr<framework::ExecutorPrepareContext> checkpoint_ctx,
distributed::RPCServer *rpc_server) { distributed::RPCServer *rpc_server) {
h->SetScope(scope); h->SetScope(scope);
...@@ -299,6 +303,7 @@ static void FillRequestCtx( ...@@ -299,6 +303,7 @@ static void FillRequestCtx(
h->SetExecutor(executor); h->SetExecutor(executor);
h->SetProgram(program); h->SetProgram(program);
h->SetPrefetchPreparedCtx(prefetch_ctx); h->SetPrefetchPreparedCtx(prefetch_ctx);
h->SetSparseGradToParam(sparse_grad_name_to_param_name);
h->SetRPCServer(rpc_server); h->SetRPCServer(rpc_server);
h->SetCheckpointNotifyPreparedCtx(checkpoint_ctx); h->SetCheckpointNotifyPreparedCtx(checkpoint_ctx);
} }
...@@ -414,10 +419,24 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -414,10 +419,24 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
prefetch_var_name_to_prepared_ctx[prefetch_var_name] = prefetch_prepared[i]; prefetch_var_name_to_prepared_ctx[prefetch_var_name] = prefetch_prepared[i];
} }
auto f = // parse attr of kSparseGradToParam sparse_grad_name -> param_name
std::bind(FillRequestCtx, std::placeholders::_1, &recv_scope, &dev_ctx, std::unordered_map<std::string, std::string> sparse_grad_name_to_param_name;
&executor, program, &prefetch_var_name_to_prepared_ctx, auto sparse_grad_name_to_param_name_str =
ckpt_pre_context, rpc_service_.get()); Attr<std::vector<std::string>>(kSparseGradToParam);
for (const auto &sparse_grad_name_and_param_name :
sparse_grad_name_to_param_name_str) {
std::vector<std::string> pieces;
split(sparse_grad_name_and_param_name, ':', &pieces);
PADDLE_ENFORCE_EQ(pieces.size(), 2);
VLOG(3) << "after split, sparse_grad_name = " << pieces[0]
<< ", param_name = " << pieces[1];
sparse_grad_name_to_param_name[pieces[0]] = pieces[1];
}
auto f = std::bind(
FillRequestCtx, std::placeholders::_1, &recv_scope, &dev_ctx, &executor,
program, &prefetch_var_name_to_prepared_ctx,
&sparse_grad_name_to_param_name, ckpt_pre_context, rpc_service_.get());
f(request_send_handler_.get()); f(request_send_handler_.get());
f(request_get_handler_.get()); f(request_get_handler_.get());
...@@ -445,6 +464,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -445,6 +464,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
RunSyncLoop(&executor, program, &recv_scope, &dev_ctx, RunSyncLoop(&executor, program, &recv_scope, &dev_ctx,
prefetch_block_id_list, checkpoint_block_id); prefetch_block_id_list, checkpoint_block_id);
} else { } else {
distributed::AsyncSparseParamUpdateRecorder::Init(
fan_in, sparse_grad_name_to_param_name);
RunAsyncLoop(&executor, program, &recv_scope); RunAsyncLoop(&executor, program, &recv_scope);
} }
} }
...@@ -475,6 +496,10 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -475,6 +496,10 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<std::vector<std::string>>(kPrefetchVarNameToBlockId, AddAttr<std::vector<std::string>>(kPrefetchVarNameToBlockId,
"prefetch blocks to run on server side.") "prefetch blocks to run on server side.")
.SetDefault({}); .SetDefault({});
AddAttr<std::vector<std::string>>(
kSparseGradToParam,
"sparse grad name to param name. like: 'emb@Grad:emb'")
.SetDefault({});
AddAttr<int>("Fanin", "How many clients send to this server.") AddAttr<int>("Fanin", "How many clients send to this server.")
.SetDefault(1); .SetDefault(1);
AddAttr<int>(kCheckpointBlockId, AddAttr<int>(kCheckpointBlockId,
......
...@@ -35,6 +35,7 @@ namespace operators { ...@@ -35,6 +35,7 @@ namespace operators {
constexpr char kOptimizeBlocks[] = "optimize_blocks"; constexpr char kOptimizeBlocks[] = "optimize_blocks";
constexpr char kPrefetchVarNameToBlockId[] = "prefetch_var_name_to_block_id"; constexpr char kPrefetchVarNameToBlockId[] = "prefetch_var_name_to_block_id";
constexpr char kCheckpointBlockId[] = "checkpint_block_id"; constexpr char kCheckpointBlockId[] = "checkpint_block_id";
constexpr char kSparseGradToParam[] = "sparse_grad_to_param";
void RunServer(std::shared_ptr<distributed::RPCServer> service); void RunServer(std::shared_ptr<distributed::RPCServer> service);
......
...@@ -791,11 +791,15 @@ class DistributeTranspiler(object): ...@@ -791,11 +791,15 @@ class DistributeTranspiler(object):
global_ops = [] global_ops = []
# sparse grad name to param name
sparse_grad_to_param = []
def __append_optimize_op__(op, block, grad_to_block_id, merged_var, def __append_optimize_op__(op, block, grad_to_block_id, merged_var,
lr_ops): lr_ops):
if self._is_optimizer_op(op): if self._is_optimizer_op(op):
self._append_pserver_ops(block, op, endpoint, grad_to_block_id, self._append_pserver_ops(block, op, endpoint, grad_to_block_id,
self.origin_program, merged_var) self.origin_program, merged_var,
sparse_grad_to_param)
elif op not in lr_ops: elif op not in lr_ops:
self._append_pserver_non_opt_ops(block, op) self._append_pserver_non_opt_ops(block, op)
...@@ -911,6 +915,7 @@ class DistributeTranspiler(object): ...@@ -911,6 +915,7 @@ class DistributeTranspiler(object):
"Fanin": self.trainer_num, "Fanin": self.trainer_num,
"sync_mode": self.sync_mode, "sync_mode": self.sync_mode,
"grad_to_block_id": grad_to_block_id, "grad_to_block_id": grad_to_block_id,
"sparse_grad_to_param": sparse_grad_to_param,
} }
if self.has_distributed_lookup_table: if self.has_distributed_lookup_table:
...@@ -1778,7 +1783,8 @@ class DistributeTranspiler(object): ...@@ -1778,7 +1783,8 @@ class DistributeTranspiler(object):
return o4 return o4
def _append_pserver_ops(self, optimize_block, opt_op, endpoint, def _append_pserver_ops(self, optimize_block, opt_op, endpoint,
grad_to_block_id, origin_program, merged_var): grad_to_block_id, origin_program, merged_var,
sparse_grad_to_param):
program = optimize_block.program program = optimize_block.program
pserver_block = program.global_block() pserver_block = program.global_block()
new_inputs = collections.OrderedDict() new_inputs = collections.OrderedDict()
...@@ -1862,6 +1868,12 @@ class DistributeTranspiler(object): ...@@ -1862,6 +1868,12 @@ class DistributeTranspiler(object):
outputs=outputs, outputs=outputs,
attrs=opt_op.all_attrs()) attrs=opt_op.all_attrs())
# record sparse grad to param name
if new_inputs["Grad"].type == core.VarDesc.VarType.SELECTED_ROWS:
sparse_grad_to_param.append(
str(new_inputs["Grad"].name) + ":" + str(new_inputs["Param"]
.name))
def _get_pserver_grad_param_var(self, var, var_dict): def _get_pserver_grad_param_var(self, var, var_dict):
""" """
Return pserver side grad/param variable, return None Return pserver side grad/param variable, return None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册