提交 13ecb5e5 编写于 作者: Q qiaolongfei

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into...

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into lookup_table_support_SelectedRows_as_parameter
......@@ -17,8 +17,6 @@ limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/threadpool.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/nccl_helper.h"
#endif
......
......@@ -2,7 +2,7 @@ if(WITH_DISTRIBUTE)
grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor selected_rows)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(serde_test.cc grpc_server_test PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(serde_test SRCS serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr
cares zlib protobuf sendrecvop_grpc)
cc_test(grpc_server_test SRCS grpc_server_test.cc DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf)
......
......@@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "grpc_client.h"
#include <sys/time.h>
#include "paddle/fluid/operators/detail/grpc_client.h"
#include <limits>
#include "paddle/fluid/framework/threadpool.h"
namespace paddle {
......@@ -52,7 +54,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_);
call->StartCall();
call->Finish(&s->reply_, &s->status_, (void*)s);
call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
});
req_count_++;
......@@ -70,8 +72,7 @@ void ProcGetResponse(const VarHandle& var_h,
template <typename T>
void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) {
::grpc::Slice slice(proto.ByteSizeLong());
proto.SerializeWithCachedSizesToArray(
const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(slice.begin())));
proto.SerializeWithCachedSizesToArray(const_cast<uint8_t*>(slice.begin()));
::grpc::ByteBuffer tmp(&slice, 1);
result->Swap(&tmp);
}
......@@ -109,7 +110,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/GetVariable", buf, &cq_);
call->StartCall();
call->Finish(&s->reply_, &s->status_, (void*)s);
call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
});
req_count_++;
......@@ -153,7 +154,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req,
&cq_);
call->StartCall();
call->Finish(&s->reply_, &s->status_, (void*)s);
call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
});
req_count_++;
......@@ -169,7 +170,7 @@ void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
sendrecv::VariableMessage req;
req.set_varname(BATCH_BARRIER_MESSAGE);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, (void*)s);
rpc->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
req_count_++;
}
......@@ -181,7 +182,7 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
sendrecv::VariableMessage req;
req.set_varname(FETCH_BARRIER_MESSAGE);
auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, (void*)s);
rpc->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
req_count_++;
}
......
......@@ -14,6 +14,9 @@ limitations under the License. */
#include "paddle/fluid/operators/detail/grpc_server.h"
#include <limits>
#include <string>
using ::grpc::ServerAsyncResponseWriter;
namespace paddle {
......@@ -156,6 +159,8 @@ class RequestPrefetch final : public RequestBase {
::grpc::ByteBuffer reply;
// TODO(Yancey1989): execute the Block which containers prefetch ops
VLOG(3) << "RequestPrefetch Process in";
responder_.Finish(reply, ::grpc::Status::OK, this);
status_ = FINISH;
}
......@@ -221,6 +226,7 @@ void AsyncGRPCServer::ShutdownQueue() {
std::unique_lock<std::mutex> lock(cq_mutex_);
cq_send_->Shutdown();
cq_get_->Shutdown();
cq_prefetch_->Shutdown();
}
// This URL explains why shutdown is complicate:
......@@ -233,6 +239,7 @@ void AsyncGRPCServer::ShutDown() {
void AsyncGRPCServer::TryToRegisterNewSendOne() {
std::unique_lock<std::mutex> lock(cq_mutex_);
if (is_shut_down_) {
VLOG(3) << "shutdown, do not TryToRegisterNewSendOne";
return;
}
RequestSend* send = new RequestSend(&service_, cq_send_.get(), scope_,
......@@ -243,6 +250,7 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() {
void AsyncGRPCServer::TryToRegisterNewGetOne() {
std::unique_lock<std::mutex> lock(cq_mutex_);
if (is_shut_down_) {
VLOG(3) << "shutdown, do not TryToRegisterNewGetOne";
return;
}
RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_,
......@@ -253,6 +261,7 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
void AsyncGRPCServer::TryToRegisterNewPrefetchOne() {
std::unique_lock<std::mutex> lock(cq_mutex_);
if (is_shut_down_) {
VLOG(3) << "shutdown, do not TryToRegisterNewPrefetchOne";
return;
}
RequestPrefetch* prefetch =
......@@ -270,25 +279,28 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
void* tag = NULL;
bool ok = false;
while (true) {
VLOG(3) << "HandleRequest for " << cq_name << " while in";
if (!cq->Next(&tag, &ok)) {
LOG(INFO) << cq_name << " CompletionQueue shutdown!";
break;
}
VLOG(3) << "HandleRequest for " << cq_name << " while after Next";
PADDLE_ENFORCE(tag);
// FIXME(typhoonzero): de-couple the barriers with recv_op
if (!is_shut_down_ && cq_name == "cq_get") WaitCond(1);
if (!is_shut_down_ && cq_name == "cq_send") WaitCond(0);
RequestBase* base = (RequestBase*)tag;
RequestBase* base = reinterpret_cast<RequestBase*>(tag);
// reference:
// https://github.com/tensorflow/tensorflow/issues/5596
// https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM
// https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I
if (!ok) {
LOG(WARNING) << cq_name << " recv no regular event:argument name"
<< base->GetReqName();
LOG(WARNING) << cq_name << " recv no regular event:argument name["
<< base->GetReqName() << "]";
TryToRegisterNewOne();
delete base;
continue;
......
......@@ -15,7 +15,8 @@ limitations under the License. */
#pragma once
#include <grpc++/grpc++.h>
#include <thread>
#include <string>
#include <utility>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
......@@ -93,6 +94,7 @@ class AsyncGRPCServer final {
// received variable from RPC, operators fetch variable from this queue.
SimpleBlockQueue<MessageWithName> var_get_queue_;
// client send variable to this queue.
ReceivedQueue var_recv_queue_;
// condition of the sub program
......
......@@ -28,6 +28,7 @@ std::unique_ptr<detail::AsyncGRPCServer> rpc_service_;
void StartServer(const std::string& endpoint) {
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint));
rpc_service_->RunSyncUpdate();
}
TEST(PREFETCH, CPU) {
......@@ -39,13 +40,23 @@ TEST(PREFETCH, CPU) {
platform::CPUPlace place;
platform::CPUDeviceContext ctx(place);
// create var on local scope
std::string var_name("tmp_0");
auto var = scope.Var(var_name);
auto tensor = var->GetMutable<framework::LoDTensor>();
tensor->Resize({10, 10});
std::string in_var_name("in");
std::string out_var_name("out");
auto* in_var = scope.Var(in_var_name);
auto* in_tensor = in_var->GetMutable<framework::LoDTensor>();
in_tensor->Resize({10, 10});
VLOG(3) << "before mutable_data";
in_tensor->mutable_data<int>(place);
scope.Var(out_var_name);
VLOG(3) << "before fetch";
detail::RPCClient client;
client.AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, var_name, "");
client.AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name,
out_var_name);
client.Wait();
rpc_service_->ShutDown();
server_thread.join();
rpc_service_.reset(nullptr);
}
......@@ -80,7 +80,7 @@ enum class GrpcMethod {
};
static const int kGrpcNumMethods =
static_cast<int>(GrpcMethod::kGetVariable) + 1;
static_cast<int>(GrpcMethod::kPrefetchVariable) + 1;
inline const char* GrpcMethodName(GrpcMethod id) {
switch (id) {
......@@ -89,7 +89,7 @@ inline const char* GrpcMethodName(GrpcMethod id) {
case GrpcMethod::kGetVariable:
return "/sendrecv.SendRecvService/GetVariable";
case GrpcMethod::kPrefetchVariable:
return "/sendrecv.SendREcvService/PrefetchVariable";
return "/sendrecv.SendRecvService/PrefetchVariable";
}
// Shouldn't be reached.
......@@ -117,5 +117,5 @@ class GrpcService final {
};
} // namespace detail
} // namespace operator
} // namespace operators
} // namespace paddle
......@@ -13,22 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <stdint.h>
#include <sys/stat.h>
#include <ostream>
#include <thread>
#include <unistd.h>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/proto_desc.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/simple_block_queue.h"
#include "paddle/fluid/string/printf.h"
namespace paddle {
namespace operators {
......@@ -111,6 +102,11 @@ class ListenAndServOp : public framework::OperatorBase {
framework::Executor executor(dev_place);
// TODO(qiao) set proper fields for table lookup and update
rpc_service_->SetExecutor(&executor);
rpc_service_->SetPrefetchBlkdId(0);
rpc_service_->SetProgram(program);
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
bool exit_flag = false;
// Record received sparse variables, so that
......@@ -173,7 +169,8 @@ class ListenAndServOp : public framework::OperatorBase {
}
ParallelExecuteBlocks(parallel_blkids, &executor, program, &recv_scope);
VLOG(2) << "run all blocks spent (ms) " << detail::GetTimestamp() - ts;
VLOG(3) << "run all blocks spent " << detail::GetTimestamp() - ts
<< "(ms)";
// Reset the received sparse variables, the sum operator would not
// sum the input sparse variables which rows is empty at the next
......
......@@ -20,12 +20,29 @@ namespace paddle {
namespace operators {
namespace reader {
static constexpr size_t kDoubleBufferSize = 2;
// 'Double buffer' means we shall maintain two batches of input data at the same
// time. So the kCacheSize shoul be at least 2.
static constexpr size_t kCacheSize = 2;
// There will be two bacthes out of the channel during training:
// 1. the one waiting to be sent to the channel
// 2. the one just be received from the channel, which is also being used by
// subsequent operators.
// So the channel size should be kChacheSize - 2
static constexpr size_t kChannelSize = 0; // kCacheSize - 2
class DoubleBufferReader : public framework::DecoratedReader {
public:
struct Item {
Item() : ctx_(nullptr) {}
Item(Item&& b) {
payloads_ = std::move(b.payloads_);
ctx_ = std::move(b.ctx_);
}
Item& operator=(Item&& b) {
payloads_ = std::move(b.payloads_);
ctx_ = std::move(b.ctx_);
return *this;
}
std::vector<framework::LoDTensor> payloads_;
platform::DeviceContext* ctx_;
......@@ -34,42 +51,44 @@ class DoubleBufferReader : public framework::DecoratedReader {
explicit DoubleBufferReader(
ReaderBase* reader, platform::Place target_place = platform::CPUPlace())
: DecoratedReader(reader), place_(target_place) {
for (size_t i = 0; i < kDoubleBufferSize; ++i) {
if (platform::is_gpu_place(place_)) {
#ifdef PADDLE_WITH_CUDA
for (size_t i = 0; i < kCacheSize; ++i) {
if (platform::is_gpu_place(place_)) {
ctxs_.emplace_back(new platform::CUDADeviceContext(
boost::get<platform::CUDAPlace>(place_)));
#endif
}
}
start_thread();
}
void start_thread() {
buffer_ = framework::MakeChannel<Item>(kDoubleBufferSize);
prefetcher_ = std::thread([this] { PrefetchThreadFunc(); });
#endif
StartPrefetcher();
}
bool HasNext() const override;
void ReadNext(std::vector<framework::LoDTensor>* out) override;
void ReInit() override;
~DoubleBufferReader() {
buffer_->Close();
prefetcher_.join();
delete buffer_;
~DoubleBufferReader() { EndPrefetcher(); }
private:
void StartPrefetcher() {
channel_ = framework::MakeChannel<Item>(kChannelSize);
prefetcher_ = std::thread([this] { PrefetchThreadFunc(); });
}
bool HasNext() const override;
void EndPrefetcher() {
channel_->Close();
if (prefetcher_.joinable()) {
prefetcher_.join();
}
delete channel_;
channel_ = nullptr;
}
private:
void PrefetchThreadFunc();
std::thread prefetcher_;
framework::Channel<Item>* buffer_;
framework::Channel<Item>* channel_;
platform::Place place_;
std::vector<std::unique_ptr<platform::DeviceContext>> ctxs_;
mutable Item local_buffer_;
};
class CreateDoubleBufferReaderOp : public framework::OperatorBase {
......@@ -123,70 +142,70 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
}
};
bool DoubleBufferReader::HasNext() const {
while (!channel_->IsClosed() && !channel_->CanReceive()) {
}
return channel_->CanReceive();
}
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
if (!HasNext()) {
PADDLE_THROW("There is no next data!");
}
if (local_buffer_.payloads_.empty()) {
buffer_->Receive(&local_buffer_);
}
*out = local_buffer_.payloads_;
local_buffer_.payloads_.clear();
if (local_buffer_.ctx_) {
local_buffer_.ctx_->Wait();
Item batch;
channel_->Receive(&batch);
*out = batch.payloads_;
if (batch.ctx_) {
batch.ctx_->Wait();
}
}
void DoubleBufferReader::ReInit() {
reader_->ReInit();
buffer_->Close();
prefetcher_.join();
delete buffer_;
start_thread();
EndPrefetcher();
StartPrefetcher();
}
void DoubleBufferReader::PrefetchThreadFunc() {
VLOG(5) << "A new prefetch thread starts.";
size_t gpu_ctx_offset = 0;
std::vector<std::vector<framework::LoDTensor>> cpu_tensor_cache(kCacheSize);
std::vector<std::vector<framework::LoDTensor>> gpu_tensor_cache(kCacheSize);
size_t cached_tensor_id = 0;
while (reader_->HasNext()) {
Item batch;
reader_->ReadNext(&batch.payloads_);
auto& cpu_batch = cpu_tensor_cache[cached_tensor_id];
reader_->ReadNext(&cpu_batch);
if (platform::is_gpu_place(place_)) {
std::vector<framework::LoDTensor> gpu_batch;
auto& gpu_ctx = this->ctxs_[gpu_ctx_offset++];
gpu_ctx_offset %= this->ctxs_.size();
gpu_batch.resize(batch.payloads_.size());
for (size_t i = 0; i < batch.payloads_.size(); ++i) {
framework::TensorCopy(batch.payloads_[i], place_, *gpu_ctx,
&gpu_batch[i]);
gpu_batch[i].set_lod(batch.payloads_[i].lod());
}
batch.ctx_ = gpu_ctx.get();
std::swap(gpu_batch, batch.payloads_);
auto& gpu_batch = gpu_tensor_cache[cached_tensor_id];
auto* gpu_ctx = ctxs_[cached_tensor_id].get();
gpu_batch.resize(cpu_batch.size());
for (size_t i = 0; i < cpu_batch.size(); ++i) {
framework::TensorCopy(cpu_batch[i], place_, *gpu_ctx, &gpu_batch[i]);
gpu_batch[i].set_lod(cpu_batch[i].lod());
}
batch.payloads_ = gpu_batch;
batch.ctx_ = gpu_ctx;
} else {
// CPUPlace
batch.payloads_ = cpu_batch;
}
++cached_tensor_id;
cached_tensor_id %= kCacheSize;
try {
buffer_->Send(&batch);
channel_->Send(&batch);
} catch (paddle::platform::EnforceNotMet e) {
VLOG(5) << "WARNING: The double buffer channel has been closed. The "
"prefetch thread will terminate.";
break;
}
}
buffer_->Close();
channel_->Close();
VLOG(5) << "Prefetch thread terminates.";
}
bool DoubleBufferReader::HasNext() const {
if (local_buffer_.payloads_.empty()) {
bool ok = buffer_->Receive(&local_buffer_);
return ok;
} else {
return true;
}
}
} // namespace reader
} // namespace operators
} // namespace paddle
......
......@@ -276,12 +276,15 @@ class DistributeTranspiler:
suff_idx = v.name.find(".trainer_")
if suff_idx >= 0:
orig_var_name = v.name[:suff_idx]
pserver_program.global_block().create_var(
else:
orig_var_name = v.name
single_trainer_var = pserver_program.global_block().create_var(
name=orig_var_name,
persistable=True,
type=v.type,
dtype=v.dtype,
shape=v.shape)
if self.trainers > 1:
for trainer_id in xrange(self.trainers):
var = pserver_program.global_block().create_var(
name="%s.trainer_%d" % (orig_var_name, trainer_id),
......@@ -290,6 +293,8 @@ class DistributeTranspiler:
dtype=v.dtype,
shape=v.shape)
recv_inputs.append(var)
else:
recv_inputs.append(single_trainer_var)
# step3
optimize_block = pserver_program.create_block(0)
......@@ -511,8 +516,11 @@ class DistributeTranspiler:
def _append_split_op(self, program, gradblocks):
# Split variables that need to be split and append respective ops
add_suffix = False
if self.trainers > 1:
add_suffix = True
var_mapping = self._create_vars_from_blocklist(
program, gradblocks, add_trainer_suffix=True)
program, gradblocks, add_trainer_suffix=add_suffix)
for varname, splited_vars in var_mapping.iteritems():
# variable that don't need to split have empty splited_vars
if len(splited_vars) <= 1:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册