未验证 提交 0a8a86e0 编写于 作者: 武毅 提交者: GitHub

Send recv op (#5520)

* WIP send recv op

* WIP send recv

* put grpc impl in details

* put grpc impl in details

* update wip

* update proto

* update proto

* update proto

* clean cmake

* wip on op implementations

* wip on op implementations

* compile ok adding ut

* wip unitest

* add extern cares for linking

* wip add ut

* working version send recv

* revert optimizer.py

* update test cmake

* add libtool to dockerfile

* update cmake dependency

* update cmake depends

* update cmake grpc depends

* fix cmake dependency

* fix compile error

* fix compile

* follow comments

* update

* update copyfrom
上级 dc82a309
......@@ -25,4 +25,3 @@ AllowAllParametersOfDeclarationOnNextLine: true
BinPackParameters: false
BinPackArguments: false
...
......@@ -133,6 +133,8 @@ include(external/any) # download libn::any
include(external/eigen) # download eigen3
include(external/pybind11) # download pybind11
include(external/nccl)
include(external/cares)
include(external/grpc)
include(cudnn) # set cudnn libraries, must before configure
include(configure) # add paddle env configuration
......
......@@ -29,7 +29,7 @@ RUN apt-get update && \
automake locales clang-format swig doxygen cmake \
liblapack-dev liblapacke-dev libboost-dev \
clang-3.8 llvm-3.8 libclang-3.8-dev \
net-tools && \
net-tools libtool && \
apt-get clean -y
# Install Go and glide
......
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
IF(MOBILE_INFERENCE)
return()
ENDIF()
include (ExternalProject)
# NOTE: c-ares is needed when linking with grpc.
SET(CARES_SOURCES_DIR ${THIRD_PARTY_PATH}/cares)
SET(CARES_INSTALL_DIR ${THIRD_PARTY_PATH}/install/cares)
SET(CARES_INCLUDE_DIR "${CARES_INSTALL_DIR}/include/" CACHE PATH "cares include directory." FORCE)
ExternalProject_Add(
extern_cares
GIT_REPOSITORY "https://github.com/c-ares/c-ares.git"
GIT_TAG "cares-1_13_0"
PREFIX ${CARES_SOURCES_DIR}
UPDATE_COMMAND ""
CONFIGURE_COMMAND ./buildconf && ./configure --disable-shared --prefix=${CARES_INSTALL_DIR}
BUILD_IN_SOURCE 1
BUILD_COMMAND make
INSTALL_COMMAND make install
)
ADD_LIBRARY(cares STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET cares PROPERTY IMPORTED_LOCATION
"${CARES_INSTALL_DIR}/lib/libcares.a")
include_directories(${CARES_INCLUDE_DIR})
ADD_DEPENDENCIES(cares extern_cares)
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
IF(MOBILE_INFERENCE)
return()
ENDIF()
include (ExternalProject)
SET(GRPC_SOURCES_DIR ${THIRD_PARTY_PATH}/grpc)
SET(GRPC_INSTALL_DIR ${THIRD_PARTY_PATH}/install/grpc)
SET(GRPC_INCLUDE_DIR "${GRPC_INSTALL_DIR}/include/" CACHE PATH "grpc include directory." FORCE)
SET(GRPC_CPP_PLUGIN "${GRPC_INSTALL_DIR}/bin/grpc_cpp_plugin" CACHE FILEPATH "GRPC_CPP_PLUGIN" FORCE)
ExternalProject_Add(
extern_grpc
DEPENDS protobuf zlib
GIT_REPOSITORY "https://github.com/grpc/grpc.git"
GIT_TAG "v1.7.x"
PREFIX ${GRPC_SOURCES_DIR}
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_IN_SOURCE 1
BUILD_COMMAND make
INSTALL_COMMAND make prefix=${GRPC_INSTALL_DIR} install
)
# FIXME(typhoonzero): hack to get static lib path, try a better way like merge them.
ADD_LIBRARY(grpc++_unsecure STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET grpc++_unsecure PROPERTY IMPORTED_LOCATION
"${GRPC_INSTALL_DIR}/lib/libgrpc++_unsecure.a")
ADD_LIBRARY(grpc++ STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET grpc++ PROPERTY IMPORTED_LOCATION
"${GRPC_INSTALL_DIR}/lib/libgrpc++.a")
ADD_LIBRARY(gpr STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET gpr PROPERTY IMPORTED_LOCATION
"${GRPC_INSTALL_DIR}/lib/libgpr.a")
ADD_LIBRARY(grpc_unsecure STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET grpc_unsecure PROPERTY IMPORTED_LOCATION
"${GRPC_INSTALL_DIR}/lib/libgrpc_unsecure.a")
include_directories(${GRPC_INCLUDE_DIR})
ADD_DEPENDENCIES(grpc++_unsecure extern_grpc)
......@@ -50,6 +50,8 @@ ExternalProject_Add(
)
LIST(APPEND external_project_dependencies zlib)
ADD_LIBRARY(zlib_target STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET zlib_target PROPERTY IMPORTED_LOCATION ${ZLIB_LIBRARIES})
IF(WITH_C_API)
INSTALL(DIRECTORY ${ZLIB_INCLUDE_DIR} DESTINATION third_party/zlib)
......
......@@ -467,3 +467,50 @@ function(py_test TARGET_NAME)
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
endif()
endfunction()
# grpc_library generate grpc code using grpc_cpp_plugin and protoc
# then build the generated protobuf code and grpc code with your
# implementation source codes together. Use SRCS argument for your
# implementation source files and PROTO argument for your .proto
# files.
#
# Usage: grpc_library(my_target SRCS my_client.cc PROTO my_target.proto DEPS my_dep)
function(grpc_library TARGET_NAME)
set(oneValueArgs PROTO)
set(multiValueArgs SRCS DEPS)
set(options "")
cmake_parse_arguments(grpc_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
message(STATUS "generating grpc ${grpc_library_PROTO}")
get_filename_component(ABS_PROTO ${grpc_library_PROTO} ABSOLUTE)
get_filename_component(PROTO_WE ${grpc_library_PROTO} NAME_WE)
get_filename_component(PROTO_PATH ${ABS_PROTO} PATH)
protobuf_generate_cpp(grpc_proto_srcs grpc_proto_hdrs "${ABS_PROTO}")
set(grpc_grpc_srcs "${CMAKE_CURRENT_BINARY_DIR}/${PROTO_WE}.grpc.pb.cc")
set(grpc_grpc_hdrs "${CMAKE_CURRENT_BINARY_DIR}/${PROTO_WE}.grpc.pb.h")
cc_library("${TARGET_NAME}_proto" SRCS "${grpc_proto_srcs}")
add_custom_command(
OUTPUT "${grpc_grpc_srcs}" "${grpc_grpc_hdrs}"
COMMAND ${PROTOBUF_PROTOC_EXECUTABLE}
ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}" -I "${PROTO_PATH}"
--plugin=protoc-gen-grpc="${GRPC_CPP_PLUGIN}" "${ABS_PROTO}"
DEPENDS "${ABS_PROTO}" ${PROTOBUF_PROTOC_EXECUTABLE} extern_grpc)
# FIXME(typhoonzero): grpc generated code do not generate virtual-dtor, mark it
# as compiler warnings instead of error. Should try remove the warnings also.
set_source_files_properties(
${grpc_grpc_srcs}
PROPERTIES
COMPILE_FLAGS "-Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
cc_library("${TARGET_NAME}_grpc" SRCS "${grpc_grpc_srcs}")
set_source_files_properties(
${grpc_library_SRCS}
PROPERTIES
COMPILE_FLAGS "-Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
cc_library("${TARGET_NAME}" SRCS "${grpc_library_SRCS}" DEPS "${TARGET_NAME}_grpc" "${TARGET_NAME}_proto" "${grpc_library_DEPS}")
endfunction()
......@@ -13,6 +13,8 @@
limitations under the License. */
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/data_type.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/memory/memcpy.h"
#include "paddle/memory/memory.h"
......@@ -27,11 +29,11 @@
namespace paddle {
namespace framework {
std::ostream& operator<<(std::ostream& os, const LoD& lod) {
std::ostream &operator<<(std::ostream &os, const LoD &lod) {
os << "{";
for (auto& v : lod) {
for (auto &v : lod) {
os << "{";
for (auto& i : v) {
for (auto &i : v) {
os << i << ",";
}
os << "}";
......@@ -41,7 +43,7 @@ std::ostream& operator<<(std::ostream& os, const LoD& lod) {
return os;
}
LoD SliceLevels(const LoD& in, size_t level_begin, size_t level_end) {
LoD SliceLevels(const LoD &in, size_t level_begin, size_t level_end) {
LoD new_lod;
new_lod.reserve(level_end - level_begin);
for (size_t i = level_begin; i < level_end; i++) {
......@@ -53,7 +55,7 @@ LoD SliceLevels(const LoD& in, size_t level_begin, size_t level_end) {
return new_lod;
}
LoD SliceInLevel(const LoD& in, size_t level, size_t elem_begin,
LoD SliceInLevel(const LoD &in, size_t level, size_t elem_begin,
size_t elem_end) {
PADDLE_ENFORCE_LT(level, in.size());
PADDLE_ENFORCE_LT(elem_end, in[level].size());
......@@ -64,9 +66,9 @@ LoD SliceInLevel(const LoD& in, size_t level, size_t elem_begin,
res[0].assign(in[level].begin() + elem_begin,
in[level].begin() + elem_end + 1);
for (size_t lvl = 1; lvl < res.size(); lvl++) {
const auto& in_level = in[level + lvl];
const auto& above_level = res[lvl - 1];
auto& out_level = res[lvl];
const auto &in_level = in[level + lvl];
const auto &above_level = res[lvl - 1];
auto &out_level = res[lvl];
out_level.assign(in_level.begin() + above_level.front(),
in_level.begin() + above_level.back() + 1);
}
......@@ -74,33 +76,33 @@ LoD SliceInLevel(const LoD& in, size_t level, size_t elem_begin,
// to make the first offset equals 0, all the elements minus the first
// element
size_t front = res[lvl].front();
for (auto& ele : res[lvl]) {
for (auto &ele : res[lvl]) {
ele -= front;
}
}
return res;
}
LoD ToAbsOffset(const LoD& in) {
LoD ToAbsOffset(const LoD &in) {
// the lowest level stores relative offsets
if (in.empty() || in.size() == 1) return in;
LoD result = in;
for (int level = result.size() - 2; level >= 0; level--) {
for (auto& ele : result[level]) {
for (auto &ele : result[level]) {
ele = result[level + 1][ele];
}
}
return result;
}
bool operator==(const LoD& a, const LoD& b) {
bool operator==(const LoD &a, const LoD &b) {
if (a.size() != b.size()) {
return false;
}
for (size_t i = 0; i < a.size(); i++) {
const auto& a_level = a[i];
const auto& b_level = b[i];
const auto &a_level = a[i];
const auto &b_level = b[i];
if (a_level.size() != b_level.size()) {
return false;
}
......@@ -151,7 +153,7 @@ void LoDTensor::ShrinkInLevel(size_t level, size_t elem_begin,
}
using LoDAndOffset = std::pair<LoD, std::pair<size_t, size_t>>;
LoDAndOffset GetSubLoDAndAbsoluteOffset(const LoD& lod, size_t start_idx,
LoDAndOffset GetSubLoDAndAbsoluteOffset(const LoD &lod, size_t start_idx,
size_t end_idx, size_t start_level) {
LoD sub_lod;
......@@ -170,7 +172,7 @@ LoDAndOffset GetSubLoDAndAbsoluteOffset(const LoD& lod, size_t start_idx,
return LoDAndOffset{sub_lod, {start_idx, end_idx}};
}
void AppendLoD(LoD* lod, const LoD& lod_length) {
void AppendLoD(LoD *lod, const LoD &lod_length) {
PADDLE_ENFORCE(
lod->empty() || lod->size() == lod_length.size(),
"The lod_length should has the same size with the appended lod.");
......@@ -178,12 +180,139 @@ void AppendLoD(LoD* lod, const LoD& lod_length) {
*lod = LoD(lod_length.size(), std::vector<size_t>({0}));
}
for (size_t i = 0; i < lod->size(); ++i) {
auto& level = (*lod)[i];
auto &level = (*lod)[i];
for (size_t len : lod_length[i]) {
level.push_back(level.back() + len);
}
}
}
void SerializeToStream(std::ostream &os, const LoDTensor &tensor,
const platform::DeviceContext &dev_ctx) {
// TODO(typhoonzero): serialize to ostream
{ // the 1st field, uint32_t version
constexpr uint32_t version = 0;
os.write(reinterpret_cast<const char *>(&version), sizeof(version));
}
{ // the 2nd field, tensor description
// int32_t size
// void* protobuf message
framework::TensorDesc desc;
desc.set_data_type(framework::ToDataType(tensor.type()));
auto dims = framework::vectorize(tensor.dims());
auto *pb_dims = desc.mutable_dims();
pb_dims->Resize(static_cast<int>(dims.size()), 0);
std::copy(dims.begin(), dims.end(), pb_dims->begin());
int32_t size = desc.ByteSize();
os.write(reinterpret_cast<const char *>(&size), sizeof(size));
auto out = desc.SerializeAsString();
os.write(out.data(), size);
}
{ // the 3rd field, tensor data
uint64_t size = tensor.memory_size();
auto *data_ptr = tensor.data<void>();
PADDLE_ENFORCE(size < std::numeric_limits<std::streamsize>::max(),
"Index overflow when writing tensor");
if (platform::is_gpu_place(tensor.place())) {
#ifdef PADDLE_WITH_CUDA
constexpr size_t kBufSize = 1024 * 1024 * 64; // 64MB
std::unique_ptr<char[]> buf(new char[kBufSize]);
auto &gpu_dev_ctx =
static_cast<const platform::CUDADeviceContext &>(dev_ctx);
platform::CPUPlace cpu;
uintptr_t data = reinterpret_cast<uintptr_t>(data_ptr);
while (size != 0) {
size_t size_to_write = std::min(kBufSize, static_cast<size_t>(size));
memory::Copy(cpu, buf.get(),
boost::get<platform::GPUPlace>(tensor.place()),
reinterpret_cast<const void *>(data), size_to_write,
gpu_dev_ctx.stream());
gpu_dev_ctx.Wait();
os.write(buf.get(), size_to_write);
data += size_to_write;
size -= size_to_write;
}
#else
PADDLE_THROW("Unexpected branch");
#endif
} else {
os.write(static_cast<const char *>(data_ptr),
static_cast<std::streamsize>(size));
}
}
{ // the 4th field, lod information
// uint64_t lod_level
// uint64_t lod_level_1 size in byte.
// int* lod_level_1 data
// ...
auto lod = tensor.lod();
uint64_t size = lod.size();
os.write(reinterpret_cast<const char *>(&size), sizeof(size));
for (auto &each : lod) {
size = each.size() * sizeof(framework::LoD::value_type::value_type);
os.write(reinterpret_cast<const char *>(&size), sizeof(size));
os.write(reinterpret_cast<const char *>(each.data()),
static_cast<std::streamsize>(size));
}
}
}
void DeserializeFromStream(std::istream &is, LoDTensor *tensor) {
uint32_t version;
is.read(reinterpret_cast<char *>(&version), sizeof(version));
PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported");
framework::TensorDesc desc;
{ // int32_t size
// proto buffer
int32_t size;
is.read(reinterpret_cast<char *>(&size), sizeof(size));
std::unique_ptr<char[]> buf(new char[size]);
is.read(reinterpret_cast<char *>(buf.get()), size);
PADDLE_ENFORCE(desc.ParseFromArray(buf.get(), size),
"Cannot parse tensor desc");
}
{ // read tensor
std::vector<int64_t> dims;
dims.reserve(static_cast<size_t>(desc.dims().size()));
std::copy(desc.dims().begin(), desc.dims().end(), std::back_inserter(dims));
tensor->Resize(framework::make_ddim(dims));
void *buf;
platform::Place cpu = platform::CPUPlace();
switch (desc.data_type()) {
case framework::FP32:
buf = tensor->mutable_data<float>(cpu);
break;
case framework::FP64:
buf = tensor->mutable_data<double>(cpu);
break;
case framework::INT32:
buf = tensor->mutable_data<int>(cpu);
break;
case framework::INT64:
buf = tensor->mutable_data<int64_t>(cpu);
break;
default:
PADDLE_THROW("DataType %d not supported", desc.data_type());
}
is.read(static_cast<char *>(buf), tensor->memory_size());
}
{ // read lod
uint64_t lod_level;
is.read(reinterpret_cast<char *>(&lod_level), sizeof(lod_level));
auto &lod = *tensor->mutable_lod();
lod.resize(lod_level);
for (uint64_t i = 0; i < lod_level; ++i) {
uint64_t size;
is.read(reinterpret_cast<char *>(&size), sizeof(size));
std::vector<size_t> tmp(size / sizeof(size_t));
is.read(reinterpret_cast<char *>(tmp.data()),
static_cast<std::streamsize>(size));
lod[i] = tmp;
}
}
}
} // namespace framework
} // namespace paddle
......@@ -189,5 +189,14 @@ std::pair<LoD, std::pair<size_t, size_t>> GetSubLoDAndAbsoluteOffset(
void AppendLoD(LoD* lod, const LoD& lod_length);
/*
* Serialize/Desiralize LoDTensor to std::ostream
* You can pass ofstream or ostringstream to serilize to file
* or to a in memory string. GPU tensor will be copied to CPU.
*/
void SerializeToStream(std::ostream& os, const LoDTensor& tensor,
const platform::DeviceContext& dev_ctx);
void DeserializeFromStream(std::istream& is, LoDTensor* tensor);
} // namespace framework
} // namespace paddle
......@@ -205,8 +205,24 @@ set(DEPS_OPS
tensor_array_read_write_op
gru_op
adagrad_op
sgd_op)
sgd_op
save_op
load_op
send_op
recv_op)
add_subdirectory(detail)
op_library(send_op SRCS send_op.cc DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib_target protobuf)
set_source_files_properties(
send_op.cc
PROPERTIES
COMPILE_FLAGS "-Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
op_library(recv_op SRCS recv_op.cc DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib_target protobuf)
set_source_files_properties(
recv_op.cc
PROPERTIES
COMPILE_FLAGS "-Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op)
op_library(cross_entropy_op DEPS cross_entropy)
......@@ -235,6 +251,10 @@ op_library(conv_transpose_op DEPS vol2col)
op_library(gru_op DEPS sequence2batch gru_compute)
op_library(recurrent_op SRCS recurrent_op.cc DEPS executor)
# FIXME(typhoonzero): save/load depends lodtensor serialization functions
op_library(save_op DEPS lod_tensor)
op_library(load_op DEPS lod_tensor)
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
foreach(src ${GENERAL_OPS})
op_library(${src})
......@@ -242,6 +262,8 @@ endforeach()
set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library")
cc_test(gather_test SRCS gather_test.cc DEPS tensor)
cc_test(net_op_test SRCS net_op_test.cc DEPS net_op)
cc_test(scatter_test SRCS scatter_test.cc DEPS tensor)
......@@ -251,3 +273,4 @@ if(WITH_GPU)
cc_test(nccl_op_test SRCS nccl_op_test.cu.cc DEPS nccl_op gpu_info device_context)
endif()
cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op)
cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS send_op recv_op sum_op executor)
grpc_library(sendrecvop_grpc SRCS recv_impl.cc send_impl.cc PROTO send_recv.proto DEPS lod_tensor selected_rows)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "send_recv_impl.h"
namespace paddle {
namespace operators {
namespace detail {
Status SendRecvServerImpl::SendVariable(ServerContext *context,
const VariableMessage *in_var,
VariableMessage *out_var) {
framework::LoDTensor t;
// TODO(typhoonzero): desirealize in_tensor and run pserver network.
std::istringstream iss(in_var->serialized());
framework::DeserializeFromStream(iss, &t);
lodtensor_queue_.Push(std::move(t));
// Block util the sub graph is done.
t = lodtensor_return_queue_.Pop();
std::ostringstream oss;
// FIXME(typhoonzero): get context from op.
framework::SerializeToStream(oss, t, platform::CPUDeviceContext());
std::string *varname = out_var->mutable_varname();
*varname = in_var->varname();
std::string *serialized = out_var->mutable_serialized();
*serialized = oss.str();
return Status::OK;
}
} // namespace detail
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "send_recv_impl.h"
namespace paddle {
namespace operators {
namespace detail {
bool RPCClient::SendVariable(const framework::Scope& scope,
const std::string& inname,
const std::string& outname) {
ClientContext context;
VariableMessage msg, out_msg;
// FIXME(typhoonzero): pass device context to here.
auto ctx = platform::CPUDeviceContext();
auto* var = scope.FindVar(inname);
PADDLE_ENFORCE(var);
// TODO(typhoonzero): support SelectedRows
PADDLE_ENFORCE(var->IsType<framework::LoDTensor>(),
"Only support LoDTensor, %s has wrong type", inname);
const framework::LoDTensor& tensor = var->Get<framework::LoDTensor>();
std::ostringstream oss;
framework::SerializeToStream(oss, tensor, ctx);
msg.set_varname(inname);
msg.set_serialized(oss.str());
Status status = stub_->SendVariable(&context, msg, &out_msg);
if (!status.ok()) {
return false;
}
std::istringstream iss(out_msg.serialized());
framework::LoDTensor ret_tensor;
framework::DeserializeFromStream(iss, &ret_tensor);
auto* outvar = scope.FindVar(outname);
framework::LoDTensor* out_tensor = outvar->GetMutable<framework::LoDTensor>();
// FIXME(typhoonzero): do not copy.
framework::CopyFrom(ret_tensor, ctx.GetPlace(), ctx, out_tensor);
return true;
}
} // namespace detail
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
syntax = "proto3";
package sendrecv;
service SendRecvService {
// For parameter server round-robin like hashing, do not split tensors.
// Send and recv only one tensor
rpc SendVariable(VariableMessage) returns (VariableMessage) {}
}
// VariableMessage is serialized paddle variable message.
// It can be:
// Tensor
// LoDTensor
// SelectedRows
message VariableMessage {
string varname = 1;
bytes serialized = 2;
}
message VoidMessage {
}
\ No newline at end of file
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/data_type.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/selected_rows.h"
#include "paddle/operators/detail/simple_block_queue.h"
// #include <grpc++/channel.h>
// #include <grpc++/client_context.h>
// #include <grpc++/create_channel.h>
// #include <grpc++/security/credentials.h>
#include "paddle/operators/detail/send_recv.grpc.pb.h"
#include "paddle/operators/detail/send_recv.pb.h"
#include <grpc++/grpc++.h>
using grpc::Channel;
using grpc::Server;
using grpc::ServerContext;
using grpc::ServerReader;
using grpc::ServerBuilder;
using grpc::ClientContext;
using grpc::ClientReader;
using grpc::ClientReaderWriter;
using grpc::ClientWriter;
using grpc::Status;
using sendrecv::SendRecvService;
using sendrecv::VariableMessage;
using sendrecv::VoidMessage;
namespace paddle {
namespace operators {
namespace detail {
class SendRecvServerImpl final : public SendRecvService::Service {
public:
explicit SendRecvServerImpl() {}
Status SendVariable(ServerContext *context, const VariableMessage *in_var,
VariableMessage *out_var) override;
const framework::LoDTensor Get() { return this->lodtensor_queue_.Pop(); }
void Push(const framework::LoDTensor &tensor) {
this->lodtensor_return_queue_.Push(tensor);
}
private:
SimpleBlockQueue<framework::LoDTensor> lodtensor_queue_;
SimpleBlockQueue<framework::LoDTensor> lodtensor_return_queue_;
SimpleBlockQueue<framework::SelectedRows> selected_rows_queue_;
SimpleBlockQueue<framework::SelectedRows> selected_rows_return_queue_;
};
// RPCClient is a class to send tensors to pserver sub-network
// using different hashing methods.
class RPCClient {
public:
RPCClient(std::shared_ptr<Channel> channel)
: stub_(SendRecvService::NewStub(channel)) {}
bool SendVariable(const framework::Scope &scope, const std::string &inname,
const std::string &outname);
private:
std::unique_ptr<SendRecvService::Stub> stub_;
};
} // namespace detail
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <condition_variable>
#include <deque>
#include <mutex>
namespace paddle {
namespace operators {
namespace detail {
template <typename T>
class SimpleBlockQueue {
private:
std::mutex mutex_;
std::condition_variable condition_;
std::deque<T> queue_;
public:
void Push(T const& value) {
{
std::unique_lock<std::mutex> lock(this->mutex_);
queue_.push_front(value);
}
this->condition_.notify_one();
}
T Pop() {
std::unique_lock<std::mutex> lock(this->mutex_);
this->condition_.wait(lock, [=] { return !this->queue_.empty(); });
T rc(std::move(this->queue_.back()));
this->queue_.pop_back();
return rc;
}
};
} // namespace detail
} // namespace operators
} // namespace paddle
......@@ -38,61 +38,7 @@ class LoadOp : public framework::OperatorBase {
out_var_name);
auto *tensor = out_var->GetMutable<framework::LoDTensor>();
uint32_t version;
fin.read(reinterpret_cast<char *>(&version), sizeof(version));
PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported");
framework::TensorDesc desc;
{ // int32_t size
// proto buffer
int32_t size;
fin.read(reinterpret_cast<char *>(&size), sizeof(size));
std::unique_ptr<char[]> buf(new char[size]);
fin.read(reinterpret_cast<char *>(buf.get()), size);
PADDLE_ENFORCE(desc.ParseFromArray(buf.get(), size),
"Cannot parse tensor desc");
}
{ // read tensor
std::vector<int64_t> dims;
dims.reserve(static_cast<size_t>(desc.dims().size()));
std::copy(desc.dims().begin(), desc.dims().end(),
std::back_inserter(dims));
tensor->Resize(framework::make_ddim(dims));
void *buf;
platform::Place cpu = platform::CPUPlace();
switch (desc.data_type()) {
case framework::FP32:
buf = tensor->mutable_data<float>(cpu);
break;
case framework::FP64:
buf = tensor->mutable_data<double>(cpu);
break;
case framework::INT32:
buf = tensor->mutable_data<int>(cpu);
break;
case framework::INT64:
buf = tensor->mutable_data<int64_t>(cpu);
break;
default:
PADDLE_THROW("DataType %d not supported", desc.data_type());
}
fin.read(static_cast<char *>(buf), tensor->memory_size());
}
{ // read lod
uint64_t lod_level;
fin.read(reinterpret_cast<char *>(&lod_level), sizeof(lod_level));
auto &lod = *tensor->mutable_lod();
lod.resize(lod_level);
for (uint64_t i = 0; i < lod_level; ++i) {
uint64_t size;
fin.read(reinterpret_cast<char *>(&size), sizeof(size));
std::vector<size_t> tmp(size / sizeof(size_t));
fin.read(reinterpret_cast<char *>(tmp.data()),
static_cast<std::streamsize>(size));
lod[i] = tmp;
}
}
framework::DeserializeFromStream(fin, tensor);
auto place = dev_ctx.GetPlace();
if (platform::is_gpu_place(place)) {
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <stdint.h>
#include <sys/stat.h>
#include <ostream>
#include <thread>
#include <unistd.h>
#include "paddle/framework/data_type.h"
#include "paddle/framework/executor.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/detail/send_recv_impl.h"
#include "paddle/operators/detail/simple_block_queue.h"
namespace paddle {
namespace operators {
void RunServer(Server **rpc_server,
std::shared_ptr<detail::SendRecvServerImpl> service,
const std::string &server_address) {
ServerBuilder builder;
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
builder.RegisterService(service.get());
std::unique_ptr<Server> server(builder.BuildAndStart());
*rpc_server = server.get();
LOG(INFO) << "Server listening on " << server_address << std::endl;
server->Wait();
}
class RecvOp : public framework::OperatorBase {
public:
RecvOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {
if (!rpc_service_) {
rpc_service_.reset(new detail::SendRecvServerImpl());
std::string endpoint = Attr<std::string>("endpoint");
server_thread_.reset(
new std::thread(RunServer, &rpc_server_, rpc_service_, endpoint));
}
}
virtual ~RecvOp() {
rpc_server_->Shutdown();
server_thread_->join();
}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
// blocking get one var from client.
const framework::LoDTensor &t = rpc_service_->Get();
framework::Scope &recv_scope = scope.NewScope();
// set graph input var
auto *var = recv_scope.Var(Input("RX"));
auto *tensor = var->GetMutable<framework::LoDTensor>();
// FIXME(typhoonzero): do not copy
framework::CopyFrom(t, dev_ctx.GetPlace(), dev_ctx, tensor);
auto *block = Attr<framework::BlockDescBind *>("OptimizeBlock");
auto *program = block->Program();
framework::Executor executor(dev_ctx);
// Run sub graph to get optimized tensor
executor.Run(*program, &recv_scope, block->ID(),
false /*create_local_scope*/);
auto *out_var = recv_scope.FindVar("Out");
// push back
rpc_service_->Push(out_var->Get<framework::LoDTensor>());
}
protected:
// grpc server instance to track status and gracefully shutdown.
// borrow an pointer from server thread.
Server *rpc_server_{nullptr};
// grpc send/recv service implement to register.
std::shared_ptr<detail::SendRecvServerImpl> rpc_service_;
std::shared_ptr<std::thread> server_thread_;
};
class RecvOpMaker : public framework::OpProtoAndCheckerMaker {
public:
RecvOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("RX", "(Tensor) Input tensor to be saved");
AddComment(R"DOC(
Recv operator
This operator will recv tensor from send_op
)DOC");
AddAttr<std::string>("endpoint",
"(string, default 127.0.0.1:6164)"
"IP address to listen on.")
.SetDefault("127.0.0.1:6164")
.AddCustomChecker([](const std::string &ip) { return !ip.empty(); });
AddAttr<framework::BlockDescBind *>("OptimizeBlock", "type BlockDescBind*",
"optimize network run in server");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(recv, ops::RecvOp, ops::RecvOpMaker);
......@@ -88,73 +88,7 @@ class SaveOp : public framework::OperatorBase {
"SaveOp only support LoDTensor, %s has wrong type", iname);
auto &tensor = var->Get<framework::LoDTensor>();
{ // the 1st field, uint32_t version
constexpr uint32_t version = 0;
fout.write(reinterpret_cast<const char *>(&version), sizeof(version));
}
{ // the 2nd field, tensor description
// int32_t size
// void* protobuf message
framework::TensorDesc desc;
desc.set_data_type(framework::ToDataType(tensor.type()));
auto dims = framework::vectorize(tensor.dims());
auto *pb_dims = desc.mutable_dims();
pb_dims->Resize(static_cast<int>(dims.size()), 0);
std::copy(dims.begin(), dims.end(), pb_dims->begin());
int32_t size = desc.ByteSize();
fout.write(reinterpret_cast<const char *>(&size), sizeof(size));
auto out = desc.SerializeAsString();
fout.write(out.data(), size);
}
{ // the 3rd field, tensor data
uint64_t size = tensor.memory_size();
auto *data_ptr = tensor.data<void>();
PADDLE_ENFORCE(size < std::numeric_limits<std::streamsize>::max(),
"Index overflow when writing tensor");
if (platform::is_gpu_place(tensor.place())) {
#ifdef PADDLE_WITH_CUDA
constexpr size_t kBufSize = 1024 * 1024 * 64; // 64MB
std::unique_ptr<char[]> buf(new char[kBufSize]);
auto &gpu_dev_ctx =
static_cast<const platform::CUDADeviceContext &>(dev_ctx);
platform::CPUPlace cpu;
uintptr_t data = reinterpret_cast<uintptr_t>(data_ptr);
while (size != 0) {
size_t size_to_write = std::min(kBufSize, static_cast<size_t>(size));
memory::Copy(cpu, buf.get(),
boost::get<platform::GPUPlace>(tensor.place()),
reinterpret_cast<const void *>(data), size_to_write,
gpu_dev_ctx.stream());
gpu_dev_ctx.Wait();
fout.write(buf.get(), size_to_write);
data += size_to_write;
size -= size_to_write;
}
#else
PADDLE_THROW("Unexpected branch");
#endif
} else {
fout.write(static_cast<const char *>(data_ptr),
static_cast<std::streamsize>(size));
}
}
{ // the 4th field, lod information
// uint64_t lod_level
// uint64_t lod_level_1 size in byte.
// int* lod_level_1 data
// ...
auto lod = tensor.lod();
uint64_t size = lod.size();
fout.write(reinterpret_cast<const char *>(&size), sizeof(size));
for (auto &each : lod) {
size = each.size() * sizeof(framework::LoD::value_type::value_type);
fout.write(reinterpret_cast<const char *>(&size), sizeof(size));
fout.write(reinterpret_cast<const char *>(each.data()),
static_cast<std::streamsize>(size));
}
}
framework::SerializeToStream(fout, tensor, dev_ctx);
}
};
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <ostream>
#include "paddle/framework/data_type.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/detail/send_recv_impl.h"
#include "paddle/operators/detail/simple_block_queue.h"
namespace paddle {
namespace operators {
// TODO(typhoonzero): this is a simple implementation which only send
// one tensor
class SendOp : public framework::OperatorBase {
public:
SendOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {
// init client when the operator is created at runtime.
if (!client_) {
std::string endpoint = Attr<std::string>("endpoint");
client_.reset(new detail::RPCClient(
grpc::CreateChannel(endpoint, grpc::InsecureChannelCredentials())));
// TODO(typhoonzero): how to call InitVariables
}
}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
auto iname = Input("X");
auto oname = Output("Out");
// TODO(typhoonzero): currently it's non-blocking,
// should block until server responds.
bool ret = client_->SendVariable(scope, iname, oname);
if (!ret) {
LOG(ERROR) << "send variable error";
}
}
protected:
std::shared_ptr<detail::RPCClient> client_{nullptr};
};
class SendOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SendOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) Input tensor to be saved");
AddOutput("Out", "(Tensor) Output fetched from server");
AddComment(R"DOC(
Recv operator
This operator will recv tensor from send_op
)DOC");
AddAttr<std::string>("endpoint",
"(string, default 127.0.0.1:6164)"
"IP address to listen on.")
.SetDefault("127.0.0.1:6164")
.AddCustomChecker([](const std::string &ip) { return !ip.empty(); });
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(send, ops::SendOp, ops::SendOpMaker);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// TODO(typhoonzero): add python bindings for this test as
// a RemoteOptimizer.
#include <unistd.h>
#include <thread>
#include "gtest/gtest.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/program_desc.h"
USE_NO_KERNEL_OP(send);
USE_NO_KERNEL_OP(recv);
USE_OP(sum);
// global for simplicity.
std::unique_ptr<paddle::framework::OperatorBase> recv_op;
void InitTensorsInScope(paddle::framework::Scope &scope,
paddle::platform::CPUPlace &place) {
paddle::platform::CPUDeviceContext ctx(place);
auto var = scope.Var("X");
auto tensor = var->GetMutable<paddle::framework::LoDTensor>();
tensor->Resize({10, 10});
float *expect = tensor->mutable_data<float>(place);
for (int64_t i = 0; i < tensor->numel(); ++i) {
expect[i] = static_cast<float>(i);
}
auto out_var = scope.Var("Out");
auto out_tensor = out_var->GetMutable<paddle::framework::LoDTensor>();
out_tensor->Resize({10, 10});
tensor->mutable_data<float>(place); // allocate
}
void AddOp(const std::string &type,
const paddle::framework::VariableNameMap &inputs,
const paddle::framework::VariableNameMap &outputs,
paddle::framework::AttributeMap attrs,
paddle::framework::BlockDescBind *block) {
// insert output
for (auto kv : outputs) {
for (auto v : kv.second) {
auto var = block->Var(v);
var->SetDataType(paddle::framework::DataType::FP32);
}
}
// insert op
auto op = block->AppendOp();
op->SetType(type);
for (auto &kv : inputs) {
op->SetInput(kv.first, kv.second);
}
for (auto &kv : outputs) {
op->SetOutput(kv.first, kv.second);
}
op->SetAttrMap(attrs);
}
void StartServerNet() {
paddle::framework::Scope scope;
paddle::platform::CPUPlace place;
InitTensorsInScope(scope, place);
// sub program run in recv_op, for simple test we use sum
paddle::framework::ProgramDescBind program;
paddle::framework::BlockDescBind *block = program.MutableBlock(0);
// X for server side tensors, RX for received tensers, must be of same shape.
AddOp("sum", {{"X", {"X", "RX"}}}, {{"Out", {"Out"}}}, {}, block);
paddle::framework::AttributeMap attrs;
attrs.insert({"endpoint", std::string("127.0.0.1:6174")});
attrs.insert({"OptimizeBlock", block});
recv_op = paddle::framework::OpRegistry::CreateOp("recv", {{"RX", {"RX"}}},
{{"Out", {"Out"}}}, attrs);
paddle::platform::CPUDeviceContext ctx(place);
recv_op->Run(scope, ctx);
}
TEST(SendRecvOp, CPU) {
std::thread server_thread(StartServerNet);
sleep(5); // wait server to start
// local net
paddle::framework::Scope scope;
paddle::platform::CPUPlace place;
InitTensorsInScope(scope, place);
paddle::framework::AttributeMap attrs;
attrs.insert({"endpoint", std::string("127.0.0.1:6174")});
auto send_op = paddle::framework::OpRegistry::CreateOp(
"send", {{"X", {"X"}}}, {{"Out", {"Out"}}}, attrs);
paddle::platform::CPUDeviceContext ctx(place);
send_op->Run(scope, ctx);
auto in_var = scope.Var("X");
auto tensor = in_var->GetMutable<paddle::framework::LoDTensor>();
float *expected = tensor->data<float>();
auto out_var = scope.Var("Out");
auto target = out_var->GetMutable<paddle::framework::LoDTensor>();
// send fail cause output is none.
EXPECT_NE(target->memory_size(), size_t(0));
float *actual = target->data<float>();
for (int64_t i = 0; i < target->numel(); ++i) {
EXPECT_EQ(expected[i] * 2, actual[i]);
}
recv_op.reset(); // dtor can shutdown and join server thread.
server_thread.join();
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册