提交 b645dfac 编写于 作者: Q qiaolongfei

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into update-api-reference-1

......@@ -180,7 +180,7 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc,
print_train_time(start_time, time.time(), num_samples)
print("Pass: %d, Loss: %f" % (pass_id, np.mean(train_losses))),
# evaluation
if not args.no_test and batch_acc:
if not args.no_test and batch_acc and not args.use_reader_op:
pass_test_acc = test(exe, infer_prog, test_reader, feeder,
batch_acc)
print(", Test Accuracy: %f" % pass_test_acc)
......@@ -277,11 +277,12 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader,
batch_id += 1
print_train_time(start_time, time.time(), num_samples)
if not args.no_test and batch_acc:
if not args.no_test and batch_acc and not args.use_reader_op:
# we have not implement record io for test
# skip test when use args.use_reader_op
test_acc = test(startup_exe, infer_prog, test_reader, feeder,
batch_acc)
print("Pass: %d, Test Accuracy: %f\n" % (pass_id, test_acc))
exit(0)
def print_arguments(args):
......
......@@ -199,7 +199,10 @@ def get_model(args):
batched_train_reader = paddle.batch(
paddle.reader.shuffle(
train_reader, buf_size=5120),
batch_size=args.batch_size * args.gpus)
batched_test_reader = paddle.batch(train_reader, batch_size=args.batch_size)
batch_size=args.batch_size * args.gpus,
drop_last=True)
batched_test_reader = paddle.batch(
train_reader, batch_size=args.batch_size, drop_last=True)
return avg_cost, inference_program, optimizer, batched_train_reader, batched_test_reader, batch_acc
return avg_cost, inference_program, optimizer, batched_train_reader,\
batched_test_reader, batch_acc
......@@ -118,6 +118,10 @@ endif()
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${SIMD_FLAG}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${SIMD_FLAG}")
if(WITH_DISTRIBUTE)
add_definitions(-DPADDLE_WITH_DISTRIBUTE)
endif()
if(WITH_GOLANG)
# we need to symlink Paddle directory into GOPATH. If we
# don't do it and we have code that depends on Paddle, go
......
......@@ -83,8 +83,13 @@ cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor)
cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glog)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope
framework_proto glog lod_rank_table feed_fetch_method)
if(WITH_DISTRIBUTE)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else()
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method)
endif()
cc_library(parallel_executor SRCS parallel_executor.cc DEPS ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor)
......
......@@ -19,7 +19,7 @@
namespace paddle {
namespace framework {
namespace details {
class SSAGraph;
struct SSAGraph;
class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
public:
......
......@@ -20,6 +20,9 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/fluid/operators/detail/grpc_client.h"
#endif
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h"
......@@ -44,6 +47,14 @@ ExecutorPrepareContext::~ExecutorPrepareContext() {
Executor::Executor(const platform::Place& place) : place_(place) {}
#ifdef PADDLE_WITH_DISTRIBUTE
void Executor::Complete() {
::paddle::operators::detail::RPCClient::GetInstance<
::paddle::operators::detail::GRPCClient>()
->SendComplete();
}
#endif
void InitializeVariable(Variable* var, proto::VarType::Type var_type) {
if (var_type == proto::VarType::LOD_TENSOR) {
var->GetMutable<LoDTensor>();
......
......@@ -44,6 +44,13 @@ class Executor {
explicit Executor(const platform::Place& place);
#ifdef PADDLE_WITH_DISTRIBUTE
/*
* Sending signal to pserver to mark current trainer stop.
*/
void Complete();
#endif
/* @Brief
* Runtime evaluation of the given ProgramDesc under certain Scope
*
......
......@@ -35,14 +35,15 @@ class ReaderBase {
class DecoratedReader : public ReaderBase {
public:
explicit DecoratedReader(ReaderBase* reader) : ReaderBase(), reader_(reader) {
explicit DecoratedReader(const std::shared_ptr<ReaderBase>& reader)
: ReaderBase(), reader_(reader) {
PADDLE_ENFORCE_NOT_NULL(reader_);
}
void ReInit() override { reader_->ReInit(); }
protected:
ReaderBase* reader_;
std::shared_ptr<ReaderBase> reader_;
};
class FileReader : public ReaderBase {
......@@ -64,7 +65,7 @@ class ReaderHolder {
public:
void Reset(ReaderBase* reader) { reader_.reset(reader); }
ReaderBase* Get() const { return reader_.get(); }
std::shared_ptr<ReaderBase> Get() const { return reader_; }
void ReadNext(std::vector<LoDTensor>* out) {
PADDLE_ENFORCE_NOT_NULL(reader_);
......@@ -76,7 +77,7 @@ class ReaderHolder {
}
private:
std::unique_ptr<ReaderBase> reader_;
std::shared_ptr<ReaderBase> reader_;
};
} // namespace framework
......
......@@ -34,6 +34,12 @@ void GRPCClient::InitEventLoop() {
client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this)));
}
void GRPCClient::SendComplete() {
for (auto& it : channels_) {
this->AsyncSendComplete(it.first);
}
}
GRPCClient::~GRPCClient() {
Wait();
cq_.Shutdown();
......@@ -210,6 +216,19 @@ void GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
req_count_++;
}
void GRPCClient::AsyncSendComplete(const std::string& ep, int64_t time_out) {
const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
s->Prepare(time_out);
sendrecv::VariableMessage req;
req.set_varname(COMPLETE_MESSAGE);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
}
void GRPCClient::Wait() {
std::unique_lock<std::mutex> lk(sync_mutex_);
sync_cond_.wait(lk, [this] { return req_count_ == 0; });
......
......@@ -195,6 +195,8 @@ class GRPCClient : public RPCClient {
void Wait() override;
void SendComplete() override;
protected:
void InitImpl() override;
......@@ -204,6 +206,9 @@ class GRPCClient : public RPCClient {
void Proceed();
void AsyncSendComplete(const std::string& ep,
int64_t time_out = RPCClient::rpc_time_out);
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
private:
......
......@@ -162,16 +162,18 @@ class RequestPrefetch final : public RequestBase {
void Process() override {
// prefetch process...
std::string varname = request_->OutVarname();
VLOG(3) << "RequestPrefetch " << varname;
std::string in_var_name = request_->Varname();
std::string out_var_name = request_->OutVarname();
VLOG(3) << "RequestPrefetch, in_var_name: " << in_var_name
<< " out_var_name: " << out_var_name;
auto scope = request_->GetMutableLocalScope();
auto invar = scope->FindVar(varname);
framework::Variable* outvar = nullptr;
auto invar = scope->FindVar(in_var_name);
framework::Variable* outvar = scope->FindVar(out_var_name);
request_handler_->Handle(varname, scope, invar, &outvar);
request_handler_->Handle(in_var_name, scope, invar, &outvar, out_var_name);
SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(),
SerializeToByteBuffer(out_var_name, outvar, *request_handler_->dev_ctx(),
&reply_);
Finish(reply_, &responder_);
}
......@@ -287,7 +289,7 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
} else if (rpc_name == kRequestPrefetch) {
b = new RequestPrefetch(&service_, cq.get(), handler, req_id);
} else {
PADDLE_ENFORCE(false, "not surpported rpc");
PADDLE_ENFORCE(false, "not supported rpc");
}
reqs[req_id] = b;
......
......@@ -40,6 +40,7 @@ constexpr char kRequestPrefetch[] = "RequestPrefetch";
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
#define COMPLETE_MESSAGE "COMPLETE@RECV"
class RPCServer;
......@@ -60,9 +61,12 @@ class RequestHandler {
void SetDevCtx(const platform::DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; }
void SetProgram(framework::ProgramDesc* program) { program_ = program; }
void SetExecutor(framework::Executor* executor) { executor_ = executor; }
// Used for dist lookup table prefetch
void SetPrefetchPreparedCtx(
std::unique_ptr<framework::ExecutorPrepareContext> prepared) {
prefetch_ctx_.reset(prepared.release());
std::unordered_map<
std::string, std::shared_ptr<framework::ExecutorPrepareContext>>* g) {
prefetch_var_name_to_prepared_ctx_ = g;
}
// Used for async.
......@@ -78,9 +82,6 @@ class RequestHandler {
bool sync_mode() { return sync_mode_; }
framework::Scope* scope() { return scope_; }
const platform::DeviceContext* dev_ctx() { return dev_ctx_; }
framework::ExecutorPrepareContext* prefetch_ctx() {
return prefetch_ctx_.get();
}
framework::ProgramDesc* program() { return program_; }
framework::Executor* executor() { return executor_; }
......@@ -99,8 +100,8 @@ class RequestHandler {
// *request_handler_->dev_ctx(), &reply_);
// }
virtual bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var,
framework::Variable** outvar) = 0;
framework::Variable* var, framework::Variable** outvar,
const std::string& out_var_name = "") = 0;
protected:
const bool sync_mode_;
......@@ -109,12 +110,17 @@ class RequestHandler {
framework::Executor* executor_;
framework::Scope* scope_;
framework::ProgramDesc* program_;
std::unique_ptr<framework::ExecutorPrepareContext> prefetch_ctx_;
// used for distribute lookup table prefetch
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>*
prefetch_var_name_to_prepared_ctx_;
// Used for async.
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>*
grad_to_prepared_ctx_;
RPCServer* rpc_server_;
};
......
......@@ -30,7 +30,8 @@ namespace detail {
bool RequestSendHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar) {
framework::Variable** outvar,
const std::string& out_var_name) {
VLOG(4) << "RequestSendHandler:" << varname;
// Async
......@@ -49,6 +50,9 @@ bool RequestSendHandler::Handle(const std::string& varname,
if (varname == BATCH_BARRIER_MESSAGE) {
VLOG(3) << "sync: recv batch barrier message";
rpc_server_->IncreaseBatchBarrier(kRequestSend);
} else if (varname == COMPLETE_MESSAGE) {
VLOG(3) << "sync: recv complete message";
rpc_server_->DecreaseClientNum();
} else {
VLOG(3) << "sync: received var_name: " << varname;
if (sync_mode_) {
......@@ -79,7 +83,8 @@ void RequestSendHandler::ResetSparseVarRecorder() {
bool RequestGetHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar) {
framework::Variable** outvar,
const std::string& out_var_name) {
VLOG(4) << "RequestGetHandler:" << varname;
if (varname != FETCH_BARRIER_MESSAGE) {
......@@ -102,13 +107,14 @@ bool RequestGetHandler::Handle(const std::string& varname,
bool RequestPrefetchHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar) {
framework::Variable** outvar,
const std::string& out_var_name) {
VLOG(4) << "RequestPrefetchHandler " << varname;
auto var_desc = program_->Block(0).FindVar(varname);
*outvar = scope->FindVar(varname);
auto var_desc = program_->Block(0).FindVar(out_var_name);
InitializeVariable(*outvar, var_desc->GetType());
executor_->RunPreparedContext(prefetch_ctx_.get(), scope);
executor_->RunPreparedContext(
(*prefetch_var_name_to_prepared_ctx_)[varname].get(), scope);
return true;
}
......
......@@ -39,7 +39,8 @@ class RequestSendHandler final : public RequestHandler {
explicit RequestSendHandler(bool sync_mode) : RequestHandler(sync_mode) {}
virtual ~RequestSendHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar) override;
framework::Variable* var, framework::Variable** outvar,
const std::string& out_var_name = "") override;
void ResetSparseVarRecorder();
private:
......@@ -52,7 +53,8 @@ class RequestGetHandler final : public RequestHandler {
explicit RequestGetHandler(bool sync_mode) : RequestHandler(sync_mode) {}
virtual ~RequestGetHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar) override;
framework::Variable* var, framework::Variable** outvar,
const std::string& out_var_name = "") override;
};
class RequestPrefetchHandler final : public RequestHandler {
......@@ -60,7 +62,8 @@ class RequestPrefetchHandler final : public RequestHandler {
explicit RequestPrefetchHandler(bool sync_mode) : RequestHandler(sync_mode) {}
virtual ~RequestPrefetchHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar) override;
framework::Variable* var, framework::Variable** outvar,
const std::string& out_var_name = "") override;
};
} // namespace detail
......
......@@ -53,6 +53,11 @@ class RPCClient {
virtual void AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out = rpc_time_out) = 0;
// SendComplete tells all the server that current trainer have no more data
// to train, so that the pserver can reduce it's barrier count, and continue
// to train with other trainers.
virtual void SendComplete() = 0;
virtual void Wait() = 0;
static constexpr int64_t rpc_time_out = 120 * 1000;
......
......@@ -43,7 +43,7 @@ void RPCServer::SavePort() const {
void RPCServer::WaitBarrier(const std::string& rpc_name) {
std::unique_lock<std::mutex> lock(this->mutex_);
barrier_cond_.wait(lock, [=] {
barrier_cond_.wait(lock, [this, &rpc_name] {
return (barrier_counter_[rpc_name] >= client_num_ || exit_flag_.load());
});
......@@ -53,19 +53,23 @@ void RPCServer::WaitBarrier(const std::string& rpc_name) {
void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) {
VLOG(3) << "RPCServer begin IncreaseBatchBarrier " << rpc_name;
int b = 0;
{
std::unique_lock<std::mutex> lock(mutex_);
b = ++barrier_counter_[rpc_name];
}
VLOG(3) << "RPCServer IncreaseBatchBarrier " << rpc_name
<< ", barrier_count:" << b << ", fan_in" << client_num_;
std::unique_lock<std::mutex> lock(mutex_);
b = ++barrier_counter_[rpc_name];
if (b >= client_num_) {
lock.unlock();
barrier_cond_.notify_all();
lock.lock();
}
}
void RPCServer::DecreaseClientNum() {
{
std::unique_lock<std::mutex> lock(mutex_);
client_num_--;
}
barrier_cond_.notify_all();
}
void RPCServer::ResetBarrierCounter() {
VLOG(3) << "RPCServer ResetBarrierCounter ";
std::unique_lock<std::mutex> lock(mutex_);
......
......@@ -60,7 +60,7 @@ class RPCServer {
void SetCond(const std::string& rpc_name);
void WaitCond(const std::string& rpc_name);
void IncreaseBatchBarrier(const std::string rpc_name);
void DecreaseClientNum();
void ResetBarrierCounter();
protected:
......@@ -79,8 +79,7 @@ class RPCServer {
std::string bind_address_;
std::atomic<int> exit_flag_;
int selected_port_;
const int client_num_;
int client_num_;
std::unordered_map<std::string, RequestHandler*> rpc_call_map_;
std::unordered_map<std::string, int> rpc_thread_num_;
......
......@@ -98,11 +98,17 @@ void StartServer() {
framework::Executor exe(place);
platform::CPUDeviceContext ctx(place);
auto* block = AppendPrefetchBlcok(&program);
auto prepared = exe.Prepare(program, block->ID());
std::string in_var_name("ids");
std::vector<int> prefetch_block_ids{block->ID()};
auto prepared = exe.Prepare(program, prefetch_block_ids);
InitTensorsOnServer(&scope, &place, 10);
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>
prefetch_var_name_to_prepared;
prefetch_var_name_to_prepared[in_var_name] = prepared[0];
g_req_handler->SetProgram(&program);
g_req_handler->SetPrefetchPreparedCtx(std::move(prepared));
g_req_handler->SetPrefetchPreparedCtx(&prefetch_var_name_to_prepared);
g_req_handler->SetDevCtx(&ctx);
g_req_handler->SetScope(&scope);
g_req_handler->SetExecutor(&exe);
......
......@@ -66,40 +66,41 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault(-1)
.EqualGreaterThan(-1);
AddComment(string::Sprintf(R"DOC(
Limited Elementwise %s Operator.
Limited Elementwise %s Operator
The equation is:
$$%s$$
$X$ is a tensor of any dimension and the dimensions of tensor $Y$ must be
smaller than or equal to the dimensions of $X$.
- $X$: a tensor of any dimension.
- $Y$: a tensor whose dimensions must be less than or equal to the dimensions of $X$.
There are two cases for this operator:
1. The shape of $Y$ is same with $X$;
2. The shape of $Y$ is a congiguous subsequencet of $X$. The trailing dimensions
of size 1 for $Y$ will be ignored for the consideration of subsequence.
1. The shape of $Y$ is the same with $X$.
2. The shape of $Y$ is a continuous subsequence of $X$.
For case 2:
$Y$ will be broadcasted to match the shape of $X$ and axis should be
set to index of the start dimension to broadcast $Y$ onto $X$.
1. Broadcast $Y$ to match the shape of $X$, where $axis$ is the start dimension index
for broadcasting $Y$ onto $X$.
2. If $axis$ is -1 (default), $axis = rank(X) - rank(Y)$.
3. The trailing dimensions of size 1 for $Y$ will be ignored for the consideration of
subsequence, such as shape(Y) = (2, 1) => (2).
If axis is -1, it is treated as axis=rank(X)-rank(Y).
For example:
For example
.. code-block:: python
shape(X) = (2, 3, 4, 5), shape(Y) = (,)
shape(X) = (2, 3, 4, 5), shape(Y) = (5,)
shape(X) = (2, 3, 4, 5), shape(Y) = (4, 5)
shape(X) = (2, 3, 4, 5), shape(Y) = (4, 5), with axis=-1(default) or axis=2
shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1
shape(X) = (2, 3, 4, 5), shape(Y) = (2), with axis=0
shape(X) = (2, 3, 4, 5), shape(Y) = (2, 1), with axis=0
Either of the inputs $X$ and $Y$ or none can carry the LoD (Level of Details)
information. However, the output only shares the LoD information with input $X$.
The inputs $X$ and $Y$ can carry the different LoD information.
But the output only shares the LoD information with the input $X$.
)DOC",
GetName(), GetEquation()));
......
......@@ -96,19 +96,22 @@ static int64_t GetTimestamp() {
return tp.tv_sec * 1000 + tp.tv_usec / 1000;
}
void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
framework::ProgramDesc *program,
framework::Scope *recv_scope,
framework::BlockDesc *prefetch_block) const {
void ListenAndServOp::RunSyncLoop(
framework::Executor *executor, framework::ProgramDesc *program,
framework::Scope *recv_scope,
const std::vector<int> &prefetch_block_id_list) const {
size_t num_blocks = program->Size();
PADDLE_ENFORCE_GE(num_blocks, 2,
"server program should have at least 2 blocks");
std::vector<int> block_list;
for (size_t blkid = 1; blkid < num_blocks; ++blkid) {
block_list.push_back(blkid);
std::vector<int> optimize_block_id_list;
for (int blkid = 1; blkid < num_blocks; ++blkid) {
if (std::find(prefetch_block_id_list.begin(), prefetch_block_id_list.end(),
blkid) == prefetch_block_id_list.end()) {
optimize_block_id_list.push_back(blkid);
}
}
auto optimize_prepared = executor->Prepare(*program, block_list);
auto optimize_prepared = executor->Prepare(*program, optimize_block_id_list);
// Insert placeholder for block0 which holds current op itself.
optimize_prepared.insert(
optimize_prepared.begin(),
......@@ -135,16 +138,17 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
std::vector<size_t> parallel_blkids;
parallel_blkids.push_back(1);
double ts = GetTimestamp();
for (size_t blkid = 2; blkid < num_blocks; ++blkid) {
if (blkid != static_cast<size_t>(prefetch_block->ID())) {
if (program->Block(blkid).Parent() != last_parent_blkid) {
ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared,
program, recv_scope);
parallel_blkids.clear();
last_parent_blkid = program->Block(blkid).Parent();
}
parallel_blkids.push_back(blkid);
for (size_t i = 1; i < optimize_block_id_list.size(); ++i) {
// skip the first optimize block because it is already in the
// parallel_blkids.
int blkid = optimize_block_id_list[i];
if (program->Block(blkid).Parent() != last_parent_blkid) {
ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared,
program, recv_scope);
parallel_blkids.clear();
last_parent_blkid = program->Block(blkid).Parent();
}
parallel_blkids.push_back(blkid);
}
ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared, program,
recv_scope);
......@@ -210,18 +214,19 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
} // while(true)
}
static void FillRequestCtx(detail::RequestHandler *h, framework::Scope *scope,
platform::DeviceContext *dev_ctx,
framework::Executor *executor,
framework::ProgramDesc *program,
framework::ExecutorPrepareContext *prefetch_ctx,
detail::RPCServer *rpc_server) {
static void FillRequestCtx(
detail::RequestHandler *h, framework::Scope *scope,
platform::DeviceContext *dev_ctx, framework::Executor *executor,
framework::ProgramDesc *program,
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>
*prefetch_ctx,
detail::RPCServer *rpc_server) {
h->SetScope(scope);
h->SetDevCtx(dev_ctx);
h->SetExecutor(executor);
h->SetProgram(program);
h->SetPrefetchPreparedCtx(
std::unique_ptr<framework::ExecutorPrepareContext>(prefetch_ctx));
h->SetPrefetchPreparedCtx(prefetch_ctx);
h->SetRPCServer(rpc_server);
}
......@@ -255,17 +260,42 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
request_prefetch_handler_.get());
auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock);
auto *prefetch_block = Attr<framework::BlockDesc *>(kPrefetchBlock);
auto *program = optimize_block->Program();
framework::Executor executor(dev_place);
// prepare for prefetch
VLOG(3) << "prefetch block id is " << prefetch_block->ID();
auto prefetch_prepared = executor.Prepare(*program, prefetch_block->ID());
std::vector<int> prefetch_block_id_list;
std::unordered_map<int, std::string> block_id_to_prefetch_var_name;
auto prefetch_var_name_to_block_id_str =
Attr<std::vector<std::string>>(kPrefetchVarNameToBlockId);
for (const auto &prefetch_var_name_and_id :
prefetch_var_name_to_block_id_str) {
std::vector<std::string> pieces;
split(prefetch_var_name_and_id, ':', &pieces);
VLOG(3) << "after split, prefetch_var = " << pieces[0]
<< ", id=" << pieces[1];
PADDLE_ENFORCE_EQ(pieces.size(), 2);
int block_id = std::stoi(pieces[1]);
prefetch_block_id_list.push_back(block_id);
block_id_to_prefetch_var_name[block_id] = pieces[0];
}
auto prefetch_prepared = executor.Prepare(*program, prefetch_block_id_list);
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>
prefetch_var_name_to_prepared_ctx;
for (size_t i = 0; i < prefetch_block_id_list.size(); ++i) {
auto block_id = prefetch_block_id_list[i];
auto prefetch_var_name = block_id_to_prefetch_var_name[block_id];
prefetch_var_name_to_prepared_ctx[prefetch_var_name] = prefetch_prepared[i];
}
auto f = std::bind(FillRequestCtx, std::placeholders::_1, &recv_scope,
&dev_ctx, &executor, program, prefetch_prepared.release(),
rpc_service_.get());
&dev_ctx, &executor, program,
&prefetch_var_name_to_prepared_ctx, rpc_service_.get());
f(request_send_handler_.get());
f(request_get_handler_.get());
......@@ -283,7 +313,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
// Write to a file of server selected port for python use.
SavePort();
if (sync_mode) {
RunSyncLoop(&executor, program, &recv_scope, prefetch_block);
RunSyncLoop(&executor, program, &recv_scope, prefetch_block_id_list);
} else {
RunAsyncLoop(&executor, program);
}
......@@ -309,8 +339,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<bool>("sync_mode", "if works at sync_mode or not").SetDefault(true);
AddAttr<framework::BlockDesc *>(kOptimizeBlock,
"BlockID to run on server side.");
AddAttr<framework::BlockDesc *>(kPrefetchBlock,
"prefetch block to run on server side.");
AddAttr<std::vector<std::string>>(kPrefetchVarNameToBlockId,
"prefetch blocks to run on server side.")
.SetDefault({});
AddAttr<int>("Fanin", "How many clients send to this server.")
.SetDefault(1);
}
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include <atomic>
#include <set>
#include <string>
#include <vector>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
......@@ -30,7 +31,7 @@ namespace paddle {
namespace operators {
constexpr char kOptimizeBlock[] = "OptimizeBlock";
constexpr char kPrefetchBlock[] = "PrefetchBlock";
constexpr char kPrefetchVarNameToBlockId[] = "prefetch_var_name_to_block_id";
void RunServer(std::shared_ptr<detail::RPCServer> service);
......@@ -46,7 +47,7 @@ class ListenAndServOp : public framework::OperatorBase {
void RunSyncLoop(framework::Executor* executor,
framework::ProgramDesc* program,
framework::Scope* recv_scope,
framework::BlockDesc* prefetch_block) const;
const std::vector<int>& prefetch_block_id_list) const;
void RunAsyncLoop(framework::Executor* executor,
framework::ProgramDesc* program) const;
......
......@@ -20,7 +20,7 @@ namespace reader {
class BatchReader : public framework::DecoratedReader {
public:
BatchReader(ReaderBase* reader, int batch_size)
BatchReader(const std::shared_ptr<ReaderBase>& reader, int batch_size)
: DecoratedReader(reader), batch_size_(batch_size) {
buffer_.reserve(batch_size_);
}
......
......@@ -22,7 +22,8 @@ namespace reader {
class CustomReader : public framework::DecoratedReader {
public:
CustomReader(ReaderBase* reader, const framework::BlockDesc& sub_block,
CustomReader(const std::shared_ptr<ReaderBase>& reader,
const framework::BlockDesc& sub_block,
const std::vector<std::string>& source_var_names,
const std::vector<std::string>& sink_var_names)
: DecoratedReader(reader),
......
......@@ -34,7 +34,8 @@ static constexpr size_t kChannelSize = 1; // kCacheSize - 2
class DoubleBufferReader : public framework::DecoratedReader {
public:
explicit DoubleBufferReader(
ReaderBase* reader, platform::Place target_place = platform::CPUPlace())
const std::shared_ptr<ReaderBase>& reader,
platform::Place target_place = platform::CPUPlace())
: DecoratedReader(reader), place_(target_place) {
cpu_tensor_cache_.resize(kCacheSize);
gpu_tensor_cache_.resize(kCacheSize);
......
......@@ -21,7 +21,7 @@ namespace reader {
class MultiPassReader : public framework::DecoratedReader {
public:
MultiPassReader(ReaderBase* reader, int pass_num)
MultiPassReader(const std::shared_ptr<ReaderBase>& reader, int pass_num)
: DecoratedReader(reader), pass_num_(pass_num), pass_count_(0) {}
void ReadNext(std::vector<framework::LoDTensor>* out) override {
......
......@@ -23,7 +23,8 @@ namespace reader {
class ShuffleReader : public framework::DecoratedReader {
public:
ShuffleReader(ReaderBase* reader, size_t buffer_size, size_t seed = 0)
ShuffleReader(const std::shared_ptr<ReaderBase>& reader, size_t buffer_size,
size_t seed = 0)
: DecoratedReader(reader), buffer_size_(buffer_size), seed_(seed) {
VLOG(10) << "Create shuffle reader of " << reader_;
if (seed_ == 0) {
......
......@@ -21,7 +21,8 @@ namespace reader {
class ThreadedReader : public framework::DecoratedReader {
public:
explicit ThreadedReader(ReaderBase* reader) : DecoratedReader(reader) {}
explicit ThreadedReader(const std::shared_ptr<ReaderBase>& reader)
: DecoratedReader(reader) {}
void ReadNext(std::vector<framework::LoDTensor>* out) override {
std::lock_guard<std::mutex> lock(mutex_);
......
......@@ -21,12 +21,17 @@ limitations under the License. */
#include <unistd.h>
#endif
#include <algorithm>
#include "gflags/gflags.h"
DEFINE_double(fraction_of_cpu_memory_to_use, 1,
"Default use 100% of CPU memory for PaddlePaddle,"
"reserve the rest for page tables, etc");
DEFINE_uint64(
initial_cpu_memory_in_mb, 500,
"Default initial 500MB of CPU memory for PaddlePaddle, in MD unit.");
DEFINE_double(
fraction_of_cuda_pinned_memory_to_use, 0.5,
"Default use 50% of CPU memory as the pinned_memory for PaddlePaddle,"
......@@ -54,7 +59,9 @@ inline size_t CpuTotalPhysicalMemory() {
size_t CpuMaxAllocSize() {
// For distributed systems, it requires configuring and limiting
// the fraction of memory to use.
return FLAGS_fraction_of_cpu_memory_to_use * CpuTotalPhysicalMemory();
return std::min(static_cast<size_t>(FLAGS_fraction_of_cpu_memory_to_use *
CpuTotalPhysicalMemory()),
FLAGS_initial_cpu_memory_in_mb * 1 << 20);
}
size_t CpuMinChunkSize() {
......
......@@ -413,6 +413,9 @@ All parameter, weight, gradient are variables in Paddle.
py::class_<framework::Executor>(m, "Executor")
.def(py::init<const platform::Place &>())
#ifdef PADDLE_WITH_DISTRIBUTE
.def("complete", &Executor::Complete)
#endif
.def("run",
(void (Executor::*)(const ProgramDesc &, Scope *, int, bool, bool)) &
Executor::Run);
......
此差异已折叠。
......@@ -515,35 +515,38 @@ class DistributeTranspiler:
grad_to_block_id, None)
# process distributed lookup_table
prefetch_block = None
prefetch_var_name_to_block_id = []
if self.has_distributed_lookup_table:
pserver_index = self.pserver_endpoints.index(endpoint)
table_opt_block = self._create_table_optimize_block(
pserver_index, pserver_program, pre_block_idx, grad_to_block_id)
prefetch_block = self._create_prefetch_block(
prefetch_var_name_to_block_id = self._create_prefetch_block(
pserver_index, pserver_program, table_opt_block)
# NOTE: if has_distributed_lookup_table is False, then prefetch_block will
# not be executed, so it's safe to use optimize_block to hold the place
if self.has_distributed_lookup_table:
assert prefetch_block is not None
assert len(prefetch_var_name_to_block_id) > 0
else:
assert prefetch_block is None
prefetch_block = pserver_program.global_block()
assert len(prefetch_var_name_to_block_id) == 0
attrs = {
"OptimizeBlock": pserver_program.block(1),
"endpoint": endpoint,
"Fanin": self.trainer_num,
"sync_mode": self.sync_mode,
"grad_to_block_id": grad_to_block_id
}
if len(prefetch_var_name_to_block_id) > 0:
attrs['prefetch_var_name_to_block_id'] \
= prefetch_var_name_to_block_id
# step5 append the listen_and_serv op
pserver_program.global_block().append_op(
type="listen_and_serv",
inputs={'X': recv_inputs},
outputs={},
attrs={
"OptimizeBlock": pserver_program.block(1),
"endpoint": endpoint,
"Fanin": self.trainer_num,
"PrefetchBlock": prefetch_block,
"sync_mode": self.sync_mode,
"grad_to_block_id": grad_to_block_id
})
attrs=attrs)
pserver_program.sync_with_cpp()
return pserver_program
......@@ -608,8 +611,15 @@ class DistributeTranspiler:
def _replace_lookup_table_op_with_prefetch(self, program,
pserver_endpoints):
# 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op
self.prefetch_input_vars = None
self.prefetch_output_vars = None
# self.all_prefetch_input_vars =
# [[var0_prefetch_in_pserver0, var0_prefetch_in_pserver1]
# [var1_prefetch_in_pserver0, var1_prefetch_in_pserver1]]
self.all_prefetch_input_vars = []
# self.all_prefetch_input_vars =
# [[var0_prefetch_in_pserver0, var0_prefetch_in_pserver1]
# [var1_prefetch_in_pserver0, var1_prefetch_in_pserver1]]
self.all_prefetch_output_vars = []
continue_search_lookup_table_op = True
while continue_search_lookup_table_op:
......@@ -623,18 +633,19 @@ class DistributeTranspiler:
ids_name = op.input("Ids")
out_name = op.output("Out")
if self.prefetch_input_vars is None:
ids_var = program.global_block().vars[ids_name[0]]
self.prefetch_input_vars = self.create_splited_vars(
source_var=ids_var,
block=program.global_block(),
tag="_prefetch_in_")
if self.prefetch_output_vars is None:
out_var = program.global_block().vars[out_name[0]]
self.prefetch_output_vars = self.create_splited_vars(
source_var=out_var,
block=program.global_block(),
tag="_prefetch_out_")
ids_var = program.global_block().vars[ids_name[0]]
prefetch_input_vars = self.create_splited_vars(
source_var=ids_var,
block=program.global_block(),
tag="_prefetch_in_")
self.all_prefetch_input_vars.append(prefetch_input_vars)
out_var = program.global_block().vars[out_name[0]]
prefetch_output_vars = self.create_splited_vars(
source_var=out_var,
block=program.global_block(),
tag="_prefetch_out_")
self.all_prefetch_output_vars.append(prefetch_output_vars)
# insert split_ids_op
program.global_block().insert_op(
......@@ -646,14 +657,14 @@ class DistributeTranspiler:
for varname in ids_name
]
},
outputs={"Out": self.prefetch_input_vars})
outputs={"Out": prefetch_input_vars})
# insert prefetch_op
program.global_block().insert_op(
index=op_index + 1,
type="prefetch",
inputs={'X': self.prefetch_input_vars},
outputs={"Out": self.prefetch_output_vars},
inputs={'X': prefetch_input_vars},
outputs={"Out": prefetch_output_vars},
attrs={
"epmap": pserver_endpoints,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
......@@ -663,7 +674,7 @@ class DistributeTranspiler:
program.global_block().insert_op(
index=op_index + 2,
type="concat",
inputs={'X': self.prefetch_output_vars},
inputs={'X': prefetch_output_vars},
outputs={
"Out": [
program.global_block().vars[varname]
......@@ -709,30 +720,34 @@ class DistributeTranspiler:
optimize_block):
# STEP: create prefetch block
table_var = pserver_program.global_block().vars[self.table_name]
prefetch_block = pserver_program.create_block(optimize_block.idx)
trainer_ids = self.prefetch_input_vars[pserver_index]
pserver_ids = pserver_program.global_block().create_var(
name=trainer_ids.name,
type=trainer_ids.type,
shape=trainer_ids.shape,
dtype=trainer_ids.dtype)
trainer_out = self.prefetch_output_vars[pserver_index]
pserver_out = pserver_program.global_block().create_var(
name=trainer_out.name,
type=trainer_out.type,
shape=trainer_out.shape,
dtype=trainer_out.dtype)
prefetch_block.append_op(
type="lookup_sparse_table",
inputs={'Ids': pserver_ids,
"W": table_var},
outputs={"Out": pserver_out},
attrs={
"is_sparse": True, # has no effect on lookup_table op
"is_distributed": True,
"padding_idx": -1
})
return prefetch_block
prefetch_var_name_to_block_id = []
for index in range(len(self.all_prefetch_input_vars)):
prefetch_block = pserver_program.create_block(optimize_block.idx)
trainer_ids = self.all_prefetch_input_vars[index][pserver_index]
pserver_ids = pserver_program.global_block().create_var(
name=trainer_ids.name,
type=trainer_ids.type,
shape=trainer_ids.shape,
dtype=trainer_ids.dtype)
trainer_out = self.all_prefetch_output_vars[index][pserver_index]
pserver_out = pserver_program.global_block().create_var(
name=trainer_out.name,
type=trainer_out.type,
shape=trainer_out.shape,
dtype=trainer_out.dtype)
prefetch_block.append_op(
type="lookup_sparse_table",
inputs={'Ids': pserver_ids,
"W": table_var},
outputs={"Out": pserver_out},
attrs={
"is_sparse": True, # has no effect on lookup_table op
"is_distributed": True,
"padding_idx": -1
})
prefetch_var_name_to_block_id.append(trainer_ids.name + ":" + str(
prefetch_block.idx))
return prefetch_var_name_to_block_id
def _create_table_optimize_block(self, pserver_index, pserver_program,
pre_block_idx, grad_to_block_id):
......
......@@ -126,9 +126,10 @@ class DocstringChecker(BaseChecker):
'W9002':
('Doc string does not end with "." period', symbol + "-end-with",
'Used when a doc string does not end with a period'),
'W9003': ('All args with their types must be mentioned in doc string',
symbol + "-with-all-args",
'Used when not all arguments are in the doc string '),
'W9003':
('All args with their types must be mentioned in doc string %s',
symbol + "-with-all-args",
'Used when not all arguments are in the doc string '),
'W9005': ('Missing docstring or docstring is too short',
symbol + "-missing", 'Add docstring longer >=10'),
'W9006': ('Docstring indent error, use 4 space for indent',
......@@ -178,6 +179,8 @@ class DocstringChecker(BaseChecker):
self.indent_style(node)
def missing_doc_string(self, node):
if node.name.startswith("__") or node.name.startswith("_"):
return True
if node.tolineno - node.fromlineno <= 10:
return True
......@@ -199,12 +202,16 @@ class DocstringChecker(BaseChecker):
doc = node.doc
lines = doc.splitlines()
line_num = 0
for l in lines:
if line_num == 0:
continue
cur_indent = len(l) - len(l.lstrip())
if cur_indent % indent != 0:
self.add_message('W9006', node=node, line=node.fromlineno)
return False
line_num += 1
return True
......@@ -320,15 +327,19 @@ class DocstringChecker(BaseChecker):
return True
parsed_args = doc.args
args_not_documented = set(args) - set(parsed_args)
if len(args) > 0 and len(parsed_args) <= 0:
print "debug:parsed args: ", parsed_args
self.add_message('W9003', node=node, line=node.fromlineno)
self.add_message(
'W9003',
node=node,
line=node.fromlineno,
args=list(args_not_documented))
return False
for t in args:
if t not in parsed_args:
print t, " with (type) not in ", parsed_args
self.add_message('W9003', node=node, line=node.fromlineno)
self.add_message(
'W9003', node=node, line=node.fromlineno, args=[t, ])
return False
return True
......@@ -7,13 +7,13 @@ DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
export PYTHONPATH=$DIR:$PYTHONPATH
# The trick to remove deleted files: https://stackoverflow.com/a/2413151
for file in $(git diff --cached --name-status | awk '$1 != "D" {print $2}'); do
for file in $(git diff --name-status | awk '$1 != "D" {print $2}'); do
pylint --disable=all --load-plugins=docstring_checker \
--enable=doc-string-one-line,doc-string-end-with,doc-string-with-all-args,doc-string-triple-quotes,doc-string-missing,doc-string-indent-error,doc-string-with-returns,doc-string-with-raises $file;
TOTAL_ERRORS=$(expr $TOTAL_ERRORS + $?);
done
#exit $TOTAL_ERRORS
exit $TOTAL_ERRORS
#For now, just warning:
exit 0
#exit 0
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册