提交 de65398c 编写于 作者: Q Qiao Longfei

update transpiler and listen and serv op

上级 25e2b417
......@@ -51,7 +51,8 @@ endif()
cc_test(rpc_server_test SRCS rpc_server_test.cc
DEPS ${RPC_DEPS} executor proto_desc lookup_sparse_table_op SERIAL)
cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler scope)
cc_test(async_sparse_param_update_recorder_test SRCS async_sparse_param_update_recorder_test.cc DEPS enforce simple_threadpool)
cc_library(async_sparse_param_update_recorder SRCS async_sparse_param_update_recorder.cc DEPS enforce simple_threadpool)
cc_test(async_sparse_param_update_recorder_test SRCS async_sparse_param_update_recorder_test.cc DEPS async_sparse_param_update_recorder)
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory)
cc_library(parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory)
cc_library(parameter_recv SRCS parameter_recv.cc DEPS sendrecvop_rpc memory)
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h"
namespace paddle {
namespace operators {
namespace distributed {
std::once_flag AsyncSparseParamUpdateRecorder::init_flag_;
std::unique_ptr<AsyncSparseParamUpdateRecorder>
AsyncSparseParamUpdateRecorder::recorder_(nullptr);
} // namespace distributed
} // namespace operators
} // namespace paddle
......@@ -67,10 +67,9 @@ class AsyncSparseParamUpdateRecorder {
int trainer_num,
const std::unordered_map<std::string, std::string>& grad_to_param)
: trainer_num_(trainer_num), grad_to_param_(grad_to_param) {
for (auto iter = grad_to_param.begin(); iter != grad_to_param.end();
iter++) {
param_to_grad_[iter->second] = iter->first;
auto& param_name = iter->second;
for (auto& iter : grad_to_param) {
param_to_grad_[iter.second] = iter.first;
auto& param_name = iter.second;
param_to_updated_rows_[param_name] = TrainerToRows();
auto& trainer_to_rows = param_to_updated_rows_[param_name];
for (auto i = 0; i < trainer_num; ++i) {
......@@ -104,11 +103,41 @@ class AsyncSparseParamUpdateRecorder {
return param_to_grad_.find(param_name) != param_to_grad_.end();
}
bool HasGrad(const std::string& grad_name) {
return grad_to_param_.find(grad_name) != grad_to_param_.end();
}
private:
const int trainer_num_;
std::unordered_map<std::string, std::string> grad_to_param_;
std::unordered_map<std::string, std::string> param_to_grad_;
std::unordered_map<std::string, TrainerToRows> param_to_updated_rows_;
// init recorder
public:
static void Init(
int trainer_num,
const std::unordered_map<std::string, std::string>& grad_to_param) {
InitImpl(trainer_num, grad_to_param);
}
static AsyncSparseParamUpdateRecorder* GetInstance() {
return recorder_.get();
}
private:
// Init is called by GetInstance.
static void InitImpl(
int trainer_num,
const std::unordered_map<std::string, std::string>& grad_to_param) {
if (recorder_ == nullptr) {
recorder_.reset(
new AsyncSparseParamUpdateRecorder(trainer_num, grad_to_param));
}
}
static std::once_flag init_flag_;
static std::unique_ptr<AsyncSparseParamUpdateRecorder> recorder_;
};
} // namespace distributed
......
......@@ -73,6 +73,10 @@ TEST(AsyncSparseParamUpdateRecorder, All) {
EXPECT_TRUE(recorder.HasParam("param2"));
EXPECT_FALSE(recorder.HasParam("param3"));
EXPECT_TRUE(recorder.HasGrad("grad1"));
EXPECT_TRUE(recorder.HasGrad("grad2"));
EXPECT_FALSE(recorder.HasGrad("grad3"));
std::vector<int64_t> ret;
EXPECT_ANY_THROW(recorder.GetAndClear("param1", trainer_num, &ret));
......
......@@ -180,6 +180,10 @@ class RequestHandler {
grad_to_prepared_ctx_ = g;
}
void SetSparseGradToParam(std::unordered_map<std::string, std::string>* g) {
sparse_grad_to_param_ = g;
}
void SetRPCServer(RPCServer* rpc_server) { rpc_server_ = rpc_server; }
// Get attributes.
......@@ -228,6 +232,7 @@ class RequestHandler {
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>*
grad_to_prepared_ctx_;
std::unordered_map<std::string, std::string>* sparse_grad_to_param_;
RPCServer* rpc_server_;
};
......
......@@ -2,9 +2,9 @@ include(operators)
set(DISTRIBUTE_DEPS "")
if(WITH_GRPC)
set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node)
set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator async_sparse_param_update_recorder grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node)
else()
set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator brpc leveldb snappystream snappy protobuf ssl crypto zlib node)
set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator async_sparse_param_update_recorder brpc leveldb snappystream snappy protobuf ssl crypto zlib node)
if(WITH_BRPC_RDMA)
find_library(IBVERBS_LIBRARY NAMES ibverbs)
ADD_LIBRARY(ibverbs SHARED IMPORTED GLOBAL)
......
......@@ -24,8 +24,10 @@ limitations under the License. */
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h"
#include "paddle/fluid/platform/profiler.h"
DEFINE_int32(rpc_send_thread_num, 12, "number of threads for rpc send");
......@@ -292,6 +294,8 @@ static void FillRequestCtx(
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>
*prefetch_ctx,
std::unordered_map<std::string, std::string>
*sparse_grad_name_to_param_name,
std::shared_ptr<framework::ExecutorPrepareContext> checkpoint_ctx,
distributed::RPCServer *rpc_server) {
h->SetScope(scope);
......@@ -299,6 +303,7 @@ static void FillRequestCtx(
h->SetExecutor(executor);
h->SetProgram(program);
h->SetPrefetchPreparedCtx(prefetch_ctx);
h->SetSparseGradToParam(sparse_grad_name_to_param_name);
h->SetRPCServer(rpc_server);
h->SetCheckpointNotifyPreparedCtx(checkpoint_ctx);
}
......@@ -414,10 +419,24 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
prefetch_var_name_to_prepared_ctx[prefetch_var_name] = prefetch_prepared[i];
}
auto f =
std::bind(FillRequestCtx, std::placeholders::_1, &recv_scope, &dev_ctx,
&executor, program, &prefetch_var_name_to_prepared_ctx,
ckpt_pre_context, rpc_service_.get());
// parse attr of kSparseGradToParam sparse_grad_name -> param_name
std::unordered_map<std::string, std::string> sparse_grad_name_to_param_name;
auto sparse_grad_name_to_param_name_str =
Attr<std::vector<std::string>>(kSparseGradToParam);
for (const auto &sparse_grad_name_and_param_name :
sparse_grad_name_to_param_name_str) {
std::vector<std::string> pieces;
split(sparse_grad_name_and_param_name, ':', &pieces);
PADDLE_ENFORCE_EQ(pieces.size(), 2);
VLOG(3) << "after split, sparse_grad_name = " << pieces[0]
<< ", param_name = " << pieces[1];
sparse_grad_name_to_param_name[pieces[0]] = pieces[1];
}
auto f = std::bind(
FillRequestCtx, std::placeholders::_1, &recv_scope, &dev_ctx, &executor,
program, &prefetch_var_name_to_prepared_ctx,
&sparse_grad_name_to_param_name, ckpt_pre_context, rpc_service_.get());
f(request_send_handler_.get());
f(request_get_handler_.get());
......@@ -445,6 +464,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
RunSyncLoop(&executor, program, &recv_scope, &dev_ctx,
prefetch_block_id_list, checkpoint_block_id);
} else {
distributed::AsyncSparseParamUpdateRecorder::Init(
fan_in, sparse_grad_name_to_param_name);
RunAsyncLoop(&executor, program, &recv_scope);
}
}
......@@ -475,6 +496,10 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<std::vector<std::string>>(kPrefetchVarNameToBlockId,
"prefetch blocks to run on server side.")
.SetDefault({});
AddAttr<std::vector<std::string>>(
kSparseGradToParam,
"sparse grad name to param name. like: 'emb@Grad:emb'")
.SetDefault({});
AddAttr<int>("Fanin", "How many clients send to this server.")
.SetDefault(1);
AddAttr<int>(kCheckpointBlockId,
......
......@@ -35,6 +35,7 @@ namespace operators {
constexpr char kOptimizeBlocks[] = "optimize_blocks";
constexpr char kPrefetchVarNameToBlockId[] = "prefetch_var_name_to_block_id";
constexpr char kCheckpointBlockId[] = "checkpint_block_id";
constexpr char kSparseGradToParam[] = "sparse_grad_to_param";
void RunServer(std::shared_ptr<distributed::RPCServer> service);
......
......@@ -791,11 +791,15 @@ class DistributeTranspiler(object):
global_ops = []
# sparse grad name to param name
sparse_grad_to_param = []
def __append_optimize_op__(op, block, grad_to_block_id, merged_var,
lr_ops):
if self._is_optimizer_op(op):
self._append_pserver_ops(block, op, endpoint, grad_to_block_id,
self.origin_program, merged_var)
self.origin_program, merged_var,
sparse_grad_to_param)
elif op not in lr_ops:
self._append_pserver_non_opt_ops(block, op)
......@@ -911,6 +915,7 @@ class DistributeTranspiler(object):
"Fanin": self.trainer_num,
"sync_mode": self.sync_mode,
"grad_to_block_id": grad_to_block_id,
"sparse_grad_to_param": sparse_grad_to_param,
}
if self.has_distributed_lookup_table:
......@@ -1778,7 +1783,8 @@ class DistributeTranspiler(object):
return o4
def _append_pserver_ops(self, optimize_block, opt_op, endpoint,
grad_to_block_id, origin_program, merged_var):
grad_to_block_id, origin_program, merged_var,
sparse_grad_to_param):
program = optimize_block.program
pserver_block = program.global_block()
new_inputs = collections.OrderedDict()
......@@ -1862,6 +1868,12 @@ class DistributeTranspiler(object):
outputs=outputs,
attrs=opt_op.all_attrs())
# record sparse grad to param name
if new_inputs["Grad"].type == core.VarDesc.VarType.SELECTED_ROWS:
sparse_grad_to_param.append(
str(new_inputs["Grad"].name) + ":" + str(new_inputs["Param"]
.name))
def _get_pserver_grad_param_var(self, var, var_dict):
"""
Return pserver side grad/param variable, return None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册