From 0a8a86e0c9733dd85e82c58d2042d1abb7c85b73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=AD=A6=E6=AF=85?= Date: Tue, 28 Nov 2017 11:02:24 +0800 Subject: [PATCH] Send recv op (#5520) * WIP send recv op * WIP send recv * put grpc impl in details * put grpc impl in details * update wip * update proto * update proto * update proto * clean cmake * wip on op implementations * wip on op implementations * compile ok adding ut * wip unitest * add extern cares for linking * wip add ut * working version send recv * revert optimizer.py * update test cmake * add libtool to dockerfile * update cmake dependency * update cmake depends * update cmake grpc depends * fix cmake dependency * fix compile error * fix compile * follow comments * update * update copyfrom --- .clang-format | 1 - CMakeLists.txt | 2 + Dockerfile | 2 +- cmake/external/cares.cmake | 45 +++++ cmake/external/grpc.cmake | 58 +++++++ cmake/external/zlib.cmake | 2 + cmake/generic.cmake | 47 ++++++ paddle/framework/lod_tensor.cc | 163 +++++++++++++++++-- paddle/framework/lod_tensor.h | 9 + paddle/operators/CMakeLists.txt | 25 ++- paddle/operators/detail/CMakeLists.txt | 1 + paddle/operators/detail/recv_impl.cc | 44 +++++ paddle/operators/detail/send_impl.cc | 54 ++++++ paddle/operators/detail/send_recv.proto | 37 +++++ paddle/operators/detail/send_recv_impl.h | 87 ++++++++++ paddle/operators/detail/simple_block_queue.h | 52 ++++++ paddle/operators/load_op.cc | 56 +------ paddle/operators/recv_op.cc | 121 ++++++++++++++ paddle/operators/save_op.cc | 68 +------- paddle/operators/send_op.cc | 84 ++++++++++ paddle/operators/send_recv_op_test.cc | 125 ++++++++++++++ 21 files changed, 941 insertions(+), 142 deletions(-) create mode 100644 cmake/external/cares.cmake create mode 100644 cmake/external/grpc.cmake create mode 100644 paddle/operators/detail/CMakeLists.txt create mode 100644 paddle/operators/detail/recv_impl.cc create mode 100644 paddle/operators/detail/send_impl.cc create mode 100644 paddle/operators/detail/send_recv.proto create mode 100644 paddle/operators/detail/send_recv_impl.h create mode 100644 paddle/operators/detail/simple_block_queue.h create mode 100644 paddle/operators/recv_op.cc create mode 100644 paddle/operators/send_op.cc create mode 100644 paddle/operators/send_recv_op_test.cc diff --git a/.clang-format b/.clang-format index 9ba433b1736..aff93435f58 100644 --- a/.clang-format +++ b/.clang-format @@ -25,4 +25,3 @@ AllowAllParametersOfDeclarationOnNextLine: true BinPackParameters: false BinPackArguments: false ... - diff --git a/CMakeLists.txt b/CMakeLists.txt index 65164b8472b..e76512166fc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -133,6 +133,8 @@ include(external/any) # download libn::any include(external/eigen) # download eigen3 include(external/pybind11) # download pybind11 include(external/nccl) +include(external/cares) +include(external/grpc) include(cudnn) # set cudnn libraries, must before configure include(configure) # add paddle env configuration diff --git a/Dockerfile b/Dockerfile index 150344a8116..857d3f3e5f6 100644 --- a/Dockerfile +++ b/Dockerfile @@ -29,7 +29,7 @@ RUN apt-get update && \ automake locales clang-format swig doxygen cmake \ liblapack-dev liblapacke-dev libboost-dev \ clang-3.8 llvm-3.8 libclang-3.8-dev \ - net-tools && \ + net-tools libtool && \ apt-get clean -y # Install Go and glide diff --git a/cmake/external/cares.cmake b/cmake/external/cares.cmake new file mode 100644 index 00000000000..e05111ee18e --- /dev/null +++ b/cmake/external/cares.cmake @@ -0,0 +1,45 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +IF(MOBILE_INFERENCE) + return() +ENDIF() + +include (ExternalProject) + +# NOTE: c-ares is needed when linking with grpc. + +SET(CARES_SOURCES_DIR ${THIRD_PARTY_PATH}/cares) +SET(CARES_INSTALL_DIR ${THIRD_PARTY_PATH}/install/cares) +SET(CARES_INCLUDE_DIR "${CARES_INSTALL_DIR}/include/" CACHE PATH "cares include directory." FORCE) + +ExternalProject_Add( + extern_cares + GIT_REPOSITORY "https://github.com/c-ares/c-ares.git" + GIT_TAG "cares-1_13_0" + PREFIX ${CARES_SOURCES_DIR} + UPDATE_COMMAND "" + CONFIGURE_COMMAND ./buildconf && ./configure --disable-shared --prefix=${CARES_INSTALL_DIR} + BUILD_IN_SOURCE 1 + BUILD_COMMAND make + INSTALL_COMMAND make install +) + +ADD_LIBRARY(cares STATIC IMPORTED GLOBAL) +SET_PROPERTY(TARGET cares PROPERTY IMPORTED_LOCATION + "${CARES_INSTALL_DIR}/lib/libcares.a") + +include_directories(${CARES_INCLUDE_DIR}) +ADD_DEPENDENCIES(cares extern_cares) diff --git a/cmake/external/grpc.cmake b/cmake/external/grpc.cmake new file mode 100644 index 00000000000..f431c037fd5 --- /dev/null +++ b/cmake/external/grpc.cmake @@ -0,0 +1,58 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +IF(MOBILE_INFERENCE) + return() +ENDIF() + +include (ExternalProject) + +SET(GRPC_SOURCES_DIR ${THIRD_PARTY_PATH}/grpc) +SET(GRPC_INSTALL_DIR ${THIRD_PARTY_PATH}/install/grpc) +SET(GRPC_INCLUDE_DIR "${GRPC_INSTALL_DIR}/include/" CACHE PATH "grpc include directory." FORCE) +SET(GRPC_CPP_PLUGIN "${GRPC_INSTALL_DIR}/bin/grpc_cpp_plugin" CACHE FILEPATH "GRPC_CPP_PLUGIN" FORCE) + +ExternalProject_Add( + extern_grpc + DEPENDS protobuf zlib + GIT_REPOSITORY "https://github.com/grpc/grpc.git" + GIT_TAG "v1.7.x" + PREFIX ${GRPC_SOURCES_DIR} + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_IN_SOURCE 1 + BUILD_COMMAND make + INSTALL_COMMAND make prefix=${GRPC_INSTALL_DIR} install +) + +# FIXME(typhoonzero): hack to get static lib path, try a better way like merge them. +ADD_LIBRARY(grpc++_unsecure STATIC IMPORTED GLOBAL) +SET_PROPERTY(TARGET grpc++_unsecure PROPERTY IMPORTED_LOCATION + "${GRPC_INSTALL_DIR}/lib/libgrpc++_unsecure.a") + +ADD_LIBRARY(grpc++ STATIC IMPORTED GLOBAL) +SET_PROPERTY(TARGET grpc++ PROPERTY IMPORTED_LOCATION + "${GRPC_INSTALL_DIR}/lib/libgrpc++.a") +ADD_LIBRARY(gpr STATIC IMPORTED GLOBAL) +SET_PROPERTY(TARGET gpr PROPERTY IMPORTED_LOCATION + "${GRPC_INSTALL_DIR}/lib/libgpr.a") + +ADD_LIBRARY(grpc_unsecure STATIC IMPORTED GLOBAL) +SET_PROPERTY(TARGET grpc_unsecure PROPERTY IMPORTED_LOCATION + "${GRPC_INSTALL_DIR}/lib/libgrpc_unsecure.a") + +include_directories(${GRPC_INCLUDE_DIR}) +ADD_DEPENDENCIES(grpc++_unsecure extern_grpc) + diff --git a/cmake/external/zlib.cmake b/cmake/external/zlib.cmake index a98e069b7cd..1638cd8fdfc 100644 --- a/cmake/external/zlib.cmake +++ b/cmake/external/zlib.cmake @@ -50,6 +50,8 @@ ExternalProject_Add( ) LIST(APPEND external_project_dependencies zlib) +ADD_LIBRARY(zlib_target STATIC IMPORTED GLOBAL) +SET_PROPERTY(TARGET zlib_target PROPERTY IMPORTED_LOCATION ${ZLIB_LIBRARIES}) IF(WITH_C_API) INSTALL(DIRECTORY ${ZLIB_INCLUDE_DIR} DESTINATION third_party/zlib) diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 7b82d409a3b..c917ca0ff4e 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -467,3 +467,50 @@ function(py_test TARGET_NAME) WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) endif() endfunction() + +# grpc_library generate grpc code using grpc_cpp_plugin and protoc +# then build the generated protobuf code and grpc code with your +# implementation source codes together. Use SRCS argument for your +# implementation source files and PROTO argument for your .proto +# files. +# +# Usage: grpc_library(my_target SRCS my_client.cc PROTO my_target.proto DEPS my_dep) + +function(grpc_library TARGET_NAME) + set(oneValueArgs PROTO) + set(multiValueArgs SRCS DEPS) + set(options "") + cmake_parse_arguments(grpc_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + message(STATUS "generating grpc ${grpc_library_PROTO}") + + get_filename_component(ABS_PROTO ${grpc_library_PROTO} ABSOLUTE) + get_filename_component(PROTO_WE ${grpc_library_PROTO} NAME_WE) + get_filename_component(PROTO_PATH ${ABS_PROTO} PATH) + + protobuf_generate_cpp(grpc_proto_srcs grpc_proto_hdrs "${ABS_PROTO}") + set(grpc_grpc_srcs "${CMAKE_CURRENT_BINARY_DIR}/${PROTO_WE}.grpc.pb.cc") + set(grpc_grpc_hdrs "${CMAKE_CURRENT_BINARY_DIR}/${PROTO_WE}.grpc.pb.h") + cc_library("${TARGET_NAME}_proto" SRCS "${grpc_proto_srcs}") + + add_custom_command( + OUTPUT "${grpc_grpc_srcs}" "${grpc_grpc_hdrs}" + COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} + ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}" -I "${PROTO_PATH}" + --plugin=protoc-gen-grpc="${GRPC_CPP_PLUGIN}" "${ABS_PROTO}" + DEPENDS "${ABS_PROTO}" ${PROTOBUF_PROTOC_EXECUTABLE} extern_grpc) + + # FIXME(typhoonzero): grpc generated code do not generate virtual-dtor, mark it + # as compiler warnings instead of error. Should try remove the warnings also. + set_source_files_properties( + ${grpc_grpc_srcs} + PROPERTIES + COMPILE_FLAGS "-Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") + cc_library("${TARGET_NAME}_grpc" SRCS "${grpc_grpc_srcs}") + + set_source_files_properties( + ${grpc_library_SRCS} + PROPERTIES + COMPILE_FLAGS "-Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") + cc_library("${TARGET_NAME}" SRCS "${grpc_library_SRCS}" DEPS "${TARGET_NAME}_grpc" "${TARGET_NAME}_proto" "${grpc_library_DEPS}") +endfunction() diff --git a/paddle/framework/lod_tensor.cc b/paddle/framework/lod_tensor.cc index a0f2906c749..fdf6de4babf 100644 --- a/paddle/framework/lod_tensor.cc +++ b/paddle/framework/lod_tensor.cc @@ -13,6 +13,8 @@ limitations under the License. */ #include "paddle/framework/lod_tensor.h" +#include "paddle/framework/data_type.h" +#include "paddle/framework/framework.pb.h" #include "paddle/memory/memcpy.h" #include "paddle/memory/memory.h" @@ -27,11 +29,11 @@ namespace paddle { namespace framework { -std::ostream& operator<<(std::ostream& os, const LoD& lod) { +std::ostream &operator<<(std::ostream &os, const LoD &lod) { os << "{"; - for (auto& v : lod) { + for (auto &v : lod) { os << "{"; - for (auto& i : v) { + for (auto &i : v) { os << i << ","; } os << "}"; @@ -41,7 +43,7 @@ std::ostream& operator<<(std::ostream& os, const LoD& lod) { return os; } -LoD SliceLevels(const LoD& in, size_t level_begin, size_t level_end) { +LoD SliceLevels(const LoD &in, size_t level_begin, size_t level_end) { LoD new_lod; new_lod.reserve(level_end - level_begin); for (size_t i = level_begin; i < level_end; i++) { @@ -53,7 +55,7 @@ LoD SliceLevels(const LoD& in, size_t level_begin, size_t level_end) { return new_lod; } -LoD SliceInLevel(const LoD& in, size_t level, size_t elem_begin, +LoD SliceInLevel(const LoD &in, size_t level, size_t elem_begin, size_t elem_end) { PADDLE_ENFORCE_LT(level, in.size()); PADDLE_ENFORCE_LT(elem_end, in[level].size()); @@ -64,9 +66,9 @@ LoD SliceInLevel(const LoD& in, size_t level, size_t elem_begin, res[0].assign(in[level].begin() + elem_begin, in[level].begin() + elem_end + 1); for (size_t lvl = 1; lvl < res.size(); lvl++) { - const auto& in_level = in[level + lvl]; - const auto& above_level = res[lvl - 1]; - auto& out_level = res[lvl]; + const auto &in_level = in[level + lvl]; + const auto &above_level = res[lvl - 1]; + auto &out_level = res[lvl]; out_level.assign(in_level.begin() + above_level.front(), in_level.begin() + above_level.back() + 1); } @@ -74,33 +76,33 @@ LoD SliceInLevel(const LoD& in, size_t level, size_t elem_begin, // to make the first offset equals 0, all the elements minus the first // element size_t front = res[lvl].front(); - for (auto& ele : res[lvl]) { + for (auto &ele : res[lvl]) { ele -= front; } } return res; } -LoD ToAbsOffset(const LoD& in) { +LoD ToAbsOffset(const LoD &in) { // the lowest level stores relative offsets if (in.empty() || in.size() == 1) return in; LoD result = in; for (int level = result.size() - 2; level >= 0; level--) { - for (auto& ele : result[level]) { + for (auto &ele : result[level]) { ele = result[level + 1][ele]; } } return result; } -bool operator==(const LoD& a, const LoD& b) { +bool operator==(const LoD &a, const LoD &b) { if (a.size() != b.size()) { return false; } for (size_t i = 0; i < a.size(); i++) { - const auto& a_level = a[i]; - const auto& b_level = b[i]; + const auto &a_level = a[i]; + const auto &b_level = b[i]; if (a_level.size() != b_level.size()) { return false; } @@ -151,7 +153,7 @@ void LoDTensor::ShrinkInLevel(size_t level, size_t elem_begin, } using LoDAndOffset = std::pair>; -LoDAndOffset GetSubLoDAndAbsoluteOffset(const LoD& lod, size_t start_idx, +LoDAndOffset GetSubLoDAndAbsoluteOffset(const LoD &lod, size_t start_idx, size_t end_idx, size_t start_level) { LoD sub_lod; @@ -170,7 +172,7 @@ LoDAndOffset GetSubLoDAndAbsoluteOffset(const LoD& lod, size_t start_idx, return LoDAndOffset{sub_lod, {start_idx, end_idx}}; } -void AppendLoD(LoD* lod, const LoD& lod_length) { +void AppendLoD(LoD *lod, const LoD &lod_length) { PADDLE_ENFORCE( lod->empty() || lod->size() == lod_length.size(), "The lod_length should has the same size with the appended lod."); @@ -178,12 +180,139 @@ void AppendLoD(LoD* lod, const LoD& lod_length) { *lod = LoD(lod_length.size(), std::vector({0})); } for (size_t i = 0; i < lod->size(); ++i) { - auto& level = (*lod)[i]; + auto &level = (*lod)[i]; for (size_t len : lod_length[i]) { level.push_back(level.back() + len); } } } +void SerializeToStream(std::ostream &os, const LoDTensor &tensor, + const platform::DeviceContext &dev_ctx) { + // TODO(typhoonzero): serialize to ostream + { // the 1st field, uint32_t version + constexpr uint32_t version = 0; + os.write(reinterpret_cast(&version), sizeof(version)); + } + { // the 2nd field, tensor description + // int32_t size + // void* protobuf message + framework::TensorDesc desc; + desc.set_data_type(framework::ToDataType(tensor.type())); + auto dims = framework::vectorize(tensor.dims()); + auto *pb_dims = desc.mutable_dims(); + pb_dims->Resize(static_cast(dims.size()), 0); + std::copy(dims.begin(), dims.end(), pb_dims->begin()); + int32_t size = desc.ByteSize(); + os.write(reinterpret_cast(&size), sizeof(size)); + auto out = desc.SerializeAsString(); + os.write(out.data(), size); + } + { // the 3rd field, tensor data + uint64_t size = tensor.memory_size(); + auto *data_ptr = tensor.data(); + PADDLE_ENFORCE(size < std::numeric_limits::max(), + "Index overflow when writing tensor"); + if (platform::is_gpu_place(tensor.place())) { +#ifdef PADDLE_WITH_CUDA + constexpr size_t kBufSize = 1024 * 1024 * 64; // 64MB + std::unique_ptr buf(new char[kBufSize]); + auto &gpu_dev_ctx = + static_cast(dev_ctx); + platform::CPUPlace cpu; + uintptr_t data = reinterpret_cast(data_ptr); + while (size != 0) { + size_t size_to_write = std::min(kBufSize, static_cast(size)); + memory::Copy(cpu, buf.get(), + boost::get(tensor.place()), + reinterpret_cast(data), size_to_write, + gpu_dev_ctx.stream()); + gpu_dev_ctx.Wait(); + os.write(buf.get(), size_to_write); + data += size_to_write; + size -= size_to_write; + } +#else + PADDLE_THROW("Unexpected branch"); +#endif + } else { + os.write(static_cast(data_ptr), + static_cast(size)); + } + } + { // the 4th field, lod information + // uint64_t lod_level + // uint64_t lod_level_1 size in byte. + // int* lod_level_1 data + // ... + auto lod = tensor.lod(); + uint64_t size = lod.size(); + os.write(reinterpret_cast(&size), sizeof(size)); + + for (auto &each : lod) { + size = each.size() * sizeof(framework::LoD::value_type::value_type); + os.write(reinterpret_cast(&size), sizeof(size)); + os.write(reinterpret_cast(each.data()), + static_cast(size)); + } + } +} + +void DeserializeFromStream(std::istream &is, LoDTensor *tensor) { + uint32_t version; + is.read(reinterpret_cast(&version), sizeof(version)); + PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported"); + framework::TensorDesc desc; + { // int32_t size + // proto buffer + int32_t size; + is.read(reinterpret_cast(&size), sizeof(size)); + std::unique_ptr buf(new char[size]); + is.read(reinterpret_cast(buf.get()), size); + PADDLE_ENFORCE(desc.ParseFromArray(buf.get(), size), + "Cannot parse tensor desc"); + } + { // read tensor + std::vector dims; + dims.reserve(static_cast(desc.dims().size())); + std::copy(desc.dims().begin(), desc.dims().end(), std::back_inserter(dims)); + tensor->Resize(framework::make_ddim(dims)); + + void *buf; + platform::Place cpu = platform::CPUPlace(); + switch (desc.data_type()) { + case framework::FP32: + buf = tensor->mutable_data(cpu); + break; + case framework::FP64: + buf = tensor->mutable_data(cpu); + break; + case framework::INT32: + buf = tensor->mutable_data(cpu); + break; + case framework::INT64: + buf = tensor->mutable_data(cpu); + break; + default: + PADDLE_THROW("DataType %d not supported", desc.data_type()); + } + is.read(static_cast(buf), tensor->memory_size()); + } + { // read lod + uint64_t lod_level; + is.read(reinterpret_cast(&lod_level), sizeof(lod_level)); + auto &lod = *tensor->mutable_lod(); + lod.resize(lod_level); + for (uint64_t i = 0; i < lod_level; ++i) { + uint64_t size; + is.read(reinterpret_cast(&size), sizeof(size)); + std::vector tmp(size / sizeof(size_t)); + is.read(reinterpret_cast(tmp.data()), + static_cast(size)); + lod[i] = tmp; + } + } +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/lod_tensor.h b/paddle/framework/lod_tensor.h index 21bdfca1111..9411c96aea4 100644 --- a/paddle/framework/lod_tensor.h +++ b/paddle/framework/lod_tensor.h @@ -189,5 +189,14 @@ std::pair> GetSubLoDAndAbsoluteOffset( void AppendLoD(LoD* lod, const LoD& lod_length); +/* + * Serialize/Desiralize LoDTensor to std::ostream + * You can pass ofstream or ostringstream to serilize to file + * or to a in memory string. GPU tensor will be copied to CPU. + */ +void SerializeToStream(std::ostream& os, const LoDTensor& tensor, + const platform::DeviceContext& dev_ctx); +void DeserializeFromStream(std::istream& is, LoDTensor* tensor); + } // namespace framework } // namespace paddle diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index a4c4374cf2f..7e5d4fd640f 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -205,8 +205,24 @@ set(DEPS_OPS tensor_array_read_write_op gru_op adagrad_op - sgd_op) + sgd_op + save_op + load_op + send_op + recv_op) +add_subdirectory(detail) +op_library(send_op SRCS send_op.cc DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib_target protobuf) +set_source_files_properties( + send_op.cc + PROPERTIES + COMPILE_FLAGS "-Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") + +op_library(recv_op SRCS recv_op.cc DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib_target protobuf) +set_source_files_properties( + recv_op.cc + PROPERTIES + COMPILE_FLAGS "-Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op) op_library(cross_entropy_op DEPS cross_entropy) @@ -235,6 +251,10 @@ op_library(conv_transpose_op DEPS vol2col) op_library(gru_op DEPS sequence2batch gru_compute) op_library(recurrent_op SRCS recurrent_op.cc DEPS executor) +# FIXME(typhoonzero): save/load depends lodtensor serialization functions +op_library(save_op DEPS lod_tensor) +op_library(load_op DEPS lod_tensor) + list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) foreach(src ${GENERAL_OPS}) op_library(${src}) @@ -242,6 +262,8 @@ endforeach() set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library") + + cc_test(gather_test SRCS gather_test.cc DEPS tensor) cc_test(net_op_test SRCS net_op_test.cc DEPS net_op) cc_test(scatter_test SRCS scatter_test.cc DEPS tensor) @@ -251,3 +273,4 @@ if(WITH_GPU) cc_test(nccl_op_test SRCS nccl_op_test.cu.cc DEPS nccl_op gpu_info device_context) endif() cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op) +cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS send_op recv_op sum_op executor) diff --git a/paddle/operators/detail/CMakeLists.txt b/paddle/operators/detail/CMakeLists.txt new file mode 100644 index 00000000000..f6bdc63cc2c --- /dev/null +++ b/paddle/operators/detail/CMakeLists.txt @@ -0,0 +1 @@ +grpc_library(sendrecvop_grpc SRCS recv_impl.cc send_impl.cc PROTO send_recv.proto DEPS lod_tensor selected_rows) diff --git a/paddle/operators/detail/recv_impl.cc b/paddle/operators/detail/recv_impl.cc new file mode 100644 index 00000000000..89dc5045221 --- /dev/null +++ b/paddle/operators/detail/recv_impl.cc @@ -0,0 +1,44 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "send_recv_impl.h" + +namespace paddle { +namespace operators { +namespace detail { + +Status SendRecvServerImpl::SendVariable(ServerContext *context, + const VariableMessage *in_var, + VariableMessage *out_var) { + framework::LoDTensor t; + // TODO(typhoonzero): desirealize in_tensor and run pserver network. + std::istringstream iss(in_var->serialized()); + framework::DeserializeFromStream(iss, &t); + lodtensor_queue_.Push(std::move(t)); + // Block util the sub graph is done. + t = lodtensor_return_queue_.Pop(); + std::ostringstream oss; + // FIXME(typhoonzero): get context from op. + framework::SerializeToStream(oss, t, platform::CPUDeviceContext()); + std::string *varname = out_var->mutable_varname(); + *varname = in_var->varname(); + std::string *serialized = out_var->mutable_serialized(); + *serialized = oss.str(); + + return Status::OK; +} + +} // namespace detail +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/detail/send_impl.cc b/paddle/operators/detail/send_impl.cc new file mode 100644 index 00000000000..da1ddf75d2a --- /dev/null +++ b/paddle/operators/detail/send_impl.cc @@ -0,0 +1,54 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "send_recv_impl.h" + +namespace paddle { +namespace operators { +namespace detail { + +bool RPCClient::SendVariable(const framework::Scope& scope, + const std::string& inname, + const std::string& outname) { + ClientContext context; + VariableMessage msg, out_msg; + // FIXME(typhoonzero): pass device context to here. + auto ctx = platform::CPUDeviceContext(); + auto* var = scope.FindVar(inname); + PADDLE_ENFORCE(var); + // TODO(typhoonzero): support SelectedRows + PADDLE_ENFORCE(var->IsType(), + "Only support LoDTensor, %s has wrong type", inname); + const framework::LoDTensor& tensor = var->Get(); + std::ostringstream oss; + framework::SerializeToStream(oss, tensor, ctx); + msg.set_varname(inname); + msg.set_serialized(oss.str()); + Status status = stub_->SendVariable(&context, msg, &out_msg); + if (!status.ok()) { + return false; + } + std::istringstream iss(out_msg.serialized()); + framework::LoDTensor ret_tensor; + framework::DeserializeFromStream(iss, &ret_tensor); + auto* outvar = scope.FindVar(outname); + framework::LoDTensor* out_tensor = outvar->GetMutable(); + // FIXME(typhoonzero): do not copy. + framework::CopyFrom(ret_tensor, ctx.GetPlace(), ctx, out_tensor); + return true; +} + +} // namespace detail +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/detail/send_recv.proto b/paddle/operators/detail/send_recv.proto new file mode 100644 index 00000000000..66f84678b3c --- /dev/null +++ b/paddle/operators/detail/send_recv.proto @@ -0,0 +1,37 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +syntax = "proto3"; + +package sendrecv; + +service SendRecvService { + // For parameter server round-robin like hashing, do not split tensors. + // Send and recv only one tensor + rpc SendVariable(VariableMessage) returns (VariableMessage) {} +} + +// VariableMessage is serialized paddle variable message. +// It can be: +// Tensor +// LoDTensor +// SelectedRows +message VariableMessage { + string varname = 1; + bytes serialized = 2; +} + +message VoidMessage { + +} \ No newline at end of file diff --git a/paddle/operators/detail/send_recv_impl.h b/paddle/operators/detail/send_recv_impl.h new file mode 100644 index 00000000000..b9a5340a863 --- /dev/null +++ b/paddle/operators/detail/send_recv_impl.h @@ -0,0 +1,87 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/framework/data_type.h" +#include "paddle/framework/lod_tensor.h" +#include "paddle/framework/scope.h" +#include "paddle/framework/selected_rows.h" +#include "paddle/operators/detail/simple_block_queue.h" + +// #include +// #include +// #include +// #include +#include "paddle/operators/detail/send_recv.grpc.pb.h" +#include "paddle/operators/detail/send_recv.pb.h" + +#include + +using grpc::Channel; +using grpc::Server; +using grpc::ServerContext; +using grpc::ServerReader; +using grpc::ServerBuilder; + +using grpc::ClientContext; +using grpc::ClientReader; +using grpc::ClientReaderWriter; +using grpc::ClientWriter; +using grpc::Status; +using sendrecv::SendRecvService; +using sendrecv::VariableMessage; +using sendrecv::VoidMessage; + +namespace paddle { +namespace operators { +namespace detail { + +class SendRecvServerImpl final : public SendRecvService::Service { + public: + explicit SendRecvServerImpl() {} + + Status SendVariable(ServerContext *context, const VariableMessage *in_var, + VariableMessage *out_var) override; + + const framework::LoDTensor Get() { return this->lodtensor_queue_.Pop(); } + + void Push(const framework::LoDTensor &tensor) { + this->lodtensor_return_queue_.Push(tensor); + } + + private: + SimpleBlockQueue lodtensor_queue_; + SimpleBlockQueue lodtensor_return_queue_; + SimpleBlockQueue selected_rows_queue_; + SimpleBlockQueue selected_rows_return_queue_; +}; + +// RPCClient is a class to send tensors to pserver sub-network +// using different hashing methods. +class RPCClient { + public: + RPCClient(std::shared_ptr channel) + : stub_(SendRecvService::NewStub(channel)) {} + + bool SendVariable(const framework::Scope &scope, const std::string &inname, + const std::string &outname); + + private: + std::unique_ptr stub_; +}; + +} // namespace detail +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/detail/simple_block_queue.h b/paddle/operators/detail/simple_block_queue.h new file mode 100644 index 00000000000..44899217579 --- /dev/null +++ b/paddle/operators/detail/simple_block_queue.h @@ -0,0 +1,52 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include +#include +#include + +namespace paddle { +namespace operators { +namespace detail { + +template +class SimpleBlockQueue { + private: + std::mutex mutex_; + std::condition_variable condition_; + std::deque queue_; + + public: + void Push(T const& value) { + { + std::unique_lock lock(this->mutex_); + queue_.push_front(value); + } + this->condition_.notify_one(); + } + + T Pop() { + std::unique_lock lock(this->mutex_); + this->condition_.wait(lock, [=] { return !this->queue_.empty(); }); + T rc(std::move(this->queue_.back())); + this->queue_.pop_back(); + return rc; + } +}; + +} // namespace detail +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/load_op.cc b/paddle/operators/load_op.cc index b0838eed161..4e58b84430f 100644 --- a/paddle/operators/load_op.cc +++ b/paddle/operators/load_op.cc @@ -38,61 +38,7 @@ class LoadOp : public framework::OperatorBase { out_var_name); auto *tensor = out_var->GetMutable(); - - uint32_t version; - fin.read(reinterpret_cast(&version), sizeof(version)); - PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported"); - framework::TensorDesc desc; - { // int32_t size - // proto buffer - int32_t size; - fin.read(reinterpret_cast(&size), sizeof(size)); - std::unique_ptr buf(new char[size]); - fin.read(reinterpret_cast(buf.get()), size); - PADDLE_ENFORCE(desc.ParseFromArray(buf.get(), size), - "Cannot parse tensor desc"); - } - { // read tensor - std::vector dims; - dims.reserve(static_cast(desc.dims().size())); - std::copy(desc.dims().begin(), desc.dims().end(), - std::back_inserter(dims)); - tensor->Resize(framework::make_ddim(dims)); - - void *buf; - platform::Place cpu = platform::CPUPlace(); - switch (desc.data_type()) { - case framework::FP32: - buf = tensor->mutable_data(cpu); - break; - case framework::FP64: - buf = tensor->mutable_data(cpu); - break; - case framework::INT32: - buf = tensor->mutable_data(cpu); - break; - case framework::INT64: - buf = tensor->mutable_data(cpu); - break; - default: - PADDLE_THROW("DataType %d not supported", desc.data_type()); - } - fin.read(static_cast(buf), tensor->memory_size()); - } - { // read lod - uint64_t lod_level; - fin.read(reinterpret_cast(&lod_level), sizeof(lod_level)); - auto &lod = *tensor->mutable_lod(); - lod.resize(lod_level); - for (uint64_t i = 0; i < lod_level; ++i) { - uint64_t size; - fin.read(reinterpret_cast(&size), sizeof(size)); - std::vector tmp(size / sizeof(size_t)); - fin.read(reinterpret_cast(tmp.data()), - static_cast(size)); - lod[i] = tmp; - } - } + framework::DeserializeFromStream(fin, tensor); auto place = dev_ctx.GetPlace(); if (platform::is_gpu_place(place)) { diff --git a/paddle/operators/recv_op.cc b/paddle/operators/recv_op.cc new file mode 100644 index 00000000000..c69e416e10f --- /dev/null +++ b/paddle/operators/recv_op.cc @@ -0,0 +1,121 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include +#include +#include +#include + +#include + +#include "paddle/framework/data_type.h" +#include "paddle/framework/executor.h" +#include "paddle/framework/framework.pb.h" +#include "paddle/framework/lod_tensor.h" +#include "paddle/framework/op_registry.h" +#include "paddle/operators/detail/send_recv_impl.h" +#include "paddle/operators/detail/simple_block_queue.h" + +namespace paddle { +namespace operators { + +void RunServer(Server **rpc_server, + std::shared_ptr service, + const std::string &server_address) { + ServerBuilder builder; + builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); + builder.RegisterService(service.get()); + std::unique_ptr server(builder.BuildAndStart()); + *rpc_server = server.get(); + LOG(INFO) << "Server listening on " << server_address << std::endl; + server->Wait(); +} + +class RecvOp : public framework::OperatorBase { + public: + RecvOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorBase(type, inputs, outputs, attrs) { + if (!rpc_service_) { + rpc_service_.reset(new detail::SendRecvServerImpl()); + std::string endpoint = Attr("endpoint"); + server_thread_.reset( + new std::thread(RunServer, &rpc_server_, rpc_service_, endpoint)); + } + } + + virtual ~RecvOp() { + rpc_server_->Shutdown(); + server_thread_->join(); + } + + void Run(const framework::Scope &scope, + const platform::DeviceContext &dev_ctx) const override { + // blocking get one var from client. + const framework::LoDTensor &t = rpc_service_->Get(); + framework::Scope &recv_scope = scope.NewScope(); + // set graph input var + auto *var = recv_scope.Var(Input("RX")); + auto *tensor = var->GetMutable(); + // FIXME(typhoonzero): do not copy + framework::CopyFrom(t, dev_ctx.GetPlace(), dev_ctx, tensor); + + auto *block = Attr("OptimizeBlock"); + auto *program = block->Program(); + framework::Executor executor(dev_ctx); + // Run sub graph to get optimized tensor + executor.Run(*program, &recv_scope, block->ID(), + false /*create_local_scope*/); + + auto *out_var = recv_scope.FindVar("Out"); + // push back + rpc_service_->Push(out_var->Get()); + } + + protected: + // grpc server instance to track status and gracefully shutdown. + // borrow an pointer from server thread. + Server *rpc_server_{nullptr}; + // grpc send/recv service implement to register. + std::shared_ptr rpc_service_; + std::shared_ptr server_thread_; +}; + +class RecvOpMaker : public framework::OpProtoAndCheckerMaker { + public: + RecvOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("RX", "(Tensor) Input tensor to be saved"); + AddComment(R"DOC( +Recv operator + +This operator will recv tensor from send_op +)DOC"); + AddAttr("endpoint", + "(string, default 127.0.0.1:6164)" + "IP address to listen on.") + .SetDefault("127.0.0.1:6164") + .AddCustomChecker([](const std::string &ip) { return !ip.empty(); }); + AddAttr("OptimizeBlock", "type BlockDescBind*", + "optimize network run in server"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(recv, ops::RecvOp, ops::RecvOpMaker); diff --git a/paddle/operators/save_op.cc b/paddle/operators/save_op.cc index 56909fb65f4..d4921cb80c8 100644 --- a/paddle/operators/save_op.cc +++ b/paddle/operators/save_op.cc @@ -88,73 +88,7 @@ class SaveOp : public framework::OperatorBase { "SaveOp only support LoDTensor, %s has wrong type", iname); auto &tensor = var->Get(); - - { // the 1st field, uint32_t version - constexpr uint32_t version = 0; - fout.write(reinterpret_cast(&version), sizeof(version)); - } - { // the 2nd field, tensor description - // int32_t size - // void* protobuf message - framework::TensorDesc desc; - desc.set_data_type(framework::ToDataType(tensor.type())); - auto dims = framework::vectorize(tensor.dims()); - auto *pb_dims = desc.mutable_dims(); - pb_dims->Resize(static_cast(dims.size()), 0); - std::copy(dims.begin(), dims.end(), pb_dims->begin()); - int32_t size = desc.ByteSize(); - fout.write(reinterpret_cast(&size), sizeof(size)); - auto out = desc.SerializeAsString(); - fout.write(out.data(), size); - } - { // the 3rd field, tensor data - uint64_t size = tensor.memory_size(); - auto *data_ptr = tensor.data(); - PADDLE_ENFORCE(size < std::numeric_limits::max(), - "Index overflow when writing tensor"); - if (platform::is_gpu_place(tensor.place())) { -#ifdef PADDLE_WITH_CUDA - constexpr size_t kBufSize = 1024 * 1024 * 64; // 64MB - std::unique_ptr buf(new char[kBufSize]); - auto &gpu_dev_ctx = - static_cast(dev_ctx); - platform::CPUPlace cpu; - uintptr_t data = reinterpret_cast(data_ptr); - while (size != 0) { - size_t size_to_write = std::min(kBufSize, static_cast(size)); - memory::Copy(cpu, buf.get(), - boost::get(tensor.place()), - reinterpret_cast(data), size_to_write, - gpu_dev_ctx.stream()); - gpu_dev_ctx.Wait(); - fout.write(buf.get(), size_to_write); - data += size_to_write; - size -= size_to_write; - } -#else - PADDLE_THROW("Unexpected branch"); -#endif - } else { - fout.write(static_cast(data_ptr), - static_cast(size)); - } - } - { // the 4th field, lod information - // uint64_t lod_level - // uint64_t lod_level_1 size in byte. - // int* lod_level_1 data - // ... - auto lod = tensor.lod(); - uint64_t size = lod.size(); - fout.write(reinterpret_cast(&size), sizeof(size)); - - for (auto &each : lod) { - size = each.size() * sizeof(framework::LoD::value_type::value_type); - fout.write(reinterpret_cast(&size), sizeof(size)); - fout.write(reinterpret_cast(each.data()), - static_cast(size)); - } - } + framework::SerializeToStream(fout, tensor, dev_ctx); } }; diff --git a/paddle/operators/send_op.cc b/paddle/operators/send_op.cc new file mode 100644 index 00000000000..a3059847f2d --- /dev/null +++ b/paddle/operators/send_op.cc @@ -0,0 +1,84 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include + +#include "paddle/framework/data_type.h" +#include "paddle/framework/framework.pb.h" +#include "paddle/framework/lod_tensor.h" +#include "paddle/framework/op_registry.h" + +#include "paddle/operators/detail/send_recv_impl.h" +#include "paddle/operators/detail/simple_block_queue.h" + +namespace paddle { +namespace operators { + +// TODO(typhoonzero): this is a simple implementation which only send +// one tensor +class SendOp : public framework::OperatorBase { + public: + SendOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorBase(type, inputs, outputs, attrs) { + // init client when the operator is created at runtime. + if (!client_) { + std::string endpoint = Attr("endpoint"); + client_.reset(new detail::RPCClient( + grpc::CreateChannel(endpoint, grpc::InsecureChannelCredentials()))); + // TODO(typhoonzero): how to call InitVariables + } + } + void Run(const framework::Scope &scope, + const platform::DeviceContext &dev_ctx) const override { + auto iname = Input("X"); + auto oname = Output("Out"); + // TODO(typhoonzero): currently it's non-blocking, + // should block until server responds. + bool ret = client_->SendVariable(scope, iname, oname); + if (!ret) { + LOG(ERROR) << "send variable error"; + } + } + + protected: + std::shared_ptr client_{nullptr}; +}; + +class SendOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SendOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "(Tensor) Input tensor to be saved"); + AddOutput("Out", "(Tensor) Output fetched from server"); + AddComment(R"DOC( +Recv operator + +This operator will recv tensor from send_op +)DOC"); + AddAttr("endpoint", + "(string, default 127.0.0.1:6164)" + "IP address to listen on.") + .SetDefault("127.0.0.1:6164") + .AddCustomChecker([](const std::string &ip) { return !ip.empty(); }); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(send, ops::SendOp, ops::SendOpMaker); diff --git a/paddle/operators/send_recv_op_test.cc b/paddle/operators/send_recv_op_test.cc new file mode 100644 index 00000000000..ac03eb3752e --- /dev/null +++ b/paddle/operators/send_recv_op_test.cc @@ -0,0 +1,125 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +// TODO(typhoonzero): add python bindings for this test as +// a RemoteOptimizer. + +#include +#include + +#include "gtest/gtest.h" +#include "paddle/framework/op_registry.h" +#include "paddle/framework/operator.h" +#include "paddle/framework/program_desc.h" + +USE_NO_KERNEL_OP(send); +USE_NO_KERNEL_OP(recv); +USE_OP(sum); + +// global for simplicity. +std::unique_ptr recv_op; + +void InitTensorsInScope(paddle::framework::Scope &scope, + paddle::platform::CPUPlace &place) { + paddle::platform::CPUDeviceContext ctx(place); + auto var = scope.Var("X"); + auto tensor = var->GetMutable(); + tensor->Resize({10, 10}); + float *expect = tensor->mutable_data(place); + for (int64_t i = 0; i < tensor->numel(); ++i) { + expect[i] = static_cast(i); + } + + auto out_var = scope.Var("Out"); + auto out_tensor = out_var->GetMutable(); + out_tensor->Resize({10, 10}); + tensor->mutable_data(place); // allocate +} + +void AddOp(const std::string &type, + const paddle::framework::VariableNameMap &inputs, + const paddle::framework::VariableNameMap &outputs, + paddle::framework::AttributeMap attrs, + paddle::framework::BlockDescBind *block) { + // insert output + for (auto kv : outputs) { + for (auto v : kv.second) { + auto var = block->Var(v); + var->SetDataType(paddle::framework::DataType::FP32); + } + } + + // insert op + auto op = block->AppendOp(); + op->SetType(type); + for (auto &kv : inputs) { + op->SetInput(kv.first, kv.second); + } + for (auto &kv : outputs) { + op->SetOutput(kv.first, kv.second); + } + op->SetAttrMap(attrs); +} + +void StartServerNet() { + paddle::framework::Scope scope; + paddle::platform::CPUPlace place; + InitTensorsInScope(scope, place); + + // sub program run in recv_op, for simple test we use sum + paddle::framework::ProgramDescBind program; + paddle::framework::BlockDescBind *block = program.MutableBlock(0); + // X for server side tensors, RX for received tensers, must be of same shape. + AddOp("sum", {{"X", {"X", "RX"}}}, {{"Out", {"Out"}}}, {}, block); + + paddle::framework::AttributeMap attrs; + attrs.insert({"endpoint", std::string("127.0.0.1:6174")}); + attrs.insert({"OptimizeBlock", block}); + recv_op = paddle::framework::OpRegistry::CreateOp("recv", {{"RX", {"RX"}}}, + {{"Out", {"Out"}}}, attrs); + paddle::platform::CPUDeviceContext ctx(place); + recv_op->Run(scope, ctx); +} + +TEST(SendRecvOp, CPU) { + std::thread server_thread(StartServerNet); + sleep(5); // wait server to start + // local net + paddle::framework::Scope scope; + paddle::platform::CPUPlace place; + InitTensorsInScope(scope, place); + + paddle::framework::AttributeMap attrs; + attrs.insert({"endpoint", std::string("127.0.0.1:6174")}); + + auto send_op = paddle::framework::OpRegistry::CreateOp( + "send", {{"X", {"X"}}}, {{"Out", {"Out"}}}, attrs); + paddle::platform::CPUDeviceContext ctx(place); + send_op->Run(scope, ctx); + + auto in_var = scope.Var("X"); + auto tensor = in_var->GetMutable(); + float *expected = tensor->data(); + + auto out_var = scope.Var("Out"); + auto target = out_var->GetMutable(); + // send fail cause output is none. + EXPECT_NE(target->memory_size(), size_t(0)); + float *actual = target->data(); + for (int64_t i = 0; i < target->numel(); ++i) { + EXPECT_EQ(expected[i] * 2, actual[i]); + } + recv_op.reset(); // dtor can shutdown and join server thread. + server_thread.join(); +} -- GitLab