From 22ea4c30c2679a89e6800b7c1cdc2c68ff4e55cb Mon Sep 17 00:00:00 2001 From: tianshuo78520a <707759223@qq.com> Date: Wed, 14 Apr 2021 11:21:49 +0800 Subject: [PATCH] Delete grpc.cmake/distribeted/distributed_ops (#32166) * Delete grpc.cmake/distribeted/distributed_ops * reset operators/CMakeLists.txt * rm test_transpiler_ops.py * del test_transpiler_ops.py --- cmake/external/grpc.cmake | 77 -- .../operators/collective/allreduce_op.cc | 2 +- .../operators/collective/allreduce_op.cu.cc | 2 +- .../operators/distributed/CMakeLists.txt | 76 -- .../async_sparse_param_update_recorder.cc | 27 - .../async_sparse_param_update_recorder.h | 186 ---- ...async_sparse_param_update_recorder_test.cc | 97 -- .../operators/distributed/brpc/brpc_client.cc | 462 -------- .../operators/distributed/brpc/brpc_client.h | 174 --- .../distributed/brpc/brpc_rdma_pool.cc | 86 -- .../distributed/brpc/brpc_rdma_pool.h | 56 - .../distributed/brpc/brpc_sendrecvop_utils.cc | 224 ---- .../distributed/brpc/brpc_sendrecvop_utils.h | 49 - .../distributed/brpc/brpc_serde_test.cc | 175 ---- .../operators/distributed/brpc/brpc_server.cc | 417 -------- .../operators/distributed/brpc/brpc_server.h | 53 - .../brpc/brpc_variable_response.cc | 75 -- .../distributed/brpc/brpc_variable_response.h | 67 -- .../distributed/collective_client.cc | 57 - .../operators/distributed/collective_client.h | 104 -- .../distributed/collective_server.cc | 68 -- .../operators/distributed/collective_server.h | 116 -- .../distributed/collective_server_test.cc | 131 --- .../operators/distributed/communicator.cc | 989 ------------------ .../operators/distributed/communicator.h | 490 --------- .../distributed/communicator_common.h | 91 -- .../distributed/communicator_test.cc | 106 -- .../fluid/operators/distributed/distributed.h | 36 - .../operators/distributed/distributed_pb.h | 30 - .../grpc/grpc_bytebuffer_stream.cc | 92 -- .../distributed/grpc/grpc_bytebuffer_stream.h | 174 --- .../operators/distributed/grpc/grpc_client.cc | 671 ------------ .../operators/distributed/grpc/grpc_client.h | 321 ------ .../operators/distributed/grpc/grpc_serde.cc | 190 ---- .../operators/distributed/grpc/grpc_serde.h | 69 -- .../distributed/grpc/grpc_serde_test.cc | 224 ---- .../operators/distributed/grpc/grpc_server.cc | 720 ------------- .../operators/distributed/grpc/grpc_server.h | 93 -- .../operators/distributed/grpc/grpc_service.h | 145 --- .../grpc/grpc_variable_response.cc | 344 ------ .../distributed/grpc/grpc_variable_response.h | 67 -- .../distributed/heart_beat_monitor.cc | 97 -- .../distributed/heart_beat_monitor.h | 127 --- .../distributed/heart_beat_monitor_test.cc | 54 - .../operators/distributed/large_scale_kv.cc | 26 - .../operators/distributed/large_scale_kv.h | 848 --------------- .../distributed/parameter_prefetch.cc | 311 ------ .../distributed/parameter_prefetch.h | 53 - .../operators/distributed/parameter_recv.cc | 248 ----- .../operators/distributed/parameter_recv.h | 37 - .../operators/distributed/parameter_send.cc | 331 ------ .../operators/distributed/parameter_send.h | 35 - .../distributed/proto_encoder_helper.h | 146 --- .../operators/distributed/request_handler.h | 261 ----- .../distributed/request_handler_impl.cc | 354 ------- .../distributed/request_handler_impl.h | 198 ---- .../fluid/operators/distributed/rpc_client.cc | 32 - .../fluid/operators/distributed/rpc_client.h | 143 --- .../fluid/operators/distributed/rpc_server.cc | 242 ----- .../fluid/operators/distributed/rpc_server.h | 149 --- .../operators/distributed/rpc_server_test.cc | 344 ------ .../operators/distributed/send_recv.proto.in | 88 -- .../operators/distributed/sendrecvop_utils.cc | 117 --- .../operators/distributed/sendrecvop_utils.h | 110 -- .../operators/distributed/varhandle_test.cc | 50 - .../distributed/variable_response.cc | 271 ----- .../operators/distributed/variable_response.h | 155 --- .../operators/distributed_ops/CMakeLists.txt | 38 - .../operators/distributed_ops/allreduce_op.cc | 80 -- .../distributed_ops/allreduce_op.cu.cc | 25 - .../operators/distributed_ops/allreduce_op.h | 90 -- .../operators/distributed_ops/broadcast_op.cc | 79 -- .../distributed_ops/broadcast_op.cu.cc | 91 -- .../distributed_ops/checkpoint_notify_op.cc | 117 --- .../distributed_lookup_table_op.cc | 156 --- .../distributed_lookup_table_op.cu.cc | 22 - .../distributed_lookup_table_op.h | 66 -- .../operators/distributed_ops/fake_init_op.cc | 81 -- .../distributed_ops/fetch_barrier_op.cc | 105 -- .../distributed_ops/fl_listen_and_serv_op.cc | 284 ----- .../distributed_ops/fl_listen_and_serv_op.h | 107 -- .../distributed_ops/gen_nccl_id_op.cc | 313 ------ .../distributed_ops/listen_and_serv_op.cc | 636 ----------- .../distributed_ops/listen_and_serv_op.h | 135 --- .../lookup_sparse_table_fuse_adam_op.cc | 158 --- .../lookup_sparse_table_fuse_adam_op.h | 142 --- .../lookup_sparse_table_fuse_sgd_op.cc | 125 --- .../lookup_sparse_table_fuse_sgd_op.h | 105 -- .../lookup_sparse_table_grad_split_op.cc | 79 -- .../lookup_sparse_table_grad_split_op.h | 97 -- .../lookup_sparse_table_init_op.cc | 147 --- .../lookup_sparse_table_merge_op.cc | 84 -- .../lookup_sparse_table_merge_op.h | 78 -- .../lookup_sparse_table_read_op.cc | 133 --- .../lookup_sparse_table_write_op.cc | 116 -- .../operators/distributed_ops/merge_ids_op.cc | 134 --- .../operators/distributed_ops/merge_ids_op.h | 112 -- .../operators/distributed_ops/prefetch_op.cc | 119 --- .../operators/distributed_ops/recv_op.cc | 153 --- .../operators/distributed_ops/recv_save_op.cc | 328 ------ .../distributed_ops/ref_by_trainer_id_op.cc | 99 -- .../ref_by_trainer_id_op.cu.cc | 26 - .../distributed_ops/ref_by_trainer_id_op.h | 53 - .../distributed_ops/send_and_recv_op.cc | 98 -- .../distributed_ops/send_barrier_op.cc | 120 --- .../operators/distributed_ops/send_op.cc | 160 --- .../distributed_ops/send_recv_op_test.cc | 257 ----- .../distributed_ops/send_recv_util.h | 73 -- .../distributed_ops/sparse_tensor_load_op.cc | 217 ---- .../distributed_ops/split_byref_op.cc | 103 -- .../distributed_ops/split_byref_op.cu.cc | 19 - .../distributed_ops/split_byref_op.h | 43 - .../operators/distributed_ops/split_ids_op.cc | 96 -- .../operators/distributed_ops/split_ids_op.h | 127 --- .../distributed_ops/test_send_nccl_id.cc | 107 -- .../fluid/operators/split_selected_rows_op.cc | 99 -- .../fluid/operators/split_selected_rows_op.cu | 19 - .../fluid/operators/split_selected_rows_op.h | 108 -- .../fluid/tests/unittests/CMakeLists.txt | 2 - .../unittests/test_split_selected_rows_op.py | 137 --- .../tests/unittests/test_transpiler_ops.py | 143 --- 121 files changed, 2 insertions(+), 19099 deletions(-) delete mode 100644 cmake/external/grpc.cmake delete mode 100644 paddle/fluid/operators/distributed/CMakeLists.txt delete mode 100644 paddle/fluid/operators/distributed/async_sparse_param_update_recorder.cc delete mode 100644 paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h delete mode 100644 paddle/fluid/operators/distributed/async_sparse_param_update_recorder_test.cc delete mode 100644 paddle/fluid/operators/distributed/brpc/brpc_client.cc delete mode 100644 paddle/fluid/operators/distributed/brpc/brpc_client.h delete mode 100644 paddle/fluid/operators/distributed/brpc/brpc_rdma_pool.cc delete mode 100644 paddle/fluid/operators/distributed/brpc/brpc_rdma_pool.h delete mode 100644 paddle/fluid/operators/distributed/brpc/brpc_sendrecvop_utils.cc delete mode 100644 paddle/fluid/operators/distributed/brpc/brpc_sendrecvop_utils.h delete mode 100644 paddle/fluid/operators/distributed/brpc/brpc_serde_test.cc delete mode 100644 paddle/fluid/operators/distributed/brpc/brpc_server.cc delete mode 100644 paddle/fluid/operators/distributed/brpc/brpc_server.h delete mode 100644 paddle/fluid/operators/distributed/brpc/brpc_variable_response.cc delete mode 100644 paddle/fluid/operators/distributed/brpc/brpc_variable_response.h delete mode 100644 paddle/fluid/operators/distributed/collective_client.cc delete mode 100644 paddle/fluid/operators/distributed/collective_client.h delete mode 100644 paddle/fluid/operators/distributed/collective_server.cc delete mode 100644 paddle/fluid/operators/distributed/collective_server.h delete mode 100644 paddle/fluid/operators/distributed/collective_server_test.cc delete mode 100644 paddle/fluid/operators/distributed/communicator.cc delete mode 100644 paddle/fluid/operators/distributed/communicator.h delete mode 100644 paddle/fluid/operators/distributed/communicator_common.h delete mode 100644 paddle/fluid/operators/distributed/communicator_test.cc delete mode 100644 paddle/fluid/operators/distributed/distributed.h delete mode 100644 paddle/fluid/operators/distributed/distributed_pb.h delete mode 100644 paddle/fluid/operators/distributed/grpc/grpc_bytebuffer_stream.cc delete mode 100644 paddle/fluid/operators/distributed/grpc/grpc_bytebuffer_stream.h delete mode 100644 paddle/fluid/operators/distributed/grpc/grpc_client.cc delete mode 100644 paddle/fluid/operators/distributed/grpc/grpc_client.h delete mode 100644 paddle/fluid/operators/distributed/grpc/grpc_serde.cc delete mode 100644 paddle/fluid/operators/distributed/grpc/grpc_serde.h delete mode 100644 paddle/fluid/operators/distributed/grpc/grpc_serde_test.cc delete mode 100644 paddle/fluid/operators/distributed/grpc/grpc_server.cc delete mode 100644 paddle/fluid/operators/distributed/grpc/grpc_server.h delete mode 100644 paddle/fluid/operators/distributed/grpc/grpc_service.h delete mode 100644 paddle/fluid/operators/distributed/grpc/grpc_variable_response.cc delete mode 100644 paddle/fluid/operators/distributed/grpc/grpc_variable_response.h delete mode 100644 paddle/fluid/operators/distributed/heart_beat_monitor.cc delete mode 100644 paddle/fluid/operators/distributed/heart_beat_monitor.h delete mode 100644 paddle/fluid/operators/distributed/heart_beat_monitor_test.cc delete mode 100644 paddle/fluid/operators/distributed/large_scale_kv.cc delete mode 100644 paddle/fluid/operators/distributed/large_scale_kv.h delete mode 100644 paddle/fluid/operators/distributed/parameter_prefetch.cc delete mode 100644 paddle/fluid/operators/distributed/parameter_prefetch.h delete mode 100644 paddle/fluid/operators/distributed/parameter_recv.cc delete mode 100644 paddle/fluid/operators/distributed/parameter_recv.h delete mode 100644 paddle/fluid/operators/distributed/parameter_send.cc delete mode 100644 paddle/fluid/operators/distributed/parameter_send.h delete mode 100644 paddle/fluid/operators/distributed/proto_encoder_helper.h delete mode 100644 paddle/fluid/operators/distributed/request_handler.h delete mode 100644 paddle/fluid/operators/distributed/request_handler_impl.cc delete mode 100644 paddle/fluid/operators/distributed/request_handler_impl.h delete mode 100644 paddle/fluid/operators/distributed/rpc_client.cc delete mode 100644 paddle/fluid/operators/distributed/rpc_client.h delete mode 100644 paddle/fluid/operators/distributed/rpc_server.cc delete mode 100644 paddle/fluid/operators/distributed/rpc_server.h delete mode 100644 paddle/fluid/operators/distributed/rpc_server_test.cc delete mode 100644 paddle/fluid/operators/distributed/send_recv.proto.in delete mode 100644 paddle/fluid/operators/distributed/sendrecvop_utils.cc delete mode 100644 paddle/fluid/operators/distributed/sendrecvop_utils.h delete mode 100644 paddle/fluid/operators/distributed/varhandle_test.cc delete mode 100644 paddle/fluid/operators/distributed/variable_response.cc delete mode 100644 paddle/fluid/operators/distributed/variable_response.h delete mode 100644 paddle/fluid/operators/distributed_ops/CMakeLists.txt delete mode 100644 paddle/fluid/operators/distributed_ops/allreduce_op.cc delete mode 100644 paddle/fluid/operators/distributed_ops/allreduce_op.cu.cc delete mode 100644 paddle/fluid/operators/distributed_ops/allreduce_op.h delete mode 100644 paddle/fluid/operators/distributed_ops/broadcast_op.cc delete mode 100644 paddle/fluid/operators/distributed_ops/broadcast_op.cu.cc delete mode 100644 paddle/fluid/operators/distributed_ops/checkpoint_notify_op.cc delete mode 100644 paddle/fluid/operators/distributed_ops/distributed_lookup_table_op.cc delete mode 100644 paddle/fluid/operators/distributed_ops/distributed_lookup_table_op.cu.cc delete mode 100644 paddle/fluid/operators/distributed_ops/distributed_lookup_table_op.h delete mode 100644 paddle/fluid/operators/distributed_ops/fake_init_op.cc delete mode 100644 paddle/fluid/operators/distributed_ops/fetch_barrier_op.cc delete mode 100644 paddle/fluid/operators/distributed_ops/fl_listen_and_serv_op.cc delete mode 100644 paddle/fluid/operators/distributed_ops/fl_listen_and_serv_op.h delete mode 100644 paddle/fluid/operators/distributed_ops/gen_nccl_id_op.cc delete mode 100644 paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc delete mode 100644 paddle/fluid/operators/distributed_ops/listen_and_serv_op.h delete mode 100644 paddle/fluid/operators/distributed_ops/lookup_sparse_table_fuse_adam_op.cc delete mode 100644 paddle/fluid/operators/distributed_ops/lookup_sparse_table_fuse_adam_op.h delete mode 100644 paddle/fluid/operators/distributed_ops/lookup_sparse_table_fuse_sgd_op.cc delete mode 100644 paddle/fluid/operators/distributed_ops/lookup_sparse_table_fuse_sgd_op.h delete mode 100644 paddle/fluid/operators/distributed_ops/lookup_sparse_table_grad_split_op.cc delete mode 100644 paddle/fluid/operators/distributed_ops/lookup_sparse_table_grad_split_op.h delete mode 100644 paddle/fluid/operators/distributed_ops/lookup_sparse_table_init_op.cc delete mode 100644 paddle/fluid/operators/distributed_ops/lookup_sparse_table_merge_op.cc delete mode 100644 paddle/fluid/operators/distributed_ops/lookup_sparse_table_merge_op.h delete mode 100644 paddle/fluid/operators/distributed_ops/lookup_sparse_table_read_op.cc delete mode 100644 paddle/fluid/operators/distributed_ops/lookup_sparse_table_write_op.cc delete mode 100644 paddle/fluid/operators/distributed_ops/merge_ids_op.cc delete mode 100644 paddle/fluid/operators/distributed_ops/merge_ids_op.h delete mode 100644 paddle/fluid/operators/distributed_ops/prefetch_op.cc delete mode 100644 paddle/fluid/operators/distributed_ops/recv_op.cc delete mode 100644 paddle/fluid/operators/distributed_ops/recv_save_op.cc delete mode 100644 paddle/fluid/operators/distributed_ops/ref_by_trainer_id_op.cc delete mode 100644 paddle/fluid/operators/distributed_ops/ref_by_trainer_id_op.cu.cc delete mode 100644 paddle/fluid/operators/distributed_ops/ref_by_trainer_id_op.h delete mode 100644 paddle/fluid/operators/distributed_ops/send_and_recv_op.cc delete mode 100644 paddle/fluid/operators/distributed_ops/send_barrier_op.cc delete mode 100644 paddle/fluid/operators/distributed_ops/send_op.cc delete mode 100644 paddle/fluid/operators/distributed_ops/send_recv_op_test.cc delete mode 100644 paddle/fluid/operators/distributed_ops/send_recv_util.h delete mode 100644 paddle/fluid/operators/distributed_ops/sparse_tensor_load_op.cc delete mode 100644 paddle/fluid/operators/distributed_ops/split_byref_op.cc delete mode 100644 paddle/fluid/operators/distributed_ops/split_byref_op.cu.cc delete mode 100644 paddle/fluid/operators/distributed_ops/split_byref_op.h delete mode 100644 paddle/fluid/operators/distributed_ops/split_ids_op.cc delete mode 100644 paddle/fluid/operators/distributed_ops/split_ids_op.h delete mode 100644 paddle/fluid/operators/distributed_ops/test_send_nccl_id.cc delete mode 100644 paddle/fluid/operators/split_selected_rows_op.cc delete mode 100644 paddle/fluid/operators/split_selected_rows_op.cu delete mode 100644 paddle/fluid/operators/split_selected_rows_op.h delete mode 100644 python/paddle/fluid/tests/unittests/test_split_selected_rows_op.py delete mode 100644 python/paddle/fluid/tests/unittests/test_transpiler_ops.py diff --git a/cmake/external/grpc.cmake b/cmake/external/grpc.cmake deleted file mode 100644 index 536e95c1dc..0000000000 --- a/cmake/external/grpc.cmake +++ /dev/null @@ -1,77 +0,0 @@ -# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -include (ExternalProject) - -SET(GRPC_SOURCES_DIR ${THIRD_PARTY_PATH}/grpc) -SET(GRPC_INSTALL_DIR ${THIRD_PARTY_PATH}/install/grpc) -SET(GRPC_INCLUDE_DIR "${GRPC_INSTALL_DIR}/include/" CACHE PATH "grpc include directory." FORCE) -SET(GRPC_CPP_PLUGIN "${GRPC_INSTALL_DIR}/bin/grpc_cpp_plugin" CACHE FILEPATH "GRPC_CPP_PLUGIN" FORCE) - -include(ProcessorCount) -ProcessorCount(NUM_OF_PROCESSOR) - -IF(APPLE) - SET(BUILD_CMD make -n HAS_SYSTEM_PROTOBUF=false -s -j ${NUM_OF_PROCESSOR} static grpc_cpp_plugin | sed "s/-Werror//g" | sh) - SET(GRPC_INSTALL_CMD make prefix=${GRPC_INSTALL_DIR} install) -ELSE() - SET(GRPC_CFLAGS "-Wno-error -std=c11 ${CLFAGS}") - SET(GRPC_CXXFLAGS "-Wno-error -std=c++11 ${CXXFLAGS}") - SET(BUILD_CMD make CFLAGS=${GRPC_CFLAGS} CXXFLAGS=${GRPC_CXXFLAGS} HAS_SYSTEM_PROTOBUF=false -s -j ${NUM_OF_PROCESSOR} static grpc_cpp_plugin) - SET(GRPC_INSTALL_CMD make prefix=${GRPC_INSTALL_DIR} install CFLAGS=${GRPC_CFLAGS} CXXFLAGS=${GRPC_CXXFLAGS}) -ENDIF() - -# FIXME(wuyi): do not build zlib cares protobuf twice, find a way to build grpc with them -ExternalProject_Add( - extern_grpc - DEPENDS protobuf zlib - # NOTE(wuyi): - # this package is generated by following steps: - # 1. git clone -b v1.8.x https://github.com/grpc/grpc.git - # 2. git submodule update --init - # 3. keep only zlib, cares, protobuf, boringssl under "third_party", - # checkout and clean other dirs under third_party - # 4. remove .git, and package the directory. - URL http://paddlepaddledeps.bj.bcebos.com/grpc-v1.10.x_paddle.tar.gz - URL_MD5 f5442d137ddccee252e194b1bc90f98c - PREFIX ${GRPC_SOURCES_DIR} - UPDATE_COMMAND "" - CONFIGURE_COMMAND "" - BUILD_IN_SOURCE 1 - # NOTE(yuyang18): - # Disable -Werror, otherwise the compile will fail in MacOS. - # It seems that we cannot configure that by make command. - # Just dry run make command and remove `-Werror`, then use a shell to run make commands - BUILD_COMMAND ${BUILD_CMD} - INSTALL_COMMAND ${GRPC_INSTALL_CMD} -) - -ADD_LIBRARY(grpc++_unsecure STATIC IMPORTED GLOBAL) -SET_PROPERTY(TARGET grpc++_unsecure PROPERTY IMPORTED_LOCATION - "${GRPC_INSTALL_DIR}/lib/libgrpc++_unsecure.a") - -ADD_LIBRARY(grpc++ STATIC IMPORTED GLOBAL) -SET_PROPERTY(TARGET grpc++ PROPERTY IMPORTED_LOCATION - "${GRPC_INSTALL_DIR}/lib/libgrpc++.a") -ADD_LIBRARY(gpr STATIC IMPORTED GLOBAL) -SET_PROPERTY(TARGET gpr PROPERTY IMPORTED_LOCATION - "${GRPC_INSTALL_DIR}/lib/libgpr.a") - -ADD_LIBRARY(grpc_unsecure STATIC IMPORTED GLOBAL) -SET_PROPERTY(TARGET grpc_unsecure PROPERTY IMPORTED_LOCATION - "${GRPC_INSTALL_DIR}/lib/libgrpc_unsecure.a") - -include_directories(${GRPC_INCLUDE_DIR}) -ADD_DEPENDENCIES(grpc++_unsecure extern_grpc) diff --git a/paddle/fluid/operators/collective/allreduce_op.cc b/paddle/fluid/operators/collective/allreduce_op.cc index 86f1c28a9d..63b135a74c 100644 --- a/paddle/fluid/operators/collective/allreduce_op.cc +++ b/paddle/fluid/operators/collective/allreduce_op.cc @@ -15,7 +15,7 @@ limitations under the License. */ #include // NOLINT #include -#include "paddle/fluid/operators/distributed_ops/allreduce_op.h" +#include "paddle/fluid/operators/collective/allreduce_op.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/collective/allreduce_op.cu.cc b/paddle/fluid/operators/collective/allreduce_op.cu.cc index 9b70f78399..fe2e491055 100644 --- a/paddle/fluid/operators/collective/allreduce_op.cu.cc +++ b/paddle/fluid/operators/collective/allreduce_op.cu.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/distributed_ops/allreduce_op.h" +#include "paddle/fluid/operators/collective/allreduce_op.h" namespace ops = paddle::operators; namespace plat = paddle::platform; diff --git a/paddle/fluid/operators/distributed/CMakeLists.txt b/paddle/fluid/operators/distributed/CMakeLists.txt deleted file mode 100644 index c9db6148bc..0000000000 --- a/paddle/fluid/operators/distributed/CMakeLists.txt +++ /dev/null @@ -1,76 +0,0 @@ -return() - -if(WITH_GRPC) - set(cc_generic_services "false") -else() - set(cc_generic_services "true") -endif() -configure_file(send_recv.proto.in ${CMAKE_CURRENT_SOURCE_DIR}/send_recv.proto @ONLY) - -cc_library(async_sparse_param_update_recorder SRCS async_sparse_param_update_recorder.cc DEPS enforce simple_threadpool) -cc_test(async_sparse_param_update_recorder_test SRCS async_sparse_param_update_recorder_test.cc DEPS async_sparse_param_update_recorder) - -cc_library(heart_beat_monitor SRCS heart_beat_monitor.cc DEPS enforce simple_threadpool) -cc_library(large_scale_kv SRCS large_scale_kv.cc DEPS enforce simple_threadpool device_context) -cc_test(heart_beat_monitor_test SRCS heart_beat_monitor_test.cc DEPS heart_beat_monitor) - -# FIXME(typhoonzero): use add_subdirectory once we clean the dependency of these files -set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") -if(WITH_GRPC) - set(GRPC_DEPS grpc++_unsecure grpc_unsecure gpr zlib protobuf) - set(GRPC_SRCS grpc/grpc_client.cc grpc/grpc_server.cc grpc/grpc_serde.cc grpc/grpc_bytebuffer_stream.cc grpc/grpc_variable_response.cc) - grpc_library(sendrecvop_rpc SRCS sendrecvop_utils.cc - request_handler_impl.cc rpc_client.cc rpc_server.cc - variable_response.cc - collective_client.cc collective_server.cc - ${GRPC_SRCS} - PROTO send_recv.proto - DEPS lod_tensor selected_rows_functor memory scope ${GRPC_DEPS} async_sparse_param_update_recorder heart_beat_monitor large_scale_kv) - - set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) - set(RPC_DEPS sendrecvop_rpc ${GRPC_DEPS}) - - cc_test(grpc_serde_test SRCS grpc/grpc_serde_test.cc - DEPS ${RPC_DEPS} scope profiler math_function) - -else() - set(BRPC_SRCS brpc/brpc_client.cc brpc/brpc_server.cc brpc/brpc_sendrecvop_utils.cc brpc/brpc_variable_response.cc brpc/brpc_rdma_pool.cc) - set_source_files_properties(${BRPC_SRCS} parameter_prefetch.cc parameter_send.cc parameter_recv.cc communicator.cc rpc_server_test.cc brpc/brpc_serde_test.cc collective_server.cc collective_server_test.cc collective_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) - - set(BRPC_DEPS brpc ssl crypto protobuf leveldb zlib) - - brpc_library(sendrecvop_rpc SRCS sendrecvop_utils.cc - request_handler_impl.cc rpc_client.cc rpc_server.cc - variable_response.cc - collective_client.cc collective_server.cc - ${BRPC_SRCS} - PROTO send_recv.proto - DEPS lod_tensor selected_rows memory scope ${BRPC_DEPS}) - - set(RPC_DEPS sendrecvop_rpc ${BRPC_DEPS}) - cc_test(brpc_serde_test SRCS brpc/brpc_serde_test.cc - DEPS ${RPC_DEPS} gflags glog executor proto_desc lookup_sparse_table_read_op) -endif() - - -cc_test(rpc_server_test SRCS rpc_server_test.cc - DEPS ${RPC_DEPS} executor scope proto_desc lookup_sparse_table_read_op checkpoint_notify_op scale_op ) -cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler scope) -cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory node) -cc_library(parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory) -cc_library(parameter_recv SRCS parameter_recv.cc DEPS sendrecvop_rpc memory) -cc_library(communicator SRCS communicator.cc DEPS scope selected_rows tensor variable_helper selected_rows_functor simple_threadpool parameter_send parameter_recv generator) -cc_test(communicator_test SRCS communicator_test.cc DEPS communicator) -if(WITH_GPU OR WITH_ROCM) - cc_test(collective_server_test SRCS collective_server_test.cc - DEPS sendrecvop_rpc executor ${RPC_DEPS} - selected_rows_functor scope math_function) -endif() -if(WITH_TESTING) - if(TEST rpc_server_test) - set_tests_properties(rpc_server_test PROPERTIES TIMEOUT 120) - endif() - if(TEST heart_beat_monitor_test) - set_tests_properties(heart_beat_monitor_test PROPERTIES TIMEOUT 120) - endif() -endif() diff --git a/paddle/fluid/operators/distributed/async_sparse_param_update_recorder.cc b/paddle/fluid/operators/distributed/async_sparse_param_update_recorder.cc deleted file mode 100644 index 3f3b6b959e..0000000000 --- a/paddle/fluid/operators/distributed/async_sparse_param_update_recorder.cc +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h" - -namespace paddle { -namespace operators { -namespace distributed { - -std::once_flag AsyncSparseParamUpdateRecorder::init_flag_; -std::unique_ptr - AsyncSparseParamUpdateRecorder::recorder_(nullptr); - -} // namespace distributed -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h b/paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h deleted file mode 100644 index 28a5f2ad6c..0000000000 --- a/paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h +++ /dev/null @@ -1,186 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include // NOLINT -#include -#include -#include -#include -#include -#include - -#include "paddle/fluid/platform/enforce.h" - -namespace paddle { -namespace operators { -namespace distributed { - -class ConcurrentSet { - public: - ConcurrentSet() : pool_(new ::ThreadPool(1)) {} - ~ConcurrentSet() {} - - std::future Update(const std::vector& rows) { - auto task = [this, rows] { - if (VLOG_IS_ON(3)) { - std::ostringstream sstream; - sstream << "["; - for (auto& id : rows) { - sstream << id << ", "; - } - sstream << "]"; - VLOG(3) << "update ids -> " << sstream.str(); - } - for (auto row : rows) { - set_.insert(row); - } - }; - return pool_->enqueue(std::move(task)); - } - - std::future GetAndClear(std::vector* result) { - auto task = [this, &result] { - result->clear(); - for (auto& id : set_) { - result->push_back(id); - } - if (VLOG_IS_ON(3)) { - std::ostringstream sstream; - sstream << "["; - for (auto& id : *result) { - sstream << id << ", "; - } - sstream << "]"; - VLOG(3) << "result ids size: " << result->size() << " " - << sstream.str(); - } - set_.clear(); - }; - return pool_->enqueue(std::move(task)); - } - - private: - std::unordered_set set_; - std::unique_ptr<::ThreadPool> pool_{nullptr}; -}; - -class AsyncSparseParamUpdateRecorder { - using TrainerToRows = std::vector>; - - public: - AsyncSparseParamUpdateRecorder( - int trainer_num, - const std::unordered_map& grad_to_param) - : trainer_num_(trainer_num), grad_to_param_(grad_to_param) { - if (VLOG_IS_ON(3)) { - std::ostringstream sstream; - sstream << "["; - for (auto& item : grad_to_param) { - sstream << item.first << ":" << item.second << ", "; - } - sstream << "]"; - VLOG(3) << "trainer_num: " << trainer_num - << " grad_to_param_: " << sstream.str(); - } - for (auto& iter : grad_to_param) { - param_to_grad_[iter.second] = iter.first; - auto& param_name = iter.second; - param_to_updated_rows_[param_name] = TrainerToRows(); - auto& trainer_to_rows = param_to_updated_rows_[param_name]; - for (auto i = 0; i < trainer_num; ++i) { - trainer_to_rows.emplace_back(new ConcurrentSet()); - } - } - } - - ~AsyncSparseParamUpdateRecorder() = default; - - void Update(const std::string& grad_name, - const std::vector& update_rows) { - VLOG(3) << "update grad: " << grad_name - << " row size: " << update_rows.size(); - auto& param_name = grad_to_param_.at(grad_name); - auto& trainer_to_rows = param_to_updated_rows_.at(param_name); - - std::vector> fs; - for (auto& set : trainer_to_rows) { - fs.push_back(set->Update(update_rows)); - } - for (auto& f : fs) { - f.wait(); - } - } - - void GetAndClear(const std::string& param_name, int trainer_id, - std::vector* result) { - VLOG(3) << "GetAndClear param: " << param_name - << " for trainer: " << trainer_id; - PADDLE_ENFORCE_LT( - trainer_id, trainer_num_, - platform::errors::InvalidArgument( - "The value of trainer_id: %s should less than trainer_num: %s.", - trainer_id, trainer_num_)); - param_to_updated_rows_.at(param_name)[trainer_id] - ->GetAndClear(result) - .wait(); - } - - bool HasParam(const std::string& param_name) { - return param_to_grad_.find(param_name) != param_to_grad_.end(); - } - - bool HasGrad(const std::string& grad_name) { - return grad_to_param_.find(grad_name) != grad_to_param_.end(); - } - - private: - const int trainer_num_; - std::unordered_map grad_to_param_; - std::unordered_map param_to_grad_; - std::unordered_map param_to_updated_rows_; - - // init recorder - public: - static void Init( - int trainer_num, - const std::unordered_map& grad_to_param) { - InitImpl(trainer_num, grad_to_param); - } - - static AsyncSparseParamUpdateRecorder* GetInstance() { - return recorder_.get(); - } - - private: - // Init is called by GetInstance. - static void InitImpl( - int trainer_num, - const std::unordered_map& grad_to_param) { - if (recorder_ == nullptr) { - recorder_.reset( - new AsyncSparseParamUpdateRecorder(trainer_num, grad_to_param)); - } - } - - static std::once_flag init_flag_; - static std::unique_ptr recorder_; -}; - -} // namespace distributed -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/distributed/async_sparse_param_update_recorder_test.cc b/paddle/fluid/operators/distributed/async_sparse_param_update_recorder_test.cc deleted file mode 100644 index 2d78559625..0000000000 --- a/paddle/fluid/operators/distributed/async_sparse_param_update_recorder_test.cc +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h" -#include -#include "gtest/gtest.h" - -namespace paddle { -namespace operators { -namespace distributed { - -TEST(ConcurrentSet, All) { - ConcurrentSet concurrent_set; - std::vector in1 = {1, 2, 3, 4}; - std::vector in2 = {2, 3, 5, 6}; - - std::vector> futures; - futures.push_back(concurrent_set.Update(in1)); - futures.push_back(concurrent_set.Update(in2)); - - for (auto &f : futures) { - f.wait(); - } - - std::unordered_set in; - std::copy(in1.begin(), in1.end(), std::inserter(in, in.begin())); - std::copy(in2.begin(), in2.end(), std::inserter(in, in.begin())); - - std::vector ret; - concurrent_set.GetAndClear(&ret).wait(); - - std::unordered_set out; - std::copy(ret.begin(), ret.end(), std::inserter(out, out.begin())); - - EXPECT_EQ(in, out); - - concurrent_set.GetAndClear(&ret).wait(); - EXPECT_EQ(ret.size(), 0UL); -} - -TEST(AsyncSparseParamUpdateRecorder, All) { - std::unordered_map grad_to_param; - grad_to_param["grad1"] = "param1"; - grad_to_param["grad2"] = "param2"; - - int trainer_num = 10; - - AsyncSparseParamUpdateRecorder recorder(trainer_num, grad_to_param); - std::vector in1 = {1, 2, 3, 4}; - std::vector in2 = {2, 3, 5, 6}; - - std::unordered_set in; - std::copy(in1.begin(), in1.end(), std::inserter(in, in.begin())); - std::copy(in2.begin(), in2.end(), std::inserter(in, in.begin())); - - recorder.Update("grad1", in1); - recorder.Update("grad1", in2); - - EXPECT_TRUE(recorder.HasParam("param1")); - EXPECT_TRUE(recorder.HasParam("param2")); - EXPECT_FALSE(recorder.HasParam("param3")); - - EXPECT_TRUE(recorder.HasGrad("grad1")); - EXPECT_TRUE(recorder.HasGrad("grad2")); - EXPECT_FALSE(recorder.HasGrad("grad3")); - - std::vector ret; - EXPECT_ANY_THROW(recorder.GetAndClear("param1", trainer_num, &ret)); - - for (int i = 0; i < trainer_num; ++i) { - std::vector ret; - std::unordered_set out; - - recorder.GetAndClear("param1", i, &ret); - std::copy(ret.begin(), ret.end(), std::inserter(out, out.begin())); - - EXPECT_EQ(in, out); - - recorder.GetAndClear("param1", i, &ret); - EXPECT_EQ(ret.size(), 0UL); - } -} - -} // namespace distributed -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/distributed/brpc/brpc_client.cc b/paddle/fluid/operators/distributed/brpc/brpc_client.cc deleted file mode 100644 index b2a26089c8..0000000000 --- a/paddle/fluid/operators/distributed/brpc/brpc_client.cc +++ /dev/null @@ -1,462 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/operators/distributed/brpc/brpc_client.h" -#include "paddle/fluid/framework/threadpool.h" -#include "paddle/fluid/operators/distributed/brpc/brpc_sendrecvop_utils.h" -#include "paddle/fluid/platform/profiler.h" - -namespace paddle { -namespace operators { -namespace distributed { - -DEFINE_int32(timeout_ms, 30000, "RPC timeout in milliseconds"); -DEFINE_int32(max_retry, 3, "Max retries(not including the first RPC)"); - -BRPCClient::~BRPCClient() { Wait(); } - -void HandleSendResponse(brpc::Controller* cntl, sendrecv::VoidMessage* response, - VarHandlePtr var_h, ChannelQueuePtr ch_ptr, - ChannelContextPtr ch_ctx, BRPCClient* cls) { - // std::unique_ptr makes sure cntl/response will be deleted before returning. - std::unique_ptr cntl_guard(cntl); - std::unique_ptr response_guard(response); - - // this channel can be used by other now. - ch_ptr->Push(ch_ctx); - - if (cntl->Failed()) { - PADDLE_THROW(platform::errors::Unavailable( - "Failed to send variable %s, error text is %s.", var_h->name(), - cntl->ErrorText())); - var_h->Finish(false); - cls->DecreaseReqCount(); - return; - } - var_h->Finish(true); - cls->DecreaseReqCount(); - - VLOG(4) << "HandleSendResponse from: " << cntl->remote_side() - << ", varname: " << var_h->name() - << ", latency: " << cntl->latency_us() << "us"; - VLOG(4) << "Finish HandleSendResponse"; -} - -VarHandlePtr BRPCClient::AsyncSendVar(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& var_name, - int64_t time_out) { - const platform::DeviceContext* p_ctx = &ctx; - const std::string ep_val = ep; - const std::string var_name_val = var_name; - const framework::Scope* p_scope = &scope; - const auto ch_ptr = GetChannel(ep_val); - const std::string method = kSendRPC; - VarHandlePtr var_h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope)); - - framework::AsyncIO([=] { - auto ch_ctx = ch_ptr->Pop(); - brpc::Controller* cntl = new brpc::Controller(); - sendrecv::VoidMessage* response = new sendrecv::VoidMessage(); - cntl->set_timeout_ms(time_out); - - auto* var = p_scope->FindVar(var_name_val); - sendrecv::VariableMessage request; - distributed::SerializeToIOBuf(var_name_val, var, *p_ctx, &request, - &cntl->request_attachment(), "", false, - trainer_id_); - - google::protobuf::Closure* done = brpc::NewCallback( - &HandleSendResponse, cntl, response, var_h, ch_ptr, ch_ctx, this); - - platform::RecordRPCEvent record_event(method); - - ch_ctx->stub->SendVariable(cntl, &request, response, done); - - if (UNLIKELY(platform::IsProfileEnabled())) { - var_h->Wait(); - } - }); - req_count_++; - - return var_h; -} -void HandleFetchBarrierResponse(brpc::Controller* cntl, - sendrecv::VariableMessage* response, - VarHandlePtr var_h, ChannelQueuePtr ch_ptr, - ChannelContextPtr ch_ctx, BRPCClient* cls) { - // std::unique_ptr makes sure cntl/response will be deleted before returning. - std::unique_ptr cntl_guard(cntl); - std::unique_ptr response_guard(response); - - // this channel can be used other now. - ch_ptr->Push(ch_ctx); - - if (cntl->Failed()) { - PADDLE_THROW(platform::errors::Unavailable( - "Failed to get HandleFetchBarrierResponse %s, error text is %s.", - var_h->name(), cntl->ErrorText())); - var_h->Finish(false); - cls->DecreaseReqCount(); - return; - } - - var_h->Finish(true); - cls->DecreaseReqCount(); - - VLOG(4) << "HandleFetchBarrierResponse from: " << cntl->remote_side() - << ", varname: " << var_h->name() - << ", latency: " << cntl->latency_us() << "us"; - VLOG(4) << "Finish HandleFetchBarrierResponse"; -} -void HandleGetResponse(brpc::Controller* cntl, - sendrecv::VariableMessage* response, VarHandlePtr var_h, - ChannelQueuePtr ch_ptr, ChannelContextPtr ch_ctx, - BRPCClient* cls) { - // std::unique_ptr makes sure cntl/response will be deleted before returning. - std::unique_ptr cntl_guard(cntl); - std::unique_ptr response_guard(response); - - // this channel can be used other now. - ch_ptr->Push(ch_ctx); - - if (cntl->Failed()) { - PADDLE_THROW(platform::errors::Unavailable( - "Failed to get variable %s, error text is %s.", var_h->name(), - cntl->ErrorText())); - cls->DecreaseReqCount(); - var_h->Finish(false); - return; - } - - VLOG(4) << "HandleGetResponse from: " << cntl->remote_side() - << ", varname: " << var_h->name() - << ", latency: " << cntl->latency_us() << "us"; - - framework::Variable* outvar = nullptr; - int trainer_id; - distributed::DeserializeFromIOBuf(*response, cntl->response_attachment(), - *var_h->ctx(), var_h->scope(), &outvar, - &trainer_id); - VLOG(4) << "Finish HandleGetResponse"; - cls->DecreaseReqCount(); - var_h->Finish(true); -} - -VarHandlePtr BRPCClient::_AsyncGetVar(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& var_name, - const std::string& out_var_name, - const std::string& method_name, - int64_t time_out) { - const platform::DeviceContext* p_ctx = &ctx; - const std::string ep_val = ep; - const std::string var_name_val = var_name; - const std::string out_varname_val = out_var_name; - const framework::Scope* p_scope = &scope; - const auto ch_ptr = GetChannel(ep_val); - const std::string method = kGetRPC; - VarHandlePtr var_h( - new VarHandle(ep, method, out_varname_val, p_ctx, p_scope)); - - framework::AsyncIO([=] { - auto ch_ctx = ch_ptr->Pop(); - - brpc::Controller* cntl = new brpc::Controller(); - sendrecv::VariableMessage* response = new sendrecv::VariableMessage(); - cntl->set_timeout_ms(time_out); - - sendrecv::VariableMessage req; - req.set_varname(var_name_val); - req.set_out_varname(out_varname_val); - req.set_trainer_id(trainer_id_); - - google::protobuf::Closure* done = brpc::NewCallback( - &HandleGetResponse, cntl, response, var_h, ch_ptr, ch_ctx, this); - - platform::RecordRPCEvent record_event(method); - - if (method_name == kGetMonomerRPC) { - ch_ctx->stub->GetMonomerVariable(cntl, &req, response, done); - } else if (method_name == kGetNoBarrierRPC) { - ch_ctx->stub->GetVariableNoBarrier(cntl, &req, response, done); - } else { - ch_ctx->stub->GetVariable(cntl, &req, response, done); - } - - if (UNLIKELY(platform::IsProfileEnabled())) { - var_h->Wait(); - } - }); - - req_count_++; - - return var_h; -} - -VarHandlePtr BRPCClient::AsyncGetVarNoBarrier( - const std::string& ep, const platform::DeviceContext& ctx, - const framework::Scope& scope, const std::string& var_name, - const std::string& out_var_name, int64_t time_out) { - std::string var_name_no_barrier = - string::Sprintf("%s%s", var_name, WITHOUT_BARRIER_MESSAGE); - - return _AsyncGetVar(ep, ctx, scope, var_name_no_barrier, out_var_name, - kGetNoBarrierRPC, time_out); -} - -VarHandlePtr BRPCClient::AsyncGetMonomerVariable( - const std::string& ep, const platform::DeviceContext& ctx, - const framework::Scope& scope, const std::string& var_name, - int64_t time_out) { - return _AsyncGetVar(ep, ctx, scope, var_name, var_name, kGetMonomerRPC, - time_out); -} - -VarHandlePtr BRPCClient::AsyncGetMonomerBarrier(const std::string& ep, - const std::string& var_name, - int64_t time_out) { - return AsyncSendMessage(ep, kSendMonomerFetchBarrierRPC, var_name, time_out); -} - -VarHandlePtr BRPCClient::AsyncGetVar(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& var_name, - const std::string& out_var_name, - const std::string& table_name, - int64_t time_out) { - return _AsyncGetVar(ep, ctx, scope, var_name, out_var_name, kGetRPC, - time_out); -} - -VarHandlePtr BRPCClient::AsyncPrefetchVar(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& in_var_name, - const std::string& out_var_name, - const std::string& table_name, - int64_t time_out) { - const platform::DeviceContext* p_ctx = &ctx; - const std::string ep_val = ep; - const std::string in_var_name_val = in_var_name; - const std::string out_var_name_val = out_var_name; - const std::string table_name_val = table_name; - const framework::Scope* p_scope = &scope; - const auto ch_ptr = GetChannel(ep_val); - - const std::string method = kPrefetchRPC; - - VarHandlePtr var_h( - new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope)); - - framework::AsyncIO([=] { - auto ch_ctx = ch_ptr->Pop(); - - brpc::Controller* cntl = new brpc::Controller(); - sendrecv::VariableMessage* response = new sendrecv::VariableMessage(); - cntl->set_timeout_ms(time_out); - - auto* var = p_scope->FindVar(in_var_name_val); - sendrecv::VariableMessage req; - distributed::SerializeToIOBuf(in_var_name_val, var, *p_ctx, &req, - &cntl->request_attachment(), out_var_name_val, - false, 0, table_name_val); - - platform::RecordRPCEvent record_event(method); - - google::protobuf::Closure* done = brpc::NewCallback( - &HandleGetResponse, cntl, response, var_h, ch_ptr, ch_ctx, this); - - ch_ctx->stub->PrefetchVariable(cntl, &req, response, done); - - if (UNLIKELY(platform::IsProfileEnabled())) { - var_h->Wait(); - } - }); - - req_count_++; - return var_h; -} - -VarHandlePtr BRPCClient::AsyncSendBatchBarrier(const std::string& ep, - int64_t time_out) { - return AsyncSendMessage(ep, kBatchBarrierRPC, BATCH_BARRIER_MESSAGE, - time_out); -} - -VarHandlePtr BRPCClient::AsyncSendFetchBarrier(const std::string& ep, - int64_t time_out) { - auto ch_ptr = GetChannel(ep); - auto ch_ctx = ch_ptr->Pop(); - - brpc::Controller* cntl = new brpc::Controller(); - sendrecv::VariableMessage* response = new sendrecv::VariableMessage(); - cntl->set_timeout_ms(time_out); - - sendrecv::VariableMessage req; - req.set_varname(FETCH_BARRIER_MESSAGE); - - const std::string method = kFetchBarrierRPC; - // var handle - VarHandlePtr var_h( - new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr)); - - platform::RecordRPCEvent record_event(method); - - google::protobuf::Closure* done = brpc::NewCallback( - &HandleFetchBarrierResponse, cntl, response, var_h, ch_ptr, ch_ctx, this); - - ch_ctx->stub->GetVariable(cntl, &req, response, done); - - req_count_++; - - if (UNLIKELY(platform::IsProfileEnabled())) { - var_h->Wait(); - } - - return var_h; -} - -bool BRPCClient::Wait() { - VLOG(9) << "begin to brpcclient wait"; - { - std::unique_lock lk(sync_mutex_); - sync_cond_.wait(lk, [this] { return req_count_ == 0; }); - } - VLOG(9) << "end to brpcclient wait"; - return true; -} - -ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) { - VLOG(4) << "begin to GetChannel:" << ep; - { - std::lock_guard guard(chan_mutex_); - auto it = channels_.find(ep); - if (it != channels_.end()) { - VLOG(4) << "end to GetChannel:" << ep; - return it->second; - } - } - - ChannelQueuePtr q(new framework::BlockingQueue()); - - brpc::ChannelOptions options; -#ifdef PADDLE_WITH_BRPC_RDMA - options.use_rdma = true; -#endif - options.protocol = "baidu_std"; - // don't use pooled type. the server can't afford that. - options.connection_type = "single"; - options.connect_timeout_ms = 1000; - options.timeout_ms = FLAGS_timeout_ms /*milliseconds*/; - options.max_retry = FLAGS_max_retry; - - VLOG(1) << "create " << brpc_channel_num_per_server_ - << " brpc channels to pserver:" << ep; - - for (int i = 0; i < brpc_channel_num_per_server_; ++i) { - std::shared_ptr c(new ChannelContext()); - if (c->channel.Init(ep.c_str(), &options) != 0) { - PADDLE_THROW( - platform::errors::Unavailable("Failed to initialize channel.")); - return nullptr; - } - - c->stub.reset(new sendrecv::SendRecvService_Stub( - static_cast(&c->channel))); - q->Push(c); - } - - { - std::lock_guard guard(chan_mutex_); - channels_[ep] = q; - } - - VLOG(4) << "end to GetChannel:" << ep; - return q; -} - -VarHandlePtr BRPCClient::AsyncSendComplete(const std::string& ep, - int64_t time_out) { - return AsyncSendMessage(ep, kSendCompleteRPC, COMPLETE_MESSAGE, time_out); -} - -void BRPCClient::SendComplete() { - for (auto& kv : channels_) { - AsyncSendComplete(kv.first); - } -} - -VarHandlePtr BRPCClient::AsyncSendVarMessage( - const std::string& ep, const std::string& method_name, - const sendrecv::VariableMessage& req, int64_t time_out) { - auto ch_ptr = GetChannel(ep); - auto ch_ctx = ch_ptr->Pop(); - - brpc::Controller* cntl = new brpc::Controller(); - sendrecv::VoidMessage* response = new sendrecv::VoidMessage(); - cntl->set_timeout_ms(time_out); - - platform::RecordRPCEvent record_event(method_name); - - VarHandlePtr var_h( - new VarHandle(ep, method_name, req.varname(), nullptr, nullptr)); - - google::protobuf::Closure* done = brpc::NewCallback( - &HandleSendResponse, cntl, response, var_h, ch_ptr, ch_ctx, this); - - if (method_name == kCheckPointNotifyRPC) { - ch_ctx->stub->CheckpointNotify(cntl, &req, response, done); - } else if (method_name == kSendMonomerFetchBarrierRPC) { - ch_ctx->stub->GetMonomerBarrier(cntl, &req, response, done); - } else { - ch_ctx->stub->SendVariable(cntl, &req, response, done); - } - req_count_++; - - if (UNLIKELY(platform::IsProfileEnabled())) { - var_h->Wait(); - } - - return var_h; -} - -VarHandlePtr BRPCClient::AsyncSendMessage(const std::string& ep, - const std::string& method_name, - const std::string& message, - int64_t time_out) { - sendrecv::VariableMessage req; - req.set_varname(message); - - return AsyncSendVarMessage(ep, method_name, req, time_out); -} - -VarHandlePtr BRPCClient::AsyncCheckpointNotify(const std::string& ep, - const std::string& dirname, - const std::string& varname, - const int mode, - int64_t time_out) { - sendrecv::VariableMessage req; - req.set_varname(varname); - req.set_out_varname(dirname); - - return AsyncSendVarMessage(ep, "CheckPointNotifyRPC", req, time_out); -} - -} // namespace distributed -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/distributed/brpc/brpc_client.h b/paddle/fluid/operators/distributed/brpc/brpc_client.h deleted file mode 100644 index 91f94b4c9d..0000000000 --- a/paddle/fluid/operators/distributed/brpc/brpc_client.h +++ /dev/null @@ -1,174 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include - -#include // NOLINT -#include -#include -#include -#include -#include -#include // NOLINT -#include -#include -#include - -#include "brpc/channel.h" -#include "paddle/fluid/framework/blocking_queue.h" -#include "paddle/fluid/framework/data_type.h" -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/scope.h" -#include "paddle/fluid/framework/selected_rows.h" -#include "paddle/fluid/operators/distributed/brpc/brpc_sendrecvop_utils.h" -#include "paddle/fluid/operators/distributed/distributed_pb.h" -#include "paddle/fluid/operators/distributed/request_handler.h" -#include "paddle/fluid/operators/distributed/rpc_client.h" -#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN - -namespace paddle { -namespace operators { -namespace distributed { - -struct ChannelContext { - brpc::Channel channel; - std::shared_ptr stub; -}; - -typedef std::shared_ptr ChannelContextPtr; -typedef std::shared_ptr> - ChannelQueuePtr; - -class BRPCClient : public RPCClient { - public: - BRPCClient() {} - virtual ~BRPCClient(); - - VarHandlePtr AsyncSendVar(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& var_name, - int64_t time_out = FLAGS_rpc_deadline) override; - - VarHandlePtr AsyncGetVar(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& var_name, - const std::string& out_var_name, - const std::string& table_name = "", - int64_t time_out = FLAGS_rpc_deadline) override; - - VarHandlePtr AsyncGetMonomerBarrier( - const std::string& ep, const std::string& var_name, - int64_t time_out = FLAGS_rpc_deadline) override; - - VarHandlePtr AsyncGetMonomerVariable( - const std::string& ep, const platform::DeviceContext& ctx, - const framework::Scope& scope, const std::string& var_name, - int64_t time_out = FLAGS_rpc_deadline) override; - - VarHandlePtr AsyncGetVarNoBarrier(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& var_name, - const std::string& out_varname, - int64_t time_out = FLAGS_rpc_deadline); - - VarHandlePtr AsyncPrefetchVar(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& in_var_name, - const std::string& out_var_name, - const std::string& table_name = "", - int64_t time_out = FLAGS_rpc_deadline) override; - - VarHandlePtr AsyncSendBatchBarrier( - const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override; - - VarHandlePtr AsyncSendFetchBarrier( - const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override; - - VarHandlePtr AsyncCheckpointNotify( - const std::string& ep, const std::string& dirname, - const std::string& varname, const int mode, - int64_t time_out = FLAGS_rpc_deadline) override; - - bool Wait() override; - - void SendComplete() override; - - private: - VarHandlePtr _AsyncGetVar( - const std::string& ep, const platform::DeviceContext& ctx, - const framework::Scope& scope, const std::string& var_name, - const std::string& out_var_name, const std::string& method_name, - const std::string& table_name, int64_t time_out = FLAGS_rpc_deadline); - - void Proceed(); - ChannelQueuePtr GetChannel(const std::string& ep); - - VarHandlePtr AsyncSendComplete(const std::string& ep, - int64_t time_out = FLAGS_rpc_deadline); - - VarHandlePtr AsyncSendMessage(const std::string& ep, - const std::string& method_name, - const std::string& message, int64_t time_out); - - VarHandlePtr AsyncSendVarMessage(const std::string& ep, - const std::string& method_name, - const sendrecv::VariableMessage& req, - int64_t time_out); - - friend void HandleSendResponse(brpc::Controller* cntl, - sendrecv::VoidMessage* response, - VarHandlePtr var_h, ChannelQueuePtr ch_ptr, - ChannelContextPtr ch_ctx, BRPCClient* cls); - - friend void HandleGetResponse(brpc::Controller* cntl, - sendrecv::VariableMessage* response, - VarHandlePtr var_h, ChannelQueuePtr ch_ptr, - ChannelContextPtr ch_ctx, BRPCClient* cls); - - friend void HandleFetchBarrierResponse(brpc::Controller* cntl, - sendrecv::VariableMessage* response, - VarHandlePtr var_h, - ChannelQueuePtr ch_ptr, - ChannelContextPtr ch_ctx, - BRPCClient* cls); - void DecreaseReqCount() { - if (--req_count_ <= 0) { - sync_cond_.notify_all(); - } - } - - private: - std::unordered_map channels_; - - // mutex for Wait client sync - std::mutex sync_mutex_; - std::condition_variable sync_cond_; - std::atomic req_count_{0}; - - static constexpr int brpc_channel_num_per_server_ = 4; - - // mutex for GetChannel thread safety - std::mutex chan_mutex_; - DISABLE_COPY_AND_ASSIGN(BRPCClient); -}; - -} // namespace distributed -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/distributed/brpc/brpc_rdma_pool.cc b/paddle/fluid/operators/distributed/brpc/brpc_rdma_pool.cc deleted file mode 100644 index 94f0b9919a..0000000000 --- a/paddle/fluid/operators/distributed/brpc/brpc_rdma_pool.cc +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifdef PADDLE_WITH_BRPC_RDMA - -#include "paddle/fluid/operators/distributed/brpc/brpc_rdma_pool.h" -#include "brpc/channel.h" -#include "brpc/rdma/rdma_helper.h" -#include "paddle/fluid/platform/enforce.h" - -namespace paddle { -namespace operators { -namespace distributed { - -RdmaMemPool& RdmaMemPool::Instance() { - static RdmaMemPool* g_rdma_mem_pool = new RdmaMemPool(); - return *g_rdma_mem_pool; -} - -void* RdmaMemPool::Find(const std::string& varname, int64_t size) { - pthread_rwlock_rdlock(&access_); - auto it = pool_.find(varname); - if (it == pool_.end()) { - pthread_rwlock_unlock(&access_); - return nullptr; - } - - auto info = it->second; - if (info.data_size != size) { - pthread_rwlock_unlock(&access_); - PADDLE_THROW(platform::errors::InvalidArgument( - "var:%s size:%ld != %ld", varname, size, info.data_size)); - return nullptr; - } - - pthread_rwlock_unlock(&access_); - return info.data; -} - -void RdmaMemPool::Register(const std::string& varname, void* data, - int64_t data_size) { - void* old = Find(varname, data_size); - if (old != nullptr) { - PADDLE_ENFORCE_EQ( - data, old, platform::errors::InvalidArgument("var:%s data:%ld != %ld", - varname, data, old)); - VLOG(7) << "Find on rdma:" << varname << " data:" << data - << " data_size:" << data_size; - return; - } - - VarInfo info; - info.data = data; - info.data_size = data_size; - - pthread_rwlock_wrlock(&access_); - pool_[varname] = info; - pthread_rwlock_unlock(&access_); - - if (brpc::rdma::RegisterMemoryForRdma(data, data_size)) { - PADDLE_THROW(platform::errors::Unavailable( - "Register memory for RDMA failed. Register %s data: %s data size %d " - "error.", - varname, data, data_size)); - } - - VLOG(4) << "register on rdma:" << varname << " data:" << data - << " data_size:" << data_size; -} - -} // namespace distributed -} // namespace operators -} // namespace paddle - -#endif diff --git a/paddle/fluid/operators/distributed/brpc/brpc_rdma_pool.h b/paddle/fluid/operators/distributed/brpc/brpc_rdma_pool.h deleted file mode 100644 index 156a93ec57..0000000000 --- a/paddle/fluid/operators/distributed/brpc/brpc_rdma_pool.h +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#ifdef PADDLE_WITH_BRPC_RDMA - -#include // NOLINT -#include -#include - -namespace paddle { -namespace operators { -namespace distributed { - -/* - * This class is used to avoid duplicated registion of brpc::rdma. - */ -class RdmaMemPool { - public: - static RdmaMemPool& Instance(); - RdmaMemPool() : access_(PTHREAD_RWLOCK_INITIALIZER) {} - - virtual ~RdmaMemPool() { pthread_rwlock_destroy(&access_); } - - void Register(const std::string& varname, void* data, int64_t size); - void* Find(const std::string& varname, int64_t size); - - private: - struct VarInfo { - void* data; - int64_t data_size; - - VarInfo() : data(nullptr), data_size(0) {} - }; - - private: - std::unordered_map pool_; - pthread_rwlock_t access_; -}; - -} // namespace distributed -} // namespace operators -} // namespace paddle - -#endif diff --git a/paddle/fluid/operators/distributed/brpc/brpc_sendrecvop_utils.cc b/paddle/fluid/operators/distributed/brpc/brpc_sendrecvop_utils.cc deleted file mode 100644 index 411c0f36de..0000000000 --- a/paddle/fluid/operators/distributed/brpc/brpc_sendrecvop_utils.cc +++ /dev/null @@ -1,224 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef PADDLE_WITH_NCCL -#include -#endif -#ifdef PADDLE_WITH_RCCL -#include -#endif -#include -#include -#include -#include // NOLINT - -#include "paddle/fluid/framework/data_type.h" -#include "paddle/fluid/operators/distributed/brpc/brpc_rdma_pool.h" -#include "paddle/fluid/operators/distributed/brpc/brpc_sendrecvop_utils.h" -#include "paddle/fluid/operators/distributed/brpc/brpc_variable_response.h" -#include "paddle/fluid/operators/distributed/distributed_pb.h" -#include "paddle/fluid/platform/profiler.h" - -namespace paddle { -namespace operators { -namespace distributed { - -class IOBufWriter { - public: - static void Append(const std::string& varname, butil::IOBuf* iobuf, int k, - const char* v, int64_t vlen) { - if (vlen >= std::numeric_limits::max() || vlen < 0) { - PADDDLE_THROW(platform::errors::Unavailable( - "Variable lenght is invalid. Variable name is %s, length is %d.", - varname, vlen)); - } - - iobuf->append(reinterpret_cast(&k), 4); - iobuf->append(reinterpret_cast(&vlen), 8); - iobuf->append(v, vlen); - } - - static void AppendTCPZeroCopy(butil::IOBuf* iobuf, int k, const char* v, - int64_t vlen, bool in_cuda_pinned, - void (*destroy)(void*), void* user_data) { - VLOG(7) << "AppendTCPZeroCopy " - << " k:" << k - << " data:" << static_cast(const_cast(v)) - << " data_size:" << vlen << " in_cuda_pinned:" << in_cuda_pinned; - - iobuf->append(reinterpret_cast(&k), 4); - iobuf->append(reinterpret_cast(&vlen), 8); - - // FIXME(gongwb): use append_zerocopy - /* - if (in_cuda_pinned) { - iobuf->append_zerocopy(v, vlen, IOBufWriter::FreeMemory); - } else { - iobuf->append_zerocopy(v, vlen, nullptr); - } - */ - iobuf->append(v, vlen); - destroy(user_data); - } - -#ifdef PADDLE_WITH_BRPC_RDMA - static void AppendRdmaZeroCopy(const std::string varname, butil::IOBuf* iobuf, - int k, const char* v, int64_t vlen, - bool in_cuda_pinned, void (*destroy)(void*), - void* user_data) { - VLOG(7) << "AppendRdmaZeroCopy varname:" << varname << " k:" << k - << " data:" << static_cast(const_cast(v)) - << " data_size:" << vlen << " in_cuda_pinned:" << in_cuda_pinned; - - iobuf->append(reinterpret_cast(&k), 4); - iobuf->append(reinterpret_cast(&vlen), 8); - - RdmaMemPool::Instance().Register( - varname, static_cast(const_cast(v)), vlen); - - // FIXME(gongwb): use append_zerocopy - // iobuf->append_zerocopy(v, vlen, nullptr); - iobuf->append(v, vlen); - destroy(user_data); - return; - } -#endif - - static void AppendZeroCopy(const std::string varname, butil::IOBuf* iobuf, - int k, const char* v, int64_t vlen, - bool in_cuda_pinned, void (*destroy)(void*), - void* user_data) { - if (vlen >= std::numeric_limits::max() || vlen < 0) { - PADDDLE_THROW(platform::errors::Unavailable( - "Variable lenght is invalid. Variable name is %s, length is %d.", - varname, vlen)); - } - -#ifdef PADDLE_WITH_BRPC_RDMA - IOBufWriter::AppendRdmaZeroCopy(varname, iobuf, k, v, vlen, in_cuda_pinned, - destroy, user_data); -#else - IOBufWriter::AppendTCPZeroCopy(iobuf, k, v, vlen, in_cuda_pinned, destroy, - user_data); -#endif - } -}; - -void SerializeToIOBuf(const std::string& name, framework::Variable* var, - const platform::DeviceContext& ctx, VarMsg* request, - butil::IOBuf* iobuf, const std::string& out_varname, - bool var_is_not_stable, int trainer_id, - const std::string& table_name) { - std::unique_ptr payload; - - request->set_varname(name); - request->set_trainer_id(trainer_id); - // Note: normally the profiler is enabled in 1 trainer, hence only - // 1 trainer returns true for ShouldSendProfileState(). It tells PS - // servers the trainer's profiling state so that PS can follow the - // trainer. - if (platform::ShouldSendProfileState()) { - if (platform::IsProfileEnabled()) { - request->set_profile(platform::kEnableProfiler); - } else { - request->set_profile(platform::kDisableProfiler); - } - } - if (!out_varname.empty()) { - request->set_out_varname(out_varname); - } - if (!table_name.empty()) { - request->set_table_name(table_name); - } - if (var->IsType()) { - request->set_type(::sendrecv::LOD_TENSOR); - payload.reset(new TensorPayload(GetTensorPayload(var, ctx, request))); - } else if (var->IsType()) { - request->set_type(::sendrecv::SELECTED_ROWS); - payload.reset(new TensorPayload(GetSelectedRowsPayload(var, ctx, request))); -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) - } else if (var->IsType()) { - request->set_type(::sendrecv::NCCL_ID); - const ncclUniqueId& uid = var->Get(); - // TODO(gongwb): use append_zero to avoid data copy. - IOBufWriter::Append(name, iobuf, - sendrecv::VariableMessage::kSerializedFieldNumber, - uid.internal, NCCL_UNIQUE_ID_BYTES); - return; -#endif - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "Serialize does not support type: %s", typeid(var->Type()).name())); - } - - PADDLE_ENFORCE_NOT_NULL( - payload, - platform::errors::InvalidArgument( - "Not support type: %s, need to be LOD_TENSOR or SELECTED_ROWS.", - var->Type())); - - // FIXME(gongwb): it seems that can use zero copy. - if (var_is_not_stable) { - IOBufWriter::Append( - name, iobuf, ::sendrecv::VariableMessage::kSerializedFieldNumber, - static_cast(payload->ptr()), payload->memory_size()); - } else { - if (platform::is_gpu_place(ctx.GetPlace())) { -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - IOBufWriter::AppendZeroCopy( - name, iobuf, ::sendrecv::VariableMessage::kSerializedFieldNumber, - static_cast(payload->ptr()), payload->memory_size(), - true, SerializeDestroyCallback, static_cast(payload.get())); - payload.release(); -#endif - } else { - IOBufWriter::AppendZeroCopy( - name, iobuf, ::sendrecv::VariableMessage::kSerializedFieldNumber, - static_cast(payload->ptr()), payload->memory_size(), - false, SerializeDestroyCallback, static_cast(payload.get())); - payload.release(); - } - } - - if (var->IsType()) { - auto* slr = var->GetMutable(); - PADDLE_ENFORCE_EQ(VectorElemName(slr->rows()), typeid(int64_t).name(), - platform::errors::InvalidArgument( - "Got wrong type: %s, expect type: int64_t", - VectorElemName(slr->rows()))); - size_t rows_memory_size = slr->rows().size() * sizeof(int64_t); - - IOBufWriter::Append(name, iobuf, - ::sendrecv::VariableMessage::kRowsFieldNumber, - reinterpret_cast(slr->rows().data()), - static_cast(rows_memory_size)); - } -} - -void DeserializeFromIOBuf(const ::sendrecv::VariableMessage& meta, - const butil::IOBuf& iobuf, - const platform::DeviceContext& ctx, - const framework::Scope* scope, - framework::Variable** var, int* trainer_id) { - operators::distributed::BRPCVariableResponse resp(scope, &ctx); - PADDLE_ENFORCE_EQ( - resp.Parse(iobuf, meta), 0, - platform::errors::InvalidArgument("parse iobuf to tensor error!")); - *var = resp.GetVar(); - *trainer_id = resp.GetTrainerId(); -} - -} // namespace distributed -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/distributed/brpc/brpc_sendrecvop_utils.h b/paddle/fluid/operators/distributed/brpc/brpc_sendrecvop_utils.h deleted file mode 100644 index a5bdc331eb..0000000000 --- a/paddle/fluid/operators/distributed/brpc/brpc_sendrecvop_utils.h +++ /dev/null @@ -1,49 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include -#include -#include - -#include "brpc/channel.h" -#include "paddle/fluid/framework/data_type.h" -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/scope.h" -#include "paddle/fluid/framework/selected_rows.h" -#include "paddle/fluid/framework/tensor_util.h" -#include "paddle/fluid/framework/var_type.h" -#include "paddle/fluid/operators/distributed/distributed_pb.h" -#include "paddle/fluid/operators/distributed/sendrecvop_utils.h" - -namespace paddle { -namespace operators { -namespace distributed { - -void SerializeToIOBuf(const std::string& name, framework::Variable* var, - const platform::DeviceContext& ctx, VarMsg* request, - butil::IOBuf* iobuf, const std::string& out_varname, - bool var_is_not_stable, const int trainer_id = 0, - const std::string& table_name = std::string()); - -void DeserializeFromIOBuf(const VarMsg& meta, const butil::IOBuf& iobuf, - const platform::DeviceContext& ctx, - const framework::Scope* scope, - framework::Variable** var, int* trainer_id); - -} // namespace distributed -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/distributed/brpc/brpc_serde_test.cc b/paddle/fluid/operators/distributed/brpc/brpc_serde_test.cc deleted file mode 100644 index bcf20ad076..0000000000 --- a/paddle/fluid/operators/distributed/brpc/brpc_serde_test.cc +++ /dev/null @@ -1,175 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include -#include // NOLINT - -#include "brpc/channel.h" -#include "google/protobuf/text_format.h" -#include "gtest/gtest.h" -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/tensor_util.h" -#include "paddle/fluid/framework/variable.h" -#include "paddle/fluid/operators/distributed/brpc/brpc_sendrecvop_utils.h" -#include "paddle/fluid/operators/distributed/brpc/brpc_variable_response.h" -#include "paddle/fluid/operators/distributed/sendrecvop_utils.h" -#include "paddle/fluid/operators/distributed/variable_response.h" -#include "paddle/fluid/operators/math/math_function.h" -#include "paddle/fluid/platform/place.h" -#include "paddle/fluid/string/printf.h" - -namespace framework = paddle::framework; -namespace platform = paddle::platform; -namespace operators = paddle::operators; -namespace math = paddle::operators::math; -namespace memory = paddle::memory; - -void RunSerdeTestSelectedRows(platform::Place place) { - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - auto& ctx = *pool.Get(place); - - butil::IOBuf iobuf; - sendrecv::VariableMessage msg; - int tensor_numel = 564 * 128; - - // serialize var to IOBuf - { - framework::Variable var; - auto* slr = var.GetMutable(); - slr->set_height(1000); - auto* tensor = slr->mutable_value(); - auto* rows = slr->mutable_rows(); - tensor->Resize(framework::make_ddim({564, 128})); - tensor->mutable_data(place); - math::set_constant(ctx, tensor, 32.7); - for (int i = 0; i < 564; ++i) rows->push_back(i); - - operators::distributed::SerializeToIOBuf("myvar", &var, ctx, &msg, &iobuf, - "", false); - } - - // desrialize - { - framework::Scope scope; - scope.Var("myvar"); - operators::distributed::BRPCVariableResponse resp(&scope, &ctx); - EXPECT_EQ(resp.Parse(iobuf, msg), 0); - - framework::Variable* var2 = resp.GetVar(); - - auto* slr2 = var2->GetMutable(); - auto* tensor2 = slr2->mutable_value(); - auto* rows2 = slr2->mutable_rows(); - float* tensor_data2 = nullptr; - framework::Tensor tmp_tensor; - - if (platform::is_gpu_place(ctx.GetPlace())) { - platform::CPUPlace cpu; - framework::TensorCopy(*tensor2, cpu, &tmp_tensor); - tensor_data2 = tmp_tensor.data(); - } else { - tensor_data2 = const_cast(tensor2->data()); - } - const int64_t* rows_data2 = rows2->data(); - - for (int i = 0; i < tensor_numel; ++i) { - EXPECT_FLOAT_EQ(tensor_data2[i], 32.7); - } - for (size_t i = 0; i < rows2->size(); ++i) { - EXPECT_EQ(rows_data2[i], static_cast(i)); - } - EXPECT_EQ(slr2->height(), 1000); - } -} - -void RunTestLodTensor(platform::Place place) { - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - auto& ctx = *pool.Get(place); - - // serialize var to ByteBuffer - butil::IOBuf iobuf; - sendrecv::VariableMessage msg; - int tensor_numel = 512 * 8 * 4 * 2; - { - framework::Variable var; - auto* tensor = var.GetMutable(); - tensor->Resize(framework::make_ddim({512, 8, 4, 2})); - framework::LoD lod; - lod.push_back(framework::Vector({1, 3, 8})); - tensor->set_lod(lod); - tensor->mutable_data(place); - math::set_constant(ctx, tensor, 31.9); - - operators::distributed::SerializeToIOBuf("myvar", &var, ctx, &msg, &iobuf, - "", false); - } - - // check sendrecv::VariableMessage meta data - { - EXPECT_EQ(msg.varname(), "myvar"); - EXPECT_EQ(msg.type(), 0); - EXPECT_EQ(msg.dims()[0], 512); - EXPECT_EQ(msg.dims()[1], 8); - EXPECT_EQ(msg.dims()[2], 4); - EXPECT_EQ(msg.dims()[3], 2); - EXPECT_EQ(msg.lod_level(), 1); - EXPECT_EQ(msg.lod(0).lod_data(0), 1); - EXPECT_EQ(msg.lod(0).lod_data(1), 3); - EXPECT_EQ(msg.lod(0).lod_data(2), 8); - } - - // deserialize - { - framework::Scope scope; - scope.Var("myvar"); - operators::distributed::BRPCVariableResponse resp(&scope, &ctx); - EXPECT_EQ(resp.Parse(iobuf, msg), 0); - - framework::Variable* var2 = resp.GetVar(); - - auto tensor2 = var2->Get(); - float* tensor_data2 = nullptr; - framework::Tensor tmp_tensor; - - if (platform::is_gpu_place(ctx.GetPlace())) { - platform::CPUPlace cpu; - framework::TensorCopy(tensor2, cpu, &tmp_tensor); - tensor_data2 = tmp_tensor.data(); - } else { - tensor_data2 = const_cast(tensor2.data()); - } - - for (int i = 0; i < tensor_numel; ++i) - EXPECT_FLOAT_EQ(tensor_data2[i], 31.9); - } -} - -TEST(LodTensor, Run) { - platform::CPUPlace place; - RunTestLodTensor(place); -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - platform::CUDAPlace gpu(0); - RunTestLodTensor(gpu); -#endif -} - -TEST(SelectedRows, Run) { - platform::CPUPlace place; - RunSerdeTestSelectedRows(place); -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - platform::CUDAPlace gpu; - RunSerdeTestSelectedRows(gpu); -#endif -} diff --git a/paddle/fluid/operators/distributed/brpc/brpc_server.cc b/paddle/fluid/operators/distributed/brpc/brpc_server.cc deleted file mode 100644 index 5ca26f006b..0000000000 --- a/paddle/fluid/operators/distributed/brpc/brpc_server.cc +++ /dev/null @@ -1,417 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/operators/distributed/brpc/brpc_server.h" -#include -#include -#include "paddle/fluid/framework/threadpool.h" -#include "paddle/fluid/operators/distributed/brpc/brpc_sendrecvop_utils.h" -#include "paddle/fluid/operators/distributed/brpc/brpc_variable_response.h" -#include "paddle/fluid/operators/distributed/request_handler.h" - -namespace sendrecv { - -namespace distributed = paddle::operators::distributed; - -typedef std::unordered_map - HandlerMap; - -class BRPCServiceImpl : public SendRecvService { - public: - explicit BRPCServiceImpl(const HandlerMap& rpc_call_map, - distributed::RPCServer* rpc_server) - : rpc_server_(rpc_server) { - VLOG(3) << "BRPCServiceImpl size: " << rpc_call_map.size(); - auto it = rpc_call_map.find(distributed::kRequestSend); - if (it != rpc_call_map.end()) { - request_send_h_ = it->second; - send_threads_.reset(new paddle::framework::ThreadPool( - rpc_server_->GetThreadNum(distributed::kRequestSend))); - } - - it = rpc_call_map.find(distributed::kRequestGet); - if (it != rpc_call_map.end()) { - request_get_h_ = it->second; - get_threads_.reset(new paddle::framework::ThreadPool( - rpc_server_->GetThreadNum(distributed::kRequestGet))); - } - - it = rpc_call_map.find(distributed::kRequestGetNoBarrier); - if (it != rpc_call_map.end()) { - request_getnobarrier_h_ = it->second; - getnobarrier_threads_.reset(new paddle::framework::ThreadPool( - rpc_server_->GetThreadNum(distributed::kRequestGetNoBarrier))); - } - - it = rpc_call_map.find(distributed::kRequestPrefetch); - if (it != rpc_call_map.end()) { - request_prefetch_h_ = it->second; - prefetch_threads_.reset(new paddle::framework::ThreadPool( - rpc_server_->GetThreadNum(distributed::kRequestPrefetch))); - } - - it = rpc_call_map.find(distributed::kRequestCheckpoint); - if (it != rpc_call_map.end()) { - request_checkpoint_h_ = it->second; - checkpoint_notify_threads_.reset(new paddle::framework::ThreadPool( - rpc_server_->GetThreadNum(distributed::kRequestPrefetch))); - } - - it = rpc_call_map.find(distributed::kRequestGetMonomerVariable); - if (it != rpc_call_map.end()) { - request_get_monomer_handler_h_ = it->second; - } - - it = rpc_call_map.find(distributed::kRequestGetMonomerBarrier); - if (it != rpc_call_map.end()) { - request_get_monomer_barrier_handler_h_ = it->second; - } - } - - virtual ~BRPCServiceImpl() {} - void SendVariable(google::protobuf::RpcController* cntl_butil, - const VariableMessage* request, VoidMessage* response, - google::protobuf::Closure* done) override { - send_threads_->Run( - [=] { _SendVariable(cntl_butil, request, response, done); }); - } - - void _SendVariable(google::protobuf::RpcController* cntl_butil, - const VariableMessage* request, VoidMessage* response, - google::protobuf::Closure* done) { - PADDLE_ENFORCE_NOT_NULL( - request_send_h_, platform::errors::PreconditionNotMet( - "RequestSend handler should be registed first!")); - brpc::ClosureGuard done_guard(done); - brpc::Controller* cntl = static_cast(cntl_butil); - - std::string varname = request->varname(); - VLOG(3) << "RequestSend var_name:" << varname - << ", trainer_id:" << request->trainer_id() - << ", from:" << cntl->remote_side(); - - distributed::BRPCVariableResponse resp(request_send_h_->scope(), - request_send_h_->dev_ctx(), - request_send_h_->distributed_mode()); - PADDLE_ENFORCE_EQ( - resp.Parse(cntl->request_attachment(), *request), 0, - platform::errors::InvalidArgument("parse iobuf to tensor error!")); - - auto scope = resp.GetMutableLocalScope(); - auto invar = resp.GetVar(); - int trainer_id = request->trainer_id(); - paddle::framework::Variable* outvar = nullptr; - - request_send_h_->Handle(varname, scope, invar, &outvar, trainer_id); - } - - void GetVariable(google::protobuf::RpcController* cntl_butil, - const VariableMessage* request, VariableMessage* response, - google::protobuf::Closure* done) override { - get_threads_->Run( - [=] { _GetVariable(cntl_butil, request, response, done); }); - } - - void GetVariableNoBarrier(google::protobuf::RpcController* cntl_butil, - const VariableMessage* request, - VariableMessage* response, - google::protobuf::Closure* done) override { - getnobarrier_threads_->Run( - [=] { _GetVariableNoBarrier(cntl_butil, request, response, done); }); - } - - void _GetVariable(google::protobuf::RpcController* cntl_butil, - const VariableMessage* request, VariableMessage* response, - google::protobuf::Closure* done) { - PADDLE_ENFORCE_NOT_NULL( - request_get_h_, platform::errors::PreconditionNotMet( - "RequestGet handler should be registed first!")); - - brpc::ClosureGuard done_guard(done); - brpc::Controller* cntl = static_cast(cntl_butil); - - std::string varname = request->varname(); - std::string out_varname = request->out_varname(); - VLOG(3) << "RequestGet varname:" << varname - << ", out_varname:" << out_varname - << ", trainer_id:" << request->trainer_id() - << ", from:" << cntl->remote_side(); - - auto scope = request_get_h_->scope(); - paddle::framework::Variable* invar = nullptr; - int trainer_id = request->trainer_id(); - paddle::framework::Variable* outvar = nullptr; - - request_get_h_->Handle(varname, scope, invar, &outvar, trainer_id, - out_varname); - - if (outvar) { - distributed::SerializeToIOBuf(out_varname, outvar, - *request_get_h_->dev_ctx(), response, - &cntl->response_attachment(), "", false); - } - } - - void _GetVariableNoBarrier(google::protobuf::RpcController* cntl_butil, - const VariableMessage* request, - VariableMessage* response, - google::protobuf::Closure* done) { - PADDLE_ENFORCE_NOT_NULL( - request_getnobarrier_h_, - platform::errors::PreconditionNotMet( - "RequestGetNoBarrier handler should be registed first!")); - - brpc::ClosureGuard done_guard(done); - brpc::Controller* cntl = static_cast(cntl_butil); - - std::string varname = request->varname(); - std::string out_varname = request->out_varname(); - int trainer_id = request->trainer_id(); - - VLOG(3) << "RequestGetNoBarrier varname:" << varname - << ", out_varname:" << out_varname << ", trainer_id:" << trainer_id - << ", from:" << cntl->remote_side(); - - auto scope = request_getnobarrier_h_->scope(); - paddle::framework::Variable* invar = nullptr; - paddle::framework::Variable* outvar = nullptr; - - request_getnobarrier_h_->Handle(varname, scope, invar, &outvar, trainer_id, - out_varname); - - if (outvar) { - distributed::SerializeToIOBuf( - out_varname, outvar, *request_getnobarrier_h_->dev_ctx(), response, - &cntl->response_attachment(), "", false); - } - } - - void PrefetchVariable(google::protobuf::RpcController* cntl_butil, - const VariableMessage* request, - VariableMessage* response, - google::protobuf::Closure* done) override { - prefetch_threads_->Run( - [=] { _PrefetchVariable(cntl_butil, request, response, done); }); - } - - void _PrefetchVariable(google::protobuf::RpcController* cntl_butil, - const VariableMessage* request, - VariableMessage* response, - google::protobuf::Closure* done) { - PADDLE_ENFORCE_NOT_NULL(request_prefetch_h_, - platform::errors::PreconditionNotMet( - "kRequestPrefetch handler should be registed first!"); - - brpc::ClosureGuard done_guard(done); - brpc::Controller* cntl = static_cast(cntl_butil); - - // prefetch process... - std::string in_var_name = request->varname(); - std::string out_var_name = request->out_varname(); - VLOG(3) << "RequestPrefetch, in_var_name: " << in_var_name - << ", out_var_name: " << out_var_name - << ", trainer_id:" << request->trainer_id() - << ", from:" << cntl->remote_side(); - - distributed::BRPCVariableResponse resp( - request_prefetch_h_->scope(), request_prefetch_h_->dev_ctx(), true); - - PADDLE_ENFORCE_EQ(resp.Parse(cntl->request_attachment(), *request), 0, - platform::errors::InvalidArgument( - "parse iobuf to tensor error!")); - - auto scope = resp.GetMutableLocalScope(); - auto invar = scope->FindVar(in_var_name); - std::string table_name = request->table_name(); - int trainer_id = request->trainer_id(); - paddle::framework::Variable* outvar = scope->Var(out_var_name); - - request_prefetch_h_->Handle(in_var_name, scope, invar, &outvar, trainer_id, - out_var_name, table_name); - - distributed::SerializeToIOBuf(out_var_name, outvar, - *request_prefetch_h_->dev_ctx(), response, - &cntl->response_attachment(), "", true); - } - - void CheckpointNotify(google::protobuf::RpcController* cntl_butil, - const VariableMessage* request, VoidMessage* response, - google::protobuf::Closure* done) override { - checkpoint_notify_threads_->Run( - [=] { _CheckpointNotify(cntl_butil, request, response, done); }); - } - - void _CheckpointNotify(google::protobuf::RpcController* cntl_butil, - const VariableMessage* request, VoidMessage* response, - google::protobuf::Closure* done) { - PADDLE_ENFORCE_NOT_NULL( - request_checkpoint_h_, - platform::errors::PreconditionNotMet( - "kRequestCheckpointNotify handler should be registed first!")); - - brpc::ClosureGuard done_guard(done); - brpc::Controller* cntl = static_cast(cntl_butil); - - distributed::BRPCVariableResponse resp(request_checkpoint_h_->scope(), - request_checkpoint_h_->dev_ctx()); - - auto scope = resp.GetMutableLocalScope(); - - std::string checkpoint_notify = request->varname(); - std::string checkpoint_dir = request->out_varname(); - int trainer_id = request->trainer_id(); - - VLOG(4) << "RequestCheckpointNotify notify: " << checkpoint_notify - << ", dir: " << checkpoint_dir - << ", trainer_id:" << request->trainer_id() - << ", from:" << cntl->remote_side(); - - request_checkpoint_h_->Handle(checkpoint_notify, scope, nullptr, nullptr, - trainer_id, checkpoint_dir); - } - - void GetMonomerVariable(google::protobuf::RpcController* cntl_butil, - const VariableMessage* request, - VariableMessage* response, - google::protobuf::Closure* done) override { - PADDLE_ENFORCE_NOT_NULL( - request_get_monomer_handler_h_, - platform::errors::PreconditionNotMet( - "kRequestGetMonomerVariable handler should be registed first!")); - - brpc::ClosureGuard done_guard(done); - brpc::Controller* cntl = static_cast(cntl_butil); - - // proc request. - std::string varname = request->varname(); - VLOG(3) << "GetMonomerVariable " << varname - << ", trainer_id:" << request->trainer_id() - << ", from:" << cntl->remote_side(); - - rpc_server_->WaitVarCond(varname); - distributed::MonomerHandle h = rpc_server_->GetMonomer(varname); - - auto scope = h.scope_; - auto invar = scope->FindVar(varname); - paddle::framework::Variable* outvar = nullptr; - - request_get_monomer_handler_h_->Handle(varname, scope, invar, &outvar, - request->trainer_id()); - - if (outvar) { - distributed::SerializeToIOBuf(varname, outvar, *h.dev_ctx_, response, - &cntl->response_attachment(), "", false); - } - } - - void GetMonomerBarrier(google::protobuf::RpcController* cntl_butil, - const VariableMessage* request, VoidMessage* response, - google::protobuf::Closure* done) override { - PADDLE_ENFORCE_NOT_NULL( - request_get_monomer_barrier_handler_h_, - platform::errors::PreconditionNotMet( - "RequestGetMonomerBarrier handler should be registed first!")); - - brpc::ClosureGuard done_guard(done); - brpc::Controller* cntl = static_cast(cntl_butil); - - std::string varname = request->varname(); - VLOG(3) << "RequestGetMonomerBarrier var_name:" << varname - << ", trainer_id:" << request->trainer_id() - << ", from:" << cntl->remote_side(); - - rpc_server_->WaitVarCond(varname); - distributed::MonomerHandle h = rpc_server_->GetMonomer(varname); - - paddle::framework::Scope* scope = nullptr; - paddle::framework::Variable* invar = nullptr; - paddle::framework::Variable* outvar = nullptr; - - request_get_monomer_barrier_handler_h_->Handle( - varname, scope, invar, &outvar, request->trainer_id()); - } - - private: - distributed::RequestHandler* request_send_h_{nullptr}; - distributed::RequestHandler* request_get_h_{nullptr}; - distributed::RequestHandler* request_getnobarrier_h_{nullptr}; - distributed::RequestHandler* request_prefetch_h_{nullptr}; - distributed::RequestHandler* request_checkpoint_h_{nullptr}; - distributed::RequestHandler* request_get_monomer_handler_h_{nullptr}; - distributed::RequestHandler* request_get_monomer_barrier_handler_h_{nullptr}; - - distributed::RPCServer* rpc_server_{nullptr}; - - // FIXME(gongwb): brpc should support process one rpc use one threadpool. - std::unique_ptr send_threads_; - std::unique_ptr get_threads_; - std::unique_ptr getnobarrier_threads_; - std::unique_ptr prefetch_threads_; - std::unique_ptr checkpoint_notify_threads_; -}; -} // namespace sendrecv - -namespace paddle { -namespace operators { -namespace distributed { - -void AsyncBRPCServer::StartServer() { - // Instance of your service. - sendrecv::BRPCServiceImpl service_impl(rpc_call_map_, this); - - // Add the service into server. Notice the second parameter, because the - // service is put on stack, we don't want server to delete it, otherwise - // use brpc::SERVER_OWNS_SERVICE. - if (server_.AddService(&service_impl, brpc::SERVER_DOESNT_OWN_SERVICE) != 0) { - PADDDLE_THROW(platform::errors::Unavailable( - "Failed to add service into BRPC server.")); - return; - } - - brpc::ServerOptions options; -#ifdef PADDLE_WITH_BRPC_RDMA - options.use_rdma = true; -#endif - options.idle_timeout_sec = idle_timeout_s_; - options.max_concurrency = max_concurrency_; - if (server_.Start(bind_address_.c_str(), &options) != 0) { - PADDDLE_THROW(platform::errors::Unavailable( - "Failed to start EchoServer %s.", bind_address_)); - return; - } - - butil::EndPoint ep = server_.listen_address(); - selected_port_ = ep.port; - - { - std::lock_guard lock(this->mutex_ready_); - ready_ = 1; - } - condition_ready_.notify_all(); - - server_.Join(); -} - -void AsyncBRPCServer::ShutDownImpl() { server_.Stop(1000); } - -void AsyncBRPCServer::WaitServerReady() { - VLOG(3) << "AsyncGRPCServer is wait server ready"; - std::unique_lock lock(this->mutex_ready_); - condition_ready_.wait(lock, [=] { return this->ready_ == 1; }); - VLOG(3) << "AsyncGRPCServer WaitSeverReady"; -} - -}; // namespace distributed -}; // namespace operators -}; // namespace paddle diff --git a/paddle/fluid/operators/distributed/brpc/brpc_server.h b/paddle/fluid/operators/distributed/brpc/brpc_server.h deleted file mode 100644 index 78bbe5adc0..0000000000 --- a/paddle/fluid/operators/distributed/brpc/brpc_server.h +++ /dev/null @@ -1,53 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include // NOLINT -#include // NOLINT -#include - -#include "brpc/server.h" -#include "paddle/fluid/operators/distributed/distributed_pb.h" -#include "paddle/fluid/operators/distributed/rpc_server.h" - -namespace paddle { -namespace operators { -namespace distributed { - -class AsyncBRPCServer final : public RPCServer { - public: - explicit AsyncBRPCServer(const std::string& address, int client_num) - : RPCServer(address, client_num), ready_(0) {} - - virtual ~AsyncBRPCServer() {} - void StartServer() override; - void WaitServerReady() override; - - private: - void ShutDownImpl() override; - - brpc::Server server_; - - static constexpr int idle_timeout_s_ = -1; - static constexpr int max_concurrency_ = 0; - - std::mutex mutex_ready_; - std::condition_variable condition_ready_; - int ready_; -}; - -}; // namespace distributed -}; // namespace operators -}; // namespace paddle diff --git a/paddle/fluid/operators/distributed/brpc/brpc_variable_response.cc b/paddle/fluid/operators/distributed/brpc/brpc_variable_response.cc deleted file mode 100644 index 49521e8a77..0000000000 --- a/paddle/fluid/operators/distributed/brpc/brpc_variable_response.cc +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -#include "paddle/fluid/operators/distributed/brpc/brpc_variable_response.h" -#include "paddle/fluid/operators/distributed/send_recv.pb.h" - -namespace paddle { -namespace operators { -namespace distributed { - -namespace pb = ::google::protobuf; -using vr = ::sendrecv::VariableMessage; - -int BRPCVariableResponse::Parse(Source* source) { - pb::io::ZeroCopyInputStream* input_stream = source->contents(); - pb::io::CodedInputStream input(input_stream); - input.SetTotalBytesLimit(INT_MAX, INT_MAX); - - while (1) { - unsigned int tag = 0; - if (!input.ReadLittleEndian32(&tag)) { - break; - } - - uint64_t num_bytes = 0; - if (!input.ReadLittleEndian64(&num_bytes)) { - break; - } - - int field = static_cast(tag); - int ret = field == 0 ? -1 : field; - switch (field) { - case vr::kSerializedFieldNumber: { - if (!ProcSerializedField(field, &input, num_bytes)) { - return ret; - } - break; - } - case vr::kRowsFieldNumber: { - PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS || - meta_.type() == sendrecv::LOD_TENSOR) && - meta_.varname() != "", - platform::errors::PreconditionNotMet( - "meta info should be got first!")); - - if (!CopySelectRowsData(&input, *dev_ctx_, num_bytes)) { - return ret; - } - break; - } - default: { - PADDLE_THROW(platform::errors::Unavailable( - "not surpported %u fieldnumber", field)); - return ret; - } - } - } - - return 0; -} -} // namespace distributed -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/distributed/brpc/brpc_variable_response.h b/paddle/fluid/operators/distributed/brpc/brpc_variable_response.h deleted file mode 100644 index 6282f08a72..0000000000 --- a/paddle/fluid/operators/distributed/brpc/brpc_variable_response.h +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include - -#include "brpc/channel.h" -#include "paddle/fluid/framework/data_type.h" -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/scope.h" -#include "paddle/fluid/framework/selected_rows.h" -#include "paddle/fluid/framework/var_type.h" - -#include "paddle/fluid/operators/distributed/distributed_pb.h" - -#include "google/protobuf/io/coded_stream.h" -#include "google/protobuf/io/zero_copy_stream.h" -#include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/operators/distributed/variable_response.h" - -namespace paddle { -namespace operators { -namespace distributed { - -class BRPCSourceWrapper : public Source { - public: - explicit BRPCSourceWrapper(const butil::IOBuf& iobuf) : source_(iobuf) {} - ::google::protobuf::io::ZeroCopyInputStream* contents() override { - return &source_; - } - - private: - butil::IOBufAsZeroCopyInputStream source_; -}; - -class BRPCVariableResponse : public VariableResponse { - public: - BRPCVariableResponse(const framework::Scope* scope, - const platform::DeviceContext* dev_ctx, - bool create_scope = false) - : VariableResponse(scope, dev_ctx, create_scope) {} - - virtual ~BRPCVariableResponse() {} - - // parse attachment from iobuf - int Parse(Source* source) override; - int Parse(const butil::IOBuf& iobuf, const sendrecv::VariableMessage& meta) { - BRPCSourceWrapper wrapper(iobuf); - return VariableResponse::Parse(&wrapper, meta); - } -}; - -}; // namespace distributed -}; // namespace operators -}; // namespace paddle diff --git a/paddle/fluid/operators/distributed/collective_client.cc b/paddle/fluid/operators/distributed/collective_client.cc deleted file mode 100644 index fcd3e6abea..0000000000 --- a/paddle/fluid/operators/distributed/collective_client.cc +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/operators/distributed/collective_client.h" -#include -#include "gflags/gflags.h" - -DECLARE_int32(rpc_deadline); - -namespace paddle { -namespace operators { -namespace distributed { -std::once_flag CollectiveClient::init_flag_; -std::unique_ptr CollectiveClient::client_(nullptr); - -bool CollectiveClient::Gather(const std::vector& remote_vars, - std::vector* dst, - const platform::DeviceContext& ctx, - framework::Scope* scope, int64_t time_out) { - for (auto r : remote_vars) { - VLOG(50) << "begin gather from ep:" << r.String(); - scope->Var(r.var_name_)->GetMutable(); - VarHandlePtr ptr = rpc_client_->AsyncGetMonomerVariable( - r.ep_, ctx, *scope, r.var_name_, time_out); - } - - rpc_client_->Wait(); - - for (auto r : remote_vars) { - auto select_rows = - scope->FindVar(r.var_name_)->GetMutable(); - dst->push_back(select_rows); - - VLOG(4) << "gather from ep:" << r.String() - << ", select_rows:" << GetSelectedRowsInfo(*select_rows); - - rpc_client_->AsyncGetMonomerBarrier(r.ep_, r.var_name_); - } - - rpc_client_->Wait(); - return true; -} - -} // namespace distributed -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/distributed/collective_client.h b/paddle/fluid/operators/distributed/collective_client.h deleted file mode 100644 index e7d8bb8df9..0000000000 --- a/paddle/fluid/operators/distributed/collective_client.h +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include // NOLINT -#include -#include -#include - -#include "gflags/gflags.h" -#include "paddle/fluid/framework/data_type.h" -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/scope.h" -#include "paddle/fluid/operators/distributed/distributed.h" -#include "paddle/fluid/operators/distributed/request_handler.h" - -namespace paddle { -namespace framework { -class Scope; -class SelectedRows; -} // namespace framework -namespace platform { -class DeviceContext; -} // namespace platform -} // namespace paddle - -DECLARE_int32(rpc_deadline); - -namespace paddle { -namespace operators { -namespace distributed { - -inline std::string GetSelectedRowsInfo(const framework::SelectedRows& slr) { - std::stringstream ss; - ss << ", height:" << slr.height() << ", rows:["; - for (unsigned int i = 0; i < slr.rows().size(); i++) { - if (i != slr.rows().size() - 1) { - ss << slr.rows()[i] << ","; - } else { - ss << slr.rows()[i]; - } - } - ss << "], dims:" << slr.value().dims(); - return ss.str(); -} - -struct RemoteVar { - std::string ep_; - std::string var_name_; - int trainer_id_{0}; - - std::string String() { - std::stringstream ss; - ss << "ep:" << ep_ << ", var_name:" << var_name_ - << ", trainer_id:" << trainer_id_; - - return ss.str(); - } -}; - -class CollectiveClient { - public: - CollectiveClient() { - rpc_client_.reset(new RPCCLIENT_T()); - rpc_client_->InitImpl(); - } - virtual ~CollectiveClient() {} - - // note this function will retain the rank order. - bool Gather(const std::vector& remote_vars, - std::vector* dst, - const platform::DeviceContext& ctx, framework::Scope* scope, - int64_t time_out = FLAGS_rpc_deadline); - - static CollectiveClient* GetInstance() { - std::call_once(init_flag_, [&]() { - if (client_.get() == nullptr) { - client_.reset(new CollectiveClient()); - } - }); - return client_.get(); - } - - private: - std::unique_ptr rpc_client_; - - static std::once_flag init_flag_; - static std::unique_ptr client_; -}; -} // namespace distributed -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/distributed/collective_server.cc b/paddle/fluid/operators/distributed/collective_server.cc deleted file mode 100644 index cdd37742d2..0000000000 --- a/paddle/fluid/operators/distributed/collective_server.cc +++ /dev/null @@ -1,68 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/distributed/collective_server.h" -#include - -DEFINE_int32(collective_get_thread_num, 5, "number of threads for rpc get"); - -namespace paddle { -namespace operators { -namespace distributed { - -std::once_flag CollectiveServer::init_flag_; -std::shared_ptr CollectiveServer::collective_server_(nullptr); - -CollectiveServer::CollectiveServer(const std::string& end_point, int fan_in) { - VLOG(1) << "Create colllective server:" << end_point << ", fan_in:" << fan_in; - rpc_server_.reset(new RPCSERVER_T(end_point, fan_in)); -} - -void CollectiveServer::Stop() { - rpc_server_->ShutDown(); - server_thread_->join(); - loop_thread_->join(); -} - -void CollectiveServer::StartServer() { - get_monomer_handler_.reset(new GetMonomerHandler()); - get_monomer_handler_->SetRPCServer(rpc_server_.get()); - - get_barrier_handler_.reset(new GetMonomerBarrierHandler()); - get_barrier_handler_->SetRPCServer(rpc_server_.get()); - - rpc_server_->RegisterRPC(distributed::kRequestGetMonomerVariable, - get_monomer_handler_.get(), - FLAGS_collective_get_thread_num); - rpc_server_->RegisterRPC(distributed::kRequestGetMonomerBarrier, - get_barrier_handler_.get(), 1); - - server_thread_.reset(new std::thread([&]() { rpc_server_->StartServer(); })); - rpc_server_->WaitServerReady(); - - loop_thread_.reset(new std::thread([&]() { - while (true) { - if (rpc_server_->IsExit()) { - LOG(WARNING) << "get exit!rpc_processor break!"; - break; - } - sleep(1); - } - VLOG(1) << "CollectiveServer loop_thread end"; - })); -} - -}; // namespace distributed -}; // namespace operators -}; // namespace paddle diff --git a/paddle/fluid/operators/distributed/collective_server.h b/paddle/fluid/operators/distributed/collective_server.h deleted file mode 100644 index 4964923286..0000000000 --- a/paddle/fluid/operators/distributed/collective_server.h +++ /dev/null @@ -1,116 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include -#include -#include -#include // NOLINT -#include -#include -#include "gflags/gflags.h" -#include "paddle/fluid/operators/distributed/distributed.h" -#include "paddle/fluid/operators/distributed/request_handler.h" -#include "paddle/fluid/operators/distributed/request_handler_impl.h" -#include "paddle/fluid/operators/distributed/rpc_server.h" - -namespace paddle { -namespace framework { -class Variable; -} // namespace framework -} // namespace paddle - -namespace paddle { -namespace operators { -namespace distributed { - -class CollectiveServer; - -class GetMonomerHandler final : public RequestHandler { - public: - GetMonomerHandler() : RequestHandler(true) {} - virtual ~GetMonomerHandler() {} - bool Handle(const std::string& var_name, framework::Scope* scope, - framework::Variable* var, framework::Variable** outvar, - const int trainer_id, const std::string& out_var_name = "", - const std::string& table_name = "") override { - VLOG(50) << "GetMonomerHandler recv " << var_name; - - *outvar = scope->FindVar(var_name); - PADDLE_ENFORCE_NOT_NULL( - outvar, platform::errors::NotFound("var: %s is not found.", var_name)); - - return true; - } -}; - -class GetMonomerBarrierHandler final : public RequestHandler { - public: - GetMonomerBarrierHandler() : RequestHandler(true) {} - virtual ~GetMonomerBarrierHandler() {} - bool Handle(const std::string& var_name, framework::Scope* scope, - framework::Variable* var, framework::Variable** outvar, - const int trainer_id, const std::string& out_var_name = "", - const std::string& table_name = "") override { - VLOG(50) << "GetMonomerHandler recv " << var_name; - - rpc_server_->IncreaseVarBarrier(var_name); - - return true; - } -}; - -class CollectiveServer final { - public: - explicit CollectiveServer(const std::string& end_point, int fan_in); - - virtual ~CollectiveServer() {} - - void StartServer(); - - static CollectiveServer* GetInstance(const std::string& end_point, - int fan_in) { - std::call_once(init_flag_, [&]() { - if (collective_server_.get() == nullptr) { - collective_server_.reset(new CollectiveServer(end_point, fan_in)); - collective_server_->StartServer(); - } - }); - - return collective_server_.get(); - } - - std::shared_ptr GetRPCServer() { return rpc_server_; } - - void Stop(); - - private: - std::unique_ptr get_monomer_handler_; - std::unique_ptr get_barrier_handler_; - - std::shared_ptr rpc_server_; - std::shared_ptr server_thread_; - std::shared_ptr loop_thread_; - - bool ready_{false}; - - static std::once_flag init_flag_; - static std::shared_ptr collective_server_; -}; - -}; // namespace distributed -}; // namespace operators -}; // namespace paddle diff --git a/paddle/fluid/operators/distributed/collective_server_test.cc b/paddle/fluid/operators/distributed/collective_server_test.cc deleted file mode 100644 index 92b2eb4b51..0000000000 --- a/paddle/fluid/operators/distributed/collective_server_test.cc +++ /dev/null @@ -1,131 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include -#include - -#include "gtest/gtest.h" -#include "paddle/fluid/operators/distributed/collective_client.h" -#include "paddle/fluid/operators/distributed/collective_server.h" - -namespace paddle { -namespace framework { -class Variable; -} // namespace framework -} // namespace paddle - -namespace framework = paddle::framework; -namespace platform = paddle::platform; -namespace distributed = paddle::operators::distributed; - -std::unique_ptr StartServer( - const std::string& ep, int fan_in, framework::Scope* scope, - platform::DeviceContext* dev_ctx) { - distributed::CollectiveServer* server = - distributed::CollectiveServer::GetInstance(ep, fan_in); - - auto rpc_server = server->GetRPCServer(); - rpc_server->RegisterVar("var1", distributed::kRequestGetMonomerVariable, - scope, dev_ctx); - - std::cout << "StartServer return" << std::endl; - return std::unique_ptr(server); -} - -std::unique_ptr GenerateVars(platform::Place place) { - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - auto& ctx = *pool.Get(place); - - framework::Scope* scope = new framework::Scope(); - framework::Variable* var = scope->Var("var1"); - auto* slr = var->GetMutable(); - slr->set_height(20000); - - auto* tensor = slr->mutable_value(); - auto* rows = slr->mutable_rows(); - - tensor->Resize(framework::make_ddim({3, 1024})); - tensor->mutable_data(place); - - paddle::operators::math::set_constant(ctx, tensor, 32.7); - for (int i = 0; i < 3; ++i) rows->push_back(i); - - std::cout << "src:" << distributed::GetSelectedRowsInfo(*slr); - - return std::unique_ptr(scope); -} - -void Gather(const std::vector& vars, - platform::DeviceContext* dev_ctx) { - distributed::CollectiveClient* client = - distributed::CollectiveClient::GetInstance(); - - framework::Scope* scope = new framework::Scope(); - framework::Variable* var = scope->Var("var1"); - var->GetMutable(); - - std::vector dst; - client->Gather(vars, &dst, *dev_ctx, scope); - std::cout << "dst:" << distributed::GetSelectedRowsInfo(*dst[0]); - dev_ctx->Wait(); - - ASSERT_EQ(dst[0]->value().dims(), framework::make_ddim({3, 1024})); - ASSERT_EQ(dst[0]->height(), 20000); - ASSERT_EQ(dst[0]->rows().size(), static_cast(3)); - for (int i = 0; i < 3; i++) { - ASSERT_EQ(dst[0]->rows()[i], i); - } - - std::vector vec; - TensorToVector(dst[0]->value(), *dev_ctx, &vec); - for (size_t i = 0; i < 3 * 1024; i++) { - ASSERT_FLOAT_EQ(vec[i], 32.7); - } -} - -TEST(CollectiveServer, GPU) { - setenv("http_proxy", "", 1); - setenv("https_proxy", "", 1); - - platform::CUDAPlace place; - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - auto& ctx = *pool.Get(place); - - std::string ep = "127.0.0.1:7164"; - auto scope = GenerateVars(place); - - auto* v1 = scope->FindVar("var1"); - std::cout << "var1:" << v1 << std::endl; - - auto server = StartServer(ep, 2, scope.get(), &ctx); - auto rpc_server = server->GetRPCServer(); - - distributed::RemoteVar var; - var.ep_ = ep; - var.var_name_ = "var1"; - var.trainer_id_ = 0; - - std::vector vars{var}; - Gather(vars, &ctx); - Gather(vars, &ctx); - - std::cout << "begin WaitVarBarrier" << std::endl; - rpc_server->WaitVarBarrier("var1"); - rpc_server->ClearRegisteredVars(); - server->Stop(); - - scope.release(); - server.release(); -} diff --git a/paddle/fluid/operators/distributed/communicator.cc b/paddle/fluid/operators/distributed/communicator.cc deleted file mode 100644 index 4ee27a6414..0000000000 --- a/paddle/fluid/operators/distributed/communicator.cc +++ /dev/null @@ -1,989 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/distributed/communicator.h" - -#include - -#include -#include // NOLINT -#include -#include // NOLINT -#include - -#include "gflags/gflags.h" -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/selected_rows.h" -#include "paddle/fluid/framework/tensor_util.h" -#include "paddle/fluid/framework/threadpool.h" -#include "paddle/fluid/framework/variable_helper.h" -#include "paddle/fluid/operators/distributed/distributed.h" -#include "paddle/fluid/operators/distributed/parameter_recv.h" -#include "paddle/fluid/operators/distributed/parameter_send.h" -#include "paddle/fluid/string/printf.h" -#include "paddle/fluid/string/split.h" - -namespace paddle { -namespace operators { -namespace distributed { - -using Tree = - std::map>>; -using RpcCtxMap = operators::distributed::RpcCtxMap; - -inline double GetCurrentUS() { - struct timeval time; - gettimeofday(&time, NULL); - return 1e+6 * time.tv_sec + time.tv_usec; -} - -Communicator::Communicator() {} - -std::once_flag Communicator::init_flag_; -std::shared_ptr Communicator::communicator_(nullptr); - -void AsyncCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx, - const RpcCtxMap &recv_varname_to_ctx, - Scope *recv_scope) { - send_varname_to_ctx_ = std::move(send_varname_to_ctx); - recv_varname_to_ctx_ = std::move(recv_varname_to_ctx); - recv_scope_ = std::move(recv_scope); - - if (send_varname_to_ctx.size() == 0) { - VLOG(0) << "nothing need to be send, will not start send_thread"; - } else { - send_scope_.reset(new Scope()); - for (auto &iter : send_varname_to_ctx_) { - if (iter.first == STEP_COUNTER && !need_global_step_) continue; - send_varname_to_queue_[iter.first] = - std::make_shared>>( - send_queue_size_); - } - send_threadpool_.reset(new ::ThreadPool(thread_pool_size_)); - } - - if (recv_varname_to_ctx.size() == 0) { - VLOG(0) << "nothing need to be received, will not start recv_thread"; - } else { - recv_threadpool_.reset(new ::ThreadPool(thread_pool_size_)); - } - - InitParams(); -} - -void AsyncCommunicator::InitParams() { RecvNoBarrier(); } - -AsyncCommunicator::~AsyncCommunicator() { - running_ = false; - if (main_thread_) main_thread_->join(); -} - -void AsyncCommunicator::SendGlobalStep(int batches) { - if (!need_global_step_) { - return; - } - - if (batches == 0) { - return; - } - - auto &var_name = STEP_COUNTER; - auto *out_var = send_scope_->Var(var_name); - auto *out_t = out_var->GetMutable(); - auto *data = out_t->mutable_data({1}, platform::CPUPlace()); - data[0] = static_cast(batches); - - auto &ctx = send_varname_to_ctx_.at(var_name); - auto send_functor = distributed::ParameterSend(); - send_functor(ctx, *send_scope_, true, 1); -} - -void AsyncCommunicator::SendByCommunicator() { - std::vector> task_futures; - task_futures.reserve(send_varname_to_ctx_.size()); - VLOG(3) << "run send graph"; - - auto before_run_send_graph = GetCurrentUS(); - for (auto &iter : send_varname_to_queue_) { - auto &var_name = iter.first; - auto &var_queue = iter.second; - - auto send_task = [this, &var_name, &var_queue] { - VLOG(3) << var_name << " merge and send; "; - std::vector> vars; - - int merged_var_num = 0; - int wait_times = 0; - while (merged_var_num < max_merge_var_num_) { - if (var_queue->Size() == 0) { - VLOG(4) << "wait_times -> " << wait_times; - if (wait_times >= send_wait_times_) { - break; - } - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - wait_times++; - continue; - } else { - wait_times = 0; - - vars.push_back(var_queue->Pop()); - merged_var_num++; - } - } - auto before_merge = GetCurrentUS(); - if (var_name == STEP_COUNTER) { - SendGlobalStep(merged_var_num); - auto after_merge = GetCurrentUS(); - VLOG(3) << "merge and send " << merged_var_num << " " << var_name - << " use time " << after_merge - before_merge; - return; - } - - auto &ctx = send_varname_to_ctx_.at(var_name); - - MergeVars(var_name, vars, send_scope_.get(), ctx.merge_add); - auto after_merge = GetCurrentUS(); - VLOG(3) << "merge " << merged_var_num << " " << var_name << " use time " - << after_merge - before_merge; - - auto send_functor = distributed::ParameterSend(); - send_functor(ctx, *send_scope_, true, 1); - auto after_send = GetCurrentUS(); - VLOG(3) << "send " << var_name << " use time " - << after_send - after_merge; - - if (var_name.rfind("@GRAD") != var_name.size() - 5) return; - - auto recv_param = var_name.substr(0, var_name.size() - 5); - if (recv_varname_to_ctx_.find(recv_param) == recv_varname_to_ctx_.end()) - return; - - auto recv_functor = distributed::ParameterRecv(); - recv_functor(recv_varname_to_ctx_.at(recv_param), *recv_scope_); - auto after_recv = GetCurrentUS(); - VLOG(3) << "recv " << recv_param << " use time " - << after_recv - after_send; - }; - task_futures.emplace_back(send_threadpool_->enqueue(std::move(send_task))); - } - for (auto &task_f : task_futures) { - task_f.wait(); - } - auto after_run_send_graph = GetCurrentUS(); - - VLOG(3) << "run send graph use time " - << (after_run_send_graph - before_run_send_graph); -} - -void HalfAsyncCommunicator::SendByCommunicator() { - std::vector> task_futures; - task_futures.reserve(send_varname_to_ctx_.size()); - VLOG(3) << "run send graph"; - - int batches = BatchesCounter(); - if (batches <= 0) return; - - auto before_run_send_graph = GetCurrentUS(); - for (auto &iter : send_varname_to_queue_) { - auto &var_name = iter.first; - auto &var_queue = iter.second; - - auto send_task = [this, batches, &var_name, &var_queue] { - VLOG(3) << var_name << " merge and send; "; - auto before_task = GetCurrentUS(); - std::vector> vars; - vars.reserve(batches); - - for (int i = 0; i < batches; ++i) { - vars.push_back(var_queue->Pop()); - } - - if (var_name == STEP_COUNTER) { - SendGlobalStep(batches); - auto end_task = GetCurrentUS(); - VLOG(3) << "merge " << batches << " " << var_name << " use time " - << end_task - before_task; - return; - } - - auto &ctx = send_varname_to_ctx_.at(var_name); - - auto before_merge = GetCurrentUS(); - MergeVars(var_name, vars, send_scope_.get(), ctx.merge_add); - auto after_merge = GetCurrentUS(); - VLOG(3) << "merge " << batches << " " << var_name << " use time " - << after_merge - before_merge; - - auto send_functor = distributed::ParameterSend(); - send_functor(ctx, *send_scope_, true, 1); - auto after_send = GetCurrentUS(); - VLOG(3) << "send " << var_name << " use time " - << after_send - before_task; - - if (var_name.rfind("@GRAD") != var_name.size() - 5) return; - - auto recv_param = var_name.substr(0, var_name.size() - 5); - if (recv_varname_to_ctx_.find(recv_param) == recv_varname_to_ctx_.end()) - return; - - auto recv_functor = distributed::ParameterRecv(); - recv_functor(recv_varname_to_ctx_.at(recv_param), *recv_scope_); - auto after_recv = GetCurrentUS(); - VLOG(3) << "recv " << recv_param << " use time " - << after_recv - after_send; - return; - }; - task_futures.emplace_back(send_threadpool_->enqueue(std::move(send_task))); - } - for (auto &task_f : task_futures) { - task_f.wait(); - } - auto after_run_send_graph = GetCurrentUS(); - - VLOG(3) << "run send graph use time " - << (after_run_send_graph - before_run_send_graph); -} - -void AsyncCommunicator::MainThread() { - VLOG(3) << "MainThread start and wait"; - - while (waiting_ && running_) { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - VLOG(3) << "wait for running"; - } - - while (running_) { - SendByCommunicator(); - BarrierSend(); - } - VLOG(3) << "communicator stopped, send thread exit"; -} - -void HalfAsyncCommunicator::MainThread() { - VLOG(3) << "MainThread start and wait"; - - while (waiting_ && running_) { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - VLOG(3) << "wait for running"; - } - - while (running_) { - SendByCommunicator(); - BarrierSend(); - RecvByCommunicator(); - BarrierRecv(); - BarrierWeakUp(); - } - VLOG(3) << "communicator stopped, send thread exit"; -} - -void AsyncCommunicator::RecvByCommunicator() { - VLOG(3) << "parallel run recv graph"; - if (!running_) return; - RecvNoBarrier(); - VLOG(3) << "run recv graph use time"; -} - -void AsyncCommunicator::RecvNoBarrier() { - std::vector> task_futures; - task_futures.reserve(recv_varname_to_ctx_.size()); - - for (auto &iter : recv_varname_to_ctx_) { - auto recv_task = [this, &iter] { - auto before_task = GetCurrentUS(); - auto &var_name = iter.first; - auto recv_functor = distributed::ParameterRecv(); - recv_functor(iter.second, *recv_scope_); - auto end_task = GetCurrentUS(); - VLOG(1) << "recv var " << var_name << " use time " - << (end_task - before_task); - }; - task_futures.emplace_back(recv_threadpool_->enqueue(std::move(recv_task))); - } - - for (auto &task : task_futures) { - task.wait(); - } -} - -void AsyncCommunicator::Start() { - VLOG(3) << "Communicator start"; - if (!communicator_) { - VLOG(0) << "Communicator is not inited, do nothing"; - } else { - VLOG(3) << "start send thread and recv thread"; - waiting_ = true; - running_ = true; - BarrierTriggerReset(max_merge_var_num_); - // start send and recv thread - main_thread_.reset( - new std::thread(std::bind(&AsyncCommunicator::MainThread, this))); - } -} - -void AsyncCommunicator::Stop() { - VLOG(3) << "Communicator stop"; - running_ = false; - if (!communicator_) { - VLOG(0) << "Communicator is not inited, do nothing"; - } else { - if (main_thread_) { - VLOG(3) << "stop send thread"; - main_thread_->join(); - main_thread_.reset(nullptr); - } - } - VLOG(3) << "Communicator stop done"; -} - -void AsyncCommunicator::Send(const std::vector &var_names, - const std::vector &var_tables, - const framework::Scope &scope) { - waiting_ = false; - - PADDLE_ENFORCE_EQ( - var_tables.size(), 1, - platform::errors::InvalidArgument("var_tables.size() == 1 is permitted")); - - auto table_name = var_tables[0]; - - if (table_name == STEP_COUNTER && !need_global_step_) return; - - auto before_send_op = GetCurrentUS(); - auto &queue = send_varname_to_queue_.at(table_name); - - if (table_name == STEP_COUNTER) { - auto tmp_var = std::make_shared(); - auto *tensor = tmp_var->GetMutable(); - tensor->Resize(framework::make_ddim({1})); - auto *out_d = tensor->mutable_data(platform::CPUPlace()); - out_d[0] = 1; - queue->Push(tmp_var); - } else { - PADDLE_ENFORCE_GE(var_names.size(), 1, - platform::errors::InvalidArgument( - "var_names.size() >= 1 is permitted")); - - auto *var = scope.FindVar(var_names[0]); - - PADDLE_ENFORCE_EQ( - var->IsInitialized(), true, - platform::errors::InvalidArgument("grad var should be inited")); - - auto tmp_var = std::make_shared(); - if (var->IsType()) { - framework::CopyVariable(*var, tmp_var.get()); - queue->Push(tmp_var); - } else if (var->IsType()) { - // push var into send queue by var_name - auto var_name = var_names[0]; - framework::CopyVariable(*var, tmp_var.get()); - queue->Push(tmp_var); - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "unknown var type to copy, only support LoDTensor/SelectedRows")); - } - } - auto after_send_op = GetCurrentUS(); - VLOG(3) << "send to " << table_name << " with queue size " << queue->Size() - << ", use time " << (after_send_op - before_send_op); -} - -void HalfAsyncCommunicator::Clean() { - for (auto &iter : send_varname_to_queue_) { - auto &var_name = iter.first; - auto &var_queue = iter.second; - - while (var_queue->Size() > 0) { - var_queue->Pop(); - } - - VLOG(3) << "clean var: " << var_name << " done"; - } -} - -int HalfAsyncCommunicator::BatchesCounter() { - while (running_) { - if (barrier_counter_.load() >= barrier_trigger_.load() && - barrier_trigger_.load() != 0) { - break; - } else { - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - } - } - - return barrier_counter_.load(); -} - -void HalfAsyncCommunicator::Barrier() { - barrier_counter_++; - - if (!running_) { - VLOG(3) << "Communicator is not running, release barrier"; - return; - } - - { - std::unique_lock lk(barrier_mutex_); - barrier_cond_.wait(lk, [this] { return (barrier_counter_ == 0); }); - } -} - -void HalfAsyncCommunicator::BarrierTriggerDecrement() { - barrier_trigger_--; - VLOG(3) << "BarrierTriggerDecrement decrement barrier trigger to " - << barrier_trigger_.load(); -} - -void HalfAsyncCommunicator::BarrierTriggerReset(int initial_val) { - barrier_trigger_.store(initial_val); - - VLOG(3) << "BarrierTriggerReset reset barrier trigger to " - << barrier_trigger_.load(); -} - -void HalfAsyncCommunicator::BarrierWeakUp() { - barrier_counter_.store(0); - barrier_cond_.notify_all(); -} - -void SyncCommunicator::BarrierSend() { - if (!running_) return; - - distributed::RPCClient *rpc_client = - distributed::RPCClient::GetInstance(trainer_id_); - - std::vector rets; - - for (auto &ep : pserver_endpoints_) { - rets.push_back(rpc_client->AsyncSendBatchBarrier(ep)); - } - - for (size_t i = 0; i < rets.size(); i++) { - PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::External( - "internal error in RPCClient")); - } - - VLOG(4) << "BarrierSend with SyncCommunicator"; -} - -void SyncCommunicator::BarrierRecv() { - if (!running_) return; - - distributed::RPCClient *rpc_client = - distributed::RPCClient::GetInstance(trainer_id_); - - std::vector rets; - for (auto &ep : pserver_endpoints_) { - rets.push_back(rpc_client->AsyncSendFetchBarrier(ep)); - } - - for (size_t i = 0; i < rets.size(); i++) { - PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::External( - "internal error in RPCClient")); - } - - VLOG(4) << "BarrierRecv with SyncCommunicator"; -} - -void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx, - const RpcCtxMap &recv_varname_to_ctx, - Scope *recv_scope) { - send_varname_to_ctx_ = std::move(send_varname_to_ctx); - recv_varname_to_ctx_ = std::move(recv_varname_to_ctx); - recv_scope_ = std::move(recv_scope); - - PADDLE_ENFORCE_GT( - send_varname_to_ctx.size(), 0, - platform::errors::InvalidArgument("send var contexts can not be zero")); - - send_scope_.reset(new Scope()); - for (auto &iter : send_varname_to_ctx_) { - auto &varname = iter.first; - - if (varname == STEP_COUNTER) { - send_varname_to_queue_[varname] = - std::make_shared>>( - send_queue_size_); - } else { - auto &send_ctx = iter.second; - - send_var_nums_ += send_ctx.splited_varnames.size(); - if (!send_ctx.is_sparse) { - continue; - } - int pserver_num = static_cast(send_ctx.epmap.size()); - for (int ep_idx = 0; ep_idx < pserver_num; ep_idx++) { - sparse_id_queues_.insert( - std::pair>>>>( - send_ctx.splited_varnames[ep_idx], - std::make_shared< - BlockingQueue>>>( - send_queue_size_))); - } - } - } - send_threadpool_.reset(new ::ThreadPool(thread_pool_size_)); - - if (recv_varname_to_ctx.size() == 0) { - VLOG(0) << "nothing need to be received, will not start recv_thread"; - } else { - recv_threadpool_.reset(new ::ThreadPool(thread_pool_size_)); - } - - delta_scope_.reset(new Scope()); - old_scope_.reset(new Scope()); - pserver_scope_.reset(new Scope()); - - InitParams(); -} - -void GeoCommunicator::Send(const std::vector &var_names, - const std::vector &var_tables, - const framework::Scope &scope) { - waiting_ = false; - PADDLE_ENFORCE_EQ( - var_tables.size(), 1, - platform::errors::InvalidArgument("var_tables.size() == 1 is permitted")); - - auto table_name = var_tables[0]; - if (table_name == STEP_COUNTER) return; - - auto before_send = GetCurrentUS(); - size_t splited_var_nums = - send_varname_to_ctx_[table_name].splited_varnames.size(); - - std::unordered_map> ids_table; - - for (size_t j = 0; j < splited_var_nums; j++) { - ids_table.insert(std::pair>( - send_varname_to_ctx_[table_name].splited_varnames[j], - std::unordered_set())); - } - auto *var = scope.FindVar(var_names[0]); - auto &rows = var->Get().rows(); - - // insert ids which has not been record - for (size_t j = 0; j < rows.size(); j++) { - auto ep_idx = rows[j] % splited_var_nums; - ids_table.at(send_varname_to_ctx_[table_name].splited_varnames[ep_idx]) - .insert(rows[j]); - } - - auto before_push = GetCurrentUS(); - for (auto &iter : ids_table) { - auto &key = iter.first; - auto &sparse_ids_set = iter.second; - auto sparse_ids_vec = std::make_shared>(); - sparse_ids_vec->assign(sparse_ids_set.begin(), sparse_ids_set.end()); - sparse_id_queues_.at(key)->Push(sparse_ids_vec); - VLOG(3) << "push " << sparse_ids_vec->size() << " ids to " << key - << "'s queue"; - } - auto after_send = GetCurrentUS(); - VLOG(3) << "run send " << table_name << " op finish. using " - << (before_push - before_send) << "; " << (after_send - before_push); -} - -void GeoCommunicator::MainThread() { - VLOG(3) << "MainThread start and wait"; - - while (waiting_ && running_) { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - VLOG(3) << "wait for running"; - } - - while (running_) { - std::vector> tasks; - tasks.reserve(send_var_nums_); - - for (auto &iter : send_varname_to_ctx_) { - auto &var_name = iter.first; - auto &send_ctx = iter.second; - int pserver_num = static_cast(send_ctx.epmap.size()); - if (send_ctx.is_sparse) { - for (int ep_idx = 0; ep_idx < pserver_num; ep_idx++) { - auto send_recv_task = [this, ep_idx, &var_name] { - auto before_send_sparse = GetCurrentUS(); - if (var_name == STEP_COUNTER) { - return; - } - auto send_varname = - send_varname_to_ctx_.at(var_name).splited_varnames[ep_idx]; - auto sparse_ids = MergeSparseIds(send_varname); - if (sparse_ids.size() == 0) { - return; - } - SendSparse(var_name, ep_idx, sparse_ids); - auto after_send_sparse = GetCurrentUS(); - RecvSparse(var_name, ep_idx); - auto after_recv_sparse = GetCurrentUS(); - VLOG(3) - << "send recv " - << send_varname_to_ctx_.at(var_name).splited_varnames[ep_idx] - << " finish, using " << (after_send_sparse - before_send_sparse) - << " and " << (after_recv_sparse - after_send_sparse) - << "; total = " << (after_recv_sparse - before_send_sparse); - }; - tasks.emplace_back( - send_threadpool_->enqueue(std::move(send_recv_task))); - } - } else { - auto send_recv_task = [this, &var_name, &send_ctx] { - if (var_name == STEP_COUNTER) { - return; - } - SendDense(var_name); - RecvDense(var_name); - }; - tasks.emplace_back( - send_threadpool_->enqueue(std::move(send_recv_task))); - } - } - for (auto &task : tasks) { - task.wait(); - } - } -} - -std::vector GeoCommunicator::MergeSparseIds( - const std::string &send_varname) { - size_t merge_num = 0, wait_times = 0; - std::unordered_set sparse_ids; - while (merge_num < static_cast(max_merge_var_num_)) { - VLOG(3) << "Merge Number of " << send_varname << " = " << merge_num; - if (sparse_id_queues_.at(send_varname)->Size() > 0) { - wait_times = 0; - std::shared_ptr> pop_ids = - sparse_id_queues_.at(send_varname)->Pop(); - for (size_t j = 0; j < pop_ids->size(); j++) { - sparse_ids.insert(pop_ids->at(j)); - } - merge_num += 1; - VLOG(3) << "sparse_id_queues_(" << send_varname << ") pushed"; - } else if (sparse_id_queues_.at(send_varname)->Size() == 0) { - VLOG(3) << "wait_times -> " << wait_times; - if (wait_times >= static_cast(send_wait_times_)) { - break; - } - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - wait_times++; - continue; - } - } - std::vector res; - res.assign(sparse_ids.begin(), sparse_ids.end()); - return res; -} -void GeoCommunicator::SendSparse(const std::string &varname, int ep_idx, - const std::vector &sparse_ids) { - auto &rpc_ctx = send_varname_to_ctx_.at(varname); - auto send_varname = rpc_ctx.splited_varnames[ep_idx]; - auto trainer_id = rpc_ctx.trainer_id; - auto endpoint = rpc_ctx.epmap[ep_idx]; - auto pserver_num = rpc_ctx.epmap.size(); - - auto *var_latest = recv_scope_->FindVar(varname); - - PADDLE_ENFORCE_EQ(var_latest->IsInitialized(), true, - platform::errors::Unavailable( - "%s is not initialized, please check", varname)); - auto &t_latest = var_latest->Get(); - - auto dims1 = t_latest.dims()[1]; - - auto cpu_ctx = paddle::platform::CPUDeviceContext(); - auto *var_delta = delta_scope_->Var(send_varname); - auto *t_delta = var_delta->GetMutable(); - - auto *t_value = t_delta->mutable_value(); - t_value->mutable_data( - framework::make_ddim({static_cast(sparse_ids.size()), dims1}), - cpu_ctx.GetPlace()); - - std::vector *>> values; - auto *ins = distributed::LargeScaleKV::GetInstance(); - ins->Get(varname)->Get(sparse_ids, {"Param"}, &values); - - auto blas = math::GetBlas(cpu_ctx); - float coefficient = 1.0 / static_cast(trainers_); - - for (auto j = 0; j < static_cast(sparse_ids.size()); ++j) { - blas.VSUB(dims1, t_latest.data() + sparse_ids[j] * dims1, - values[j][0]->data(), t_value->data() + j * dims1); - blas.SCAL(dims1, coefficient, t_value->data() + j * dims1); - blas.VADD(dims1, values[j][0]->data(), t_value->data() + j * dims1, - values[j][0]->data()); - } - - std::vector send_rows; - send_rows.reserve(sparse_ids.size()); - for (auto idx : sparse_ids) { - send_rows.push_back(idx / pserver_num); - } - t_delta->set_height(rpc_ctx.height_sections[ep_idx]); - t_delta->set_rows(send_rows); - - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto &cpu_ctx_send = *pool.Get(platform::CPUPlace()); - distributed::RPCClient *rpc_client = - distributed::RPCClient::GetInstance(trainer_id); - - auto ret = rpc_client->AsyncSendVar(endpoint, cpu_ctx_send, - *delta_scope_.get(), send_varname); - ret->Wait(); -} - -void GeoCommunicator::SendDense(const std::string &varname) { - auto *var_latest = recv_scope_->FindVar(varname); - auto *var_timestamp = old_scope_->FindVar(varname); - - PADDLE_ENFORCE_EQ(var_latest->IsInitialized(), true, - platform::errors::Unavailable( - "%s is not initialized, please check", varname)); - PADDLE_ENFORCE_EQ(var_timestamp->IsInitialized(), true, - platform::errors::Unavailable( - "%s is not initialized, please check", varname)); - - auto &t_latest = var_latest->Get(); - auto t_timestamp = var_timestamp->GetMutable(); - - auto cpu_ctx = paddle::platform::CPUDeviceContext(); - auto *var_delta = delta_scope_->Var(varname); - auto *t_delta = var_delta->GetMutable(); - t_delta->mutable_data(t_latest.dims(), cpu_ctx.GetPlace()); - - auto blas = math::GetBlas(cpu_ctx); - blas.VSUB(t_latest.numel(), t_latest.data(), - t_timestamp->data(), t_delta->data()); - - float coefficient = 1.0 / static_cast(trainers_); - blas.SCAL(t_latest.numel(), coefficient, t_delta->data()); - - blas.VADD(t_latest.numel(), t_timestamp->data(), - t_delta->data(), t_timestamp->data()); - - auto &ctx = send_varname_to_ctx_.at(varname); - auto send = distributed::ParameterSend(); - send(ctx, *delta_scope_, true, 1); -} - -void GeoCommunicator::RecvByCommunicator() { return; } - -void GeoCommunicator::RecvSparse(const std::string &varname, int ep_idx) { - auto train_id = recv_varname_to_ctx_.at(varname).trainer_id; - auto endpoint = recv_varname_to_ctx_.at(varname).epmap[ep_idx]; - auto splited_var_name = - recv_varname_to_ctx_.at(varname).splited_varnames[ep_idx]; - auto pserver_num = recv_varname_to_ctx_.at(varname).epmap.size(); - - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto &cpu_ctx_recv = *pool.Get(platform::CPUPlace()); - distributed::RPCClient *rpc_client = - distributed::RPCClient::GetInstance(train_id); - - auto *var_psrever = pserver_scope_->Var(splited_var_name); - auto handle = rpc_client->AsyncGetVar(endpoint, cpu_ctx_recv, - *pserver_scope_.get(), splited_var_name, - splited_var_name, splited_var_name); - handle->Wait(); - - auto *var_latest = recv_scope_->FindVar(varname); - - PADDLE_ENFORCE_EQ( - var_psrever->IsInitialized(), true, - platform::errors::Unavailable( - "%s in pserver scope is not initialized, please check", varname)); - - std::vector ids; - ids.assign(var_psrever->Get().rows().begin(), - var_psrever->Get().rows().end()); - - for (size_t j = 0; j < ids.size(); j++) { - ids[j] = ids[j] * pserver_num + ep_idx; - } - - VLOG(3) << "RecvSparse receive var: " << splited_var_name - << " ids Size: " << ids.size(); - - auto t_psrever = var_psrever->Get().value(); - - std::vector *>> old_values; - - auto *ins = distributed::LargeScaleKV::GetInstance(); - ins->Get(varname)->Get(ids, {"Param"}, &old_values); - - auto *t_latest = var_latest->GetMutable(); - - auto dims1 = t_latest->dims()[1]; - auto numel = ids.size() * dims1; - - std::vector v_delta; - v_delta.resize(numel); - - auto cpu_ctx = paddle::platform::CPUDeviceContext(); - auto blas = math::GetBlas(cpu_ctx); - - for (auto j = 0; j < static_cast(ids.size()); ++j) { - blas.VSUB(dims1, t_psrever.data() + j * dims1, - old_values[j][0]->data(), v_delta.data() + j * dims1); - blas.VADD(dims1, t_latest->data() + ids[j] * dims1, - v_delta.data() + j * dims1, - t_latest->data() + ids[j] * dims1); - blas.VCOPY(dims1, t_psrever.data() + j * dims1, - old_values[j][0]->data()); - } -} - -void GeoCommunicator::RecvDense(const std::string &varname) { - auto *var_latest = recv_scope_->FindVar(varname); - auto *var_timestamp = old_scope_->FindVar(varname); - auto *var_psrever = pserver_scope_->Var(varname); - - auto &ctx = recv_varname_to_ctx_.at(varname); - auto recv = distributed::ParameterRecv(); - recv(ctx, *pserver_scope_); - - PADDLE_ENFORCE_EQ( - var_psrever->IsInitialized(), true, - platform::errors::Unavailable( - "%s in pserver scope is not initialized, please check", varname)); - - auto t_psrever = var_psrever->Get(); - auto t_latest = var_latest->GetMutable(); - auto t_timestamp = var_timestamp->GetMutable(); - - auto cpu_ctx = paddle::platform::CPUDeviceContext(); - auto *var_delta = delta_scope_->Var(varname); - auto *t_delta = var_delta->GetMutable(); - t_delta->mutable_data(t_latest->dims(), cpu_ctx.GetPlace()); - - auto blas = math::GetBlas(cpu_ctx); - blas.VSUB(t_latest->numel(), t_psrever.data(), - t_timestamp->data(), t_delta->data()); - blas.VADD(t_latest->numel(), t_latest->data(), t_delta->data(), - t_latest->data()); - blas.VCOPY(t_latest->numel(), t_psrever.data(), - t_timestamp->data()); -} - -void GeoCommunicator::InitParams() { - std::vector> tasks; - tasks.reserve(recv_varname_to_ctx_.size()); - - for (auto &iter : recv_varname_to_ctx_) { - auto &var_name = iter.first; - auto &recv_ctx = iter.second; - - auto recv_task = [this, &var_name, &recv_ctx] { - if (!recv_ctx.is_sparse) { - InitDense(var_name); - } - }; - tasks.emplace_back(send_threadpool_->enqueue(std::move(recv_task))); - } - - for (auto &task : tasks) { - task.wait(); - } - InitSparse(); -} - -void GeoCommunicator::InitDense(const std::string varname) { - auto &ctx = recv_varname_to_ctx_.at(varname); - auto recv = distributed::ParameterRecv(); - recv(ctx, *recv_scope_); - - auto *global_var = recv_scope_->FindVar(varname); - global_var->GetMutable(); - - auto *old_var = old_scope_->Var(varname); - old_var->GetMutable(); - - framework::CopyVariable(*global_var, old_var); - VLOG(1) << "init dense variable " << varname << " done"; -} - -void GeoCommunicator::InitSparse() { - auto sparse_metas = string::split_string(sparse_attrs_, "#"); - - std::vector metas; - std::vector dicts; - - for (auto &sparse_meta : sparse_metas) { - auto attrs = string::split_string(sparse_meta, ":"); - - auto meta = distributed::SparseMeta(); - meta.name = attrs[0]; - meta.value_names = {"Param"}; - - auto dic = string::split_string(attrs[1], ","); - dicts.push_back(std::stoi(dic[0])); - meta.value_dims = {std::stoi(dic[1])}; - meta.mode = distributed::Mode::training; - meta.grad_name = "none"; - meta.cached_varnames = {}; - meta.initializer_attrs = string::split_string(attrs[2]); - meta.entry = "none"; - - VLOG(3) << "add sparse meta: " << meta.ToString(); - metas.push_back(meta); - } - - LargeScaleKV::Init(metas); - - for (auto &meta : metas) { - auto &ctx = recv_varname_to_ctx_.at(meta.name); - auto recv = distributed::ParameterRecv(); - - auto *global_var = recv_scope_->FindVar(meta.name); - auto global_value = global_var->Get(); - auto rows = global_value.dims()[0]; - auto dim1 = global_value.dims()[1]; - - recv(ctx, *recv_scope_); - VLOG(1) << "recv " << meta.name << " with global scope for init"; - - auto n_rows = global_var->Get().dims()[0]; - - PADDLE_ENFORCE_EQ( - rows, n_rows, - platform::errors::InvalidArgument( - "global var: %s origin dim must equal recved rows", meta.name)); - - std::vector ids(rows); - std::iota(ids.begin(), ids.end(), 0); - - auto *ins = distributed::LargeScaleKV::GetInstance(); - std::vector *>> values; - - ins->Get(meta.name)->Init(ids); - ins->Get(meta.name)->Get(ids, {"Param"}, &values); - - auto blas = math::GetBlas( - paddle::platform::CPUDeviceContext()); - - for (auto &id : ids) { - blas.VCOPY(dim1, global_value.data() + id * dim1, - values[id][0]->data()); - } - } - - VLOG(3) << "init sparse variable done"; -} - -} // namespace distributed -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/distributed/communicator.h b/paddle/fluid/operators/distributed/communicator.h deleted file mode 100644 index 4be3253d39..0000000000 --- a/paddle/fluid/operators/distributed/communicator.h +++ /dev/null @@ -1,490 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "gflags/gflags.h" -#include "paddle/fluid/framework/scope.h" -#include "paddle/fluid/framework/variable.h" -#include "paddle/fluid/framework/variable_helper.h" -#include "paddle/fluid/operators/distributed/communicator_common.h" -#include "paddle/fluid/operators/distributed/distributed.h" -#include "paddle/fluid/operators/distributed/large_scale_kv.h" -#include "paddle/fluid/operators/distributed/rpc_client.h" -#include "paddle/fluid/operators/distributed_ops/send_recv_util.h" -#include "paddle/fluid/operators/math/blas.h" -#include "paddle/fluid/operators/math/math_function.h" -#include "paddle/fluid/operators/math/selected_rows_functor.h" -#include "paddle/fluid/platform/device_context.h" -#include "paddle/fluid/platform/enforce.h" -#include "paddle/fluid/platform/place.h" -#include "paddle/fluid/string/split.h" - -DECLARE_bool(communicator_is_sgd_optimizer); - -namespace paddle { -namespace operators { -namespace distributed { - -using Scope = framework::Scope; -using Variable = framework::Variable; - -template -class BlockingQueue { - public: - explicit BlockingQueue(size_t capacity) : capacity_(capacity) { - PADDLE_ENFORCE_GT(capacity_, 0, - platform::errors::InvalidArgument( - "The capacity must be greater than 0.")); - } - - bool Push(const T &elem) { - { - std::unique_lock lock(mutex_); - cv_.wait(lock, [&] { return queue_.size() < capacity_; }); - PADDLE_ENFORCE_LT( - queue_.size(), capacity_, - platform::errors::OutOfRange("The queue size: %s out of capacity:%s", - queue_.size(), capacity_)); - queue_.push_back(elem); - } - cv_.notify_one(); - return true; - } - - bool Push(T &&elem) { - { - std::unique_lock lock(mutex_); - cv_.wait(lock, [&] { return queue_.size() < capacity_; }); - PADDLE_ENFORCE_LT( - queue_.size(), capacity_, - platform::errors::OutOfRange("The queue size: %s out of capacity:%s", - queue_.size(), capacity_)); - queue_.emplace_back(std::move(elem)); - } - cv_.notify_one(); - return true; - } - - T Pop() { - std::unique_lock lock(mutex_); - cv_.wait(lock, [=] { return !queue_.empty(); }); - T rc(std::move(queue_.front())); - queue_.pop_front(); - cv_.notify_one(); - return rc; - } - - size_t Cap() const { - std::lock_guard lock(mutex_); - return capacity_; - } - - size_t Size() const { - std::lock_guard lock(mutex_); - return queue_.size(); - } - - private: - const size_t capacity_; - std::deque queue_; - - mutable std::mutex mutex_; - std::condition_variable cv_; -}; - -template -using EigenVector = framework::EigenVector; - -template -inline void MergeVars(const std::string &var_name, - const std::vector> &vars, - Scope *scope, bool merge_add = true) { - PADDLE_ENFORCE_NE(vars.empty(), true, platform::errors::InvalidArgument( - "vector vars are empty.")); - auto cpu_place = platform::CPUPlace(); - auto &var0 = vars[0]; - auto *out_var = scope->Var(var_name); - if (var0->IsType()) { - auto dims = var0->Get().dims(); - VLOG(3) << "merge " << var_name << " LoDTensor dims " << dims - << "; merge add: " << merge_add; - // init output tensor - auto *out_t = out_var->GetMutable(); - out_t->mutable_data(dims, cpu_place); - // check the input dims - for (auto &var : vars) { - auto &var_t = var->Get(); - PADDLE_ENFORCE_EQ( - var_t.dims(), dims, - platform::errors::InvalidArgument("vars should have the same dims")); - } - - // set output tensor to 0. - auto cpu_ctx = paddle::platform::CPUDeviceContext(); - math::SetConstant constant_functor; - constant_functor(cpu_ctx, out_t, static_cast(0)); - // sum all vars to out - auto result = EigenVector::Flatten(*out_t); - for (auto &var : vars) { - auto &in_t = var->Get(); - auto in = EigenVector::Flatten(in_t); - result.device(*cpu_ctx.eigen_device()) = result + in; - } - if (!merge_add) { - result.device(*cpu_ctx.eigen_device()) = - result / static_cast(vars.size()); - } - } else if (var0->IsType()) { - auto &slr0 = var0->Get(); - auto *out_slr = out_var->GetMutable(); - out_slr->mutable_rows()->clear(); - out_slr->mutable_value()->mutable_data({{}}, cpu_place); - std::vector inputs; - inputs.reserve(vars.size()); - for (auto &var : vars) { - inputs.push_back(&var->Get()); - } - auto dev_ctx = paddle::platform::CPUDeviceContext(); - if (merge_add) { - math::scatter::MergeAdd merge_add; - merge_add(dev_ctx, inputs, out_slr); - } else { - math::scatter::MergeAverage - merge_average; - merge_average(dev_ctx, inputs, out_slr); - } - - VLOG(3) << "merge " << var_name << " SelectedRows height: " << slr0.height() - << " dims: " << slr0.value().dims() << "; merge add: " << merge_add; - } else { - PADDLE_THROW(platform::errors::InvalidArgument("unsupported var type: %s!", - var0->Type())); - } -} - -using RpcCtxMap = std::unordered_map; -using SparseValue = std::unordered_map>; - -class Communicator { - public: - Communicator(); - - explicit Communicator(const std::map &envs_) { - for (auto &iter : envs_) { - envs[iter.first] = iter.second; - } - } - - virtual ~Communicator() {} - - virtual void Start() = 0; - - virtual void Stop() = 0; - - virtual bool IsRunning() { return running_; } - - virtual void Clean() {} - - virtual void Send(const std::vector &var_names, - const std::vector &var_tables, - const framework::Scope &scope) = 0; - - virtual void RecvNoBarrier() {} - - virtual void Barrier() {} - - virtual void BarrierTriggerDecrement() {} - - virtual void BarrierTriggerReset(int init_counter) {} - - virtual void InitEnvs() = 0; - - virtual void InitImpl(const RpcCtxMap &send_varname_to_ctx, - const RpcCtxMap &recv_varname_to_ctx, - Scope *recv_scope) {} - - static Communicator *GetInstance() { return communicator_.get(); } - - static std::shared_ptr GetInstantcePtr() { - return communicator_; - } - - template - static Communicator *InitInstance( - const RpcCtxMap &send_ctx, const RpcCtxMap &recv_ctx, Scope *recv_scope, - const std::map &envs) { - std::call_once(init_flag_, &Communicator::InitWithRpcCtx, send_ctx, - recv_ctx, recv_scope, std::ref(envs)); - return communicator_.get(); - } - - // Init is called by InitInstance. - template - static void InitWithRpcCtx(const RpcCtxMap &send_ctx, - const RpcCtxMap &recv_ctx, Scope *recv_scope, - const std::map &envs) { - if (communicator_.get() == nullptr) { - communicator_.reset(new T(std::ref(envs))); - communicator_->InitEnvs(); - communicator_->InitImpl(send_ctx, recv_ctx, recv_scope); - } - } - - protected: - bool running_ = false; - bool waiting_ = true; - static std::shared_ptr communicator_; - static std::once_flag init_flag_; - std::unordered_map envs; -}; - -class AsyncCommunicator : public Communicator { - public: - AsyncCommunicator() : Communicator() {} - - explicit AsyncCommunicator(const std::map &envs) - : Communicator(envs) {} - - ~AsyncCommunicator(); - - void InitEnvs() { - min_send_grad_num_before_recv_ = - std::stoi(envs.at("communicator_min_send_grad_num_before_recv")); - thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size")); - max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num")); - send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times")); - send_queue_size_ = std::stoi(envs.at("communicator_send_queue_size")); - need_global_step_ = - static_cast(std::stoi(envs.at("need_global_step"))); - VLOG(0) << "AsyncCommunicator Initialized"; - } - - void Start() override; - - void Stop() override; - - void InitImpl(const RpcCtxMap &send_varname_to_ctx, - const RpcCtxMap &recv_varname_to_ctx, - Scope *recv_scope) override; - - void InitParams(); - - virtual void MainThread(); - - void Send(const std::vector &var_names, - const std::vector &var_tables, - const framework::Scope &scope) override; - - virtual void SendByCommunicator(); - virtual void SendGlobalStep(int batches); - - virtual void RecvByCommunicator(); - - virtual void RecvNoBarrier(); - - virtual void BarrierSend() {} - - virtual void BarrierRecv() {} - - virtual void BarrierWeakUp() {} - - protected: - int min_send_grad_num_before_recv_; - int thread_pool_size_; - int max_merge_var_num_; - int send_wait_times_; - int send_queue_size_; - int trainer_id_ = 0; - bool need_global_step_ = false; - - std::unordered_map>>> - send_varname_to_queue_; - RpcCtxMap send_varname_to_ctx_; - RpcCtxMap recv_varname_to_ctx_; - std::unique_ptr main_thread_{nullptr}; - Scope *recv_scope_; // should be global scope - std::unique_ptr send_scope_; // an independent scope - std::unique_ptr<::ThreadPool> send_threadpool_{nullptr}; - std::unique_ptr<::ThreadPool> recv_threadpool_{nullptr}; - std::atomic_uint grad_num_{0}; // the num of gradient sent since last recv -}; - -class HalfAsyncCommunicator : public AsyncCommunicator { - public: - HalfAsyncCommunicator() {} - - explicit HalfAsyncCommunicator(const std::map &envs) - : AsyncCommunicator(envs) {} - - void InitEnvs() { - min_send_grad_num_before_recv_ = 0; - - max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num")); - send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times")); - thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size")); - send_queue_size_ = std::stoi(envs.at("communicator_send_queue_size")); - need_global_step_ = - static_cast(std::stoi(envs.at("need_global_step"))); - VLOG(0) << "HalfAsyncCommunicator Initialized"; - } - - void MainThread() override; - - void SendByCommunicator() override; - - void Clean() override; - - void Barrier() override; - - void BarrierTriggerDecrement() override; - - void BarrierTriggerReset(int initial_val) override; - - int BatchesCounter(); - - void BarrierWeakUp(); - - protected: - // mutex for Wait for barrier - std::mutex barrier_mutex_; - std::condition_variable barrier_cond_; - std::atomic barrier_trigger_{0}; - std::atomic barrier_counter_{0}; -}; - -class SyncCommunicator : public HalfAsyncCommunicator { - public: - SyncCommunicator() : HalfAsyncCommunicator() {} - - explicit SyncCommunicator(const std::map &envs) - : HalfAsyncCommunicator(envs) {} - - void InitEnvs() { - min_send_grad_num_before_recv_ = 0; - - max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num")); - send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times")); - thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size")); - send_queue_size_ = std::stoi(envs.at("communicator_send_queue_size")); - need_global_step_ = - static_cast(std::stoi(envs.at("need_global_step"))); - - trainer_id_ = std::stoi(envs.at("trainer_id")); - auto pserver_strings = envs.at("pserver_endpoints"); - pserver_endpoints_ = paddle::string::Split(pserver_strings, ','); - VLOG(0) << "SyncCommunicator Initialized"; - } - - void BarrierSend(); - - void BarrierRecv(); - - private: - std::vector pserver_endpoints_{}; -}; - -class GeoCommunicator : public AsyncCommunicator { - public: - GeoCommunicator() : AsyncCommunicator() {} - - explicit GeoCommunicator(const std::map &envs) - : AsyncCommunicator(envs) {} - - void InitImpl(const RpcCtxMap &send_varname_to_ctx, - const RpcCtxMap &recv_varname_to_ctx, - Scope *recv_scope) override; - void MainThread() override; - void InitEnvs() { - min_send_grad_num_before_recv_ = 0; - - max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num")); - send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times")); - thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size")); - - send_queue_size_ = max_merge_var_num_; - trainers_ = std::stoi(envs.at("trainers")); - sparse_attrs_ = envs.at("sparse_attrs"); - VLOG(0) << "GeoCommunicator Initialized"; - } - - void Send(const std::vector &var_names, - const std::vector &var_tables, - const framework::Scope &scope) override; - - void SendByCommunicator() { return; } - - std::vector MergeSparseIds(const std::string &send_varname); - - void SendSparse(const std::string &varname, int ep_idx, - const std::vector &sparse_ids); - - void SendDense(const std::string &varname); - - void SendGlobalStep(int batches) override {} - - void RecvByCommunicator() override; - - void RecvSparse(const std::string &varname, int ep_idx); - - void RecvDense(const std::string &varname); - - void InitParams(); - - void InitSparse(); - - void InitDense(const std::string varname); - - private: - int trainers_; - std::string sparse_attrs_; - - // parameter for delta calc and send - std::shared_ptr delta_scope_; - - // parameter for storage the pserver param after last recv - std::shared_ptr old_scope_; - - // parameter on pserver - std::shared_ptr pserver_scope_; - - int send_var_nums_ = 0; - - std::unordered_map> old_sparses_; - - std::unordered_map< - std::string, - std::shared_ptr>>>> - sparse_id_queues_; -}; - -} // namespace distributed -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/distributed/communicator_common.h b/paddle/fluid/operators/distributed/communicator_common.h deleted file mode 100644 index 122d904eba..0000000000 --- a/paddle/fluid/operators/distributed/communicator_common.h +++ /dev/null @@ -1,91 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include -#include - -namespace paddle { -namespace operators { -namespace distributed { - -struct CommContext { - CommContext() = default; - - CommContext(const std::string &name, const std::vector &names, - const std::vector &emap, - const std::vector §ions, - const std::vector &origin_names, int id, - bool merge_add_ = true, bool is_sparse_ = true, - bool is_distributed_ = false) - : var_name(name), - splited_varnames(names), - epmap(emap), - height_sections(sections), - origin_varnames(origin_names), - trainer_id(id), - merge_add(merge_add_), - is_sparse(is_sparse_), - is_distributed(is_distributed_) {} - - CommContext(const CommContext &ctx) { - var_name = ctx.var_name; - splited_varnames = ctx.splited_varnames; - epmap = ctx.epmap; - height_sections = ctx.height_sections; - trainer_id = ctx.trainer_id; - merge_add = ctx.merge_add; - is_sparse = ctx.is_sparse; - origin_varnames = ctx.origin_varnames; - is_distributed = ctx.is_distributed; - } - - std::string print() const { - std::stringstream ss; - - ss << "varname: " << var_name << " trainer_id: " << trainer_id << " "; - - for (size_t i = 0; i < splited_varnames.size(); i++) { - ss << "slice varname: " << splited_varnames[i] << " ep: " << epmap[i] - << " section: " << height_sections[i] << " "; - } - - ss << "origin varnames: "; - for (size_t i = 0; i < origin_varnames.size(); i++) { - ss << origin_varnames[i] << " "; - } - - ss << " aggregation->add: " << merge_add << " "; - ss << " is_sparse: " << is_sparse << "\n"; - ss << " is_distributed: " << is_distributed << "\n"; - - return ss.str(); - } - - std::string var_name; - std::vector splited_varnames; - std::vector epmap; - std::vector height_sections; - std::vector origin_varnames; - int trainer_id; - bool merge_add; - bool is_sparse; - bool is_distributed; -}; - -} // namespace distributed -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/distributed/communicator_test.cc b/paddle/fluid/operators/distributed/communicator_test.cc deleted file mode 100644 index 38b7c8b003..0000000000 --- a/paddle/fluid/operators/distributed/communicator_test.cc +++ /dev/null @@ -1,106 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -#include "paddle/fluid/operators/distributed/communicator.h" - -namespace paddle { -namespace operators { -namespace distributed { - -using LoDTensor = framework::LoDTensor; -using SelectedRows = framework::SelectedRows; - -TEST(communicator, merge_lod_tensors) { - auto cpu_place = platform::CPUPlace(); - auto dims = framework::make_ddim({2, 3}); - std::vector> in_vars; - float out_value = 0; - for (auto i = 0; i < 10; ++i) { - auto var = std::make_shared(); - in_vars.emplace_back(var); - auto *tensor = var->GetMutable(); - auto *data = tensor->mutable_data(dims, cpu_place); - for (auto j = 0; j < tensor->numel(); ++j) { - data[j] = static_cast(i); - } - out_value += static_cast(i); - } - const std::string out_name = "Out"; - std::unique_ptr scope; - scope.reset(new framework::Scope()); - scope->Var(out_name); - for (auto i = 0; i < 10; ++i) { - MergeVars(out_name, in_vars, scope.get()); - } - auto &out_tensor = scope->FindVar(out_name)->Get(); - auto *out_data = out_tensor.data(); - ASSERT_EQ(out_tensor.dims(), dims); - for (auto i = 0; i < out_tensor.numel(); ++i) { - ASSERT_EQ(out_data[i], out_value); - } -} - -TEST(communicator, merge_selected_rows) { - auto cpu_place = platform::CPUPlace(); - int64_t width = 10; - std::vector> in_vars; - const int64_t height = 100; - for (auto i = 0; i < 10; ++i) { - std::vector rows; - for (auto k = 0; k <= i; ++k) { - rows.push_back(k); - } - auto var = std::make_shared(); - in_vars.emplace_back(var); - auto *slr = var->GetMutable(); - slr->set_height(height); - slr->set_rows(rows); - auto dims = - framework::make_ddim({static_cast(rows.size()), width}); - auto *data = slr->mutable_value()->mutable_data(dims, cpu_place); - for (size_t i = 0; i < rows.size(); ++i) { - for (auto j = 0; j < width; ++j) { - data[i * width + j] = static_cast(rows[i]); - } - } - } - const std::string out_name = "Out"; - std::unique_ptr scope; - scope.reset(new framework::Scope()); - scope->Var(out_name); - for (auto i = 0; i < 10; ++i) { - MergeVars(out_name, in_vars, scope.get()); - } - auto &out_slr = scope->FindVar(out_name)->Get(); - auto &out_t = out_slr.value(); - auto *out_data = out_t.data(); - ASSERT_EQ(out_t.dims(), framework::make_ddim({10, width})); - std::vector out_values; - out_values.reserve(10); - for (auto i = 0; i < 10; ++i) { - out_values.push_back(static_cast(i * (10 - i))); - } - for (size_t i = 0; i < out_slr.rows().size(); ++i) { - ASSERT_EQ(out_slr.rows()[i], static_cast(i)); - for (auto j = 0; j < width; ++j) { - ASSERT_EQ(out_data[i * width + j], out_values[i]); - } - } -} - -} // namespace distributed -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/distributed/distributed.h b/paddle/fluid/operators/distributed/distributed.h deleted file mode 100644 index 5917c18fb0..0000000000 --- a/paddle/fluid/operators/distributed/distributed.h +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#ifdef PADDLE_WITH_DISTRIBUTE - -#ifdef PADDLE_WITH_GRPC -#include "paddle/fluid/operators/distributed/communicator.h" - -#include "paddle/fluid/operators/distributed/grpc/grpc_client.h" -#include "paddle/fluid/operators/distributed/grpc/grpc_server.h" -#define RPCSERVER_T paddle::operators::distributed::AsyncGRPCServer -#define RPCCLIENT_T paddle::operators::distributed::GRPCClient - -#else // PADDLE_WITH_GRPC - -#include "paddle/fluid/operators/distributed/brpc/brpc_client.h" -#include "paddle/fluid/operators/distributed/brpc/brpc_server.h" -#define RPCSERVER_T paddle::operators::distributed::AsyncBRPCServer -#define RPCCLIENT_T paddle::operators::distributed::BRPCClient - -#endif // PADDLE_WITH_GRPC - -#endif // PADDLE_WITH_DISTRIBUTE diff --git a/paddle/fluid/operators/distributed/distributed_pb.h b/paddle/fluid/operators/distributed/distributed_pb.h deleted file mode 100644 index f1c662be9a..0000000000 --- a/paddle/fluid/operators/distributed/distributed_pb.h +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#ifdef PADDLE_WITH_DISTRIBUTE - -#ifdef PADDLE_WITH_GRPC - -#include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h" -#include "paddle/fluid/operators/distributed/send_recv.pb.h" - -#else // PADDLE_WITH_GRPC - -#include "paddle/fluid/operators/distributed/send_recv.pb.h" - -#endif // PADDLE_WITH_GRPC - -#endif // PADDLE_WITH_DISTRIBUTE diff --git a/paddle/fluid/operators/distributed/grpc/grpc_bytebuffer_stream.cc b/paddle/fluid/operators/distributed/grpc/grpc_bytebuffer_stream.cc deleted file mode 100644 index 7d6756b413..0000000000 --- a/paddle/fluid/operators/distributed/grpc/grpc_bytebuffer_stream.cc +++ /dev/null @@ -1,92 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -// NOTE: This file was originally created by tensorflow -// (https://github.com/tensorflow/tensorflow/) we borrow this -// file and did some modifications so that we can send gRPC -// requests without too much copying of the tensor data. - -#include "paddle/fluid/operators/distributed/grpc/grpc_bytebuffer_stream.h" - -namespace grpc { -class ByteBuffer; -} // namespace grpc - -namespace paddle { -namespace operators { -namespace distributed { - -GrpcByteBufferSource::GrpcByteBufferSource() {} - -bool GrpcByteBufferSource::Init(const grpc::ByteBuffer& src) { - cur_ = -1; - left_ = 0; - ptr_ = nullptr; - byte_count_ = 0; - bool ok = src.Dump(&slices_).ok(); - if (!ok) { - slices_.clear(); - } - return ok; -} - -bool GrpcByteBufferSource::Next(const void** data, int* size) { - // Use loop instead of if in case buffer contained empty slices. - while (left_ == 0) { - // Advance to next slice. - cur_++; - if (cur_ >= slices_.size()) { - return false; - } - const ::grpc::Slice& s = slices_[cur_]; - left_ = s.size(); - ptr_ = reinterpret_cast(s.begin()); - } - - *data = ptr_; - *size = left_; - byte_count_ += left_; - ptr_ += left_; - left_ = 0; - return true; -} - -void GrpcByteBufferSource::BackUp(int count) { - ptr_ -= count; - left_ += count; - byte_count_ -= count; -} - -bool GrpcByteBufferSource::Skip(int count) { - const void* data; - int size; - while (Next(&data, &size)) { - if (size >= count) { - BackUp(size - count); - return true; - } - // size < count; - count -= size; - } - // error or we have too large count; - return false; -} - -google::protobuf::int64 GrpcByteBufferSource::ByteCount() const { - return byte_count_; -} - -} // namespace distributed -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/distributed/grpc/grpc_bytebuffer_stream.h b/paddle/fluid/operators/distributed/grpc/grpc_bytebuffer_stream.h deleted file mode 100644 index 486870de7a..0000000000 --- a/paddle/fluid/operators/distributed/grpc/grpc_bytebuffer_stream.h +++ /dev/null @@ -1,174 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -// NOTE: This file was originally created by tensorflow -// (https://github.com/tensorflow/tensorflow/) we borrow this -// file and did some modifications so that we can send gRPC -// requests without too much copying of the tensor data. - -#pragma once - -#include - -#include "google/protobuf/io/coded_stream.h" -#include "google/protobuf/io/zero_copy_stream.h" -#include "grpc++/grpc++.h" -#include "paddle/fluid/operators/distributed/variable_response.h" - -struct grpc_byte_buffer; - -namespace grpc { -// A ZeroCopyInputStream that reads from grpc_byte_buffer -class ByteBuffer; - -class GrpcBufferReader final - : public ::google::protobuf::io::ZeroCopyInputStream { - typedef void (CoreCodegenInterface::*OldReaderInitAPI)( - grpc_byte_buffer_reader* reader, grpc_byte_buffer* buffer); - typedef int (CoreCodegenInterface::*NewReaderInitAPI)( - grpc_byte_buffer_reader* reader, grpc_byte_buffer* buffer); - void ReaderInit(OldReaderInitAPI ptr, grpc_byte_buffer_reader* reader, - grpc_byte_buffer* buffer) { - (g_core_codegen_interface->*ptr)(reader, buffer); - } - void ReaderInit(NewReaderInitAPI ptr, grpc_byte_buffer_reader* reader, - grpc_byte_buffer* buffer) { - int result = (g_core_codegen_interface->*ptr)(reader, buffer); - (void)result; - } - - public: - explicit GrpcBufferReader(grpc_byte_buffer* buffer) - : byte_count_(0), backup_count_(0) { - ReaderInit(&CoreCodegenInterface::grpc_byte_buffer_reader_init, &reader_, - buffer); - } - ~GrpcBufferReader() override { - g_core_codegen_interface->grpc_byte_buffer_reader_destroy(&reader_); - } - - bool Next(const void** data, int* size) override { - if (backup_count_ > 0) { - *data = GRPC_SLICE_START_PTR(slice_) + GRPC_SLICE_LENGTH(slice_) - - backup_count_; - GPR_CODEGEN_ASSERT(backup_count_ <= INT_MAX); - *size = static_cast(backup_count_); - backup_count_ = 0; - return true; - } - if (!g_core_codegen_interface->grpc_byte_buffer_reader_next(&reader_, - &slice_)) { - return false; - } - g_core_codegen_interface->grpc_slice_unref(slice_); - *data = GRPC_SLICE_START_PTR(slice_); - // On win x64, int is only 32bit - GPR_CODEGEN_ASSERT(GRPC_SLICE_LENGTH(slice_) <= INT_MAX); - byte_count_ += * size = static_cast(GRPC_SLICE_LENGTH(slice_)); - return true; - } - - void BackUp(int count) override { backup_count_ = count; } - - bool Skip(int count) override { - const void* data; - int size; - while (Next(&data, &size)) { - if (size >= count) { - BackUp(size - count); - return true; - } - // size < count; - count -= size; - } - // error or we have too large count; - return false; - } - - ::google::protobuf::int64 ByteCount() const override { - return byte_count_ - backup_count_; - } - - private: - int64_t byte_count_; - int64_t backup_count_; - grpc_byte_buffer_reader reader_; - grpc_slice slice_; -}; - -}; // namespace grpc - -namespace paddle { -namespace operators { -namespace distributed { - -// A ZeroCopyInputStream that reads from a grpc::ByteBuffer. -class GrpcByteBufferSource - : public ::google::protobuf::io::ZeroCopyInputStream { - public: - GrpcByteBufferSource(); - bool Init(const ::grpc::ByteBuffer& src); // Can be called multiple times. - bool Next(const void** data, int* size) override; - void BackUp(int count) override; - bool Skip(int count) override; - ::google::protobuf::int64 ByteCount() const override; - - private: - std::vector<::grpc::Slice> slices_; - size_t cur_; // Current slice index. - int left_; // Number of bytes in slices_[cur_] left to yield. - const char* ptr_; // Address of next byte in slices_[cur_] to yield. - ::google::protobuf::int64 byte_count_; -}; - -class GrpcByteBufferSourceWrapper : public Source { - public: - explicit GrpcByteBufferSourceWrapper(GrpcByteBufferSource* source) - : source_(source) {} - ::google::protobuf::io::ZeroCopyInputStream* contents() override { - return source_; - } - - private: - GrpcByteBufferSource* source_; -}; - -class GrpcByteSource : public Source { - public: - explicit GrpcByteSource(grpc_byte_buffer* buffer) : buffer_(buffer) {} - ~GrpcByteSource() override { DeleteStream(); } - - typedef ::grpc::GrpcBufferReader Reader; - - ::google::protobuf::io::ZeroCopyInputStream* contents() override { - DeleteStream(); - stream_ = new (&space_) Reader(buffer_); - return stream_; - } - - private: - void DeleteStream() { - if (stream_) { - stream_->~Reader(); - } - } - - grpc_byte_buffer* buffer_; // Not owned - Reader* stream_ = nullptr; // Points into space_ if non-nullptr - char space_[sizeof(Reader)]; -}; - -} // namespace distributed -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/distributed/grpc/grpc_client.cc b/paddle/fluid/operators/distributed/grpc/grpc_client.cc deleted file mode 100644 index 97a9c14e4f..0000000000 --- a/paddle/fluid/operators/distributed/grpc/grpc_client.cc +++ /dev/null @@ -1,671 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include - -#include "glog/logging.h" // For VLOG -#include "paddle/fluid/framework/threadpool.h" -#include "paddle/fluid/operators/distributed/grpc/grpc_client.h" -#include "paddle/fluid/operators/distributed/grpc/grpc_serde.h" -#include "paddle/fluid/operators/distributed/request_handler.h" -#include "paddle/fluid/platform/port.h" -#include "paddle/fluid/platform/profiler.h" - -DEFINE_int32(rpc_client_threads, 2, ""); -DECLARE_bool(rpc_disable_reuse_port); - -namespace paddle { -namespace operators { -namespace distributed { - -void GRPCClient::InitImpl() { - // start the client process thread - // TODO(wuyi): can make this in a threadpool - client_threads_.resize(FLAGS_rpc_client_threads); - for (int i = 0; i < FLAGS_rpc_client_threads; i++) { - client_threads_[i].reset( - new std::thread(std::bind(&GRPCClient::Proceed, this))); - } -} - -void GRPCClient::SendComplete() { - std::unique_lock lk(completed_mutex_); - if (!completed_) { - for (auto& it : channels_) { - VLOG(3) << "send complete message to " << it.first; - this->AsyncSendComplete(it.first); - } - PADDLE_ENFORCE_EQ(this->Wait(), true, platform::errors::PreconditionNotMet( - "internal grpc service error.")); - completed_ = true; - } -} - -GRPCClient::~GRPCClient() { - stopped_ = true; - Wait(); - cq_.Shutdown(); - { - std::lock_guard guard(chan_mutex_); - for (auto& it : channels_) { - it.second.reset(); - } - channels_.clear(); - } - for (size_t i = 0; i < client_threads_.size(); i++) - client_threads_[i]->join(); -} - -VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& var_name, - int64_t time_out) { - const platform::DeviceContext* p_ctx = &ctx; - const std::string ep_val = ep; - const std::string var_name_val = var_name; - const framework::Scope* p_scope = &scope; - const auto ch = GetChannel(ep_val); - const std::string method = kSendRPC; - - int retry_times_ = 0; - - while (true) { - SendProcessor* s = new SendProcessor(ch); - VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope)); - s->Prepare(h, time_out); - - framework::Async([var_name_val, p_scope, p_ctx, s, method, h, this] { - auto* var = p_scope->FindVar(var_name_val); - - ::grpc::ByteBuffer req; - SerializeToByteBuffer(var_name_val, var, *p_ctx, &req, "", trainer_id_); - - VLOG(3) << s->GetVarHandlePtr()->String() << " begin"; - - // stub context - s->response_call_back_ = nullptr; - - platform::RecordRPCEvent record_event(method); - - auto call = s->stub_g_.PrepareUnaryCall( - s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, - &cq_); - call->StartCall(); - call->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); - - if (UNLIKELY(platform::IsProfileEnabled())) { - h->Wait(); - } - }); - req_count_++; - - if (FLAGS_rpc_retry_times > 0 && retry_times_ < FLAGS_rpc_retry_times) { - h->Wait(); - if (h->should_retry) { - VLOG(3) << "rpc call failed, retry times " << retry_times_; - retry_times_++; - std::random_device rd; - std::this_thread::sleep_for(std::chrono::milliseconds(rd() % 5)); - continue; - } - } - - return h; - } -} - -void ProcGetResponse(const VarHandle& var_h, - const ::grpc::ByteBuffer& ret_msg) { - VLOG(4) << "ProcGetResponse"; - framework::Variable* outvar = nullptr; - // get response's trainer_id is not used - int trainer_id; - DeserializeFromByteBuffer(ret_msg, *var_h.ctx(), var_h.scope(), &outvar, - &trainer_id); -} - -void ProcGetRecvResponse(const VarHandle& var_h, - const ::grpc::ByteBuffer& ret_msg) { - VLOG(4) << "ProcGetRecvResponse"; - framework::Variable* outvar = nullptr; - int trainer_id; - DeserializeRecvFromByteBuffer(ret_msg, *var_h.ctx(), var_h.scope(), &outvar, - &trainer_id); -} - -template -void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) { - ::grpc::Slice slice(proto.ByteSizeLong()); - proto.SerializeWithCachedSizesToArray(const_cast(slice.begin())); - ::grpc::ByteBuffer tmp(&slice, 1); - result->Swap(&tmp); -} - -VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& var_name, - const std::string& out_varname, - const std::string& table_name, - int64_t time_out) { - return _AsyncGetVar(ep, ctx, scope, kGetRPC, var_name, out_varname, - "/sendrecv.SendRecvService/GetVariable", table_name, - time_out); -} - -VarHandlePtr GRPCClient::AsyncGetVarNoBarrier( - const std::string& ep, const platform::DeviceContext& ctx, - const framework::Scope& scope, const std::string& var_name, - const std::string& out_varname, int64_t time_out) { - std::string var_name_no_barrier = - string::Sprintf("%s%s", var_name, WITHOUT_BARRIER_MESSAGE); - - return _AsyncGetVar( - ep, ctx, scope, kGetNoBarrierRPC, var_name_no_barrier, out_varname, - "/sendrecv.SendRecvService/GetVariableNoBarrier", "", time_out); -} - -VarHandlePtr GRPCClient::AsyncGetMonomerVariable( - const std::string& ep, const platform::DeviceContext& ctx, - const framework::Scope& scope, const std::string& var_name, - int64_t time_out) { - return _AsyncGetVar(ep, ctx, scope, kGetMonomerRPC, var_name, var_name, - "/sendrecv.SendRecvService/GetMonomerVariable", "", - time_out); -} - -VarHandlePtr GRPCClient::_AsyncGetVar( - const std::string& ep, const platform::DeviceContext& ctx, - const framework::Scope& scope, const std::string& method, - const std::string& var_name, const std::string& out_varname, - const std::string& rpc_path, const std::string& table_name, - int64_t time_out) { - const platform::DeviceContext* p_ctx = &ctx; - const std::string ep_val = ep; - const std::string var_name_val = var_name; - const std::string out_varname_val = out_varname; - const std::string table_name_val = table_name; - const framework::Scope* p_scope = &scope; - const auto ch = GetChannel(ep_val); - - int retry_times_ = 0; - - while (true) { - GetProcessor* s = new GetProcessor(ch); - - VarHandlePtr h(new VarHandle(ep, method, out_varname_val, p_ctx, p_scope)); - s->Prepare(h, time_out); - - framework::Async([var_name_val, out_varname_val, table_name_val, s, method, - p_ctx, h, rpc_path, this] { - // prepare input - sendrecv::VariableMessage req; - req.set_varname(var_name_val); - req.set_out_varname(out_varname_val); - req.set_trainer_id(trainer_id_); - req.set_table_name(table_name_val); - ::grpc::ByteBuffer buf; - RequestToByteBuffer(req, &buf); - - VLOG(3) << s->GetVarHandlePtr()->String() << " begin"; - - // stub context - s->response_call_back_ = ProcGetResponse; - - platform::RecordRPCEvent record_event(method); - - auto call = - s->stub_g_.PrepareUnaryCall(s->context_.get(), rpc_path, buf, &cq_); - call->StartCall(); - call->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); - - if (UNLIKELY(platform::IsProfileEnabled())) { - h->Wait(); - } - }); - req_count_++; - - if (FLAGS_rpc_retry_times > 0 && retry_times_ < FLAGS_rpc_retry_times) { - h->Wait(); - if (h->should_retry) { - VLOG(3) << "rpc call failed, retry times " << retry_times_; - retry_times_++; - std::random_device rd; - std::this_thread::sleep_for(std::chrono::milliseconds(rd() % 5)); - continue; - } - } - - return h; - } -} - -VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& in_var_name, - const std::string& out_var_name, - const std::string& table_name, - int64_t time_out) { - const platform::DeviceContext* p_ctx = &ctx; - const std::string ep_val = ep; - const std::string in_var_name_val = in_var_name; - const std::string out_var_name_val = out_var_name; - const std::string table_name_val = table_name; - const framework::Scope* p_scope = &scope; - const auto ch = GetChannel(ep_val); - - const std::string method = kPrefetchRPC; - int retry_times_ = 0; - - while (true) { - GetProcessor* s = new GetProcessor(ch); - VarHandlePtr h(new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope)); - s->Prepare(h, kPrefetchTimeout); - - auto* var = p_scope->FindVar(in_var_name_val); - - ::grpc::ByteBuffer req; - SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val, - 0, table_name_val); - - VLOG(3) << s->GetVarHandlePtr()->String() << " begin"; - - // stub context - s->response_call_back_ = ProcGetResponse; - - platform::RecordRPCEvent record_event(method); - - auto call = s->stub_g_.PrepareUnaryCall( - s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req, - &cq_); - call->StartCall(); - call->Finish(&s->reply_, &s->status_, static_cast(s)); - - if (UNLIKELY(platform::IsProfileEnabled())) { - h->Wait(); - } - - req_count_++; - - if (FLAGS_rpc_retry_times > 0 && retry_times_ < FLAGS_rpc_retry_times) { - h->Wait(); - if (h->should_retry) { - VLOG(3) << "rpc call failed, retry times " << retry_times_; - retry_times_++; - std::random_device rd; - std::this_thread::sleep_for(std::chrono::milliseconds(rd() % 5)); - continue; - } - } - - return h; - } -} - -VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep, - int64_t time_out) { - const auto ch = GetChannel(ep); - - BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); - const std::string method = kBatchBarrierRPC; - VarHandlePtr h( - new VarHandle(ep, method, BATCH_BARRIER_MESSAGE, nullptr, nullptr)); - s->Prepare(h, time_out); - - sendrecv::VariableMessage req; - req.set_varname(BATCH_BARRIER_MESSAGE); - - platform::RecordRPCEvent record_event(method); - - auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); - rpc->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); - req_count_++; - - if (UNLIKELY(platform::IsProfileEnabled())) { - h->Wait(); - } - - return h; -} - -VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep, - int64_t time_out) { - const auto ch = GetChannel(ep); - FetchBarrierProcessor* s = new FetchBarrierProcessor(ch); - const std::string method = kFetchBarrierRPC; - VarHandlePtr h( - new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr)); - s->Prepare(h, time_out); - - sendrecv::VariableMessage req; - req.set_varname(FETCH_BARRIER_MESSAGE); - - platform::RecordRPCEvent record_event(method); - - auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_); - rpc->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); - req_count_++; - - if (UNLIKELY(platform::IsProfileEnabled())) { - h->Wait(); - } - - return h; -} - -VarHandlePtr GRPCClient::AsyncGetMonomerBarrier(const std::string& ep, - const std::string& var_name, - int64_t time_out) { - const auto ch = GetChannel(ep); - BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); - const std::string method = kSendMonomerFetchBarrierRPC; - VarHandlePtr h(new VarHandle(ep, method, var_name, nullptr, nullptr)); - s->Prepare(h, time_out); - - VLOG(30) << s->GetVarHandlePtr()->String() << " begin"; - - sendrecv::VariableMessage req; - req.set_varname(var_name); - - platform::RecordRPCEvent record_event(method); - - auto rpc = s->stub_->AsyncGetMonomerBarrier(s->context_.get(), req, &cq_); - rpc->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); - req_count_++; - - if (UNLIKELY(platform::IsProfileEnabled())) { - h->Wait(); - } - - return h; -} - -VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep, - int64_t time_out) { - const auto ch = GetChannel(ep); - - BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); - const std::string method = kSendCompleteRPC; - VarHandlePtr h(new VarHandle(ep, method, COMPLETE_MESSAGE, nullptr, nullptr)); - s->Prepare(h, time_out); - - sendrecv::VariableMessage req; - req.set_trainer_id(trainer_id_); - req.set_varname(COMPLETE_MESSAGE); - - platform::RecordRPCEvent record_event(method); - - auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); - rpc->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); - req_count_++; - - if (UNLIKELY(platform::IsProfileEnabled())) { - h->Wait(); - } - - return h; -} - -VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep, - const std::string& dirname, - const std::string& varname, - const int mode, - int64_t time_out) { - const auto ch = GetChannel(ep); - - CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch); - - const std::string method = kCheckPointNotifyRPC; - - VarHandlePtr h( - new VarHandle(ep, method, CHECKPOINT_SAVE_MESSAGE, nullptr, nullptr)); - s->Prepare(h, time_out); - - sendrecv::VariableMessage req; - req.set_varname(varname); - req.set_table_name(std::to_string(mode)); - req.set_out_varname(dirname); - - platform::RecordRPCEvent record_event(method); - - auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_); - rpc->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); - req_count_++; - - if (UNLIKELY(platform::IsProfileEnabled())) { - h->Wait(); - } - - return h; -} - -VarHandlePtr GRPCClient::AsyncDistributeNotify( - const std::string& ep, const platform::DeviceContext& ctx, - const framework::Scope& scope, const std::string& var_name, - int64_t time_out) { - const platform::DeviceContext* p_ctx = &ctx; - const std::string ep_val = ep; - const std::string var_name_val = var_name; - const framework::Scope* p_scope = &scope; - const auto ch = GetChannel(ep_val); - const std::string method = kRequestNotify; - - SendProcessor* s = new SendProcessor(ch); - VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope)); - s->Prepare(h, time_out); - - framework::Async([var_name_val, p_scope, p_ctx, s, method, h, this] { - auto* var = p_scope->FindVar(var_name_val); - - ::grpc::ByteBuffer req; - SerializeToByteBuffer(var_name_val, var, *p_ctx, &req, "", trainer_id_); - - VLOG(3) << s->GetVarHandlePtr()->String() << " begin"; - - // stub context - s->response_call_back_ = nullptr; - - platform::RecordRPCEvent record_event(method); - - auto call = s->stub_g_.PrepareUnaryCall( - s->context_.get(), "/sendrecv.SendRecvService/DistributeNotify", req, - &cq_); - call->StartCall(); - call->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); - }); - req_count_++; - - if (UNLIKELY(platform::IsProfileEnabled())) { - h->Wait(); - } - - return h; -} - -VarHandlePtr GRPCClient::AsyncSendAndRecv(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& send_var_name, - const std::string& recv_var_name, - const std::string& table_name, - int64_t time_out) { - const platform::DeviceContext* p_ctx = &ctx; - const std::string ep_val = ep; - const std::string send_var_name_val = send_var_name; - const std::string recv_var_name_val = recv_var_name; - const std::string table_name_val = table_name; - const framework::Scope* p_scope = &scope; - const auto ch = GetChannel(ep_val); - const std::string method = kSendAndRecvRPC; - VLOG(4) << "GRPCClient::SendAndRecv Begin ,Send_var_name: " - << send_var_name_val << " Recv_var_name: " << recv_var_name_val; - int retry_times_ = 0; - - while (true) { - SendAndRecvProcessor* s = new SendAndRecvProcessor(ch); - VarHandlePtr h( - new VarHandle(ep, method, send_var_name_val, p_ctx, p_scope)); - VarHandlePtr h_recv( - new VarHandle(ep, method, recv_var_name_val, p_ctx, p_scope)); - s->Prepare(h, time_out); - s->RecvPrepare(h_recv); - - framework::Async([send_var_name_val, recv_var_name_val, table_name_val, - p_scope, p_ctx, s, method, h, this] { - auto* send_var = p_scope->FindVar(send_var_name_val); - send_var->GetMutable()->set_lod({}); - ::grpc::ByteBuffer buf; - VLOG(4) << "SerializeToByteBuffer: send_var_name_val: " - << send_var_name_val - << " recv_var_name_val: " << recv_var_name_val; - SerializeToByteBuffer(send_var_name_val, send_var, *p_ctx, &buf, - recv_var_name_val, trainer_id_, table_name_val); - - VLOG(3) << s->GetVarHandlePtr()->String() << " begin"; - - // stub context - s->response_call_back_ = ProcGetRecvResponse; - - platform::RecordRPCEvent record_event(method); - - auto call = s->stub_g_.PrepareUnaryCall( - s->context_.get(), "/sendrecv.SendRecvService/SendAndRecvVariable", - buf, &cq_); - call->StartCall(); - call->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); - - if (UNLIKELY(platform::IsProfileEnabled())) { - h->Wait(); - } - }); - req_count_++; - - if (FLAGS_rpc_retry_times > 0 && retry_times_ < FLAGS_rpc_retry_times) { - h->Wait(); - if (h->should_retry) { - VLOG(3) << "rpc call failed, retry times " << retry_times_; - retry_times_++; - std::random_device rd; - std::this_thread::sleep_for(std::chrono::milliseconds(rd() % 5)); - continue; - } - } - - return h; - } -} - -bool GRPCClient::Wait() { - std::unique_lock lk(sync_mutex_); - sync_cond_.wait(lk, [this] { return (req_count_ == 0 || ok_ == false); }); - return ok_; -} - -inline bool ShouldRetry(const std::string& method, int error_code) { - if (method == kPrefetchRPC) { - return true; - } - - if (error_code == grpc::StatusCode::DEADLINE_EXCEEDED) { - return true; - } - - return false; -} - -void GRPCClient::Proceed() { - void* tag = nullptr; - bool ok = false; - - VLOG(3) << "GRPCClient Proceed begin"; - while (!stopped_ && cq_.Next(&tag, &ok)) { - BaseProcessor* c = static_cast(tag); - GPR_ASSERT(ok); - PADDLE_ENFORCE_NOT_NULL( - c, platform::errors::PreconditionNotMet("Make BaseProcessor failed.")); - - if (c->status_.ok()) { - VLOG(3) << c->GetVarHandlePtr()->String() << " process"; - c->Process(); - } else if (ShouldRetry(c->GetVarHandlePtr()->method(), - c->status_.error_code())) { - VLOG(0) << c->GetVarHandlePtr()->String() - << " meets grpc error, error_code:" << c->status_.error_code() - << " error_message:" << c->status_.error_message() - << " error_details:" << c->status_.error_details() - << " should retry!"; - c->GetVarHandlePtr()->should_retry = true; - c->Finish(false); - } else { - PADDLE_THROW(platform::errors::External( - "%s meets grpc error, error_code is %d, error message is %s, error " - "details is %s.", - c->GetVarHandlePtr()->String(), c->status_.error_code(), - c->status_.error_message(), c->status_.error_details())); - c->Finish(false); - } - - bool notify = false; - { - std::lock_guard lk(sync_mutex_); - req_count_--; - notify = (req_count_ <= 0 || !c->status_.ok()); - } - - delete c; - - if (notify) { - sync_cond_.notify_all(); - } - } - - // Last log message - // Avoid using VLOG() and LOG(): in the destructor of google::LogMessage() a - // static Mutex log_mutex is used for synchronization, which might have been - // destructed at this moment. - if (FLAGS_v >= 3) { - std::string msg("GRPCClient Proceed end"); - fwrite(msg.c_str(), msg.length(), 1, stderr); - } -} - -std::shared_ptr GRPCClient::GetChannel(const std::string& ep) { - std::lock_guard guard(chan_mutex_); - auto it = channels_.find(ep); - if (it != channels_.end()) { - return it->second; - } - - // Channel configurations: - grpc::ChannelArguments args; - args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 2000); - if (FLAGS_rpc_disable_reuse_port) { - args.SetInt(GRPC_ARG_ALLOW_REUSEPORT, 0); - } - args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE); - args.SetMaxSendMessageSize(std::numeric_limits::max()); - args.SetMaxReceiveMessageSize(std::numeric_limits::max()); - - auto ch = - grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args); - channels_[ep] = ch; - return ch; -} - -} // namespace distributed -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/distributed/grpc/grpc_client.h b/paddle/fluid/operators/distributed/grpc/grpc_client.h deleted file mode 100644 index 5885f944b6..0000000000 --- a/paddle/fluid/operators/distributed/grpc/grpc_client.h +++ /dev/null @@ -1,321 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include -#include // NOLINT -#include // NOLINT -#include -#include -#include -#include -#include -#include // NOLINT -#include -#include // NOLINT -#include -#include - -#include "grpc++/channel.h" -#include "grpc++/generic/generic_stub.h" -#include "grpc++/grpc++.h" -#include "grpc++/support/byte_buffer.h" -#include "grpc++/support/slice.h" -#include "grpc/support/log.h" -#include "paddle/fluid/framework/blocking_queue.h" -#include "paddle/fluid/framework/data_type.h" -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/scope.h" -#include "paddle/fluid/framework/selected_rows.h" -#include "paddle/fluid/operators/distributed/distributed_pb.h" -#include "paddle/fluid/operators/distributed/request_handler.h" -#include "paddle/fluid/operators/distributed/rpc_client.h" -#include "paddle/fluid/operators/distributed/sendrecvop_utils.h" -#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN - -namespace grpc { -class Channel; -} // namespace grpc -namespace paddle { -namespace framework { -class Scope; -} // namespace framework -namespace platform { -class DeviceContext; -} // namespace platform -} // namespace paddle - -namespace paddle { -namespace operators { -namespace distributed { - -void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg); - -void ProcGetRecvResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg); - -class BaseProcessor { - public: - BaseProcessor() { context_ = nullptr; } - - virtual ~BaseProcessor() {} - - virtual void Prepare(VarHandlePtr h, int64_t time_out) { - var_h_ = h; - - context_.reset(new grpc::ClientContext()); - context_->set_wait_for_ready(true); - if (time_out) { - std::chrono::system_clock::time_point deadline = - std::chrono::system_clock::now() + - std::chrono::milliseconds(time_out); - context_->set_deadline(deadline); - } - } - - void Process() { - ProcessImpl(); - var_h_->Finish(true); - } - - VarHandlePtr GetVarHandlePtr() { return var_h_; } - bool Wait() { return var_h_->Wait(); } - void Finish(bool ok) { return var_h_->Finish(ok); } - virtual void ProcessImpl() = 0; - - std::unique_ptr context_; - grpc::Status status_; - - protected: - VarHandlePtr var_h_; -}; - -typedef std::function - RequestSendCallBack; - -class SendProcessor : public BaseProcessor { - public: - explicit SendProcessor(std::shared_ptr ch) - : BaseProcessor(), stub_g_(ch) {} - - virtual ~SendProcessor() {} - - void ProcessImpl() override { - if (response_call_back_) { - response_call_back_(*var_h_.get(), reply_); - } - } - - ::grpc::GenericStub stub_g_; - ::grpc::ByteBuffer reply_; - RequestSendCallBack response_call_back_ = nullptr; -}; - -typedef std::function - RequestGetCallBack; - -class GetProcessor : public BaseProcessor { - public: - explicit GetProcessor(std::shared_ptr ch) - : BaseProcessor(), stub_g_(ch) {} - - virtual ~GetProcessor() {} - - void ProcessImpl() override { - if (response_call_back_) { - response_call_back_(*var_h_.get(), reply_); - } - } - - ::grpc::ByteBuffer reply_; - ::grpc::GenericStub stub_g_; - RequestGetCallBack response_call_back_ = ProcGetResponse; -}; - -class SendAndRecvProcessor : public BaseProcessor { - public: - explicit SendAndRecvProcessor(std::shared_ptr ch) - : BaseProcessor(), stub_g_(ch) {} - - virtual ~SendAndRecvProcessor() {} - - void ProcessImpl() override { - if (response_call_back_) { - response_call_back_(*var_h_recv_.get(), reply_); - var_h_recv_->Finish(true); - } - } - - void RecvPrepare(VarHandlePtr h_recv) { var_h_recv_ = h_recv; } - - ::grpc::ByteBuffer reply_; - ::grpc::GenericStub stub_g_; - RequestGetCallBack response_call_back_ = ProcGetResponse; - VarHandlePtr var_h_recv_; -}; - -class BatchBarrierProcessor : public BaseProcessor { - public: - explicit BatchBarrierProcessor(std::shared_ptr ch) - : BaseProcessor() { - stub_ = sendrecv::SendRecvService::NewStub(ch); - } - - virtual ~BatchBarrierProcessor() {} - - void ProcessImpl() override {} - sendrecv::VoidMessage reply_; - std::unique_ptr stub_; -}; - -class FetchBarrierProcessor : public BaseProcessor { - public: - explicit FetchBarrierProcessor(std::shared_ptr ch) - : BaseProcessor() { - stub_ = sendrecv::SendRecvService::NewStub(ch); - } - - virtual ~FetchBarrierProcessor() {} - - void ProcessImpl() override {} - sendrecv::VariableMessage reply_; - std::unique_ptr stub_; -}; - -class CheckpointNotifyProcessor : public BaseProcessor { - public: - explicit CheckpointNotifyProcessor(std::shared_ptr ch) - : BaseProcessor() { - stub_ = sendrecv::SendRecvService::NewStub(ch); - } - - virtual ~CheckpointNotifyProcessor() {} - - void ProcessImpl() override {} - sendrecv::VoidMessage reply_; - std::unique_ptr stub_; -}; - -class GRPCClient : public RPCClient { - public: - GRPCClient() : ok_(true), completed_(false), stopped_(false) {} - virtual ~GRPCClient(); - - VarHandlePtr AsyncSendVar(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& var_name, - int64_t time_out = FLAGS_rpc_deadline) override; - - VarHandlePtr AsyncGetVar(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& var_name, - const std::string& out_varname, - const std::string& table_name = "", - int64_t time_out = FLAGS_rpc_deadline) override; - - VarHandlePtr AsyncGetVarNoBarrier( - const std::string& ep, const platform::DeviceContext& ctx, - const framework::Scope& scope, const std::string& var_name, - const std::string& out_varname, - int64_t time_out = FLAGS_rpc_deadline) override; - - VarHandlePtr AsyncGetMonomerVariable( - const std::string& ep, const platform::DeviceContext& ctx, - const framework::Scope& scope, const std::string& var_name, - int64_t time_out = FLAGS_rpc_deadline) override; - - VarHandlePtr AsyncPrefetchVar(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& in_var_name, - const std::string& out_var_name, - const std::string& table_name = "", - int64_t time_out = FLAGS_rpc_deadline) override; - - VarHandlePtr AsyncSendBatchBarrier( - const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override; - - VarHandlePtr AsyncSendFetchBarrier(const std::string& ep, - int64_t time_out) override; - - VarHandlePtr AsyncGetMonomerBarrier( - const std::string& ep, const std::string& var_name, - int64_t time_out = FLAGS_rpc_deadline) override; - - VarHandlePtr AsyncCheckpointNotify( - const std::string& ep, const std::string& dirname, - const std::string& varname, const int mode, - int64_t time_out = FLAGS_rpc_deadline) override; - - VarHandlePtr AsyncDistributeNotify( - const std::string& ep, const platform::DeviceContext& ctx, - const framework::Scope& scope, const std::string& var_name, - int64_t time_out = FLAGS_rpc_deadline) override; - - VarHandlePtr AsyncSendAndRecv(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& send_var_name, - const std::string& recv_var_name, - const std::string& table_name = "", - int64_t time_out = FLAGS_rpc_deadline) override; - - VarHandlePtr AsyncSendComplete( - const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override; - - bool Wait() override; - - void SendComplete() override; - - void InitImpl() override; - - private: - void Proceed(); - - std::shared_ptr GetChannel(const std::string& ep); - VarHandlePtr _AsyncGetVar( - const std::string& ep, const platform::DeviceContext& ctx, - const framework::Scope& scope, const std::string& method, - const std::string& var_name, const std::string& out_varname, - const std::string& rpc_path, const std::string& table_name = "", - int64_t time_out = FLAGS_rpc_deadline); - - private: - grpc::CompletionQueue cq_; - std::unordered_map> channels_; - std::vector> client_threads_; - - // mutex for Wait client sync - std::mutex sync_mutex_; - std::condition_variable sync_cond_; - std::atomic req_count_{0}; - bool ok_; - - // mutex for GetChannel thread safety - std::mutex chan_mutex_; - DISABLE_COPY_AND_ASSIGN(GRPCClient); - - // mutex for sending complete message only once - std::mutex completed_mutex_; - bool completed_; - - volatile bool stopped_; -}; - -} // namespace distributed -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/distributed/grpc/grpc_serde.cc b/paddle/fluid/operators/distributed/grpc/grpc_serde.cc deleted file mode 100644 index 0fc9b69577..0000000000 --- a/paddle/fluid/operators/distributed/grpc/grpc_serde.cc +++ /dev/null @@ -1,190 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef PADDLE_WITH_NCCL -#include -#endif -#ifdef PADDLE_WITH_RCCL -#include -#endif -#include -#include -#include "grpcpp/impl/codegen/byte_buffer.h" -#include "grpcpp/impl/codegen/slice.h" -#include "paddle/fluid/operators/distributed/grpc/grpc_serde.h" -#include "paddle/fluid/operators/distributed/grpc/grpc_variable_response.h" -#include "paddle/fluid/operators/distributed/proto_encoder_helper.h" -#include "paddle/fluid/operators/distributed/send_recv.pb.h" -#include "paddle/fluid/operators/distributed/sendrecvop_utils.h" -#include "paddle/fluid/platform/enforce.h" -#include "paddle/fluid/platform/profiler.h" - -namespace paddle { -namespace framework { -class Scope; -class Variable; -} // namespace framework -namespace platform { -class DeviceContext; -} // namespace platform -} // namespace paddle - -namespace paddle { -namespace operators { -namespace distributed { - -void SerializeToByteBuffer(const std::string& name, framework::Variable* var, - const platform::DeviceContext& ctx, - ::grpc::ByteBuffer* msg, const std::string& out_name, - const int trainer_id, - const std::string& table_name) { - platform::RecordRPCEvent record_event("serial"); - VarMsg request; - TensorPayload* payload = nullptr; - - request.set_varname(name); - request.set_trainer_id(trainer_id); - // Note: normally the profiler is enabled in 1 trainer, hence only - // 1 trainer returns true for ShouldSendProfileState(). It tells PS - // servers the trainer's profiling state so that PS can follow the - // trainer. - if (platform::ShouldSendProfileState()) { - if (platform::IsProfileEnabled()) { - request.set_profile(platform::kEnableProfiler); - } else { - request.set_profile(platform::kDisableProfiler); - } - } - if (!out_name.empty()) { - request.set_out_varname(out_name); - } - if (!table_name.empty()) { - request.set_table_name(table_name); - } - if (var->IsType()) { - request.set_type(::sendrecv::LOD_TENSOR); - payload = new TensorPayload(GetTensorPayload(var, ctx, &request)); - } else if (var->IsType()) { - request.set_type(::sendrecv::SELECTED_ROWS); - payload = new TensorPayload(GetSelectedRowsPayload(var, ctx, &request)); -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) - } else if (var->IsType()) { - request.set_type(::sendrecv::NCCL_ID); -#endif - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "Serialize does not support type: %s", typeid(var->Type()).name())); - } - std::string header; - request.AppendToString(&header); - auto buffer = std::unique_ptr(new char[1024]); - void* buf = buffer.get(); - ProtoEncodeHelper e(static_cast(buf), 1024); - e.WriteRawBytes(std::string(header.data(), header.size())); -// NCCLID is copied directly to the message, return bytebuffer -// with only one slice if serializing NCCLID. -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) - if (var->IsType()) { - e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, - NCCL_UNIQUE_ID_BYTES); - const ncclUniqueId& uid = var->Get(); - e.WriteRawBytes(std::string(uid.internal, NCCL_UNIQUE_ID_BYTES)); - - // for serialize NCCL_ID - ::grpc::Slice slices(e.size()); - memcpy(const_cast(slices.begin()), e.data(), e.size()); - ::grpc::ByteBuffer tmp(&slices, 1); - msg->Swap(&tmp); - return; - } -#endif - PADDLE_ENFORCE_NOT_NULL( - payload, - platform::errors::InvalidArgument( - "Not support type: %s, need to be LOD_TENSOR or SELECTED_ROWS", - var->Type())); - e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, - payload->memory_size()); - if (payload->memory_size() >= std::numeric_limits::max()) { - PADDLE_THROW(platform::errors::InvalidArgument( - "Variable %s length %d should less than %d.", name, - payload->memory_size(), std::numeric_limits::max())); - } - // steal reference of tensor data - ::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows - int num_slices = 2; // only SelectedRows have rows buffer - slices[0] = ::grpc::Slice(e.size()); - memcpy(const_cast(slices[0].begin()), e.data(), e.size()); - slices[1] = ::grpc::Slice( - grpc_slice_new_with_user_data(payload->ptr(), payload->memory_size(), - SerializeDestroyCallback, payload), - ::grpc::Slice::STEAL_REF); - - if (var->IsType()) { - auto* slr = var->GetMutable(); - ProtoEncodeHelper e2(static_cast(buf), 128); - - PADDLE_ENFORCE_EQ(VectorElemName(slr->rows()), typeid(int64_t).name(), - platform::errors::InvalidArgument( - "Got wrong type %s, expect type: int64_t", - VectorElemName(slr->rows()))); - size_t rows_memory_size = slr->rows().size() * sizeof(int64_t); - - e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size); - slices[2] = ::grpc::Slice(e2.size()); - memcpy(const_cast(slices[2].begin()), e2.data(), e2.size()); - - slices[3] = ::grpc::Slice( - grpc_slice_new_with_user_data( - const_cast( - reinterpret_cast(slr->rows().data())), - rows_memory_size, [](void* backing) {}, - const_cast( - reinterpret_cast(slr->rows().data()))), - ::grpc::Slice::STEAL_REF); - num_slices = 4; - } - ::grpc::ByteBuffer tmp(&slices[0], num_slices); - msg->Swap(&tmp); -} - -void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, - const platform::DeviceContext& ctx, - const framework::Scope* scope, - framework::Variable** var, int* trainer_id) { - platform::RecordRPCEvent record_event("deserial"); - operators::distributed::GRPCVariableResponse resp(scope, &ctx); - PADDLE_ENFORCE_EQ( - resp.Parse(msg), 0, - platform::errors::InvalidArgument("parse bytebuffer to tensor error!")); - *var = resp.GetVar(); - *trainer_id = resp.GetTrainerId(); -} - -void DeserializeRecvFromByteBuffer(const ::grpc::ByteBuffer& msg, - const platform::DeviceContext& ctx, - const framework::Scope* scope, - framework::Variable** var, int* trainer_id) { - platform::RecordRPCEvent record_event("deserial"); - operators::distributed::GRPCVariableResponse resp(scope, &ctx); - PADDLE_ENFORCE_EQ( - resp.Parse(msg), 0, - platform::errors::InvalidArgument("parse bytebuffer to tensor error!")); - *var = resp.GetRecvVar(); - *trainer_id = resp.GetTrainerId(); -} - -} // namespace distributed -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/distributed/grpc/grpc_serde.h b/paddle/fluid/operators/distributed/grpc/grpc_serde.h deleted file mode 100644 index 932f3e2f06..0000000000 --- a/paddle/fluid/operators/distributed/grpc/grpc_serde.h +++ /dev/null @@ -1,69 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include -#include - -#include "paddle/fluid/framework/data_type.h" -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/scope.h" -#include "paddle/fluid/framework/selected_rows.h" -#include "paddle/fluid/framework/tensor_util.h" -#include "paddle/fluid/framework/var_type.h" -#include "paddle/fluid/operators/distributed/distributed_pb.h" -#include "paddle/fluid/operators/distributed/sendrecvop_utils.h" -#include "paddle/fluid/platform/port.h" - -namespace grpc { -class ByteBuffer; -} // namespace grpc -namespace paddle { -namespace framework { -class Scope; -class Variable; -} // namespace framework -namespace platform { -class DeviceContext; -} // namespace platform -} // namespace paddle - -namespace paddle { -namespace operators { -namespace distributed { - -typedef void (*DestroyCallback)(void*); - -void SerializeToByteBuffer(const std::string& name, framework::Variable* var, - const platform::DeviceContext& ctx, - ::grpc::ByteBuffer* msg, - const std::string& out_varname = std::string(), - const int trainer_id = 0, - const std::string& table_name = std::string()); - -void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, - const platform::DeviceContext& ctx, - const framework::Scope* scope, - framework::Variable** var, int* trainer_id); - -void DeserializeRecvFromByteBuffer(const ::grpc::ByteBuffer& msg, - const platform::DeviceContext& ctx, - const framework::Scope* scope, - framework::Variable** var, int* trainer_id); - -} // namespace distributed -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/distributed/grpc/grpc_serde_test.cc b/paddle/fluid/operators/distributed/grpc/grpc_serde_test.cc deleted file mode 100644 index d407a72938..0000000000 --- a/paddle/fluid/operators/distributed/grpc/grpc_serde_test.cc +++ /dev/null @@ -1,224 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include -#include // NOLINT - -#include "google/protobuf/text_format.h" -#include "gtest/gtest.h" -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/tensor_util.h" -#include "paddle/fluid/framework/variable.h" -#include "paddle/fluid/operators/distributed/distributed.h" -#include "paddle/fluid/operators/distributed/grpc/grpc_serde.h" -#include "paddle/fluid/operators/distributed/grpc/grpc_variable_response.h" -#include "paddle/fluid/operators/distributed/sendrecvop_utils.h" -#include "paddle/fluid/operators/math/math_function.h" -#include "paddle/fluid/platform/place.h" -#include "paddle/fluid/string/printf.h" - -namespace framework = paddle::framework; -namespace platform = paddle::platform; -namespace operators = paddle::operators; -namespace math = paddle::operators::math; -namespace memory = paddle::memory; - -void RunSerdeTestSelectedRows(platform::Place place) { - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - auto& ctx = *pool.Get(place); - - // serialize var to ByteBuffer - framework::Variable var; - auto* slr = var.GetMutable(); - slr->set_height(1000); - auto* tensor = slr->mutable_value(); - auto* rows = slr->mutable_rows(); - tensor->Resize(framework::make_ddim({564, 128})); - tensor->mutable_data(place); - int tensor_numel = 564 * 128; - math::set_constant(ctx, tensor, 32.7); - for (int i = 0; i < 564; ++i) rows->push_back(i); - - ::grpc::ByteBuffer msg; - operators::distributed::SerializeToByteBuffer("myvar", &var, ctx, &msg); - EXPECT_GT(msg.Length(), static_cast(0)); - - // deserialize - std::vector<::grpc::Slice> slices; - (void)msg.Dump(&slices); - std::string tmp; - for (const auto& s : slices) { - tmp.append(reinterpret_cast(s.begin()), s.size()); - } - - sendrecv::VariableMessage varmsg; - EXPECT_TRUE(varmsg.ParseFromString(tmp)); - - // deserialize bytebuffer - EXPECT_EQ(varmsg.varname(), "myvar"); - EXPECT_EQ(varmsg.type(), 1); - - const float* tensor_data = - reinterpret_cast(varmsg.serialized().data()); - const int64_t* rows_data = - reinterpret_cast(varmsg.rows().data()); - for (int i = 0; i < tensor_numel; ++i) { - EXPECT_FLOAT_EQ(tensor_data[i], 32.7); - } - for (int i = 0; i < 564; ++i) { - EXPECT_EQ(rows_data[i], i); - } - - // deserialize zero-copy - // framework::Variable var2; - // operators::distributed::DeserializeFromByteBuffer(msg, ctx, &var2); - framework::Scope scope; - scope.Var("myvar"); - operators::distributed::GRPCVariableResponse resp(&scope, &ctx); - EXPECT_EQ(resp.Parse(msg), 0); - - framework::Variable* var2 = resp.GetVar(); - - auto* slr2 = var2->GetMutable(); - auto* tensor2 = slr2->mutable_value(); - auto* rows2 = slr2->mutable_rows(); - float* tensor_data2 = nullptr; - framework::Tensor tmp_tensor; - - if (platform::is_gpu_place(ctx.GetPlace())) { - platform::CPUPlace cpu; - framework::TensorCopy(*tensor2, cpu, &tmp_tensor); - tensor_data2 = tmp_tensor.data(); - } else { - tensor_data2 = const_cast(tensor2->data()); - } - const int64_t* rows_data2 = rows2->data(); - - for (int i = 0; i < tensor_numel; ++i) { - EXPECT_FLOAT_EQ(tensor_data2[i], 32.7); - } - for (size_t i = 0; i < rows2->size(); ++i) { - EXPECT_EQ(rows_data2[i], static_cast(i)); - } - EXPECT_EQ(slr2->height(), 1000); -} - -void RunTestLodTensor(platform::Place place, int from_type = 0) { - // serialize var to ByteBuffer - framework::Variable var; - auto* tensor = var.GetMutable(); - tensor->Resize(framework::make_ddim({512, 8, 4, 2})); - framework::LoD lod; - lod.push_back(framework::Vector({1, 3, 8})); - tensor->set_lod(lod); - int tensor_numel = 512 * 8 * 4 * 2; - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - auto& ctx = *pool.Get(place); - tensor->mutable_data(place); - math::set_constant(ctx, tensor, 31.9); - - ::grpc::ByteBuffer msg; - operators::distributed::SerializeToByteBuffer("myvar", &var, ctx, &msg, - "outvar", 0, "table_name"); - EXPECT_GT(msg.Length(), static_cast(0)); - - // deserialize - std::vector<::grpc::Slice> slices; - (void)msg.Dump(&slices); - std::string tmp; - for (const auto& s : slices) { - tmp.append(reinterpret_cast(s.begin()), s.size()); - } - sendrecv::VariableMessage varmsg; - EXPECT_TRUE(varmsg.ParseFromString(tmp)); - EXPECT_EQ(varmsg.varname(), "myvar"); - EXPECT_EQ(varmsg.type(), 0); - EXPECT_EQ(varmsg.dims()[0], 512); - EXPECT_EQ(varmsg.dims()[1], 8); - EXPECT_EQ(varmsg.dims()[2], 4); - EXPECT_EQ(varmsg.dims()[3], 2); - EXPECT_EQ(varmsg.lod_level(), 1); - EXPECT_EQ(varmsg.lod(0).lod_data(0), 1); - EXPECT_EQ(varmsg.lod(0).lod_data(1), 3); - EXPECT_EQ(varmsg.lod(0).lod_data(2), 8); - - const float* tensor_data = - reinterpret_cast(varmsg.serialized().data()); - for (int i = 0; i < tensor_numel; ++i) { - EXPECT_FLOAT_EQ(tensor_data[i], 31.9); - } - - // message binary - std::string str; - varmsg.SerializeToString(&str); - - // message bytebuffer - ::grpc::Slice slices_2[1]; - int num_slices = 1; - slices_2[0] = ::grpc::Slice(str.length()); - memcpy(const_cast(slices_2[0].begin()), str.c_str(), str.length()); - ::grpc::ByteBuffer bytebuffer2(&slices_2[0], num_slices); - - // deserialize zero-copy - framework::Scope scope; - scope.Var("myvar"); - operators::distributed::GRPCVariableResponse resp(&scope, &ctx); - if (from_type == 0) { - EXPECT_EQ(resp.Parse(msg), 0); - } else { - EXPECT_EQ(resp.Parse(bytebuffer2), 0); - } - - framework::Variable* var2 = resp.GetVar(); - - auto tensor2 = var2->Get(); - float* tensor_data2 = nullptr; - framework::Tensor tmp_tensor; - - if (platform::is_gpu_place(ctx.GetPlace())) { - platform::CPUPlace cpu; - framework::TensorCopy(tensor2, cpu, &tmp_tensor); - tensor_data2 = tmp_tensor.data(); - } else { - tensor_data2 = const_cast(tensor2.data()); - } - - EXPECT_EQ(varmsg.lod_level(), 1); - EXPECT_EQ(varmsg.lod(0).lod_data(0), 1); - EXPECT_EQ(varmsg.lod(0).lod_data(1), 3); - EXPECT_EQ(varmsg.lod(0).lod_data(2), 8); - for (int i = 0; i < tensor_numel; ++i) EXPECT_FLOAT_EQ(tensor_data2[i], 31.9); -} - -TEST(LodTensor, Run) { - platform::CPUPlace place; - RunTestLodTensor(place); - RunTestLodTensor(place, 1); -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - platform::CUDAPlace gpu(0); - RunTestLodTensor(gpu); - RunTestLodTensor(gpu, 1); -#endif -} - -TEST(SelectedRows, Run) { - platform::CPUPlace place; - RunSerdeTestSelectedRows(place); - -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - platform::CUDAPlace gpu; - RunSerdeTestSelectedRows(gpu); -#endif -} diff --git a/paddle/fluid/operators/distributed/grpc/grpc_server.cc b/paddle/fluid/operators/distributed/grpc/grpc_server.cc deleted file mode 100644 index 912520d782..0000000000 --- a/paddle/fluid/operators/distributed/grpc/grpc_server.cc +++ /dev/null @@ -1,720 +0,0 @@ -/*Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include -#include -#include - -#include "paddle/fluid/operators/distributed/grpc/grpc_serde.h" -#include "paddle/fluid/operators/distributed/grpc/grpc_server.h" - -namespace grpc { -class ChannelArguments; -} // namespace grpc -namespace paddle { -namespace framework { -class Variable; -} // namespace framework -namespace operators { -namespace distributed { -class GRPCVariableResponse; -} // namespace distributed -} // namespace operators -} // namespace paddle - -using ::grpc::ServerAsyncResponseWriter; - -DECLARE_bool(rpc_disable_reuse_port); -DECLARE_int32(rpc_retry_bind_port); - -namespace paddle { -namespace operators { -namespace distributed { - -enum CallStatus { PROCESS = 0, FINISH }; - -// reference: -// https://stackoverflow.com/questions/41732884/grpc-multiple-services-in-cpp-async-server -class RequestBase { - public: - explicit RequestBase(GrpcService::AsyncService* service, - ::grpc::ServerCompletionQueue* cq, - RequestHandler* request_handler, int req_id) - : service_(service), - cq_(cq), - status_(PROCESS), - request_handler_(request_handler), - req_id_(req_id) { - PADDLE_ENFORCE_NOT_NULL(cq_, platform::errors::InvalidArgument( - "ServerCompletionQueue cq are empty")); - } - virtual ~RequestBase() {} - virtual void Process() = 0; - - std::string Status2String(const std::string& method) { - std::string status = "Process"; - if (status_ == FINISH) { - status = "Finish"; - } - - std::ostringstream s; - s << method << " name:[" << GetReqName() << "]" - << ", ep:[" << ctx_.peer() << "]" - << " " << status << " using req_id:" << req_id_; - return s.str(); - } - - CallStatus Status() const { - std::lock_guard l(status_mu_); - return status_; - } - - template - void Finish(const T& reply, ServerAsyncResponseWriter* responder) { - std::lock_guard l(status_mu_); - status_ = FINISH; - responder->Finish(reply, ::grpc::Status::OK, - reinterpret_cast(static_cast(req_id_))); - } - virtual std::string GetReqName() = 0; - - protected: - mutable std::mutex status_mu_; - ::grpc::ServerContext ctx_; - GrpcService::AsyncService* service_; - ::grpc::ServerCompletionQueue* cq_; - CallStatus status_; - RequestHandler* request_handler_; - int req_id_; -}; - -class RequestSend final : public RequestBase { - public: - explicit RequestSend(GrpcService::AsyncService* service, - ::grpc::ServerCompletionQueue* cq, - RequestHandler* request_handler, int req_id) - : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) { - request_.reset(new GRPCVariableResponse(request_handler->scope(), - request_handler->dev_ctx(), true)); - int method_id = static_cast(distributed::GrpcMethod::kSendVariable); - service_->RequestAsyncUnary( - method_id, &ctx_, request_.get(), &responder_, cq_, cq_, - reinterpret_cast(static_cast(req_id))); - } - virtual ~RequestSend() {} - std::string GetReqName() override { return request_->Varname(); } - - void Process() override { - std::string varname = GetReqName(); - - auto scope = request_->GetMutableLocalScope(); - auto invar = request_->GetVar(); - int trainer_id = request_->GetTrainerId(); - - VLOG(4) << "RequestSend var_name:" << varname << " trainer: " << trainer_id; - - framework::Variable* outvar = nullptr; - request_handler_->Handle(varname, scope, invar, &outvar, trainer_id); - Finish(reply_, &responder_); - } - - protected: - sendrecv::VoidMessage reply_; - std::shared_ptr request_; - ServerAsyncResponseWriter responder_; -}; - -class RequestGet final : public RequestBase { - public: - explicit RequestGet(GrpcService::AsyncService* service, - ::grpc::ServerCompletionQueue* cq, - RequestHandler* request_handler, int req_id) - : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) { - auto method_id = static_cast(distributed::GrpcMethod::kGetVariable); - service_->RequestAsyncUnary( - method_id, &ctx_, &request_, &responder_, cq_, cq_, - reinterpret_cast(static_cast(req_id))); - } - - virtual ~RequestGet() {} - - std::string GetReqName() override { return request_.varname(); } - - void Process() override { - // proc request. - std::string varname = request_.varname(); - std::string out_varname = request_.out_varname(); - std::string table_name = request_.table_name(); - int trainer_id = request_.trainer_id(); - - VLOG(4) << "RequestGet " << out_varname << " from " << varname; - - auto scope = request_handler_->scope(); - framework::Variable* invar = nullptr; - framework::Variable* outvar = nullptr; - - tmp_scope_ = std::move(scope->NewTmpScope()); - request_handler_->Handle(varname, tmp_scope_.get(), invar, &outvar, - trainer_id, out_varname, table_name); - - VLOG(1) << "before SerializeToByteBuffer"; - if (outvar) { - SerializeToByteBuffer(out_varname, outvar, *request_handler_->dev_ctx(), - &reply_); - } - VLOG(1) << "after SerializeToByteBuffer"; - Finish(reply_, &responder_); - } - - protected: - sendrecv::VariableMessage request_; - ::grpc::ByteBuffer reply_; - std::unique_ptr tmp_scope_; - ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_; -}; - -class RequestGetNoBarrier final : public RequestBase { - public: - explicit RequestGetNoBarrier(GrpcService::AsyncService* service, - ::grpc::ServerCompletionQueue* cq, - RequestHandler* request_handler, int req_id) - : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) { - auto method_id = - static_cast(distributed::GrpcMethod::kGetVariableNoBarrier); - service_->RequestAsyncUnary( - method_id, &ctx_, &request_, &responder_, cq_, cq_, - reinterpret_cast(static_cast(req_id))); - } - - virtual ~RequestGetNoBarrier() {} - - std::string GetReqName() override { return request_.varname(); } - - void Process() override { - // proc request. - std::string varname = request_.varname(); - std::string out_varname = request_.out_varname(); - int trainer_id = request_.trainer_id(); - - VLOG(4) << "RequestGetNoBarrier " << out_varname << " from " << varname; - - auto scope = request_handler_->scope(); - framework::Variable* invar = nullptr; - framework::Variable* outvar = nullptr; - - request_handler_->Handle(varname, scope, invar, &outvar, trainer_id, - out_varname); - - if (outvar) { - SerializeToByteBuffer(out_varname, outvar, *request_handler_->dev_ctx(), - &reply_); - } - Finish(reply_, &responder_); - } - - protected: - sendrecv::VariableMessage request_; - ::grpc::ByteBuffer reply_; - ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_; -}; - -class RequestGetMonomerVariable final : public RequestBase { - public: - explicit RequestGetMonomerVariable(GrpcService::AsyncService* service, - ::grpc::ServerCompletionQueue* cq, - RequestHandler* request_handler, - int req_id, RPCServer* rpc_server) - : RequestBase(service, cq, request_handler, req_id), - responder_(&ctx_), - rpc_server_(rpc_server) { - auto method_id = - static_cast(distributed::GrpcMethod::kGetMonomerVariable); - service_->RequestAsyncUnary( - method_id, &ctx_, &request_, &responder_, cq_, cq_, - reinterpret_cast(static_cast(req_id))); - } - - virtual ~RequestGetMonomerVariable() {} - - std::string GetReqName() override { return request_.varname(); } - - void Process() override { - // proc request. - std::string varname = request_.varname(); - - rpc_server_->WaitVarCond(varname); - MonomerHandle h = rpc_server_->GetMonomer(varname); - - auto scope = h.scope_; - auto invar = scope->FindVar(varname); - framework::Variable* outvar = nullptr; - - request_handler_->Handle(varname, scope, invar, &outvar, - request_.trainer_id()); - - if (outvar) { - SerializeToByteBuffer(varname, outvar, *h.dev_ctx_, &reply_); - } - Finish(reply_, &responder_); - } - - protected: - sendrecv::VariableMessage request_; - ::grpc::ByteBuffer reply_; - ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_; - RPCServer* rpc_server_{nullptr}; -}; - -class RequestGetMonomerBarrier final : public RequestBase { - public: - explicit RequestGetMonomerBarrier(GrpcService::AsyncService* service, - ::grpc::ServerCompletionQueue* cq, - RequestHandler* request_handler, int req_id, - RPCServer* rpc_server) - : RequestBase(service, cq, request_handler, req_id), - responder_(&ctx_), - rpc_server_(rpc_server) { - auto method_id = - static_cast(distributed::GrpcMethod::kGetMonomerBarrier); - service_->RequestAsyncUnary( - method_id, &ctx_, &request_, &responder_, cq_, cq_, - reinterpret_cast(static_cast(req_id))); - } - - virtual ~RequestGetMonomerBarrier() {} - - std::string GetReqName() override { return request_.varname(); } - - void Process() override { - // proc request. - std::string varname = request_.varname(); - VLOG(4) << "RequestGetMonomerBarrier " << varname; - - rpc_server_->WaitVarCond(varname); - MonomerHandle h = rpc_server_->GetMonomer(varname); - - framework::Scope* scope = nullptr; - framework::Variable* invar = nullptr; - framework::Variable* outvar = nullptr; - - request_handler_->Handle(varname, scope, invar, &outvar, - request_.trainer_id()); - - Finish(reply_, &responder_); - } - - protected: - sendrecv::VariableMessage request_; - sendrecv::VoidMessage reply_; - ServerAsyncResponseWriter responder_; - RPCServer* rpc_server_{nullptr}; -}; - -class RequestPrefetch final : public RequestBase { - public: - explicit RequestPrefetch(GrpcService::AsyncService* service, - ::grpc::ServerCompletionQueue* cq, - RequestHandler* request_handler, int req_id) - : RequestBase(service, cq, request_handler, req_id), - responder_(&ctx_), - local_scope_(nullptr) { - request_.reset(new GRPCVariableResponse(request_handler->scope(), - request_handler->dev_ctx(), true)); - int method_id = - static_cast(distributed::GrpcMethod::kPrefetchVariable); - service_->RequestAsyncUnary( - method_id, &ctx_, request_.get(), &responder_, cq_, cq_, - reinterpret_cast(static_cast(req_id))); - } - - virtual ~RequestPrefetch() {} - - std::string GetReqName() override { return request_->Varname(); } - - void Process() override { - // prefetch process... - std::string in_var_name = request_->Varname(); - std::string out_var_name = request_->OutVarname(); - std::string table_name = request_->TableName(); - int trainer_id = request_->GetTrainerId(); - - VLOG(4) << "RequestPrefetch, in_var_name: " << in_var_name - << " out_var_name: " << out_var_name << " trainer: " << trainer_id; - - auto scope = request_->GetMutableLocalScope(); - auto invar = scope->FindVar(in_var_name); - // out var must be created in local scope! - framework::Variable* outvar = scope->Var(out_var_name); - - request_handler_->Handle(in_var_name, scope, invar, &outvar, trainer_id, - out_var_name, table_name); - - SerializeToByteBuffer(out_var_name, outvar, *request_handler_->dev_ctx(), - &reply_); - Finish(reply_, &responder_); - } - - protected: - std::shared_ptr request_; - ::grpc::ByteBuffer reply_; - ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_; - framework::Scope* local_scope_; -}; - -class RequestCheckpointNotify final : public RequestBase { - public: - explicit RequestCheckpointNotify(GrpcService::AsyncService* service, - ::grpc::ServerCompletionQueue* cq, - RequestHandler* request_handler, int req_id) - : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) { - request_.reset(new GRPCVariableResponse(request_handler->scope(), - request_handler->dev_ctx())); - int method_id = - static_cast(distributed::GrpcMethod::kCheckpointNotify); - service_->RequestAsyncUnary( - method_id, &ctx_, request_.get(), &responder_, cq_, cq_, - reinterpret_cast(static_cast(req_id))); - } - - virtual ~RequestCheckpointNotify() {} - - std::string GetReqName() override { return request_->Varname(); } - - void Process() override { - auto scope = request_->GetMutableLocalScope(); - - std::string checkpoint_notify = request_->Varname(); - std::string checkpoint_dir = request_->OutVarname(); - int trainer_id = request_->GetTrainerId(); - std::string table_name = request_->TableName(); - - VLOG(4) << "RequestCheckpointNotify notify: " << checkpoint_notify - << ", dir: " << checkpoint_dir; - - request_handler_->Handle(checkpoint_notify, scope, nullptr, nullptr, - trainer_id, checkpoint_dir, table_name); - Finish(reply_, &responder_); - } - - protected: - std::shared_ptr request_; - sendrecv::VoidMessage reply_; - ServerAsyncResponseWriter responder_; -}; - -class RequestNotify final : public RequestBase { - public: - explicit RequestNotify(GrpcService::AsyncService* service, - ::grpc::ServerCompletionQueue* cq, - RequestHandler* request_handler, int req_id) - : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) { - request_.reset(new GRPCVariableResponse(request_handler->scope(), - request_handler->dev_ctx(), true)); - int method_id = static_cast(distributed::GrpcMethod::kRequestNotify); - service_->RequestAsyncUnary( - method_id, &ctx_, request_.get(), &responder_, cq_, cq_, - reinterpret_cast(static_cast(req_id))); - } - virtual ~RequestNotify() {} - std::string GetReqName() override { return request_->Varname(); } - - void Process() override { - std::string varname = GetReqName(); - VLOG(4) << "RequestNotify var_name:" << varname; - - auto scope = request_->GetMutableLocalScope(); - auto invar = request_->GetVar(); - int trainer_id = request_->GetTrainerId(); - framework::Variable* outvar = nullptr; - request_handler_->Handle(varname, scope, invar, &outvar, trainer_id); - Finish(reply_, &responder_); - } - - protected: - sendrecv::VoidMessage reply_; - std::shared_ptr request_; - ServerAsyncResponseWriter responder_; -}; - -class RequestSendAndRecv final : public RequestBase { - public: - explicit RequestSendAndRecv(GrpcService::AsyncService* service, - ::grpc::ServerCompletionQueue* cq, - RequestHandler* request_handler, int req_id) - : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) { - request_.reset(new GRPCVariableResponse(request_handler->scope(), - request_handler->dev_ctx(), true)); - - int method_id = - static_cast(distributed::GrpcMethod::kRequestSendAndRecv); - - service_->RequestAsyncUnary( - method_id, &ctx_, request_.get(), &responder_, cq_, cq_, - reinterpret_cast(static_cast(req_id))); - } - - virtual ~RequestSendAndRecv() {} - std::string GetReqName() override { return request_->Varname(); } - - void Process() override { - std::string in_var_name = request_->Varname(); - std::string out_var_name = request_->OutVarname(); - std::string table_name = request_->TableName(); - int trainer_id = request_->GetTrainerId(); - - VLOG(4) << "RequestSendAndRecv, in_var_name: " << in_var_name - << " out_var_name: " << out_var_name << " trainer: " << trainer_id; - auto scope = request_->GetMutableLocalScope(); - auto invar = scope->FindVar(in_var_name); - framework::Variable* outvar = nullptr; - request_handler_->Handle(in_var_name, scope, invar, &outvar, trainer_id, - out_var_name, table_name); - SerializeToByteBuffer(out_var_name, outvar, *request_handler_->dev_ctx(), - &reply_); - Finish(reply_, &responder_); - } - - protected: - std::shared_ptr request_; - ::grpc::ByteBuffer reply_; - ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_; -}; - -void AsyncGRPCServer::WaitServerReady() { - VLOG(4) << "AsyncGRPCServer is waiting server ready"; - std::unique_lock lock(this->mutex_ready_); - condition_ready_.wait(lock, [=] { return this->ready_ == 1; }); - VLOG(4) << "AsyncGRPCServer WaitSeverReady"; -} - -// Define an option subclass in order to disable SO_REUSEPORT for the -// server socket. -// Come from: -// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc -class NoReusePortOption : public ::grpc::ServerBuilderOption { - public: - void UpdateArguments(::grpc::ChannelArguments* args) override { - args->SetInt(GRPC_ARG_ALLOW_REUSEPORT, 0); - } - - void UpdatePlugins(std::vector>* - plugins) override {} -}; - -void AsyncGRPCServer::StartServer() { - for (int i = 0; i < FLAGS_rpc_retry_bind_port; i++) { - ::grpc::ServerBuilder builder; - std::unique_ptr service( - new GrpcService::AsyncService()); - builder.AddListeningPort(bind_address_, ::grpc::InsecureServerCredentials(), - &selected_port_); - - builder.SetMaxSendMessageSize(std::numeric_limits::max()); - builder.SetMaxReceiveMessageSize(std::numeric_limits::max()); - if (FLAGS_rpc_disable_reuse_port) { - builder.SetOption( - std::unique_ptr<::grpc::ServerBuilderOption>(new NoReusePortOption)); - LOG(INFO) << "set FLAGS_rpc_disable_reuse_port"; - } - builder.RegisterService(service.get()); - - for (auto t : rpc_call_map_) { - rpc_cq_[t.first].reset(builder.AddCompletionQueue().release()); - } - - server_ = builder.BuildAndStart(); - if (selected_port_ != 0) { - LOG(INFO) << "Server listening on " << bind_address_ - << " successful, selected port: " << selected_port_; - service_.reset(service.release()); - break; - } - - LOG(WARNING) << "Server listening on " << bind_address_ - << " failed, selected port: " << selected_port_ - << ", retry after 3 seconds!"; - - sleep(3); - } - - PADDLE_ENFORCE_NE( - selected_port_, 0, - platform::errors::Unavailable("can't bind to address:%s", bind_address_)); - - std::function f = - std::bind(&AsyncGRPCServer::TryToRegisterNewOne, this, - std::placeholders::_1, std::placeholders::_2); - - for (auto& t : rpc_call_map_) { - auto& rpc_name = t.first; - auto& cq = rpc_cq_[rpc_name]; - auto threadnum = rpc_thread_num_[rpc_name]; - auto& reqs = rpc_reqs_[rpc_name]; - - reqs.reserve(kRequestBufSize); - - for (int i = 0; i < kRequestBufSize; i++) { - VLOG(6) << "TryToRegisterNewOne on RPC NAME: " << rpc_name << " I: " << i; - TryToRegisterNewOne(rpc_name, i); - } - - for (int i = 0; i < threadnum; i++) { - rpc_threads_[rpc_name].emplace_back(new std::thread(std::bind( - &AsyncGRPCServer::HandleRequest, this, cq.get(), rpc_name, f))); - VLOG(4) << t.first << " creates threads!"; - } - } - - { - std::lock_guard lock(this->mutex_ready_); - ready_ = 1; - } - condition_ready_.notify_all(); - - // wait server - server_->Wait(); - - for (auto& t : rpc_threads_) { - auto& threads = t.second; - for (size_t i = 0; i < threads.size(); ++i) { - threads[i]->join(); - VLOG(4) << t.first << " threads ends!"; - } - } -} - -void AsyncGRPCServer::ShutdownQueue() { - for (auto& t : rpc_cq_) { - t.second->Shutdown(); - VLOG(4) << t.first << " queue shutdown!"; - } -} - -void AsyncGRPCServer::ShutDownImpl() { - std::unique_lock lock(cq_mutex_); - is_shut_down_ = true; - ShutdownQueue(); - - VLOG(4) << "server_ shutdown!"; - server_->Shutdown(); -} - -void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name, - int req_id) { - std::unique_lock lock(cq_mutex_); - if (is_shut_down_) { - VLOG(4) << "shutdown, do not TryToRegisterNewSendOne"; - return; - } - - VLOG(4) << "TryToRegisterNewOne on RPC NAME: " << rpc_name - << " REQ ID: " << req_id; - - auto& reqs = rpc_reqs_[rpc_name]; - auto& handler = rpc_call_map_[rpc_name]; - auto& cq = rpc_cq_[rpc_name]; - - RequestBase* b = nullptr; - if (rpc_name == kRequestSend) { - b = new RequestSend(service_.get(), cq.get(), handler, req_id); - } else if (rpc_name == kRequestGet) { - b = new RequestGet(service_.get(), cq.get(), handler, req_id); - - } else if (rpc_name == kRequestGetNoBarrier) { - b = new RequestGetNoBarrier(service_.get(), cq.get(), handler, req_id); - } else if (rpc_name == kRequestGetMonomerVariable) { - b = new RequestGetMonomerVariable(service_.get(), cq.get(), handler, req_id, - this); - } else if (rpc_name == kRequestGetMonomerBarrier) { - b = new RequestGetMonomerBarrier(service_.get(), cq.get(), handler, req_id, - this); - } else if (rpc_name == kRequestPrefetch) { - b = new RequestPrefetch(service_.get(), cq.get(), handler, req_id); - } else if (rpc_name == kRequestCheckpoint) { - b = new RequestCheckpointNotify(service_.get(), cq.get(), handler, req_id); - } else if (rpc_name == kRequestNotify) { - b = new RequestNotify(service_.get(), cq.get(), handler, req_id); - } else if (rpc_name == kRequestSendAndRecv) { - b = new RequestSendAndRecv(service_.get(), cq.get(), handler, req_id); - } else { - PADDLE_THROW( - platform::errors::InvalidArgument("not supported rpc: %s", rpc_name)); - } - - reqs[req_id] = b; - - VLOG(4) << "TryToRegisterNewOne status:" << b->Status(); -} - -void AsyncGRPCServer::HandleRequest( - ::grpc::ServerCompletionQueue* cq, const std::string& rpc_name, - std::function TryToRegisterNewOne) { - void* tag = NULL; - bool ok = false; - - while (true) { - VLOG(4) << "HandleRequest " << rpc_name << " wait next"; - if (!cq->Next(&tag, &ok)) { - VLOG(4) << "CompletionQueue " << rpc_name << " shutdown!"; - break; - } - - int req_id = static_cast(reinterpret_cast(tag)); - VLOG(4) << "HandleRequest " << rpc_name << ", req_id:" << req_id - << " get next"; - - auto& reqs = rpc_reqs_[rpc_name]; - RequestBase* base = nullptr; - { - PADDLE_ENFORCE_EQ( - (req_id >= 0 && req_id < kRequestBufSize), true, - platform::errors::OutOfRange("request id: %s out of bounds: [0, %s)", - req_id, kRequestBufSize)); - std::unique_lock lock(cq_mutex_); - base = reqs[req_id]; - } - - VLOG(3) << base->Status2String(rpc_name); - - // reference: - // https://github.com/tensorflow/tensorflow/issues/5596 - // https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM - // https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I - if (!ok) { - VLOG(4) << "completion queue:" << rpc_name << " recv no regular event" - << " context:" << base->Status2String(rpc_name); - TryToRegisterNewOne(rpc_name, req_id); - delete base; - continue; - } - - switch (base->Status()) { - case PROCESS: { - base->Process(); - break; - } - case FINISH: { - TryToRegisterNewOne(rpc_name, req_id); - delete base; - break; - } - default: { assert(false); } - } - } -} - -} // namespace distributed -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/distributed/grpc/grpc_server.h b/paddle/fluid/operators/distributed/grpc/grpc_server.h deleted file mode 100644 index 3d68b7e8ce..0000000000 --- a/paddle/fluid/operators/distributed/grpc/grpc_server.h +++ /dev/null @@ -1,93 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include -#include -#include -#include // NOLINT -#include -#include - -#include "grpc++/grpc++.h" -#include "paddle/fluid/framework/blocking_queue.h" -#include "paddle/fluid/framework/executor.h" -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/framework/scope.h" -#include "paddle/fluid/framework/selected_rows.h" -#include "paddle/fluid/framework/var_type.h" -#include "paddle/fluid/operators/distributed/distributed_pb.h" -#include "paddle/fluid/operators/distributed/grpc/grpc_service.h" -#include "paddle/fluid/operators/distributed/request_handler.h" -#include "paddle/fluid/operators/distributed/rpc_server.h" -#include "paddle/fluid/operators/distributed/sendrecvop_utils.h" -#include "paddle/fluid/platform/profiler.h" - -namespace grpc { -class ServerCompletionQueue; -} // namespace grpc - -namespace paddle { -namespace operators { -namespace distributed { - -class RequestBase; - -class AsyncGRPCServer final : public RPCServer { - public: - explicit AsyncGRPCServer(const std::string& address, int client_num) - : RPCServer(address, client_num), ready_(0) {} - - virtual ~AsyncGRPCServer() {} - void WaitServerReady() override; - void StartServer() override; - - private: - // HandleRequest needs to be thread-safe. - void HandleRequest( - ::grpc::ServerCompletionQueue* cq, const std::string& rpc_name, - std::function TryToRegisterNewOne); - - void TryToRegisterNewOne(const std::string& rpc_name, int req_id); - void ShutdownQueue(); - void ShutDownImpl() override; - - private: - static const int kRequestBufSize = 100; - - std::mutex cq_mutex_; - volatile bool is_shut_down_ = false; - - std::unique_ptr service_; - std::unique_ptr<::grpc::Server> server_; - - // condition of the sub program - std::condition_variable barrier_condition_; - - std::mutex mutex_ready_; - std::condition_variable condition_ready_; - - int ready_; - - std::map> rpc_cq_; - std::map>> rpc_threads_; - std::map> rpc_reqs_; -}; - -}; // namespace distributed -}; // namespace operators -}; // namespace paddle diff --git a/paddle/fluid/operators/distributed/grpc/grpc_service.h b/paddle/fluid/operators/distributed/grpc/grpc_service.h deleted file mode 100644 index 10037c9085..0000000000 --- a/paddle/fluid/operators/distributed/grpc/grpc_service.h +++ /dev/null @@ -1,145 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "paddle/fluid/operators/distributed/grpc/grpc_variable_response.h" -#include "paddle/fluid/platform/profiler.h" - -// NOTE: This method was originally created by tensorflow -// (https://github.com/tensorflow/tensorflow/) we borrow this -// method and did some modifications so that we can parse gRPC -// requests without too much copying of the tensor data. - -namespace grpc { -class CompletionQueue; -class Channel; -class RpcService; -class ServerCompletionQueue; -class ServerContext; - -// Support parsing/unparsing of tensorflow::VariableResponse. -// Wire-format is identical to RecvVariableResponse. -template <> -class SerializationTraits< - paddle::operators::distributed::GRPCVariableResponse> { - public: - static Status Serialize( - const paddle::operators::distributed::GRPCVariableResponse& msg, - grpc_byte_buffer** bp, bool* own_buffer) { - PADDLE_THROW(paddle::platform::errors::Unimplemented( - "SerializationTraits::Serialize not implemented!")); - return Status(); - } - static Status Deserialize( - grpc_byte_buffer* buffer, - paddle::operators::distributed::GRPCVariableResponse* msg, - int max_message_size = INT_MAX) { - if (buffer == nullptr) { - return Status(StatusCode::INTERNAL, "No payload"); - } - - Status result = g_core_codegen_interface->ok(); - if (result.ok()) { - paddle::operators::distributed::GrpcByteSource source(buffer); - int ret = msg->Parse(&source); - if (ret != 0) { - result = Status(StatusCode::INTERNAL, "VariableResponse parse error"); - } - } - g_core_codegen_interface->grpc_byte_buffer_destroy(buffer); - return result; - } -}; -} // namespace grpc - -namespace paddle { -namespace operators { -namespace distributed { - -enum class GrpcMethod { - kSendVariable, - kGetVariable, - kPrefetchVariable, - kCheckpointNotify, - kGetVariableNoBarrier, - kGetMonomerVariable, - kGetMonomerBarrier, - kRequestNotify, - kRequestSendAndRecv, - // when you add new handler, change kGrpcNumMethods at the same time! -}; - -static const int kGrpcNumMethods = - static_cast(GrpcMethod::kRequestSendAndRecv) + 1; - -inline const char* GrpcMethodName(GrpcMethod id) { - switch (id) { - case GrpcMethod::kSendVariable: - return "/sendrecv.SendRecvService/SendVariable"; - case GrpcMethod::kGetVariable: - return "/sendrecv.SendRecvService/GetVariable"; - case GrpcMethod::kGetVariableNoBarrier: - return "/sendrecv.SendRecvService/GetVariableNoBarrier"; - case GrpcMethod::kGetMonomerVariable: - return "/sendrecv.SendRecvService/GetMonomerVariable"; - case GrpcMethod::kGetMonomerBarrier: - return "/sendrecv.SendRecvService/GetMonomerBarrier"; - case GrpcMethod::kPrefetchVariable: - return "/sendrecv.SendRecvService/PrefetchVariable"; - case GrpcMethod::kCheckpointNotify: - return "/sendrecv.SendRecvService/CheckpointNotify"; - case GrpcMethod::kRequestNotify: - return "/sendrecv.SendRecvService/DistributeNotify"; - case GrpcMethod::kRequestSendAndRecv: - return "/sendrecv.SendRecvService/SendAndRecvVariable"; - } - - // Shouldn't be reached. - PADDLE_THROW(platform::errors::InvalidArgument( - "Invalid id: not found valid method name")); - return nullptr; -} - -class GrpcService final { - public: - class AsyncService : public ::grpc::Service { - public: - AsyncService() { - for (int i = 0; i < kGrpcNumMethods; ++i) { - AddMethod(new ::grpc::internal::RpcServiceMethod( - GrpcMethodName(static_cast(i)), - ::grpc::internal::RpcMethod::NORMAL_RPC, nullptr)); - ::grpc::Service::MarkMethodAsync(i); - } - } - virtual ~AsyncService() {} - - // Make RequestAsyncUnary public for grpc_call.h - using ::grpc::Service::RequestAsyncUnary; - }; -}; - -} // namespace distributed -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/distributed/grpc/grpc_variable_response.cc b/paddle/fluid/operators/distributed/grpc/grpc_variable_response.cc deleted file mode 100644 index f7679e9fc9..0000000000 --- a/paddle/fluid/operators/distributed/grpc/grpc_variable_response.cc +++ /dev/null @@ -1,344 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include - -#include "google/protobuf/io/coded_stream.h" -#include "paddle/fluid/operators/distributed/grpc/grpc_bytebuffer_stream.h" -#include "paddle/fluid/operators/distributed/grpc/grpc_variable_response.h" -#include "paddle/fluid/operators/distributed/send_recv.pb.h" -#include "paddle/fluid/platform/enforce.h" -#include "paddle/fluid/platform/profiler.h" - -namespace google { -namespace protobuf { -namespace io { -class ZeroCopyInputStream; -} // namespace io -} // namespace protobuf -} // namespace google -namespace grpc { -class ByteBuffer; -} // namespace grpc - -namespace paddle { -namespace operators { -namespace distributed { - -enum WireType { - WIRETYPE_VARINT = 0, - WIRETYPE_LENGTH_DELIMITED = 2, -}; - -inline int GetTagFieldNumber(uint32_t tag) { return tag >> 3; } - -inline WireType GetTagWireType(uint32_t tag) { - return static_cast(tag & 0x7); -} - -bool ReadVarintSizeAsInt(::google::protobuf::io::CodedInputStream* input, - int* result) { - uint64_t v; - if (input->ReadVarint64(&v) && v <= static_cast(INT_MAX)) { - *result = static_cast(v); - return true; - } else { - return false; - } -} - -int GRPCVariableResponse::Parse(const ::grpc::ByteBuffer& byte_buffer) { - GrpcByteBufferSource source; - source.Init(byte_buffer); - GrpcByteBufferSourceWrapper r(&source); - - return Parse(&r); -} - -bool ParseLodData(::google::protobuf::io::CodedInputStream* input, - std::vector* lod) { - while (true) { - auto p = input->ReadTagWithCutoff(127); - int tag = GetTagFieldNumber(p.first); - WireType wt = GetTagWireType(p.first); - - if (!p.second) { - return (tag == 0); - } - - switch (tag) { - case sendrecv::VariableMessage_LodData::kLodDataFieldNumber: { - uint64_t v; - if (wt == WIRETYPE_VARINT) { - if (!input->ReadVarint64(&v)) { - return false; - } - lod->push_back(v); - break; - } - - if (wt == WIRETYPE_LENGTH_DELIMITED) { - int num_bytes = 0; - if (!input->ReadVarintSizeAsInt(&num_bytes)) { - return tag; - } - int start_pos = input->CurrentPosition(); - while (input->CurrentPosition() - start_pos < num_bytes) { - uint64_t v; - if (!input->ReadVarint64(&v)) { - return tag; - } - lod->push_back(v); - } - break; - } - - return false; - } - default: { return false; } - } - } - - return true; -} - -int GRPCVariableResponse::Parse(Source* source) { - ::google::protobuf::io::ZeroCopyInputStream* input_stream = - source->contents(); - ::google::protobuf::io::CodedInputStream input(input_stream); - input.SetTotalBytesLimit(INT_MAX, INT_MAX); - - while (true) { - auto p = input.ReadTagWithCutoff(127); - int tag = GetTagFieldNumber(p.first); - WireType wt = GetTagWireType(p.first); - if (!p.second) { - if (tag != 0) { - return -1; - } - return 0; - } - - switch (tag) { - case sendrecv::VariableMessage::kVarnameFieldNumber: { - uint32_t length; - if ((wt != WIRETYPE_LENGTH_DELIMITED) || !input.ReadVarint32(&length)) { - return tag; - } - - std::string temp; - if (!input.ReadString(&temp, length)) { - return tag; - } - - meta_.set_varname(temp); - break; - } - case sendrecv::VariableMessage::kTypeFieldNumber: { - uint32_t v; - if ((wt != WIRETYPE_VARINT) || !input.ReadVarint32(&v)) { - return tag; - } - - meta_.set_type(static_cast<::sendrecv::VarType>(v)); - break; - } - case sendrecv::VariableMessage::kDataTypeFieldNumber: { - uint32_t v = 0; - if ((wt != WIRETYPE_VARINT) || !input.ReadVarint32(&v)) { - return tag; - } - - meta_.set_data_type(static_cast<::sendrecv::VariableMessage_Type>(v)); - break; - } - case sendrecv::VariableMessage::kDimsFieldNumber: { - // not packed - if (wt == WIRETYPE_VARINT) { - uint64_t v; - if (!input.ReadVarint64(&v)) { - return tag; - } - meta_.add_dims(v); - break; - } - - // packed - if (wt == WIRETYPE_LENGTH_DELIMITED) { - int num_bytes = 0; - if (!input.ReadVarintSizeAsInt(&num_bytes)) { - return tag; - } - int start_pos = input.CurrentPosition(); - while (input.CurrentPosition() - start_pos < num_bytes) { - uint64_t v; - if (!input.ReadVarint64(&v)) { - return tag; - } - meta_.add_dims(v); - } - break; - } - return tag; - } - case sendrecv::VariableMessage::kLodLevelFieldNumber: { - uint64_t v = 0; - if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) { - return tag; - } - meta_.set_lod_level(static_cast(v)); - break; - } - case sendrecv::VariableMessage::kLodFieldNumber: { - int length = 0; - if (wt != WIRETYPE_LENGTH_DELIMITED || - !ReadVarintSizeAsInt(&input, &length)) { - return tag; - } - - std::pair<::google::protobuf::io::CodedInputStream::Limit, int> p = - input.IncrementRecursionDepthAndPushLimit(length); - - std::vector lod_data; - if (p.second < 0 || !ParseLodData(&input, &lod_data)) { - return tag; - } - - if (!input.DecrementRecursionDepthAndPopLimit(p.first)) { - return tag; - } - - if (lod_data.size() == 0) { - break; - } - - auto lod = meta_.add_lod(); - for (uint32_t i = 0; i < lod_data.size(); i++) { - lod->add_lod_data(lod_data[i]); - } - break; - } - case sendrecv::VariableMessage::kSlrHeightFieldNumber: { - uint64_t v = 0; - if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) { - return tag; - } - meta_.set_slr_height(static_cast(v)); - break; - } - case sendrecv::VariableMessage::kSerializedFieldNumber: { - int num_bytes = 0; - if (wt != WIRETYPE_LENGTH_DELIMITED || - !ReadVarintSizeAsInt(&input, &num_bytes)) { - return tag; - } - - if (!ProcSerializedField(tag, &input, num_bytes)) { - return tag; - } - - break; - } - case sendrecv::VariableMessage::kRowsFieldNumber: { - PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS || - meta_.type() == sendrecv::LOD_TENSOR) && - meta_.varname() != "", - platform::errors::PreconditionNotMet( - "meta info should be got first!")); - - int num_bytes = 0; - if (wt != WIRETYPE_LENGTH_DELIMITED || - !ReadVarintSizeAsInt(&input, &num_bytes)) { - return tag; - } - - if (!CopySelectRowsData(&input, *dev_ctx_, num_bytes)) { - return tag; - } - break; - } - case sendrecv::VariableMessage::kOutVarnameFieldNumber: { - uint32_t length; - if ((wt != WIRETYPE_LENGTH_DELIMITED) || !input.ReadVarint32(&length)) { - return tag; - } - - std::string temp; - if (!input.ReadString(&temp, length)) { - return tag; - } - - meta_.set_out_varname(temp); - break; - } - case sendrecv::VariableMessage::kProfileFieldNumber: { - uint64_t profiling = 0; - if (!input.ReadVarint64(&profiling)) { - return tag; - } - meta_.set_profile(profiling); - int64_t listener_id = platform::ListenerId(); - if (listener_id <= 0) { - break; - } - if (profiling == platform::kEnableProfiler && - !platform::IsProfileEnabled()) { - platform::EnableProfiler(platform::ProfilerState::kCPU); - } else if (profiling == platform::kDisableProfiler && - platform::IsProfileEnabled()) { - platform::DisableProfiler( - platform::EventSortingKey::kDefault, - string::Sprintf("%s_%lld", FLAGS_rpc_server_profile_path, - listener_id)); - } - break; - } - case sendrecv::VariableMessage::kTrainerIdFieldNumber: { - uint64_t trainer_id = 0; - if (!input.ReadVarint64(&trainer_id)) { - return tag; - } - meta_.set_trainer_id(trainer_id); - break; - } - case sendrecv::VariableMessage::kTableNameFieldNumber: { - uint32_t length; - if ((wt != WIRETYPE_LENGTH_DELIMITED) || !input.ReadVarint32(&length)) { - return tag; - } - - std::string temp; - if (!input.ReadString(&temp, length)) { - return tag; - } - - meta_.set_table_name(temp); - break; - } - default: { - // Unknown tag, return unknown error. - return -1; - } - } - } - - return 0; -} - -}; // namespace distributed -}; // namespace operators -}; // namespace paddle diff --git a/paddle/fluid/operators/distributed/grpc/grpc_variable_response.h b/paddle/fluid/operators/distributed/grpc/grpc_variable_response.h deleted file mode 100644 index 4d12b4a4ba..0000000000 --- a/paddle/fluid/operators/distributed/grpc/grpc_variable_response.h +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include - -#include "google/protobuf/io/coded_stream.h" -#include "google/protobuf/io/zero_copy_stream.h" -#include "paddle/fluid/framework/data_type.h" -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/scope.h" -#include "paddle/fluid/framework/selected_rows.h" -#include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/framework/var_type.h" -#include "paddle/fluid/operators/distributed/distributed_pb.h" -#include "paddle/fluid/operators/distributed/grpc/grpc_bytebuffer_stream.h" -#include "paddle/fluid/operators/distributed/variable_response.h" - -namespace grpc { -class ByteBuffer; -} // namespace grpc -namespace paddle { -namespace framework { -class Scope; -} // namespace framework -namespace platform { -class DeviceContext; -} // namespace platform -} // namespace paddle - -namespace paddle { -namespace operators { -namespace distributed { - -class GRPCVariableResponse : public VariableResponse { - public: - GRPCVariableResponse(const framework::Scope* scope, - const platform::DeviceContext* dev_ctx, - bool create_scope = false) - : VariableResponse(scope, dev_ctx, create_scope) {} - - virtual ~GRPCVariableResponse() {} - - int Parse(Source* source) override; - - // return: - // 0:ok. - // -1: unkown error. - // other: number of error field. - int Parse(const ::grpc::ByteBuffer& byte_buffer); -}; - -}; // namespace distributed -}; // namespace operators -}; // namespace paddle diff --git a/paddle/fluid/operators/distributed/heart_beat_monitor.cc b/paddle/fluid/operators/distributed/heart_beat_monitor.cc deleted file mode 100644 index 9f537f5334..0000000000 --- a/paddle/fluid/operators/distributed/heart_beat_monitor.cc +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/operators/distributed/heart_beat_monitor.h" - -#include - -namespace paddle { -namespace operators { -namespace distributed { - -DEFINE_int32(worker_update_interval_secs, 900, - " the longest time interval between the worker update variables"); - -inline int GetCurrentUS() { - // current date/time based on current system - time_t t = std::time(0); - int now = static_cast(t); - return now; -} - -void HeartBeatMonitor::Update(const int worker_id, std::string be_monitored_var, - WorkerStatus status) { - if (status == UNINITED) { - LOG(WARNING) << "HeartBeatMonitor receive UNINITED status can not be used " - "in Update, something error"; - } - - if (!is_chief_) { - return; - } - - if ((be_monitored_var == be_monitored_var_ && status == RUNNING) || - status == COMPLETED) { - auto timestamp = GetCurrentUS(); - UnderMonitoredWorker& worker = worker_status_map_.at(worker_id); - - if (worker.status != COMPLETED) { - worker.status = status; - } - worker.timestamp = timestamp; - return; - } -} - -void HeartBeatMonitor::LostWorkerMonitor() { - VLOG(1) << "worker heartbeat monitor start at No.0 parameter server"; - while (running_) { - for (int id = 0; id < workers_; ++id) { - auto& worker = worker_status_map_.at(id); - - if (worker.status == UNINITED) { - VLOG(4) << "worker " << worker.id << " is under UNINITED"; - continue; - } - if (worker.status == COMPLETED) { - VLOG(4) << "worker " << worker.id << " is under COMPLETED"; - continue; - } - - auto timestamp = GetCurrentUS(); - - VLOG(4) << "worker " << worker.id << " status is " << worker.status - << " timestamp is " << worker.timestamp << " the interval is " - << timestamp - worker.timestamp; - - if (timestamp - worker.timestamp >= FLAGS_worker_update_interval_secs) { - PADDLE_THROW(platform::errors::ExecutionTimeout( - "the latest update of worker %d is %d secs ago, we doubt the " - "the worker is not alive and this may have a bad effect on the " - "fitting result, please check", - worker.id, FLAGS_worker_update_interval_secs)); - } - } - - std::this_thread::sleep_for(std::chrono::milliseconds(10 * 1000)); - } - VLOG(1) << "worker heartbeat monitor stopped, thread exit"; -} - -std::once_flag HeartBeatMonitor::init_flag_; -std::unique_ptr HeartBeatMonitor::monitor_(nullptr); - -} // namespace distributed -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/distributed/heart_beat_monitor.h b/paddle/fluid/operators/distributed/heart_beat_monitor.h deleted file mode 100644 index d96433c318..0000000000 --- a/paddle/fluid/operators/distributed/heart_beat_monitor.h +++ /dev/null @@ -1,127 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include // NOLINT -#include -#include -#include // NOLINT -#include -#include -#include -#include -#include "gflags/gflags.h" - -#include "paddle/fluid/platform/enforce.h" - -namespace paddle { -namespace operators { -namespace distributed { - -enum WorkerStatus { UNINITED = 0, RUNNING, COMPLETED }; - -struct UnderMonitoredWorker { - int id; - WorkerStatus status; - int timestamp; - - UnderMonitoredWorker() {} - - explicit UnderMonitoredWorker(int worker_id) { - this->id = worker_id; - this->status = UNINITED; - this->timestamp = 0; - } -}; - -class HeartBeatMonitor { - public: - explicit HeartBeatMonitor(int workers, bool is_chief, - std::string be_monitored_var) - : workers_(workers), - is_chief_(is_chief), - be_monitored_var_(be_monitored_var), - running_(true) { - PADDLE_ENFORCE_GT(workers, 0, platform::errors::InvalidArgument( - "workers must greater than 0.")); - - for (auto worker_id = 0; worker_id < workers; worker_id++) { - UnderMonitoredWorker worker(worker_id); - worker_status_map_[worker_id] = std::move(worker); - } - - // we define the No.0 pserver is the first parameter server - // only No.0 will check the heartbeat of all trainers - if (is_chief) { - monitor_thread_.reset(new std::thread( - std::bind(&HeartBeatMonitor::LostWorkerMonitor, this))); - } - } - - ~HeartBeatMonitor() { - running_ = false; - if (monitor_thread_) monitor_thread_->join(); - } - - static void Init(int workers, bool is_chief, std::string be_monitored_var) { - std::call_once(init_flag_, &HeartBeatMonitor::InitImpl, workers, is_chief, - be_monitored_var); - } - - static HeartBeatMonitor* GetInstance() { return monitor_.get(); } - - void Stop() { - running_ = false; - if (!monitor_) { - VLOG(0) << "HeartBeatMonitor is not inited, do nothing"; - } else { - if (monitor_thread_) { - monitor_thread_->join(); - monitor_thread_.reset(nullptr); - } - } - } - - void Update(const int worker_id, std::string be_monitored_var, - WorkerStatus status); - - void LostWorkerMonitor(); - - private: - // Init is called by GetInstance. - static void InitImpl(int workers, bool is_chief, - std::string be_monitored_var) { - if (monitor_ == nullptr) { - monitor_.reset(new HeartBeatMonitor(workers, is_chief, be_monitored_var)); - } - } - - static std::once_flag init_flag_; - static std::unique_ptr monitor_; - - int workers_; - bool is_chief_; - std::string be_monitored_var_; - std::unordered_map worker_status_map_; - std::unique_ptr monitor_thread_{nullptr}; - std::mutex mutex_; - bool running_ = false; -}; - -} // namespace distributed -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/distributed/heart_beat_monitor_test.cc b/paddle/fluid/operators/distributed/heart_beat_monitor_test.cc deleted file mode 100644 index 8505023f63..0000000000 --- a/paddle/fluid/operators/distributed/heart_beat_monitor_test.cc +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/operators/distributed/heart_beat_monitor.h" - -#include "gtest/gtest.h" - -namespace paddle { -namespace operators { -namespace distributed { - -void run(HeartBeatMonitor* monitor) { monitor->LostWorkerMonitor(); } - -TEST(HeartBeatMonitor, All) { - int trainers = 10; - int pserver_id = 0; - std::string var = "nce_w@GRAD.block0"; - std::string var2 = "nce_w@GRAD.block2"; - - HeartBeatMonitor::Init(trainers, pserver_id == 0, var); - - auto* monitor = HeartBeatMonitor::GetInstance(); - - std::vector ids{1, 3, 5, 7}; - - for (auto& id : ids) { - monitor->Update(id, var, RUNNING); - } - - monitor->Update(9, var2, RUNNING); - monitor->Update(2, var, COMPLETED); - - std::thread t(run, monitor); - t.detach(); - - std::this_thread::sleep_for(std::chrono::milliseconds(15 * 1000)); - - monitor->Stop(); -} - -} // namespace distributed -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/distributed/large_scale_kv.cc b/paddle/fluid/operators/distributed/large_scale_kv.cc deleted file mode 100644 index d2673ed6ff..0000000000 --- a/paddle/fluid/operators/distributed/large_scale_kv.cc +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/operators/distributed/large_scale_kv.h" - -namespace paddle { -namespace operators { -namespace distributed { - -std::once_flag LargeScaleKV::init_flag_; -std::shared_ptr LargeScaleKV::scale_kv_(nullptr); - -} // namespace distributed -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/distributed/large_scale_kv.h b/paddle/fluid/operators/distributed/large_scale_kv.h deleted file mode 100644 index da2281231f..0000000000 --- a/paddle/fluid/operators/distributed/large_scale_kv.h +++ /dev/null @@ -1,848 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include // NOLINT -#include -#include -#include // NOLINT -#include -#include -#include -#include -#include "gflags/gflags.h" - -#include "paddle/fluid/framework/generator.h" -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/rw_lock.h" -#include "paddle/fluid/framework/selected_rows.h" -#include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/framework/threadpool.h" -#include "paddle/fluid/framework/variable.h" -#include "paddle/fluid/platform/device_context.h" -#include "paddle/fluid/platform/enforce.h" -#include "paddle/fluid/platform/place.h" -#include "paddle/fluid/platform/port.h" -#include "paddle/fluid/string/printf.h" -#include "paddle/fluid/string/string_helper.h" - -namespace paddle { -namespace operators { -namespace distributed { - -enum Mode { training, infer }; -enum InitType { uniform_random, fill_constant, gaussian_random }; - -inline std::vector bucket(const int v_size, const int b_size) { - int remainder = v_size % b_size; - int bucket = v_size / b_size; - std::vector ret_vec(b_size, bucket); - for (int i = 0; i < remainder; ++i) { - ret_vec[i] = ret_vec[i] + 1; - } - int cur_bucket = 0; - for (int &j : ret_vec) { - int tmp = j; - j = cur_bucket; - cur_bucket += tmp; - } - ret_vec.push_back(cur_bucket); - return ret_vec; -} - -class Initializer { - public: - Initializer() {} - - explicit Initializer(const std::vector &attrs) {} - - virtual float GetValue() = 0; - - virtual ~Initializer() {} - - protected: - std::string name_; - unsigned int seed_; -}; - -class UniformInitializer : public Initializer { - public: - explicit UniformInitializer(const std::vector &attrs) { - name_ = attrs[0]; - seed_ = static_cast(std::stoi(attrs[1])); - min_ = std::stof(attrs[2]); - max_ = std::stof(attrs[3]); - - dist_ = std::uniform_real_distribution(min_, max_); - random_engine_ = framework::GetCPURandomEngine(seed_); - } - - float GetValue() override { return dist_(*random_engine_); } - - private: - float min_; - float max_; - - std::shared_ptr random_engine_; - std::uniform_real_distribution dist_; -}; - -template -inline bool entry(const int count, const T threshold); - -template <> -inline bool entry(const int count, const std::string threshold) { - return true; -} - -template <> -inline bool entry(const int count, const int threshold) { - return count >= threshold; -} - -template <> -inline bool entry(const int count, const float threshold) { - UniformInitializer uniform = UniformInitializer({"0", "0", "1"}); - return uniform.GetValue() >= threshold; -} - -class GaussianInitializer : public Initializer { - public: - explicit GaussianInitializer(const std::vector &attrs) { - name_ = attrs[0]; - seed_ = static_cast(std::stoi(attrs[1])); - mean_ = std::stof(attrs[2]); - std_ = std::stof(attrs[3]); - - random_engine_ = framework::GetCPURandomEngine(seed_); - - dist_ = std::normal_distribution(mean_, std_); - } - - float GetValue() override { return dist_(*random_engine_); } - - private: - float std_; - float mean_; - - std::shared_ptr random_engine_; - std::normal_distribution dist_; -}; - -class FillConstantInitializer : public Initializer { - public: - explicit FillConstantInitializer(const std::vector &attrs) { - name_ = attrs[0]; - value_ = std::stof(attrs[1]); - } - - float GetValue() override { return value_; } - - private: - float value_; -}; - -struct SparseMeta { - std::string name; - std::string grad_name; - std::vector value_names; - std::vector value_dims; - std::vector cached_varnames; - std::vector initializer_attrs; - std::string entry; - Mode mode; - - std::string ToString() { - std::stringstream ss; - ss << "name: " << name << " "; - ss << "mode: " << mode << " "; - - for (int i = 0; i < static_cast(value_names.size()); i++) { - ss << "value_name: " << value_names[i] << " dim: " << value_dims[i] - << " "; - } - - ss << " grad var: " << grad_name; - - ss << " cached varnames: "; - for (int i = 0; i < static_cast(cached_varnames.size()); i++) { - ss << cached_varnames[i] << " "; - } - - ss << " initializer attrs: "; - for (int i = 0; i < static_cast(initializer_attrs.size()); i++) { - ss << initializer_attrs[i] << " "; - } - - ss << " entry attrs: " << entry; - - return ss.str(); - } -}; - -struct VALUE { - explicit VALUE(const std::vector &names) - : names_(names), count_(0), unseen_days_(0) { - values_.resize(names.size()); - for (int i = 0; i < static_cast(names.size()); i++) { - places[names[i]] = i; - } - } - - void set(std::vector> *values) { - values_ = std::move(*values); - } - - void set(const std::vector &names, - const std::vector> &values) { - for (int i = 0; i < static_cast(names.size()); i++) { - auto idx = places[names[i]]; - auto value = values[i]; - values_[idx].assign(value.begin(), value.end()); - } - } - - std::vector *> get() { - auto pts = std::vector *>(); - pts.reserve(values_.size()); - - for (auto &value : values_) { - pts.push_back(&value); - } - return pts; - } - - int fetch_count() { return ++count_; } - void reset_unseen_days() { unseen_days_ = 0; } - - void set_entry(bool is_entry) { is_entry_ = is_entry; } - - bool get_entry() { return is_entry_; } - - std::vector *> get(const std::vector names) { - auto pts = std::vector *>(); - pts.reserve(values_.size()); - - for (int i = 0; i < static_cast(names.size()); i++) { - pts.push_back(&(values_[places[names[i]]])); - } - return pts; - } - - std::vector names_; - int count_; - bool seen_after_last_save_; - int unseen_days_; - bool is_entry_; - std::vector> values_; - std::unordered_map places; -}; - -class ValueBlock { - public: - explicit ValueBlock(const std::vector value_names, - const std::vector value_dims, const Mode &mode, - const std::vector &init_attrs, - const std::string &entry_attr) - : value_names_(value_names), value_dims_(value_dims), mode_(mode) { - // for Initializer - for (size_t i = 0; i < value_names.size(); i++) { - auto name = value_names[i]; - auto slices = string::split_string(init_attrs[i], "&"); - - if (slices[0] == "gaussian_random") { - initializers_[name] = new GaussianInitializer(slices); - } else if (slices[0] == "fill_constant") { - initializers_[name] = new FillConstantInitializer(slices); - } else if (slices[0] == "uniform_random") { - initializers_[name] = new UniformInitializer(slices); - } else { - PADDLE_THROW( - platform::errors::InvalidArgument("%s can not be supported", name)); - } - } - - // for Entry - { - if (entry_attr == "none") { - entry_func_ = - std::bind(entry, std::placeholders::_1, "none"); - } else { - auto slices = string::split_string(entry_attr, "&"); - if (slices[0] == "count_filter") { - int threshold = std::stoi(slices[1]); - entry_func_ = std::bind(entry, std::placeholders::_1, threshold); - } else if (slices[0] == "probability") { - float threshold = std::stof(slices[1]); - entry_func_ = - std::bind(entry, std::placeholders::_1, threshold); - } - } - } - - rwlock_.reset(new framework::RWLock); - } - - ~ValueBlock() { - // for (auto init : initializers_) { - // delete init.second; - // initializers_.erase(init.first); - // } - // - // for (auto value : values_) { - // delete value.second; - // values_.erase(value.first); - // } - } - - void Init(const int64_t &id, std::vector> *values, - int count) { - if (Has(id)) { - PADDLE_THROW(platform::errors::AlreadyExists("id already exist, error")); - } - - if (values->size() != value_names_.size()) { - PADDLE_THROW( - platform::errors::AlreadyExists("values can not match, error")); - } - - auto value = new VALUE(value_names_); - value->set(values); - value->seen_after_last_save_ = true; - value->count_ = count; - values_[id] = value; - } - - std::vector *> Get( - const int64_t &id, const std::vector &value_names) { - rwlock_->RDLock(); - auto ret_values = values_.at(id)->get(value_names); - rwlock_->UNLock(); - return ret_values; - } - - void InitFromInitializer(const int64_t &id, - const std::vector &value_names) { - rwlock_->WRLock(); - - if (Has(id)) { - Update(id); - rwlock_->UNLock(); - return; - } - - auto rets = std::vector>(); - rets.resize(value_names_.size()); - - for (int i = 0; i < static_cast(value_names_.size()); i++) { - auto name = value_names_[i]; - auto *init = initializers_.at(name); - - auto dim = value_dims_[i]; - rets[i].resize(dim); - - for (int j = 0; j < static_cast(dim); j++) { - rets[i][j] = init->GetValue(); - } - } - - Init(id, &rets, 0); - Update(id); - rwlock_->UNLock(); - } - - bool GetEntry(const int64_t &id) { - rwlock_->RDLock(); - auto value = values_.at(id); - auto entry = value->get_entry(); - rwlock_->UNLock(); - return entry; - } - - void Set(const int64_t &id, const std::vector &value_names, - const std::vector> &values) { - rwlock_->WRLock(); - auto value = values_.at(id); - value->set(value_names, values); - rwlock_->UNLock(); - } - - void Update(const int64_t id) { - auto *value = values_.at(id); - value->reset_unseen_days(); - auto count = value->fetch_count(); - - if (!value->get_entry()) { - value->set_entry(entry_func_(count)); - } - } - - private: - bool Has(const int64_t id) { - auto got = values_.find(id); - if (got == values_.end()) { - return false; - } else { - return true; - } - } - - public: - std::unordered_map values_; - - private: - std::vector value_names_; - std::vector value_dims_; - Mode mode_; - std::function entry_func_; - std::unordered_map initializers_; - std::unique_ptr rwlock_{nullptr}; -}; - -class SparseVariable { - public: - explicit SparseVariable(const SparseMeta &meta) { - meta_.name = meta.name; - meta_.mode = meta.mode; - meta_.value_names = meta.value_names; - meta_.value_dims = meta.value_dims; - meta_.grad_name = meta.grad_name; - meta_.cached_varnames = meta.cached_varnames; - meta_.initializer_attrs = meta.initializer_attrs; - meta_.entry = meta.entry; - - for (int i = 0; i < static_cast(meta_.value_names.size()); i++) { - values_dims_[meta_.value_names[i]] = meta_.value_dims[i]; - } - - for (size_t i = 0; i < shard_num_; i++) { - auto block = std::make_shared( - meta.value_names, meta.value_dims, meta.mode, meta.initializer_attrs, - meta.entry); - shard_blocks_.emplace_back(block); - } - - rwlock_.reset(new framework::RWLock); - } - - void Init(const std::vector &ids) { - rwlock_->RDLock(); - for (auto &id : ids) { - auto *block = GetShard(id); - block->InitFromInitializer(id, meta_.value_names); - } - rwlock_->UNLock(); - } - - void Get(const std::vector &ids, - const std::vector &value_names, - std::vector *>> *values) { - values->resize(ids.size()); - - auto buckets = bucket(ids.size(), 8); - std::vector> fs; - - for (int j = 0; j < 8; ++j) { - auto begin = buckets[j]; - auto end = buckets[j + 1]; - - fs.push_back( - framework::Async([begin, end, &values, &ids, &value_names, this]() { - for (int x = begin; x < end; x++) { - auto id = ids[x]; - auto *block = GetShard(id); - auto id_values = block->Get(id, value_names); - (*values)[x] = id_values; - } - })); - } - - for (size_t i = 0; i < fs.size(); ++i) fs[i].wait(); - } - - void GetEntry(const std::vector &ids, std::vector *values) { - auto buckets = bucket(ids.size(), 8); - std::vector> fs; - - for (int j = 0; j < 8; ++j) { - auto begin = buckets[j]; - auto end = buckets[j + 1]; - - fs.push_back(framework::Async([begin, end, &values, &ids, this]() { - for (int x = begin; x < end; x++) { - auto id = ids[x]; - auto *block = GetShard(id); - auto is_entry = block->GetEntry(id); - - if (!is_entry) { - values->push_back(id); - } - } - })); - } - for (size_t i = 0; i < fs.size(); ++i) fs[i].wait(); - } - - void Set(const std::vector &ids, - const std::vector &value_names, - const std::vector>> &values) { - for (int i = 0; i < static_cast(ids.size()); i++) { - GetShard(ids[i])->Set(ids[i], value_names, values[i]); - } - } - - void Dims(std::vector value_names, std::vector *dims) { - for (auto &name : value_names) { - dims->push_back(values_dims_.at(name)); - } - } - - std::vector CachedVarnames() const { - return meta_.cached_varnames; - } - - void Load(const std::string &dirname) { - rwlock_->WRLock(); - VLOG(1) << "load " << meta_.name << " from dir: " << dirname << " begin"; - - std::vector filenames; - for (auto &value_name : meta_.value_names) { - auto filename = string::Sprintf("%s/%s", dirname, value_name); - filenames.push_back(filename); - } - - LoadFromSelectedRows(filenames, meta_.value_names); - VLOG(1) << "load " << meta_.name << " in dir: " << dirname << " done"; - rwlock_->UNLock(); - } - - void LoadFromSelectedRows(const std::vector &filenames, - const std::vector &valuenames) { - std::vector> variables; - auto place = platform::CPUPlace(); - - for (int i = 0; i < static_cast(filenames.size()); i++) { - auto var = std::make_shared(); - variables.push_back(var); - auto &filename = filenames[i]; - std::ifstream fin(filename, std::ios::binary); - auto *selectedRows = var->GetMutable(); - - platform::DeviceContextPool &pool = - platform::DeviceContextPool::Instance(); - auto &dev_ctx = *pool.Get(place); - - framework::DeserializeFromStream(fin, selectedRows, dev_ctx); - selectedRows->SyncIndex(); - } - - std::vector tensors; - - for (int i = 0; i < static_cast(filenames.size()); i++) { - auto &slr = variables[i]->Get(); - auto src_t = slr.value(); - const auto *value = src_t.data(); - tensors.push_back(value); - } - - for (int i = 1; i < static_cast(filenames.size()); i++) { - auto rows_0 = variables[0]->Get().rows(); - auto rows_i = variables[i]->Get().rows(); - - bool is_equal = std::equal(rows_0.begin(), rows_0.end(), rows_i.begin()); - - if (!is_equal) { - PADDLE_THROW(platform::errors::InvalidArgument( - "%s and %s are not equal, can not be load rightly", filenames[0], - filenames[i])); - } - } - - auto rows = variables[0]->Get().rows(); - - for (auto i = 0; i < static_cast(rows.size()); i++) { - auto id = rows[i]; - std::vector> values; - values.resize(filenames.size()); - - for (int j = 0; j < static_cast(filenames.size()); ++j) { - values[j].resize(meta_.value_dims[j]); - std::memcpy(values[j].data(), tensors[j] + i * meta_.value_dims[j], - sizeof(float) * meta_.value_dims[j]); - } - - auto *block = GetShard(id); - block->Init(id, &values, 0); - block->Update(id); - } - } - - void Save(const std::string &dirname, const int mode = 0) { - rwlock_->WRLock(); - VLOG(3) << "save " << meta_.name << " in dir: " << dirname << " begin"; - - MkDirRecursively(dirname.c_str()); - - std::vector filenames; - for (auto &value_name : meta_.value_names) { - auto filename = string::Sprintf("%s/%s", dirname, value_name); - filenames.push_back(filename); - } - - SaveToSelectedRows(filenames, meta_.value_names, mode); - VLOG(3) << "save " << meta_.name << " in dir: " << dirname << " done"; - rwlock_->UNLock(); - } - - void SaveToSelectedRows(const std::vector &filenames, - const std::vector &valuenames, - const int mode) { - for (auto &value_name : valuenames) { - auto it = std::find(meta_.value_names.begin(), meta_.value_names.end(), - value_name); - if (it == meta_.value_names.end()) { - PADDLE_THROW(platform::errors::InvalidArgument( - "[%s] is invalid param for [%s]", value_name, meta_.name)); - } - } - - auto place = platform::CPUPlace(); - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto &dev_ctx = *pool.Get(place); - - std::vector ids; - - for (auto &block : shard_blocks_) { - for (auto value : block->values_) { - if (mode == 0) { - ids.push_back(value.first); - } else { - bool id_need_save = false; - // save all params - if (mode == 1) { - id_need_save = true; - } else { - id_need_save = value.second->seen_after_last_save_; - } - - if (id_need_save) { - ids.push_back(value.first); - } - value.second->seen_after_last_save_ = false; - } - } - } - - VLOG(3) << "save " << ids.size() << " feasigns for " << meta_.name - << " with mode: " << mode; - - std::vector> variables; - std::vector tensors; - std::vector dims; - - for (int i = 0; i < static_cast(filenames.size()); i++) { - auto dim = values_dims_.at(valuenames[i]); - auto var = std::make_shared(); - auto *slr = var->GetMutable(); - auto *src_t = slr->mutable_value(); - - src_t->Resize({static_cast(ids.size()), dim}); - auto *value = src_t->mutable_data(place); - - dims.push_back(dim); - variables.push_back(var); - tensors.push_back(value); - } - - std::vector *>> values; - Get(ids, valuenames, &values); - - int64_t offset = 0; - for (auto &vss : values) { - for (int i = 0; i < static_cast(vss.size()); i++) { - auto &vs = vss[i]; - std::memcpy(tensors[i] + offset * dims[i], vs->data(), - sizeof(float) * dims[i]); - } - offset += 1; - } - - for (auto &var : variables) { - auto *slr = var->GetMutable(); - slr->set_rows(ids); - slr->set_height(ids.size()); - } - - for (int i = 0; i < static_cast(filenames.size()); i++) { - auto &filename = filenames[i]; - auto &selectedRows = variables[i]->Get(); - - std::ofstream fout(filename, std::ios::binary); - PADDLE_ENFORCE_EQ(static_cast(fout), true, - platform::errors::Unavailable( - "Cannot open %s to save variables.", filename)); - - framework::SerializeToStream(fout, selectedRows, dev_ctx); - fout.close(); - } - } - - void SaveToText(const std::vector &filenames, - const std::vector &valuenames) { - for (auto &value_name : valuenames) { - auto it = std::find(meta_.value_names.begin(), meta_.value_names.end(), - value_name); - if (it == meta_.value_names.end()) { - PADDLE_THROW(platform::errors::InvalidArgument( - "[%s] is invalid param for [%s]", value_name, meta_.name)); - } - } - - std::vector> fouts; - - for (auto filename : filenames) { - std::unique_ptr fout(new std::ofstream(filename)); - fouts.push_back(std::move(fout)); - } - - for (auto &block : shard_blocks_) { - for (auto value : block->values_) { - std::vector *> vss = value.second->get(valuenames); - - auto id = value.first; - - for (int i = 0; i < static_cast(vss.size()); i++) { - auto &vs = vss[i]; - std::stringstream ss; - ss << id << "\t"; - ss << vs->size() << "\t"; - for (auto v : (*vs)) { - ss << v << " "; - } - ss << "\n"; - - fouts[i]->write(ss.str().c_str(), sizeof(char) * ss.str().size()); - } - } - } - - for (int i = 0; i < static_cast(fouts.size()); i++) { - fouts[i]->close(); - } - } - - int64_t Size() { - int64_t cnt = 0; - - for (auto &block : shard_blocks_) { - cnt += block->values_.size(); - } - return cnt; - } - - ValueBlock *GetShard(const int64_t id) { - return shard_blocks_[id & shard_mask_].get(); - } - - SparseMeta *GetMeta() { return &meta_; } - - private: - std::unique_ptr rwlock_{nullptr}; - - SparseMeta meta_; - std::unordered_map values_dims_; - const size_t shard_mask_ = 127; - const size_t shard_num_ = 128; - std::vector> shard_blocks_; -}; - -class LargeScaleKV { - public: - LargeScaleKV() {} - - explicit LargeScaleKV(const std::vector &table_metas) { - for (auto &sparse_meta : table_metas) { - auto table_name = sparse_meta.name; - auto meta = std::shared_ptr( - new SparseVariable(std::move(sparse_meta))); - sparse_variables[table_name] = meta; - grad_to_variables[sparse_meta.grad_name] = table_name; - grad_names_.push_back(sparse_meta.grad_name); - } - } - - ~LargeScaleKV() {} - - static std::shared_ptr GetInstantcePtr() { return scale_kv_; } - - static LargeScaleKV *GetInstance() { return scale_kv_.get(); } - - static LargeScaleKV *InitInstance( - const std::vector &table_metas) { - std::call_once(init_flag_, &LargeScaleKV::Init, table_metas); - return scale_kv_.get(); - } - - static void Init(const std::vector &table_metas) { - if (scale_kv_.get() == nullptr) { - scale_kv_.reset(new LargeScaleKV(table_metas)); - } - } - - SparseVariable *Get(const std::string &name) { - auto variable = sparse_variables.at(name); - return variable.get(); - } - - bool ParamInLargeScale(const std::string &name) { - auto got = sparse_variables.find(name); - - if (got == sparse_variables.end()) { - return false; - } - - return true; - } - - bool GradInLargeScale(const std::string &name) { - auto got = grad_to_variables.find(name); - - if (got == grad_to_variables.end()) { - return false; - } - - return true; - } - - SparseVariable *GetByGrad(const std::string &name) { - return Get(grad_to_variables[name]); - } - - const std::vector &GetAllGrads() { return grad_names_; } - - private: - std::unordered_map> - sparse_variables; - std::unordered_map grad_to_variables; - std::vector grad_names_; - static std::shared_ptr scale_kv_; - static std::once_flag init_flag_; -}; - -} // namespace distributed -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/distributed/parameter_prefetch.cc b/paddle/fluid/operators/distributed/parameter_prefetch.cc deleted file mode 100644 index 558d70e5c3..0000000000 --- a/paddle/fluid/operators/distributed/parameter_prefetch.cc +++ /dev/null @@ -1,311 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/operators/distributed/parameter_prefetch.h" -#include -#include -#include -#include -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/selected_rows.h" -#include "paddle/fluid/operators/distributed/distributed.h" - -namespace paddle { -namespace framework { -class ExecutionContext; -class Scope; -} // namespace framework -} // namespace paddle - -namespace paddle { -namespace operators { -namespace distributed { - -class RPCClient; - -using LoDTensor = framework::LoDTensor; -using LoDTensor = framework::LoDTensor; -using SelectedRows = framework::SelectedRows; -using DDim = framework::DDim; - -static void SplitIdsIntoMultipleVarsBySection( - const std::vector &in_ids, - const std::vector &in_varnames, const int tables, - const int pservers, const bool is_distibuted, framework::Scope *scope, - std::vector> *splited_ids, - std::vector> *origin_ids) { - PADDLE_ENFORCE_EQ( - in_varnames.size(), tables, - platform::errors::OutOfRange( - "send varnames size: %d not equal table number: %d, internal error", - in_varnames.size(), tables)); - - PADDLE_ENFORCE_LE( - tables, pservers, - platform::errors::OutOfRange("table number %d not equal or less than " - "pserver number: %d, internal error", - tables, pservers)); - - auto place = platform::CPUPlace(); - - std::set st(in_ids.begin(), in_ids.end()); - std::vector all_ids; - all_ids.assign(st.begin(), st.end()); - - splited_ids->resize(tables); - origin_ids->resize(tables); - - if (is_distibuted) { - for (auto &id : all_ids) { - auto pserver_id = id % pservers; - (*splited_ids)[pserver_id].push_back(id); - (*origin_ids)[pserver_id].push_back(id); - } - } else { - for (auto &id : all_ids) { - auto pserver_id = id % pservers; - (*origin_ids)[pserver_id].push_back(id); - id = id / pservers; - (*splited_ids)[pserver_id].push_back(id); - } - } - - for (size_t i = 0; i < in_varnames.size(); ++i) { - auto *id_tensor = - scope->Var(in_varnames[i])->GetMutable(); - - auto &ids = (*splited_ids)[i]; - if (!ids.empty()) { - auto *id_tensor_data = id_tensor->mutable_data( - framework::make_ddim({static_cast(ids.size()), 1}), place); - memcpy(id_tensor_data, ids.data(), sizeof(int64_t) * ids.size()); - } - } -} - -typedef std::vector> TableAndEndpoints; - -void prefetch_core( - const std::vector &ids, const TableAndEndpoints &tables, - const framework::ExecutionContext &context, const framework::Scope &scope, - const bool is_distributed, - std::unordered_map> *recved_vec_map) { - distributed::RPCClient *rpc_client = - distributed::RPCClient::GetInstance( - context.Attr("trainer_id")); - - int pservers = context.Attr("pserver_num"); - - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto &actual_ctx = *pool.Get(platform::CPUPlace()); - - std::unique_ptr local_scope = scope.NewTmpScope(); - - std::vector in_var_names; - std::vector out_var_names; - for (size_t i = 0; i < tables.size(); ++i) { - in_var_names.push_back("prefetch_send@" + tables[i].second); - out_var_names.push_back("prefetch_recv@" + tables[i].second); - } - - std::vector> split_ids; - std::vector> origin_ids; - SplitIdsIntoMultipleVarsBySection(ids, in_var_names, tables.size(), pservers, - is_distributed, local_scope.get(), - &split_ids, &origin_ids); - - // create output var in local scope - for (auto &name : out_var_names) { - local_scope->Var(name)->GetMutable(); - } - - std::vector rets; - for (size_t i = 0; i < in_var_names.size(); i++) { - if (NeedSend(*local_scope.get(), in_var_names[i])) { - VLOG(3) << "sending " << in_var_names[i] << " to " << tables[i].second - << " to get " << out_var_names[i] << " back"; - rets.push_back(rpc_client->AsyncPrefetchVar( - tables[i].second, actual_ctx, *local_scope.get(), in_var_names[i], - out_var_names[i], tables[i].first)); - } else { - VLOG(3) << "don't send no-initialied variable: " << out_var_names[i]; - } - } - for (size_t i = 0; i < rets.size(); i++) { - PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::ExecutionTimeout( - "internal error in RPCClient")); - } - - for (size_t o_idx = 0; o_idx < out_var_names.size(); ++o_idx) { - auto &ids_in_this_section = origin_ids[o_idx]; - - if (!ids_in_this_section.empty()) { - auto &prefetch_out_var = - local_scope->Var(out_var_names[o_idx])->Get(); - const auto *out_var_data = prefetch_out_var.data(); - auto &dims = prefetch_out_var.dims(); - - PADDLE_ENFORCE_EQ(dims.size(), 2, - platform::errors::InvalidArgument( - "The size of Tensor dims must be 2.")); - PADDLE_ENFORCE_EQ(ids_in_this_section.size(), dims[0], - platform::errors::InvalidArgument( - "The size of ids in this section must equal to " - "dims[0]: %s, but got %s", - dims[0], ids_in_this_section.size())); - - auto row_numel = dims[1]; - - for (int64_t i = 0; i < dims[0]; ++i) { - auto origin_id = ids_in_this_section[i]; - std::vector vecs(row_numel); - - std::copy_n(out_var_data + i * row_numel, row_numel, vecs.begin()); - (*recved_vec_map)[origin_id] = vecs; - } - } else { - VLOG(3) << "ids in this section is empty"; - } - } -} - -void prefetch(const std::string &id_name, const std::string &out_name, - const std::string &persistable_var_name, - const bool is_distributed, - const std::vector &table_names, - const std::vector &endpoints, - const framework::ExecutionContext &context, - const framework::Scope &scope) { - prefetchs({id_name}, {out_name}, persistable_var_name, is_distributed, - table_names, endpoints, context, scope); -} - -void prefetchs(const std::vector &id_var_names, - const std::vector &out_var_names, - const std::string &persistable_var_name, - const bool is_distributed, - const std::vector &table_names, - const std::vector &endpoints, - const framework::ExecutionContext &context, - const framework::Scope &scope) { - auto vec_dim_1 = 0; - auto vec_dim_0 = 0; - framework::Variable *var = scope.FindVar(persistable_var_name); - - if (var->IsType()) { - vec_dim_1 = var->Get().value().dims()[1]; - } else { - vec_dim_0 = var->Get().dims()[0]; - vec_dim_1 = var->Get().dims()[1]; - } - - PADDLE_ENFORCE_GT(vec_dim_1, 0, - platform::errors::InvalidArgument( - "lookup table var's dim must gather than 0")); - - const auto place = - scope.FindVar(id_var_names[0])->Get().place(); - - std::vector> ids_group; - std::vector ids_union; - std::vector ids_lods; - TableAndEndpoints tables; - - for (auto &id_name : id_var_names) { - auto &id_tensor = scope.FindVar(id_name)->Get(); - std::vector ids; - TensorToVector(id_tensor, context.device_context(), &ids); - ids_union.insert(ids_union.end(), ids.begin(), ids.end()); - ids_group.push_back(ids); - ids_lods.push_back(id_tensor.lod()); - } - - std::unordered_set s(ids_union.begin(), ids_union.end()); - ids_union.assign(s.begin(), s.end()); - - for (auto &i : ids_union) { - PADDLE_ENFORCE_GE( - i, 0, platform::errors::OutOfRange( - "each element in embedding should be larger or equal 0")); - if (!is_distributed) { - PADDLE_ENFORCE_LT( - i, vec_dim_0, - platform::errors::OutOfRange( - "embedding id must in [0, %d) when is_distributed False", - vec_dim_0)); - } - } - - for (size_t i = 0; i < table_names.size(); i++) { - tables.push_back(std::make_pair(table_names[i], endpoints[i])); - } - std::unordered_map> recved_vec_map; - prefetch_core(ids_union, tables, context, scope, is_distributed, - &recved_vec_map); - - auto padding_idx = distributed::kNoPadding; - - if (context.HasAttr("padding_idx")) { - padding_idx = context.Attr("padding_idx"); - } - - for (size_t i = 0; i < out_var_names.size(); i++) { - std::vector ids = ids_group[i]; - auto ids_size = ids.size(); - auto *out_t = - scope.FindVar(out_var_names[i])->GetMutable(); - out_t->set_lod(ids_lods[i]); - out_t->Resize( - framework::make_ddim({static_cast(ids_size), vec_dim_1})); - auto *out_d = out_t->mutable_data(place); - - if (platform::is_cpu_place(out_t->place())) { - for (auto idx = 0; idx < static_cast(ids_size); idx++) { - const auto &id = ids[idx]; - if (padding_idx != distributed::kNoPadding && id == padding_idx) { - memset(out_d + idx * vec_dim_1, 0, sizeof(float) * vec_dim_1); - } else { - std::copy_n(recved_vec_map[id].begin(), vec_dim_1, - out_d + idx * vec_dim_1); - } - } - } else { -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - std::vector ids_value_vec(ids_size * vec_dim_1); - for (auto idx = 0; idx < static_cast(ids_size); idx++) { - const auto &id = ids[idx]; - if (padding_idx != distributed::kNoPadding && id == padding_idx) { - memset(&ids_value_vec[idx * vec_dim_1], 0, sizeof(float) * vec_dim_1); - } else { - memcpy(&ids_value_vec[idx * vec_dim_1], &recved_vec_map[id][0], - sizeof(float) * vec_dim_1); - } - } - auto &gpu_place = BOOST_GET_CONST(platform::CUDAPlace, out_t->place()); - auto &cpu_place = BOOST_GET_CONST( - platform::CPUPlace, paddle::platform::CPUDeviceContext().GetPlace()); - auto stream = context.cuda_device_context().stream(); - memory::Copy(gpu_place, out_d, cpu_place, &ids_value_vec[0], - sizeof(float) * ids_size * vec_dim_1, stream); -#else - PADDLE_ENFORCE(true, platform::errors::PermissionDenied( - "Paddle is not compiled with GPU!")); -#endif - } - } -} - -}; // namespace distributed -}; // namespace operators -}; // namespace paddle diff --git a/paddle/fluid/operators/distributed/parameter_prefetch.h b/paddle/fluid/operators/distributed/parameter_prefetch.h deleted file mode 100644 index 6fd3a99881..0000000000 --- a/paddle/fluid/operators/distributed/parameter_prefetch.h +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include - -#include "paddle/fluid/framework/operator.h" - -namespace paddle { -namespace framework { -class ExecutionContext; -class Scope; -} // namespace framework -} // namespace paddle - -namespace paddle { -namespace operators { -namespace distributed { - -constexpr int64_t kNoPadding = -1; - -void prefetchs(const std::vector& id_var_names, - const std::vector& out_var_names, - const std::string& persistable_var_name, const bool backfill, - const std::vector& table_names, - const std::vector& endpoints, - const framework::ExecutionContext& context, - const framework::Scope& scope); - -void prefetch(const std::string& id_name, const std::string& out_name, - const std::string& persistable_var_name, const bool backfill, - const std::vector& table_names, - const std::vector& endpoints, - const framework::ExecutionContext& context, - const framework::Scope& scope); - -}; // namespace distributed -}; // namespace operators -}; // namespace paddle diff --git a/paddle/fluid/operators/distributed/parameter_recv.cc b/paddle/fluid/operators/distributed/parameter_recv.cc deleted file mode 100644 index d5d3c9c3c7..0000000000 --- a/paddle/fluid/operators/distributed/parameter_recv.cc +++ /dev/null @@ -1,248 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include - -#include "glog/logging.h" -#include "paddle/fluid/framework/ddim.h" -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/scope.h" -#include "paddle/fluid/framework/selected_rows.h" -#include "paddle/fluid/operators/distributed/communicator_common.h" -#include "paddle/fluid/operators/distributed/distributed.h" -#include "paddle/fluid/operators/distributed/parameter_recv.h" -#include "paddle/fluid/platform/device_context.h" -#include "paddle/fluid/platform/enforce.h" -#include "paddle/fluid/platform/place.h" - -namespace paddle { -namespace operators { -namespace distributed { - -class RPCClient; - -using LoDTensor = framework::LoDTensor; -using LoDTensor = framework::LoDTensor; -using SelectedRows = framework::SelectedRows; -using DDim = framework::DDim; - -template -void RecvSparseLodTensor(const CommContext &rpc_ctx, - const framework::Scope &scope) { - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto cpu_place = platform::CPUPlace(); - auto &cpu_ctx = *pool.Get(cpu_place); - - distributed::RPCClient *rpc_client = - distributed::RPCClient::GetInstance(rpc_ctx.trainer_id); - - std::unique_ptr local_scope = scope.NewTmpScope(); - std::vector tensors; - std::vector rets; - std::vector recv_varnames; - for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) { - auto &recv_var_name = rpc_ctx.splited_varnames[i]; - VLOG(4) << "recv " << recv_var_name << " from " << rpc_ctx.epmap[i]; - local_scope->Var(recv_var_name); - // sparse param in recv_scope is LoDTensor - rets.push_back(rpc_client->AsyncGetVarNoBarrier( - rpc_ctx.epmap[i], cpu_ctx, *local_scope.get(), recv_var_name, - recv_var_name)); - recv_varnames.push_back(recv_var_name); - } - - for (size_t i = 0; i < rets.size(); i++) { - PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::ExecutionTimeout( - "internal error in RPCClient")); - auto &recv_var_name = recv_varnames[i]; - auto *local_var = local_scope->FindVar(recv_var_name); - const auto *value = local_var->Get().data(); - tensors.push_back(value); - } - - auto *merged_var = scope.FindVar(rpc_ctx.var_name); - - if (merged_var == nullptr || !merged_var->IsInitialized()) { - PADDLE_THROW( - platform::errors::InvalidArgument("%s must initialized at first.")); - } - auto dims1 = merged_var->Get().dims()[1]; - int64_t height = 0; - for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) { - auto *splited_var = local_scope->FindVar(rpc_ctx.splited_varnames[i]); - height += splited_var->Get().dims()[0]; - } - - PADDLE_ENFORCE_EQ( - merged_var->Get().dims()[0], height, - platform::errors::InvalidArgument( - "Received variable must has same dimension with local variable.")); - - auto *merged_t = merged_var->GetMutable(); - auto *merged_d = merged_t->mutable_data(cpu_place); - - auto pserver_num = rpc_ctx.splited_varnames.size(); - for (int x = 0; x < height; ++x) { - auto id = x % pserver_num; - auto idx = x / pserver_num; - std::memcpy(merged_d + x * dims1, tensors[id] + idx * dims1, - sizeof(float) * dims1); - } -} - -template -void RecvGeoSparseRecords(const CommContext &rpc_ctx, - const framework::Scope &scope) { - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto cpu_place = platform::CPUPlace(); - auto &cpu_ctx = *pool.Get(cpu_place); - - distributed::RPCClient *rpc_client = - distributed::RPCClient::GetInstance(rpc_ctx.trainer_id); - - std::unique_ptr local_scope = scope.NewTmpScope(); - - std::vector rets; - for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) { - auto &recv_var_name = rpc_ctx.splited_varnames[i]; - local_scope->Var(recv_var_name); - VLOG(4) << "recv " << recv_var_name << " from " << rpc_ctx.epmap[i]; - // sparse param in recv_scope is LoDTensor - rets.push_back(rpc_client->AsyncGetVar(rpc_ctx.epmap[i], cpu_ctx, - *local_scope.get(), recv_var_name, - recv_var_name, recv_var_name)); - } - - for (size_t i = 0; i < rets.size(); i++) { - PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::ExecutionTimeout( - "internal error in RPCClient")); - } - - int64_t height = 0; - int64_t ids_num = 0; - int64_t width = 0; - - std::vector all_ids; - auto pserver_num = rpc_ctx.splited_varnames.size(); - - for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) { - auto &recv_var_name = rpc_ctx.splited_varnames[i]; - auto *recv_var = local_scope->FindVar(recv_var_name); - auto &recv_t = recv_var->Get(); - - height += recv_t.height(); - ids_num += recv_t.rows().size(); - width = recv_t.value().dims()[1]; - - if (rpc_ctx.is_distributed) { - std::copy(recv_t.rows().begin(), recv_t.rows().end(), - std::back_inserter(all_ids)); - } else { - std::transform(recv_t.rows().begin(), recv_t.rows().end(), - std::back_inserter(all_ids), - [&](int64_t id) { return id * pserver_num + i; }); - } - } - - auto *var = scope.FindVar(rpc_ctx.var_name); - auto *t_ = var->GetMutable(); - T *out_data = - t_->mutable_value()->mutable_data({ids_num, width}, cpu_place); - t_->set_height(height); - t_->set_rows(all_ids); - - int64_t cnt = 0; - for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) { - auto &recv_var_name = rpc_ctx.splited_varnames[i]; - auto *recv_var = local_scope->FindVar(recv_var_name); - auto &recv_t = recv_var->Get(); - - auto rows = recv_t.rows().size(); - const T *in_data = recv_t.value().data(); - std::copy_n(in_data, rows * width, out_data + cnt); - cnt += rows * width; - } - t_->SyncIndex(); -} - -template -void RecvLodTensor(const CommContext &rpc_ctx, const framework::Scope &scope) { - distributed::RPCClient *rpc_client = - distributed::RPCClient::GetInstance(rpc_ctx.trainer_id); - - std::vector rets; - - // variable do not spilt - if (rpc_ctx.origin_varnames.size() == 1 && - rpc_ctx.splited_varnames.size() == 1) { - auto varname = rpc_ctx.origin_varnames[0]; - const auto place = - scope.FindVar(varname)->Get().place(); - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto &ctx = *pool.Get(place); - VLOG(4) << "recv " << varname << " from " << rpc_ctx.epmap[0] << " in gpu? " - << platform::is_gpu_place(place); - rets.push_back(rpc_client->AsyncGetVarNoBarrier(rpc_ctx.epmap[0], ctx, - scope, varname, varname)); - - for (size_t i = 0; i < rets.size(); i++) { - PADDLE_ENFORCE_NE( - rets[i]->Wait(), 0U, - platform::errors::ExecutionTimeout("internal error in RPCClient")); - } - - VLOG(3) << "ParameterRecv out " << rpc_ctx.var_name; - return; - } else { - PADDLE_ENFORCE(false, platform::errors::Unimplemented( - "ParameterRecv can not recv dense with multi " - "parts now, add it soon.")); - } -} - -template -void ParameterRecv::operator()(const CommContext &rpc_ctx, - const framework::Scope &scope, - bool geo_records) { - VLOG(3) << "ParameterRecv in " << rpc_ctx.var_name; - - PADDLE_ENFORCE_GE(rpc_ctx.origin_varnames.size(), 1, - platform::errors::InvalidArgument( - "origin_varnames.size() >= 1 is permitted")); - - if (rpc_ctx.is_sparse) { - if (geo_records) { - RecvGeoSparseRecords(rpc_ctx, scope); - } else { - RecvSparseLodTensor(rpc_ctx, scope); - } - } else { - RecvLodTensor(rpc_ctx, scope); - } - - VLOG(3) << "ParameterRecv out " << rpc_ctx.var_name; -} -template -void ParameterRecv::operator()(const CommContext &rpc_ctx, - const framework::Scope &scope) { - this->operator()(rpc_ctx, scope, false); -} - -template struct ParameterRecv; - -}; // namespace distributed -}; // namespace operators -}; // namespace paddle diff --git a/paddle/fluid/operators/distributed/parameter_recv.h b/paddle/fluid/operators/distributed/parameter_recv.h deleted file mode 100644 index c30d21aa79..0000000000 --- a/paddle/fluid/operators/distributed/parameter_recv.h +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include - -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/operators/distributed/communicator_common.h" - -namespace paddle { -namespace operators { -namespace distributed { - -template -struct ParameterRecv { - void operator()(const CommContext &rpc_ctx, const framework::Scope &scope, - bool barrier); - - void operator()(const CommContext &rpc_ctx, const framework::Scope &scope); -}; - -}; // namespace distributed -}; // namespace operators -}; // namespace paddle diff --git a/paddle/fluid/operators/distributed/parameter_send.cc b/paddle/fluid/operators/distributed/parameter_send.cc deleted file mode 100644 index 109514ca25..0000000000 --- a/paddle/fluid/operators/distributed/parameter_send.cc +++ /dev/null @@ -1,331 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/operators/distributed/parameter_send.h" -#include -#include -#include "glog/logging.h" -#include "paddle/fluid/framework/ddim.h" -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/selected_rows.h" -#include "paddle/fluid/operators/distributed/communicator_common.h" -#include "paddle/fluid/operators/distributed/distributed.h" -#include "paddle/fluid/operators/distributed/request_handler.h" -#include "paddle/fluid/platform/device_context.h" -#include "paddle/fluid/platform/enforce.h" -#include "paddle/fluid/platform/place.h" - -namespace paddle { -namespace framework { -class Scope; -class Tensor; -} // namespace framework -} // namespace paddle - -namespace paddle { -namespace operators { -namespace distributed { - -class RPCClient; - -using LoDTensor = framework::LoDTensor; -using LoDTensor = framework::LoDTensor; -using SelectedRows = framework::SelectedRows; -using DDim = framework::DDim; - -typedef std::vector> EP_SPLIT_TABLE_PAIRS; - -inline EP_SPLIT_TABLE_PAIRS GetMultiFieldCommContext( - const CommContext &rpc_ctx, const framework::Scope &scope, - int multi_parts) { - EP_SPLIT_TABLE_PAIRS table_pairs; - - auto *send_var = scope.FindVar(rpc_ctx.var_name); - if (send_var->IsType()) { - PADDLE_ENFORCE_GE(multi_parts, 1, - platform::errors::InvalidArgument( - "multi_parts must == 1 in parameter send, now is: %d", - multi_parts)); - - for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) { - table_pairs.push_back( - std::make_pair(rpc_ctx.epmap[i], rpc_ctx.splited_varnames[i])); - } - - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "GetMultiFieldCommContext unsupported LoDTensor current!")); - } - - return table_pairs; -} // namespace distributed - -void SendByNotifyRPC(const CommContext &rpc_ctx, - const framework::Scope &scope) { - auto cpu_ctx = paddle::platform::CPUDeviceContext(); - auto &send_var_name = rpc_ctx.var_name; - std::vector rets; - - distributed::RPCClient *rpc_client = - distributed::RPCClient::GetInstance(rpc_ctx.trainer_id); - - if (NeedSend(scope, send_var_name)) { - for (size_t j = 0; j < rpc_ctx.epmap.size(); j++) { - auto &endpoint = rpc_ctx.epmap[j]; - VLOG(4) << "sending " << send_var_name << " to " << endpoint; - rets.push_back(rpc_client->AsyncDistributeNotify(endpoint, cpu_ctx, scope, - send_var_name)); - VLOG(4) << "send var " << send_var_name << " by notify RPC done"; - } - } else { - VLOG(3) << "don't send non-initialized variable: " << rpc_ctx.var_name; - } - - for (auto &handle : rets) { - PADDLE_ENFORCE_NE(handle->Wait(), 0U, platform::errors::ExecutionTimeout( - "internal error in RPCClient")); - } -} - -template -void ParameterSend::operator()(const CommContext &rpc_ctx, - const framework::Scope &scope, bool sync, - int multi_parts) { - if (rpc_ctx.var_name == STEP_COUNTER) { - SendByNotifyRPC(rpc_ctx, scope); - return; - } - - std::unique_ptr local_scope = scope.NewTmpScope(); - - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto &cpu_ctx = *pool.Get(platform::CPUPlace()); - - distributed::RPCClient *rpc_client = - distributed::RPCClient::GetInstance(rpc_ctx.trainer_id); - - std::vector rets; - auto *send_var = scope.FindVar(rpc_ctx.var_name); - - if (send_var->IsType()) { - size_t out_num = rpc_ctx.splited_varnames.size(); - if (out_num > 1) { - auto &send_tensor = send_var->Get(); - auto &send_tensor_dims = send_tensor.dims(); - std::vector outs_dims; - outs_dims.reserve(out_num); - - // infer output shape - PADDLE_ENFORCE_EQ( - rpc_ctx.height_sections.size(), out_num, - platform::errors::InvalidArgument("tensor split sections size" - "should be equal to output size.")); - for (size_t i = 0; i < out_num; ++i) { - auto dim = send_tensor_dims; - dim[0] = rpc_ctx.height_sections[i]; - outs_dims.push_back(dim); - } - - // create output var in local scope - size_t row_offset = 0; - for (size_t i = 0; i < out_num; ++i) { - framework::Tensor *out = local_scope->Var(rpc_ctx.splited_varnames[i]) - ->GetMutable(); - *out = send_tensor.Slice(row_offset, row_offset + outs_dims[i][0]); - row_offset += outs_dims[i][0]; - } - } else { - auto &send_tensor = send_var->Get(); - framework::Tensor *out = local_scope->Var(rpc_ctx.splited_varnames[0]) - ->GetMutable(); - out->ShareDataWith(send_tensor); - } - - for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) { - auto &send_var_name = rpc_ctx.splited_varnames[i]; - auto &endpoint = rpc_ctx.epmap[i]; - VLOG(4) << " send var name: " << send_var_name - << "endpoint: " << endpoint; - if (NeedSend(*local_scope.get(), send_var_name)) { - VLOG(3) << "sending " << send_var_name << " to " << endpoint; - rets.push_back(rpc_client->AsyncSendVar( - endpoint, cpu_ctx, *local_scope.get(), send_var_name)); - VLOG(4) << "send var " << send_var_name << " async handle done"; - } else { - VLOG(3) << "don't send non-initialized variable: " - << rpc_ctx.splited_varnames[i]; - } - } - } else if (send_var->IsType()) { - auto &send_slr = send_var->Get(); - - auto &send_rows = send_slr.rows(); - if (send_rows.size() == 0) { - LOG(WARNING) - << "WARNING: The variable sent to pserver is empty, which " - "may cause an unknown error. Please check the state of " - "use_double_buffer in pyreader/dataloader async mode, you need to " - "turn it false."; - } - - std::vector> outs_rows_idx; - std::vector> outs_dense_idx; - - auto table_pairs = GetMultiFieldCommContext(rpc_ctx, scope, 1); - outs_rows_idx.resize(table_pairs.size()); - outs_dense_idx.resize(table_pairs.size()); - - auto row_numel = send_slr.value().numel() / send_slr.value().dims()[0]; - auto *src = send_slr.value().data(); - - // create output var in local scope - std::vector outs; - for (auto &table : table_pairs) { - auto *out = - local_scope->Var(table.second)->GetMutable(); - outs.push_back(out); - } - - if (!rpc_ctx.is_distributed) { - auto pserver_num = rpc_ctx.epmap.size(); - - // split rows index into output sparse vars - for (size_t i = 0; i < send_rows.size(); ++i) { - auto ep_idx = send_rows[i] % pserver_num; - auto id = send_rows[i] / pserver_num; - outs_rows_idx[ep_idx].push_back(id); - outs_dense_idx[ep_idx].push_back(i); - } - - auto place = platform::CPUPlace(); - - for (size_t out_idx = 0; out_idx < rpc_ctx.splited_varnames.size(); - out_idx++) { - auto rows_idx = outs_rows_idx[out_idx]; - - auto dims = send_slr.GetCompleteDims(); - dims[0] = rows_idx.size(); - outs[out_idx]->set_height(rpc_ctx.height_sections[out_idx]); - outs[out_idx]->mutable_rows()->clear(); - outs[out_idx]->mutable_value()->mutable_data(dims, send_slr.place()); - - if (rows_idx.size() > 0) { - for (auto idx : rows_idx) { - outs[out_idx]->mutable_rows()->push_back(idx); - } - auto dst = outs[out_idx]->mutable_value()->mutable_data(place); - for (size_t j = 0; j < rows_idx.size(); j++) { - if (platform::is_cpu_place(place)) { - memory::Copy(platform::CPUPlace(), dst + j * row_numel, - platform::CPUPlace(), - src + outs_dense_idx[out_idx][j] * row_numel, - sizeof(T) * row_numel); - } else { - PADDLE_THROW( - platform::errors::Unimplemented("do not support GPU now")); - } - } - } - PADDLE_ENFORCE_EQ( - rows_idx.size(), outs[out_idx]->rows().size(), - platform::errors::InvalidArgument( - "rows should has the same size with tensor dim 0")); - } - } else { - auto pserver_num = rpc_ctx.epmap.size(); - - // split rows index into output sparse vars - for (size_t i = 0; i < send_rows.size(); ++i) { - auto out_idx = send_rows[i] % pserver_num; - outs_rows_idx[out_idx].push_back(send_rows[i]); - outs_dense_idx[out_idx].push_back(i); - } - - auto place = platform::CPUPlace(); - - for (size_t out_idx = 0; out_idx < rpc_ctx.splited_varnames.size(); - out_idx++) { - auto rows_idx = outs_rows_idx[out_idx]; - - auto dims = send_slr.GetCompleteDims(); - dims[0] = rows_idx.size(); - - outs[out_idx]->set_height(rpc_ctx.height_sections[out_idx]); - outs[out_idx]->mutable_rows()->clear(); - outs[out_idx]->mutable_value()->mutable_data(dims, send_slr.place()); - - if (rows_idx.size() > 0) { - for (auto idx : rows_idx) { - outs[out_idx]->mutable_rows()->push_back(idx); - } - auto dst = outs[out_idx]->mutable_value()->mutable_data(place); - for (size_t j = 0; j < rows_idx.size(); j++) { - if (platform::is_cpu_place(place)) { - memory::Copy(platform::CPUPlace(), dst + j * row_numel, - platform::CPUPlace(), - src + outs_dense_idx[out_idx][j] * row_numel, - sizeof(T) * row_numel); - } else { - PADDLE_THROW( - platform::errors::Unimplemented("do not support GPU now")); - } - } - } - PADDLE_ENFORCE_EQ( - rows_idx.size(), outs[out_idx]->rows().size(), - platform::errors::InvalidArgument( - "rows should has the same size with tensor dim 0")); - } - } - - for (size_t i = 0; i < table_pairs.size(); i++) { - auto &send_var_name = table_pairs[i].second; - auto &endpoint = table_pairs[i].first; - auto need_send = NeedSend(*local_scope.get(), send_var_name); - - VLOG(4) << "send var name: " << send_var_name - << " send var endpoint: " << endpoint - << " need send: " << need_send; - - if (need_send) { - VLOG(4) << "sending " << send_var_name << " to " << endpoint; - - rets.push_back(rpc_client->AsyncSendVar( - endpoint, cpu_ctx, *local_scope.get(), send_var_name)); - VLOG(4) << "send var " << send_var_name << " async handle done"; - } else { - VLOG(4) << "don't send non-initialized variable: " - << rpc_ctx.splited_varnames[i]; - } - } - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "unsupported var type: %s to send!", send_var->Type())); - } - - VLOG(4) << "Prepare to send var " << rpc_ctx.var_name; - if (sync) { - for (auto &handle : rets) { - VLOG(4) << "Wait send var to pserver handle: " << handle; - PADDLE_ENFORCE_NE(handle->Wait(), 0U, platform::errors::ExecutionTimeout( - "internal error in RPCClient")); - } - } -} - -template struct ParameterSend; - -}; // namespace distributed -}; // namespace operators -}; // namespace paddle diff --git a/paddle/fluid/operators/distributed/parameter_send.h b/paddle/fluid/operators/distributed/parameter_send.h deleted file mode 100644 index 4335ef8c73..0000000000 --- a/paddle/fluid/operators/distributed/parameter_send.h +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include - -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/operators/distributed/communicator_common.h" - -namespace paddle { -namespace operators { -namespace distributed { - -template -struct ParameterSend { - void operator()(const CommContext &rpc_ctx, const framework::Scope &scope, - bool sync, int multi_parts); -}; - -}; // namespace distributed -}; // namespace operators -}; // namespace paddle diff --git a/paddle/fluid/operators/distributed/proto_encoder_helper.h b/paddle/fluid/operators/distributed/proto_encoder_helper.h deleted file mode 100644 index cedc98b1fc..0000000000 --- a/paddle/fluid/operators/distributed/proto_encoder_helper.h +++ /dev/null @@ -1,146 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -// NOTE: This file was originally created by tensorflow -// (https://github.com/tensorflow/tensorflow/) we borrow this -// file and did some modifications so that we can send gRPC -// requests without too much copying of the tensor data. - -#pragma once - -#include - -#include "grpc++/grpc++.h" -#include "paddle/fluid/platform/enforce.h" - -namespace paddle { -namespace operators { -namespace distributed { - -char* EncodeVarint32(char* dst, uint32_t v) { - // Operate on characters as unsigneds - unsigned char* ptr = reinterpret_cast(dst); - static const int B = 128; - if (v < (1 << 7)) { - *(ptr++) = v; - } else if (v < (1 << 14)) { - *(ptr++) = v | B; - *(ptr++) = v >> 7; - } else if (v < (1 << 21)) { - *(ptr++) = v | B; - *(ptr++) = (v >> 7) | B; - *(ptr++) = v >> 14; - } else if (v < (1 << 28)) { - *(ptr++) = v | B; - *(ptr++) = (v >> 7) | B; - *(ptr++) = (v >> 14) | B; - *(ptr++) = v >> 21; - } else { - *(ptr++) = v | B; - *(ptr++) = (v >> 7) | B; - *(ptr++) = (v >> 14) | B; - *(ptr++) = (v >> 21) | B; - *(ptr++) = v >> 28; - } - return reinterpret_cast(ptr); -} - -char* EncodeVarint64(char* dst, uint64_t v) { - static const int B = 128; - unsigned char* ptr = reinterpret_cast(dst); - while (v >= B) { - *(ptr++) = (v & (B - 1)) | B; - v >>= 7; - } - *(ptr++) = static_cast(v); - return reinterpret_cast(ptr); -} - -int VarintLength(uint64_t v) { - int len = 1; - while (v >= 128) { - v >>= 7; - len++; - } - return len; -} - -class ProtoEncodeHelper { - public: - ProtoEncodeHelper(char* buf, int max_size) - : base_(buf), p_(buf), limit_(base_ + max_size) {} - - ~ProtoEncodeHelper() {} - - const char* data() const { return base_; } - size_t size() const { return p_ - base_; } - - void WriteUint64(int tag, uint64_t v) { - Encode32(combine(tag, WIRETYPE_VARINT)); - Encode64(v); - } - void WriteBool(int tag, bool v) { - Encode32(combine(tag, WIRETYPE_VARINT)); - EncodeBool(v); - } - void WriteString(int tag, const std::string& v) { - Encode32(combine(tag, WIRETYPE_LENGTH_DELIMITED)); - Encode32(v.size()); - EncodeBytes(v.data(), v.size()); - } - void WriteVarlengthBeginning(int tag, uint32_t len) { - Encode32(combine(tag, WIRETYPE_LENGTH_DELIMITED)); - Encode32(len); - } - void WriteRawBytes(const std::string& v) { EncodeBytes(v.data(), v.size()); } - - private: - // Note: this module's behavior must match the protocol buffer wire encoding - // format. - enum { - WIRETYPE_VARINT = 0, - WIRETYPE_LENGTH_DELIMITED = 2, - }; - static uint32_t combine(uint32_t tag, uint32_t type) { - return ((tag << 3) | type); - } - inline void Encode32(uint32_t v) { - if (v < 128) { - // Fast path for single-byte values. Many of the calls will use a - // constant value for v, so the comparison will get optimized away - // when Encode32 is inlined into the caller. - *p_ = v; - p_++; - } else { - p_ = EncodeVarint32(p_, v); - } - } - void Encode64(uint64_t v) { p_ = EncodeVarint64(p_, v); } - void EncodeBool(bool v) { - *p_ = (v ? 1 : 0); // Equal to varint32 encoding of 0 or 1 - p_++; - } - void EncodeBytes(const char* bytes, int N) { - memcpy(p_, bytes, N); - p_ += N; - } - - char* base_; - char* p_; - char* limit_; // Just for CHECKs -}; - -} // namespace distributed -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/distributed/request_handler.h b/paddle/fluid/operators/distributed/request_handler.h deleted file mode 100644 index 44359af1b1..0000000000 --- a/paddle/fluid/operators/distributed/request_handler.h +++ /dev/null @@ -1,261 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include // NOLINT - -#include -#include -#include -#include -#include -#include - -#include "paddle/fluid/framework/data_type.h" -#include "paddle/fluid/framework/executor.h" -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/framework/scope.h" -#include "paddle/fluid/framework/selected_rows.h" -#include "paddle/fluid/framework/var_type.h" -#include "paddle/fluid/platform/macros.h" - -namespace paddle { -namespace operators { -namespace distributed { - -constexpr char kRequestSend[] = "RequestSend"; -constexpr char kRequestGet[] = "RequestGet"; -constexpr char kRequestGetMonomerVariable[] = "RequestGetMonomerVariable"; -constexpr char kRequestGetMonomerBarrier[] = "RequestGetMonomerBarrier"; -constexpr char kRequestPrefetch[] = "RequestPrefetch"; -constexpr char kRequestCheckpoint[] = "RequestCheckpoint"; -constexpr char kRequestPassBarrier[] = "RequestPassBarrier"; -constexpr char kRequestGetNoBarrier[] = "GetVariableNoBarrier"; -constexpr char kRequestNotify[] = "RequestNotify"; -constexpr char kRequestSendAndRecv[] = "RequestSendAndRecv"; - -constexpr char kSendRPC[] = "SendRPC"; -constexpr char kGetRPC[] = "GetRPC"; -constexpr char kGetNoBarrierRPC[] = "GetNoBarrierRPC"; -constexpr char kGetMonomerRPC[] = "GetMonomerRPC"; -constexpr char kPrefetchRPC[] = "PrefetchRPC"; -constexpr char kBatchBarrierRPC[] = "BatchBarrierRPC"; -constexpr char kFetchBarrierRPC[] = "FetchBarrierRPC"; -constexpr char kSendMonomerFetchBarrierRPC[] = "SendMonomerFetchBarrierRPC"; -constexpr char kSendCompleteRPC[] = "SendCompleteRPC"; -constexpr char kCheckPointNotifyRPC[] = "CheckPointNotifyRPC"; -constexpr char kSendAndRecvRPC[] = "SendAndRecvRPC"; -constexpr int64_t kPrefetchTimeout = 60000; - -#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV" -#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV" -#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV" -#define COMPLETE_MESSAGE "COMPLETE@RECV" -#define WITHOUT_BARRIER_MESSAGE "@WITHOUT_BARRIER@RECV" -#define LEARNING_RATE_DECAY_COUNTER "@LR_DECAY_COUNTER@" -#define STEP_COUNTER "@PS_STEP_COUNTER@" - -#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY" -#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY" - -enum DistributedMode { kSync = 0, kAsync = 1, kHalfAsync = 2, kGeo = 3 }; - -class RPCServer; - -class VarHandle { - public: - VarHandle(const std::string ep, const std::string& method, - const std::string& name, - const platform::DeviceContext* p_ctx = nullptr, - const framework::Scope* p_scope = nullptr) - : status_(kDefaultState) { - ep_ = ep; - ctx_ = p_ctx; - scope_ = p_scope; - name_ = name; - method_ = method; - } - - virtual ~VarHandle() {} - - public: - bool should_retry = false; - - bool Wait() { - int ret = kDefaultState; - { - std::unique_lock lk(sync_mutex_); - wait_cond_.wait(lk, [this] { return status_ != kDefaultState; }); - ret = status_; - } - VLOG(7) << "VarHandle wait:" << ret; - return ret != kErrorState; - } - - void Finish(bool ok) { - { - std::unique_lock lk(sync_mutex_); - status_ = ok ? kFinishState : kErrorState; - } - VLOG(7) << "VarHandle finish:" << ok; - wait_cond_.notify_all(); - } - - std::string String() const { - std::ostringstream s; - s << method_ << " name:[" << name_ << "], ep:[" << ep_ << "], status:[" - << status_ << "]"; - return s.str(); - } - - std::string ep() const { return ep_; } - const platform::DeviceContext* ctx() const { return ctx_; } - const framework::Scope* scope() const { return scope_; } - std::string name() const { return name_; } - std::string method() const { return method_; } - - protected: - // RPC endpoint. - std::string ep_; - const platform::DeviceContext* ctx_; - const framework::Scope* scope_; - // Variable name. - std::string name_; - // RPC method name. - std::string method_; - - protected: - std::mutex sync_mutex_; - std::condition_variable wait_cond_; - - enum VarHandleStatus { - kDefaultState = -1, - kErrorState = 0, - kFinishState = 1, - }; - VarHandleStatus status_; - - private: - DISABLE_COPY_AND_ASSIGN(VarHandle); -}; - -typedef std::shared_ptr VarHandlePtr; - -class RequestHandler { - public: - explicit RequestHandler(int distributed_mode) - : distributed_mode_(distributed_mode), - dev_ctx_(nullptr), - executor_(nullptr), - scope_(nullptr), - program_(nullptr), - rpc_server_(nullptr) {} - - virtual ~RequestHandler() {} - - // Set attributes. - void SetScope(framework::Scope* scope) { scope_ = scope; } - void SetDevCtx(const platform::DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; } - void SetProgram(framework::ProgramDesc* program) { program_ = program; } - void SetExecutor(framework::Executor* executor) { executor_ = executor; } - - // Used for dist lookup table prefetch - void SetPrefetchPreparedCtx( - std::unordered_map< - std::string, std::shared_ptr>* g) { - prefetch_var_name_to_prepared_ctx_ = g; - } - - void SetCheckpointNotifyPreparedCtx( - std::shared_ptr g) { - checkpoint_prepared_ctx_ = g; - } - - // Used for async. - void SetGradToPreparedCtx( - std::unordered_map< - std::string, std::shared_ptr>* g) { - grad_to_prepared_ctx_ = g; - } - - void SetSparseGradToParam(std::unordered_map* g) { - sparse_grad_to_param_ = g; - } - - void SetLrDecayPreparedCtx( - std::shared_ptr g) { - lr_decay_prepared_ctx_ = g; - } - - void SetRPCServer(RPCServer* rpc_server) { rpc_server_ = rpc_server; } - - // Get attributes. - int distributed_mode() { return distributed_mode_; } - framework::Scope* scope() { return scope_; } - const platform::DeviceContext* dev_ctx() { return dev_ctx_; } - framework::ProgramDesc* program() { return program_; } - framework::Executor* executor() { return executor_; } - - // This function processes user's rpc request. - // The implemention is in request_handler_impl. - // example: - // std::string varname = request_.varname(); - // - // auto scope = request_handler_->scope(); - // auto invar = scope->FindVar(varname); - // framework::Variable* outvar = nullptr; - // - // request_handler_->Handle(varname, scope, invar, &outvar); - // if (outvar) { - // SerializeToByteBuffer(varname, outvar, - // *request_handler_->dev_ctx(), &reply_); - // } - virtual bool Handle(const std::string& varname, framework::Scope* scope, - framework::Variable* var, framework::Variable** outvar, - const int trainer_id, - const std::string& out_var_name = "", - const std::string& table_name = "") = 0; - - protected: - const int distributed_mode_; - - const platform::DeviceContext* dev_ctx_; - framework::Executor* executor_; - framework::Scope* scope_; - framework::ProgramDesc* program_; - - // used for distribute lookup table prefetch - std::unordered_map>* - prefetch_var_name_to_prepared_ctx_; - // used for checkpoint notify - std::shared_ptr checkpoint_prepared_ctx_; - - // Used for async. - std::unordered_map>* - grad_to_prepared_ctx_; - std::unordered_map* sparse_grad_to_param_; - - // used for lr decay - std::shared_ptr lr_decay_prepared_ctx_; - RPCServer* rpc_server_; -}; - -} // namespace distributed -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/distributed/request_handler_impl.cc b/paddle/fluid/operators/distributed/request_handler_impl.cc deleted file mode 100644 index 8c4f2ef57a..0000000000 --- a/paddle/fluid/operators/distributed/request_handler_impl.cc +++ /dev/null @@ -1,354 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/operators/distributed/request_handler_impl.h" -#include -#include -#include - -#include "paddle/fluid/framework/data_type.h" -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/scope.h" -#include "paddle/fluid/framework/selected_rows.h" -#include "paddle/fluid/framework/variable_helper.h" -#include "paddle/fluid/operators/distributed/rpc_server.h" -#include "paddle/fluid/string/piece.h" -#include "paddle/fluid/string/printf.h" -#include "paddle/fluid/string/split.h" - -#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h" -#include "paddle/fluid/operators/distributed/heart_beat_monitor.h" -#include "paddle/fluid/operators/distributed/large_scale_kv.h" - -namespace paddle { -namespace operators { -namespace distributed { - -// define LOOKUP_TABLE_PATH for checkpoint notify to save lookup table variables -// to directory specified. -constexpr char LOOKUP_TABLE_PATH[] = "kLookupTablePath"; - -bool RequestSendHandler::Handle(const std::string &varname, - framework::Scope *scope, - framework::Variable *invar, - framework::Variable **outvar, - const int trainer_id, - const std::string &out_var_name, - const std::string &table_name) { - VLOG(4) << "RequestSendHandler:" << varname; - - // Sync - if (varname == BATCH_BARRIER_MESSAGE) { - VLOG(3) << "sync: recv BATCH_BARRIER_MESSAGE"; - rpc_server_->IncreaseBatchBarrier(kRequestSend); - } else if (varname == COMPLETE_MESSAGE) { - VLOG(3) << "sync: recv complete message"; - - if (HeartBeatMonitor::GetInstance() != nullptr) { - HeartBeatMonitor::GetInstance()->Update(trainer_id, "", COMPLETED); - } - - rpc_server_->Complete(); - } else { - // Async - if (distributed_mode_ != DistributedMode::kSync) { - VLOG(3) << "async process var: " << varname; - if (varname == BATCH_BARRIER_MESSAGE) { - PADDLE_THROW(platform::errors::InvalidArgument( - "async mode should not recv BATCH_BARRIER_MESSAGE or " - "COMPLETE_MESSAGE")); - } - HeartBeatMonitor::GetInstance()->Update(trainer_id, varname, RUNNING); - - std::string run_varname = varname; - - string::Piece part_piece("@PIECE"); - string::Piece var_name_piece = string::Piece(varname); - - if (string::Contains(var_name_piece, part_piece)) { - auto varname_splits = paddle::string::Split(varname, '@'); - PADDLE_ENFORCE_EQ( - varname_splits.size(), 3, - platform::errors::InvalidArgument( - "varname: %s should be separated into 3 parts by @", varname)); - run_varname = varname_splits[0]; - scope->Rename(varname, run_varname); - } - - auto *var = scope->FindVar(run_varname); - - // for sparse ids - if (var->IsType()) { - if (distributed_mode_ == DistributedMode::kAsync || - distributed_mode_ == DistributedMode::kHalfAsync) { - auto *ins = distributed::LargeScaleKV::GetInstance(); - if (ins->GradInLargeScale(run_varname)) { - auto *large_scale_var = ins->GetByGrad(run_varname); - - for (auto name : large_scale_var->CachedVarnames()) { - scope->Var(name); - } - } - } - if (distributed_mode_ == DistributedMode::kGeo) { - if (AsyncSparseParamUpdateRecorder::GetInstance()->HasGrad( - run_varname)) { - auto &grad_slr = - scope->FindVar(run_varname)->Get(); - AsyncSparseParamUpdateRecorder::GetInstance()->Update( - run_varname, grad_slr.rows()); - } - } - } - - executor_->RunPreparedContext((*grad_to_prepared_ctx_)[run_varname].get(), - scope); - return true; - } else { // sync - rpc_server_->WaitCond(kRequestSend); - VLOG(3) << "sync: processing received var: " << varname; - PADDLE_ENFORCE_NOT_NULL( - invar, platform::errors::NotFound( - "sync: Can not find server side var %s.", varname)); - } - } - return true; -} - -bool RequestGetHandler::Handle(const std::string &varname, - framework::Scope *scope, - framework::Variable *invar, - framework::Variable **outvar, - const int trainer_id, - const std::string &out_var_name, - const std::string &table_name) { - VLOG(3) << "RequestGetHandler:" << varname - << " out_var_name: " << out_var_name << " trainer_id: " << trainer_id - << " table_name: " << table_name; - - if (distributed_mode_ == DistributedMode::kSync) { - if (varname == FETCH_BARRIER_MESSAGE) { - VLOG(3) << "sync: recv fetch barrier message"; - rpc_server_->IncreaseBatchBarrier(kRequestGet); - } else { - rpc_server_->WaitCond(kRequestGet); - *outvar = scope_->FindVar(varname); - } - } else { - if (varname != FETCH_BARRIER_MESSAGE && varname != COMPLETE_MESSAGE) { - if (enable_dc_asgd_) { - // NOTE: the format is determined by distribute_transpiler.py - std::string param_bak_name = - string::Sprintf("%s.trainer_%d_bak", varname, trainer_id); - VLOG(3) << "getting " << param_bak_name << " trainer_id " << trainer_id; - auto var = scope_->FindVar(varname); - auto t_orig = var->Get(); - auto param_bak = scope_->Var(param_bak_name); - auto t = param_bak->GetMutable(); - t->mutable_data(dev_ctx_->GetPlace(), t_orig.type()); - VLOG(3) << "copying " << varname << " to " << param_bak_name; - framework::TensorCopy(t_orig, dev_ctx_->GetPlace(), t); - } - - if (distributed_mode_ == DistributedMode::kGeo && - AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(varname) && - !table_name.empty()) { - VLOG(3) << "AsyncSparseParamUpdateRecorder " << varname << " exist "; - - std::vector updated_rows; - AsyncSparseParamUpdateRecorder::GetInstance()->GetAndClear( - varname, trainer_id, &updated_rows); - - if (VLOG_IS_ON(3)) { - std::ostringstream sstream; - sstream << "["; - for (auto &row_id : updated_rows) { - sstream << row_id << ", "; - } - sstream << "]"; - VLOG(3) << "updated_rows size: " << updated_rows.size() << " " - << sstream.str(); - } - - auto &origin_tensor = - scope_->FindVar(varname)->Get(); - auto *origin_tensor_data = origin_tensor.data(); - auto &dims = origin_tensor.dims(); - *outvar = scope->Var(); - auto *out_slr = (*outvar)->GetMutable(); - out_slr->set_rows(updated_rows); - out_slr->set_height(dims[0]); - auto out_dims = framework::make_ddim( - {static_cast(updated_rows.size()), dims[1]}); - auto *data = out_slr->mutable_value()->mutable_data( - out_dims, origin_tensor.place()); - auto width = dims[1]; - for (size_t i = 0; i < updated_rows.size(); ++i) { - PADDLE_ENFORCE_LT( - updated_rows[i], dims[0], - platform::errors::OutOfRange( - "The value of updated_rows: %s out of Tensor %s dims[0]: %s", - updated_rows[i], varname, dims[0])); - memcpy(data + i * width, origin_tensor_data + updated_rows[i] * width, - sizeof(float) * width); - } - } else { - *outvar = scope_->FindVar(varname); - } - } - } - return true; -} - -bool RequestGetNoBarrierHandler::Handle(const std::string &varname, - framework::Scope *scope, - framework::Variable *invar, - framework::Variable **outvar, - const int trainer_id, - const std::string &out_var_name, - const std::string &table_name) { - VLOG(4) << "RequestGetNoBarrierHandler:" << varname - << " out_var_name: " << out_var_name; - - // get var from pserver immediately without barriers - string::Piece without_barrier_piece(WITHOUT_BARRIER_MESSAGE); - string::Piece var_name_piece = string::Piece(varname); - - if (string::Contains(var_name_piece, without_barrier_piece)) { - var_name_piece = string::TrimSuffix(var_name_piece, without_barrier_piece); - VLOG(4) << "Get var " << var_name_piece << " with " - << WITHOUT_BARRIER_MESSAGE; - *outvar = scope_->FindVar(var_name_piece.ToString()); - return true; - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "GetNoBarrier must contain %s", WITHOUT_BARRIER_MESSAGE)); - } - return true; -} - -bool RequestPrefetchHandler::Handle(const std::string &varname, - framework::Scope *scope, - framework::Variable *invar, - framework::Variable **outvar, - const int trainer_id, - const std::string &out_var_name, - const std::string &table_name) { - VLOG(4) << "RequestPrefetchHandler " << varname; - - (*outvar)->GetMutable(); - - VLOG(1) << "Prefetch " - << "tablename: " << table_name << " ids:" << varname - << " out: " << out_var_name; - paddle::platform::CPUPlace cpu_place; - auto *ins = distributed::LargeScaleKV::GetInstance(); - - if (ins->ParamInLargeScale(table_name)) { - auto lookup_table_op = PullLargeScaleOp(table_name, varname, out_var_name); - lookup_table_op->Run(*scope, cpu_place); - } else { - auto lookup_table_op = - BuildLookupTableOp(table_name, varname, out_var_name); - lookup_table_op->Run(*scope, cpu_place); - } - - return true; -} - -bool RequestCheckpointHandler::Handle(const std::string &varname, - framework::Scope *scope, - framework::Variable *invar, - framework::Variable **outvar, - const int trainer_id, - const std::string &out_var_name, - const std::string &table_name) { - VLOG(4) << "receive save var " << varname << " with path " << out_var_name - << " mode " << table_name; - - int mode = std::stoi(table_name); - - auto *ins = distributed::LargeScaleKV::GetInstance(); - ins->Get(varname)->Save(out_var_name, mode); - return true; -} - -bool RequestNotifyHandler::Handle(const std::string &varname, - framework::Scope *scope, - framework::Variable *invar, - framework::Variable **outvar, - const int trainer_id, - const std::string &out_var_name, - const std::string &table_name) { - VLOG(3) << "RequestNotifyHandler: " << varname - << ", trainer_id: " << trainer_id; - - string::Piece decay_piece(STEP_COUNTER); - string::Piece var_name_piece = string::Piece(varname); - if (string::Contains(var_name_piece, decay_piece)) { - VLOG(3) << "LearningRate Decay Counter Update"; - - auto *send_var = scope->FindVar(varname); - auto send_var_tensor = send_var->Get(); - auto *send_value = - send_var_tensor.mutable_data(send_var_tensor.place()); - - auto counter = decay_counters.at(trainer_id); - counter += send_value[0]; - decay_counters.at(trainer_id) = counter; - - auto *global_step_var = this->scope()->FindVar(LEARNING_RATE_DECAY_COUNTER); - if (global_step_var == nullptr) { - PADDLE_THROW(platform::errors::InvalidArgument( - "can not find LEARNING_RATE_DECAY_COUNTER ")); - } - - auto *tensor = global_step_var->GetMutable(); - auto *value = tensor->mutable_data(platform::CPUPlace()); - - auto global_counter = 0; - for (auto &trainer_counter : decay_counters) { - global_counter += trainer_counter.second; - } - value[0] = global_counter; - - if (lr_decay_prepared_ctx_.get() == nullptr) { - PADDLE_THROW(platform::errors::InvalidArgument( - "can not find decay block for executor")); - } - - executor_->RunPreparedContext(lr_decay_prepared_ctx_.get(), scope_); - } - return true; -} - -bool RequestSendAndRecvHandler::Handle(const std::string &varname, - framework::Scope *Scope, - framework::Variable *var, - framework::Variable **outvar, - const int trainer_id, - const std::string &out_var_name, - const std::string &table_name) { - VLOG(3) << "SendAndRecvHandle: " << varname - << " out_var_name: " << out_var_name - << " , trainer_id: " << trainer_id; - - executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(), Scope); - *outvar = Scope->FindVar(out_var_name); - return true; -} - -} // namespace distributed -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/distributed/request_handler_impl.h b/paddle/fluid/operators/distributed/request_handler_impl.h deleted file mode 100644 index 6d239673f9..0000000000 --- a/paddle/fluid/operators/distributed/request_handler_impl.h +++ /dev/null @@ -1,198 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -#include "paddle/fluid/framework/data_type.h" -#include "paddle/fluid/framework/executor.h" -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/framework/scope.h" -#include "paddle/fluid/framework/selected_rows.h" -#include "paddle/fluid/framework/var_type.h" -#include "paddle/fluid/operators/distributed/request_handler.h" - -namespace paddle { -namespace framework { -class Scope; -class Variable; -} // namespace framework -} // namespace paddle - -namespace paddle { -namespace operators { -namespace distributed { - -class RequestSendHandler final : public RequestHandler { - public: - explicit RequestSendHandler(int distributed_mode, bool enable_dc_asgd = false) - : RequestHandler(distributed_mode) { - enable_dc_asgd_ = enable_dc_asgd; - } - virtual ~RequestSendHandler() {} - bool Handle(const std::string& varname, framework::Scope* scope, - framework::Variable* var, framework::Variable** outvar, - const int trainer_id, const std::string& out_var_name = "", - const std::string& table_name = "") override; - - private: - bool enable_dc_asgd_; -}; - -class RequestGetHandler final : public RequestHandler { - public: - explicit RequestGetHandler(int distributed_mode, bool enable_dc_asgd = false) - : RequestHandler(distributed_mode) { - enable_dc_asgd_ = enable_dc_asgd; - } - virtual ~RequestGetHandler() {} - bool Handle(const std::string& varname, framework::Scope* scope, - framework::Variable* var, framework::Variable** outvar, - const int trainer_id, const std::string& out_var_name = "", - const std::string& table_name = "") override; - - private: - bool enable_dc_asgd_; -}; - -class RequestGetNoBarrierHandler final : public RequestHandler { - public: - RequestGetNoBarrierHandler() : RequestHandler(false) {} - virtual ~RequestGetNoBarrierHandler() {} - bool Handle(const std::string& varname, framework::Scope* scope, - framework::Variable* var, framework::Variable** outvar, - const int trainer_id, const std::string& out_var_name = "", - const std::string& table_name = "") override; -}; - -static inline void BuildVar(const std::string& param_name, - std::initializer_list arguments, - paddle::framework::proto::OpDesc::Var* var) { - var->set_parameter(param_name); - for (auto& arg_name : arguments) { - *var->mutable_arguments()->Add() = arg_name; - } -} - -class RequestPrefetchHandler final : public RequestHandler { - public: - explicit RequestPrefetchHandler(int distributed_mode) - : RequestHandler(distributed_mode) {} - virtual ~RequestPrefetchHandler() {} - bool Handle(const std::string& varname, framework::Scope* scope, - framework::Variable* var, framework::Variable** outvar, - const int trainer_id, const std::string& out_var_name = "", - const std::string& table_name = "") override; - - private: - std::unique_ptr PullLargeScaleOp( - const std::string& table_name, const std::string& id_name, - const std::string& out_name) { - framework::OpDesc desc; - desc.SetType("lookup_sparse_table_read"); - desc.SetInput("Ids", {id_name}); - desc.SetOutput("Out", std::vector({out_name})); - desc.SetAttr("tablename", {table_name}); - desc.SetAttr("init", true); - desc.SetAttr("value_names", std::vector({"Param"})); - - auto op = paddle::framework::OpRegistry::CreateOp(desc); - return op; - } - - std::unique_ptr BuildLookupTableOp( - const std::string& table_name, const std::string& id_name, - const std::string& out_name) { - paddle::framework::proto::OpDesc op_desc; - op_desc.set_type("lookup_table"); - BuildVar("W", {table_name.data()}, op_desc.add_inputs()); - BuildVar("Ids", {id_name.data()}, op_desc.add_inputs()); - BuildVar("Out", {out_name.data()}, op_desc.add_outputs()); - - auto op = paddle::framework::OpRegistry::CreateOp(op_desc); - return op; - } -}; - -class RequestCheckpointHandler final : public RequestHandler { - public: - explicit RequestCheckpointHandler(int distributed_mode) - : RequestHandler(distributed_mode) {} - - virtual ~RequestCheckpointHandler() {} - bool Handle(const std::string& varname, framework::Scope* scope, - framework::Variable* var, framework::Variable** outvar, - const int trainer_id, const std::string& out_var_name = "", - const std::string& table_name = "") override; - - private: - std::unique_ptr BuildCheckpointOp( - const std::string& varname, const std::string& file_path) { - paddle::framework::proto::OpDesc op_desc; - op_desc.set_type("save"); - BuildVar("X", {varname.data()}, op_desc.add_inputs()); - - auto attr = op_desc.mutable_attrs()->Add(); - attr->set_name("file_path"); - attr->set_type(paddle::framework::proto::AttrType::STRING); - attr->set_s(file_path); - - auto op = paddle::framework::OpRegistry::CreateOp(op_desc); - return op; - } -}; - -class RequestNotifyHandler final : public RequestHandler { - public: - explicit RequestNotifyHandler(int distributed_mode, int trainers) - : RequestHandler(distributed_mode) { - this->trainers = trainers; - for (int i = 0; i < trainers; i++) { - decay_counters[i] = 0; - } - } - virtual ~RequestNotifyHandler() {} - bool Handle(const std::string& varname, framework::Scope* scope, - framework::Variable* var, framework::Variable** outvar, - const int trainer_id, const std::string& out_var_name = "", - const std::string& table_name = "") override; - - private: - int trainers; - std::unordered_map decay_counters; -}; - -class RequestSendAndRecvHandler final : public RequestHandler { - public: - explicit RequestSendAndRecvHandler(int distributed_mode) - : RequestHandler(distributed_mode) {} - virtual ~RequestSendAndRecvHandler() {} - bool Handle(const std::string& varname, framework::Scope* Scope, - framework::Variable* var, framework::Variable** outvar, - const int trainer_id, const std::string& out_var_name = "", - const std::string& table_name = "") override; -}; - -} // namespace distributed -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/distributed/rpc_client.cc b/paddle/fluid/operators/distributed/rpc_client.cc deleted file mode 100644 index 57ce54870d..0000000000 --- a/paddle/fluid/operators/distributed/rpc_client.cc +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/operators/distributed/rpc_client.h" -#include "gflags/gflags.h" - -// default to 3min to avoid temprary network failures. -DEFINE_int32(rpc_deadline, 180000, "deadline timeouts for rpc"); -DEFINE_int32(rpc_retry_times, 3, "retry times for rpc"); - -namespace paddle { -namespace operators { -namespace distributed { - -std::once_flag RPCClient::init_flag_; -std::unique_ptr RPCClient::rpc_client_(nullptr); -int RPCClient::trainer_id_ = 0; - -} // namespace distributed -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/distributed/rpc_client.h b/paddle/fluid/operators/distributed/rpc_client.h deleted file mode 100644 index 2c756a6f71..0000000000 --- a/paddle/fluid/operators/distributed/rpc_client.h +++ /dev/null @@ -1,143 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include // NOLINT -#include -#include - -#include "gflags/gflags.h" -#include "paddle/fluid/framework/data_type.h" -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/scope.h" -#include "paddle/fluid/operators/distributed/request_handler.h" - -namespace paddle { -namespace framework { -class Scope; -} // namespace framework -namespace platform { -class DeviceContext; -} // namespace platform -} // namespace paddle - -DECLARE_int32(rpc_deadline); -DECLARE_int32(rpc_retry_times); - -namespace paddle { -namespace operators { -namespace distributed { - -class RPCClient { - public: - RPCClient() {} - virtual ~RPCClient() {} - virtual VarHandlePtr AsyncSendVar(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& var_name, - int64_t time_out = FLAGS_rpc_deadline) = 0; - - virtual VarHandlePtr AsyncGetVar(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& var_name, - const std::string& out_varname, - const std::string& table_name = "", - int64_t time_out = FLAGS_rpc_deadline) = 0; - - virtual VarHandlePtr AsyncGetVarNoBarrier( - const std::string& ep, const platform::DeviceContext& ctx, - const framework::Scope& scope, const std::string& var_name, - const std::string& out_varname, - int64_t time_out = FLAGS_rpc_deadline) = 0; - - virtual VarHandlePtr AsyncGetMonomerVariable( - const std::string& ep, const platform::DeviceContext& ctx, - const framework::Scope& scope, const std::string& var_name, - int64_t time_out = FLAGS_rpc_deadline) = 0; - - virtual VarHandlePtr AsyncPrefetchVar( - const std::string& ep, const platform::DeviceContext& ctx, - const framework::Scope& scope, const std::string& in_var_name, - const std::string& out_var_name, const std::string& table_name = "", - int64_t time_out = FLAGS_rpc_deadline) = 0; - - virtual VarHandlePtr AsyncSendBatchBarrier( - const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) = 0; - - virtual VarHandlePtr AsyncSendFetchBarrier( - const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) = 0; - - virtual VarHandlePtr AsyncGetMonomerBarrier( - const std::string& ep, const std::string& var_name, - int64_t time_out = FLAGS_rpc_deadline) = 0; - - virtual VarHandlePtr AsyncCheckpointNotify( - const std::string& ep, const std::string& dirname, - const std::string& varname, const int mode, - int64_t time_out = FLAGS_rpc_deadline) = 0; - - virtual VarHandlePtr AsyncDistributeNotify( - const std::string& ep, const platform::DeviceContext& ctx, - const framework::Scope& scope, const std::string& var_name, - int64_t time_out = FLAGS_rpc_deadline) = 0; - - virtual VarHandlePtr AsyncSendAndRecv( - const std::string& ep, const platform::DeviceContext& ctx, - const framework::Scope& scope, const std::string& send_var_name, - const std::string& recv_var_name, const std::string& table_name = "", - int64_t time_out = FLAGS_rpc_deadline) = 0; - - virtual VarHandlePtr AsyncSendComplete( - const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) = 0; - - // Complete tells all the pserver instances that finishe the training, - // the pserver can reduce it's barrier count, and continue to train - // with other trainers. - virtual void SendComplete() = 0; - - virtual bool Wait() = 0; - - template - static RPCClient* GetInstance(int trainer_id) { - std::call_once(init_flag_, &RPCClient::Init, trainer_id); - return rpc_client_.get(); - } - - // Init is called by GetInstance. - template - static void Init(int trainer_id) { - VLOG(1) << "init rpc client with trainer_id " << trainer_id; - trainer_id_ = trainer_id; - if (rpc_client_.get() == nullptr) { - rpc_client_.reset(new T()); - rpc_client_->InitImpl(); - } - } - - virtual void InitImpl() {} - - protected: - // each trainer have exact one trainer id, it should be static - static int trainer_id_; - - private: - static std::once_flag init_flag_; - static std::unique_ptr rpc_client_; -}; -} // namespace distributed -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/distributed/rpc_server.cc b/paddle/fluid/operators/distributed/rpc_server.cc deleted file mode 100644 index 37cf0460fb..0000000000 --- a/paddle/fluid/operators/distributed/rpc_server.cc +++ /dev/null @@ -1,242 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/operators/distributed/rpc_server.h" - -#include -#include - -namespace paddle { -namespace framework { -class Scope; -} // namespace framework -namespace platform { -class DeviceContext; -} // namespace platform -} // namespace paddle - -namespace paddle { -namespace operators { -namespace distributed { - -class RequestHandler; - -void RPCServer::ShutDown() { - VLOG(3) << "RPCServer ShutDown "; - ShutDownImpl(); - - exit_flag_ = true; - barrier_cond_.notify_all(); - rpc_cond_.notify_all(); -} - -void RPCServer::SavePort() const { - auto file_path = string::Sprintf("/tmp/paddle.%d.port", ::getpid()); - std::ofstream port_file; - port_file.open(file_path); - port_file << selected_port_; - port_file.close(); - VLOG(3) << "selected port written to " << file_path; -} - -void RPCServer::WaitBarrier(const std::string& rpc_name) { - VLOG(3) << "WaitBarrier in: " << rpc_name; - std::unique_lock lock(this->mutex_); - barrier_cond_.wait(lock, [this, &rpc_name] { - return ((barrier_counter_[rpc_name] == client_num_ && client_num_ != 0) || - exit_flag_.load()); - }); - - VLOG(3) << "WaitBarrier out: " << rpc_name - << " counter: " << barrier_counter_[rpc_name]; -} - -void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) { - VLOG(3) << "RPCServer begin IncreaseBatchBarrier " << rpc_name; - // barrier msg should make sure that it's in the right cond(send|recv) - WaitCond(rpc_name); - int b = 0; - std::unique_lock lock(mutex_); - b = ++barrier_counter_[rpc_name]; - VLOG(3) << rpc_name << " barrier_counter: " << b; - if (b >= client_num_) { - lock.unlock(); - VLOG(3) << "BatchBarrier counter reach " << client_num_ << " for " - << rpc_name; - barrier_cond_.notify_all(); - lock.lock(); - } -} - -void RPCServer::Complete() { - { - std::unique_lock lock(mutex_); - client_num_--; - need_reset_all_vars_ = true; - - VLOG(3) << "decrease client_num to: " << client_num_; - if (cur_cond_.load() == rpc_cond_map_[kRequestGet]) { - barrier_counter_[kRequestGet]--; - } - } - barrier_cond_.notify_all(); -} - -bool RPCServer::NeedResetAllVars() { - std::unique_lock lock(mutex_); - return need_reset_all_vars_; -} - -int RPCServer::GetClientNum() { - std::unique_lock lock(mutex_); - return client_num_; -} - -void RPCServer::ResetBarrierCounter() { - VLOG(3) << "RPCServer ResetBarrierCounter "; - std::unique_lock lock(mutex_); - for (auto& t : barrier_counter_) { - t.second = 0; - } - need_reset_all_vars_ = false; -} - -void RPCServer::RegisterRPC(const std::string& rpc_name, - RequestHandler* handler, int thread_num) { - rpc_call_map_[rpc_name] = handler; - rpc_thread_num_[rpc_name] = thread_num; - - static int cond = -1; - rpc_cond_map_[rpc_name] = ++cond; - VLOG(3) << "RegisterRPC rpc_name: " << rpc_name << ", handler: " << handler - << ", cond: " << rpc_cond_map_[rpc_name]; -} - -void RPCServer::SetCond(const std::string& rpc_name) { - VLOG(3) << "RPCServer SetCond " << rpc_name; - { - std::unique_lock lock(mutex_); - cur_cond_ = rpc_cond_map_[rpc_name]; - } - - rpc_cond_.notify_all(); -} - -void RPCServer::WaitCond(const std::string& rpc_name) { - VLOG(3) << "RPCServer WaitCond in " << rpc_name; - int cond = 0; - { - std::unique_lock lock(mutex_); - cond = rpc_cond_map_[rpc_name]; - } - - std::unique_lock lock(mutex_); - rpc_cond_.wait( - lock, [=] { return (cur_cond_.load() == cond || exit_flag_.load()); }); - VLOG(3) << "RPCServer WaitCond out " << rpc_name; -} - -void RPCServer::RegisterVar(const std::string& var_name, - const std::string& rpc_name, - framework::Scope* scope, - platform::DeviceContext* dev_ctx) { - MonomerHandle h; - h.var_name_ = var_name; - h.rpc_name_ = rpc_name; - h.scope_ = scope; - h.dev_ctx_ = dev_ctx; - - { - std::unique_lock lock(mutex_); - PADDLE_ENFORCE_EQ( - var_map_.find(var_name), var_map_.end(), - platform::errors::AlreadyExists("%s already in var_map.", var_name)); - var_map_[var_name] = h; - } - - rpc_cond_.notify_all(); - VLOG(3) << "RegisterVar context:" << h.String(); -} - -void RPCServer::IncreaseVarBarrier(const std::string& var_name) { - int b = 0; - MonomerHandle h; - { - std::unique_lock lock(mutex_); - b = ++var_map_[var_name].barrier_; - h = var_map_[var_name]; - } - - if (b >= client_num_) { - barrier_cond_.notify_all(); - } - - VLOG(3) << "IncreaseVarBarrier context:" << h.String(); -} - -void RPCServer::WaitVarBarrier(const std::string& var_name) { - VLOG(3) << "WaitVarBarrier var_name:" << var_name; - - std::unique_lock lock(mutex_); - barrier_cond_.wait(lock, [&]() { - return ((var_map_[var_name].barrier_ >= client_num_ && client_num_ != 0) || - exit_flag_.load()); - }); - - VLOG(3) << "WaitVarBarrier context: " << var_map_[var_name].String(); -} - -void RPCServer::SetVarCond(const std::string& var_name) { - VLOG(3) << "SetVarCond var_name:" << var_name; - { - std::unique_lock lock(mutex_); - if (var_map_.find(var_name) != var_map_.end()) { - rpc_cond_.notify_all(); - } - } -} - -void RPCServer::WaitVarCond(const std::string& var_name) { - VLOG(3) << "WaitVarCond var_name:" << var_name; - - std::unique_lock lock(mutex_); - rpc_cond_.wait(lock, [=] { - return (var_map_.find(var_name) != var_map_.end() || exit_flag_.load()); - }); - - VLOG(3) << "WaitVarCond var_name:" << var_name << " end"; -} - -MonomerHandle RPCServer::GetMonomer(const std::string& var_name) { - MonomerHandle h; - { - std::unique_lock lock(mutex_); - h = var_map_[var_name]; - } - - return h; -} - -void RPCServer::ClearRegisteredVars() { - std::unique_lock lock(mutex_); - var_map_.clear(); -} - -void RPCServer::ClearVar(const std::string& var_name) { - std::unique_lock lock(mutex_); - var_map_.erase(var_name); -} -} // namespace distributed -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/distributed/rpc_server.h b/paddle/fluid/operators/distributed/rpc_server.h deleted file mode 100644 index 2120260515..0000000000 --- a/paddle/fluid/operators/distributed/rpc_server.h +++ /dev/null @@ -1,149 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include -#include // NOLINT -#include -#include -#include - -#include "paddle/fluid/framework/scope.h" -#include "paddle/fluid/operators/distributed/request_handler.h" -#include "paddle/fluid/platform/device_context.h" - -namespace paddle { -namespace framework { -class Scope; -} // namespace framework -namespace platform { -class DeviceContext; -} // namespace platform -} // namespace paddle - -namespace paddle { -namespace operators { -namespace distributed { - -class RequestHandler; - -struct MonomerHandle { - std::string var_name_; - std::string rpc_name_; - framework::Scope* scope_{nullptr}; - platform::DeviceContext* dev_ctx_{nullptr}; - int64_t barrier_{0}; - - std::string String() { - std::stringstream ss; - ss << "var_name:" << var_name_ << ", rpc_name:" << rpc_name_ - << ", scope:" << scope_ << ", dev_ctx:" << dev_ctx_ - << ", barrier_:" << barrier_; - return ss.str(); - } -}; - -class RPCServer { - public: - explicit RPCServer(const std::string& address, int client_num) - : cur_cond_(0), - bind_address_(address), - exit_flag_(false), - selected_port_(0), - client_num_(client_num), - need_reset_all_vars_(false) {} - - virtual ~RPCServer() {} - virtual void StartServer() = 0; - virtual void WaitServerReady() = 0; - - void ShutDown(); - - bool IsExit() { return exit_flag_.load(); } - - int GetSelectedPort() const { return selected_port_; } - - int GetClientNum(); - - void SavePort() const; - - // RegisterRPC, register the rpc method name to a handler - // class, and auto generate a condition id for this call - // to be used for the barrier. - void RegisterRPC(const std::string& rpc_name, RequestHandler* handler, - int thread_num = 1); - - int GetThreadNum(const std::string& rpc_name) { - return rpc_thread_num_[rpc_name]; - } - - // Wait util all the clients have reached the barrier for one - // rpc method. This function should be called in the - // RequestHandler if you want to run the server/client in a - // synchronous mode. - void WaitBarrier(const std::string& rpc_name); - - void SetCond(const std::string& rpc_name); - void WaitCond(const std::string& rpc_name); - void IncreaseBatchBarrier(const std::string rpc_name); - - void RegisterVar(const std::string& var_name, const std::string& rpc_name, - framework::Scope* scope, platform::DeviceContext* dev_ctx); - void IncreaseVarBarrier(const std::string& var_name); - void WaitVarBarrier(const std::string& var_name); - void SetVarCond(const std::string& var_name); - void WaitVarCond(const std::string& var_name); - void ClearRegisteredVars(); - void ClearVar(const std::string& var_name); - MonomerHandle GetMonomer(const std::string& var_name); - - void Complete(); - - void ResetBarrierCounter(); - - bool NeedResetAllVars(); - - protected: - virtual void ShutDownImpl() = 0; - - private: - std::mutex mutex_; - std::unordered_map barrier_counter_; - std::condition_variable barrier_cond_; - - std::unordered_map rpc_cond_map_; - std::atomic cur_cond_; - std::condition_variable rpc_cond_; - - protected: - std::string bind_address_; - std::atomic exit_flag_; - int selected_port_; - int client_num_; - bool need_reset_all_vars_; - - std::unordered_map rpc_call_map_; - std::unordered_map rpc_thread_num_; - friend class RequestHandler; - - // TODO(gongwb): use more cond to notify or wait; - std::unordered_map var_map_; -}; - -}; // namespace distributed -}; // namespace operators -}; // namespace paddle diff --git a/paddle/fluid/operators/distributed/rpc_server_test.cc b/paddle/fluid/operators/distributed/rpc_server_test.cc deleted file mode 100644 index f592854000..0000000000 --- a/paddle/fluid/operators/distributed/rpc_server_test.cc +++ /dev/null @@ -1,344 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include -#include // NOLINT -#include -#include -#include // NOLINT -#include - -#include "gtest/gtest.h" -#include "paddle/fluid/framework/block_desc.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" - -#include "paddle/fluid/operators/distributed/distributed.h" -#include "paddle/fluid/operators/distributed/heart_beat_monitor.h" -#include "paddle/fluid/operators/distributed/large_scale_kv.h" -#include "paddle/fluid/operators/distributed/request_handler_impl.h" -#include "paddle/fluid/operators/distributed/rpc_client.h" -#include "paddle/fluid/operators/distributed/rpc_server.h" - -namespace framework = paddle::framework; -namespace platform = paddle::platform; -namespace distributed = paddle::operators::distributed; - -USE_NO_KERNEL_OP(lookup_sparse_table_read); -USE_NO_KERNEL_OP(checkpoint_notify); -USE_OP(scale); - -std::unique_ptr g_rpc_service; -std::unique_ptr g_req_handler; - -framework::BlockDesc* AppendSendAndRecvBlock(framework::ProgramDesc* program) { - auto root_block = program->MutableBlock(0); - auto* block = program->AppendBlock(*root_block); - - framework::OpDesc* op = block->AppendOp(); - op->SetType("scale"); - op->SetInput("X", {"x"}); - op->SetOutput("Out", {"res"}); - op->SetAttr("scale", 0.5f); - - auto& out = *root_block->Var("res"); - out.SetType(framework::proto::VarType::LOD_TENSOR); - out.SetShape({1, 10}); - - return block; -} - -void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) { - auto w_var = scope->Var("w"); - w_var->GetMutable(); - - auto out_var = scope->Var("out"); - out_var->GetMutable(); - - auto ids_var = scope->Var("ids"); - ids_var->GetMutable(); - - auto x_var = scope->Var("x"); - x_var->GetMutable(); - - auto res_var = scope->Var("res"); - res_var->GetMutable(); -} - -void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place, - int64_t rows_numel) { - CreateVarsOnScope(scope, place); - auto ids_var = scope->Var("ids")->GetMutable(); - int64_t* ids_ptr = - ids_var->mutable_data(framework::DDim({rows_numel, 1}), *place); - for (int64_t i = 0; i < rows_numel; ++i) ids_ptr[i] = i * 2; - - auto x_var = scope->Var("x")->GetMutable(); - float* x_ptr = - x_var->mutable_data(framework::DDim({1, rows_numel}), *place); - for (int64_t i = 0; i < rows_numel; ++i) x_ptr[i] = 1.0; -} - -void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place, - int64_t rows_numel) { - CreateVarsOnScope(scope, place); - auto w = scope->Var("w")->GetMutable(); - auto w_value = w->mutable_value(); - w_value->Resize({rows_numel, 10}); - for (int64_t i = 0; i < rows_numel; ++i) w->AutoGrownIndex(i, true); - - auto ptr = w_value->mutable_data(*place); - - for (int64_t i = 0; i < w_value->numel(); ++i) { - ptr[i] = static_cast(i / 10); - } -} - -void StartServer(const std::string& rpc_name) { - framework::ProgramDesc program; - framework::Scope scope; - platform::CPUPlace place; - framework::Executor exe(place); - platform::CPUDeviceContext ctx(place); - - std::unordered_map> - prefetch_var_name_to_prepared; - - g_req_handler->SetProgram(&program); - g_req_handler->SetPrefetchPreparedCtx(&prefetch_var_name_to_prepared); - g_req_handler->SetDevCtx(&ctx); - g_req_handler->SetScope(&scope); - g_req_handler->SetExecutor(&exe); - - g_rpc_service->RegisterRPC(rpc_name, g_req_handler.get()); - - // distributed::HeartBeatMonitor::Init(1, true, "w@grad"); - - g_req_handler->SetRPCServer(g_rpc_service.get()); - - std::thread server_thread( - std::bind(&distributed::RPCServer::StartServer, g_rpc_service.get())); - - server_thread.join(); -} - -void StartSendAndRecvServer(const std::string& rpc_name) { - framework::ProgramDesc program; - framework::Scope scope; - platform::CPUPlace place; - framework::Executor exe(place); - platform::CPUDeviceContext ctx(place); - auto block = AppendSendAndRecvBlock(&program); - std::string in_var_name("x"); - std::vector prefetch_block_ids{block->ID()}; - auto prepared = exe.Prepare(program, prefetch_block_ids); - InitTensorsOnServer(&scope, &place, 10); - - std::unordered_map> - grad_to_prepared_ctx; - grad_to_prepared_ctx[in_var_name] = prepared[0]; - - g_req_handler->SetProgram(&program); - g_req_handler->SetGradToPreparedCtx(&grad_to_prepared_ctx); - g_req_handler->SetDevCtx(&ctx); - g_req_handler->SetScope(&scope); - g_req_handler->SetExecutor(&exe); - - g_rpc_service->RegisterRPC(rpc_name, g_req_handler.get()); - g_req_handler->SetRPCServer(g_rpc_service.get()); - - std::thread server_thread( - std::bind(&distributed::RPCServer::StartServer, g_rpc_service.get())); - - server_thread.join(); -} - -TEST(COMPLETE, CPU) { - setenv("http_proxy", "", 1); - setenv("https_proxy", "", 1); - g_req_handler.reset( - new distributed::RequestSendHandler(distributed::DistributedMode::kSync)); - g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 2)); - distributed::RPCClient* client = - distributed::RPCClient::GetInstance(0); - PADDLE_ENFORCE_NE(client, nullptr, - platform::errors::InvalidArgument( - "Client Start Fail, Check Your Code & Env")); - std::thread server_thread(StartServer, distributed::kRequestSend); - g_rpc_service->WaitServerReady(); - int port = g_rpc_service->GetSelectedPort(); - std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port); - client->AsyncSendComplete(ep); - client->Wait(); - - EXPECT_EQ(g_rpc_service->GetClientNum(), 1); - - g_rpc_service->ShutDown(); - server_thread.join(); - g_rpc_service.reset(nullptr); - g_req_handler.reset(nullptr); -} - -TEST(SENDANDRECV, CPU) { - setenv("http_proxy", "", 1); - setenv("https_proxy", "", 1); - g_req_handler.reset(new distributed::RequestSendAndRecvHandler( - distributed::DistributedMode::kAsync)); - g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1)); - distributed::RPCClient* client = - distributed::RPCClient::GetInstance(0); - PADDLE_ENFORCE_NE(client, nullptr, - platform::errors::InvalidArgument( - "Client Start Fail, Check Your Code & Env")); - std::thread server_thread(StartSendAndRecvServer, - distributed::kRequestSendAndRecv); - g_rpc_service->WaitServerReady(); - int port = g_rpc_service->GetSelectedPort(); - std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port); - - framework::Scope scope; - platform::CPUPlace place; - platform::CPUDeviceContext ctx(place); - - // create var on local scope - int64_t rows_numel = 10; - InitTensorsOnClient(&scope, &place, rows_numel); - std::string in_var_name("x"); - std::string out_var_name("res"); - - client->AsyncSendAndRecv(ep, ctx, scope, in_var_name, out_var_name); - client->Wait(); - auto var = scope.Var(out_var_name); - auto value = var->GetMutable(); - auto ptr = value->mutable_data(place); - - for (int64_t i = 0; i < rows_numel; ++i) { - EXPECT_EQ(ptr[i], 0.5); - } - g_rpc_service->ShutDown(); - server_thread.join(); - LOG(INFO) << "begin reset"; - g_rpc_service.reset(nullptr); - g_req_handler.reset(nullptr); -} - -void StartCheckpointServer(const std::string& rpc_name) { - framework::ProgramDesc program; - framework::Scope scope; - platform::CPUPlace place; - framework::Executor exe(place); - platform::CPUDeviceContext ctx(place); - - std::vector metas; - - auto meta = distributed::SparseMeta(); - meta.name = "embedding.block0"; - meta.value_names = {"Param"}; - meta.value_dims = {64}; - meta.mode = distributed::Mode::training; - meta.grad_name = "embedding@Grad"; - meta.cached_varnames = {"kSparseIds"}; - meta.initializer_attrs = {"fill_constant&1.0"}; - meta.entry = "none"; - - metas.push_back(meta); - distributed::LargeScaleKV::Init(metas); - - auto* ins = distributed::LargeScaleKV::GetInstance(); - ins->Get("embedding.block0")->Init({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); - - std::unordered_map> - prefetch_var_name_to_prepared; - - g_req_handler->SetProgram(&program); - g_req_handler->SetPrefetchPreparedCtx(&prefetch_var_name_to_prepared); - g_req_handler->SetDevCtx(&ctx); - g_req_handler->SetScope(&scope); - g_req_handler->SetExecutor(&exe); - - g_rpc_service->RegisterRPC(rpc_name, g_req_handler.get()); - - g_req_handler->SetRPCServer(g_rpc_service.get()); - - std::thread server_thread( - std::bind(&distributed::RPCServer::StartServer, g_rpc_service.get())); - - server_thread.join(); -} - -TEST(LARGE_SCALE_CHECKPOINT, CPU) { - setenv("http_proxy", "", 1); - setenv("https_proxy", "", 1); - - paddle::framework::Scope scope; - paddle::platform::CPUPlace place; - - g_req_handler.reset(new distributed::RequestCheckpointHandler( - distributed::DistributedMode::kAsync)); - g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1)); - - distributed::RPCClient* client = - distributed::RPCClient::GetInstance(0); - - PADDLE_ENFORCE_NE(client, nullptr, - platform::errors::InvalidArgument( - "Client Start Fail, Check Your Code & Env")); - - std::thread server_thread(StartCheckpointServer, - distributed::kRequestCheckpoint); - g_rpc_service->WaitServerReady(); - - int port = g_rpc_service->GetSelectedPort(); - std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port); - - auto save_path = - paddle::string::Sprintf("%s/%s/%s", "/tmp/large_scale_table/base", - "embedding", "embedding.block0"); - int mode = 0; - client->AsyncCheckpointNotify(ep, save_path, "embedding.block0", mode); - client->Wait(); - - save_path = - paddle::string::Sprintf("%s/%s/%s", "/tmp/large_scale_table/delta", - "embedding", "embedding.block0"); - mode = 1; - client->AsyncCheckpointNotify(ep, save_path, "embedding.block0", mode); - client->Wait(); - - paddle::framework::AttributeMap attrs; - - std::vector eps = {ep}; - attrs["endpoints"] = eps; - attrs["dirname"] = std::string("/tmp/large_scale_table/delta1"); - attrs["varname"] = std::string("embedding"); - attrs["mode"] = 2; - std::vector slices = {"embedding.block0"}; - attrs["slice_varnames"] = slices; - std::vector remotes = {"embedding.block0"}; - attrs["remote_varnames"] = remotes; - - auto ops = - framework::OpRegistry::CreateOp("checkpoint_notify", {}, {}, attrs, true); - ops->Run(scope, place); - - g_rpc_service->ShutDown(); - server_thread.join(); - LOG(INFO) << "begin reset"; - g_rpc_service.reset(nullptr); - g_req_handler.reset(nullptr); -} diff --git a/paddle/fluid/operators/distributed/send_recv.proto.in b/paddle/fluid/operators/distributed/send_recv.proto.in deleted file mode 100644 index a333642bd1..0000000000 --- a/paddle/fluid/operators/distributed/send_recv.proto.in +++ /dev/null @@ -1,88 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. Licensed under -the Apache License, Version 2.0 (the "License"); you may not use this file -except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -syntax = "proto3"; -package sendrecv; - -option cc_generic_services = @cc_generic_services@; - -service SendRecvService { - // For parameter server round-robin like hashing, do not split tensors. - // Send and recv only one tensor - // TODO(typhoonzero): add streaming API - rpc SendVariable(VariableMessage) returns (VoidMessage) {} - // Argument VariableMessage for GetVariable should only contain varname. - rpc GetVariable(VariableMessage) returns (VariableMessage) {} - rpc GetVariableNoBarrier(VariableMessage) returns (VariableMessage) {} - // pre-fetch variable by given variable name and Ids - rpc PrefetchVariable(VariableMessage) returns (VariableMessage) {} - - rpc CheckpointNotify(VariableMessage) returns (VoidMessage) {} - rpc DistributeNotify(VariableMessage) returns (VoidMessage) {} - rpc SendAndRecvVariable(VariableMessage) returns (VariableMessage) {} - rpc GetMonomerVariable(VariableMessage) returns (VariableMessage) {} - rpc GetMonomerBarrier(VariableMessage) returns (VoidMessage) {} -} - -// It can be: LoDTensor、SelectedRows or NCCL_ID -enum VarType { - LOD_TENSOR = 0; - SELECTED_ROWS = 1; - NCCL_ID = 2; -} - -// VariableMessage is serialized paddle variable message. -// NOTICE(gongwb):don't modify this proto if you are not -// not familar with how we serialize in sendrecvop_utils.h -// and deserilize it in variable_response.h. -message VariableMessage { - enum Type { - // Pod Types - BOOL = 0; - INT16 = 1; - INT32 = 2; - INT64 = 3; - FP16 = 4; - FP32 = 5; - FP64 = 6; - } - - message LodData { repeated int64 lod_data = 1; } - string varname = 1; - // TODO(Yancey1989): reference framework::proto::VarDesc::VarType - VarType type = 2; - // bool persistable is not needed for sending. - // tensor info: - Type data_type = 3; - repeated int64 dims = 4; - - // lod details: - int64 lod_level = 5; - repeated LodData lod = 6; - // selected_rows height, aka. original dim0 - int64 slr_height = 7; - // tensor data - bytes serialized = 8; - // selected_rows data - bytes rows = 9; - // Look up table block execution output variable name. - string out_varname = 10; - // If 1, the ps server will start profiling, the ps - // server stops profiling and generates a profile to /tmp/profile_ps_* - // when profile switches from 1 to 2. - int64 profile = 11; - int64 trainer_id = 12; - string table_name = 13; -} - -message VoidMessage {} diff --git a/paddle/fluid/operators/distributed/sendrecvop_utils.cc b/paddle/fluid/operators/distributed/sendrecvop_utils.cc deleted file mode 100644 index 107c74eb26..0000000000 --- a/paddle/fluid/operators/distributed/sendrecvop_utils.cc +++ /dev/null @@ -1,117 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -#include - -#include "paddle/fluid/operators/distributed/sendrecvop_utils.h" - -namespace paddle { -namespace framework { -class Variable; -} // namespace framework -namespace memory { -namespace allocation { -class Allocation; -} // namespace allocation -} // namespace memory -} // namespace paddle - -DEFINE_bool(rpc_disable_reuse_port, false, "Disable SO_REUSEPORT or not."); -DEFINE_int32(rpc_retry_bind_port, 3, - "Retry to bind the address if address is already used."); - -namespace paddle { -namespace operators { -namespace distributed { - -using VarMsg = sendrecv::VariableMessage; - -static TensorPayload GetCommunicationAllocationFromTensor( - const platform::DeviceContext& ctx, const framework::Tensor& tensor) { - if (is_gpu_place(ctx.GetPlace())) { -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - PADDLE_ENFORCE_EQ( - is_gpu_place(tensor.place()), true, - platform::errors::PreconditionNotMet("Please run in gpu place.")); - auto& gpu_dev_ctx = - reinterpret_cast(ctx); - auto copy_size = tensor.numel() * framework::SizeOfType(tensor.type()); - platform::CUDAPinnedPlace cuda_pinned; - auto result = memory::AllocShared(cuda_pinned, copy_size); - - memory::Copy(cuda_pinned, result->ptr(), - BOOST_GET_CONST(platform::CUDAPlace, tensor.place()), - tensor.data(), copy_size, gpu_dev_ctx.stream()); - ctx.Wait(); - return TensorPayload(result); -#else - PADDLE_THROW( - platform::errors::Unavailable("This situation should not be happened")); -#endif - } else { - return TensorPayload(tensor); - } -} -TensorPayload GetTensorPayload(framework::Variable* var, - const platform::DeviceContext& ctx, - VarMsg* request) { - auto tensor = var->Get(); - // FIXME(wuyi): data types in send_recv.proto is copied from - // framework.proto - request->set_data_type(static_cast(tensor.type())); - for (auto& dim : framework::vectorize(tensor.dims())) { - request->add_dims(dim); - } - const framework::LoD lod = tensor.lod(); - if (lod.size() > 0) { - request->set_lod_level(lod.size()); - for (auto& each : lod) { - VarMsg::LodData* lod_inner = request->add_lod(); - for (auto& d : each) { - lod_inner->add_lod_data(d); - } - } - } - return GetCommunicationAllocationFromTensor(ctx, tensor); -} - -TensorPayload GetSelectedRowsPayload(framework::Variable* var, - const platform::DeviceContext& ctx, - VarMsg* request) { - auto* slr = var->GetMutable(); - request->set_data_type(static_cast(slr->value().type())); - request->set_lod_level(0); - request->set_slr_height(slr->height()); - - for (auto& dim : framework::vectorize(slr->value().dims())) { - request->add_dims(dim); - } - - auto* tensor = slr->mutable_value(); - return GetCommunicationAllocationFromTensor(ctx, *tensor); -} - -TensorPayload::TensorPayload(std::shared_ptr allocation) - : allocation_(allocation), offset_(0), memory_size_(allocation->size()) {} -TensorPayload::TensorPayload(const framework::Tensor& tensor) - : allocation_(tensor.Holder()), - offset_(tensor.offset()), - memory_size_(tensor.numel() * framework::SizeOfType(tensor.type())) {} -void* TensorPayload::ptr() const { - return reinterpret_cast( - reinterpret_cast(allocation_->ptr()) + offset_); -} -size_t TensorPayload::memory_size() const { return memory_size_; } -} // namespace distributed -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/distributed/sendrecvop_utils.h b/paddle/fluid/operators/distributed/sendrecvop_utils.h deleted file mode 100644 index 84ed1ab024..0000000000 --- a/paddle/fluid/operators/distributed/sendrecvop_utils.h +++ /dev/null @@ -1,110 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include -#include -#include -#include - -#include "paddle/fluid/framework/data_type.h" -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/scope.h" -#include "paddle/fluid/framework/selected_rows.h" -#include "paddle/fluid/framework/tensor_util.h" -#include "paddle/fluid/framework/var_type.h" -#include "paddle/fluid/operators/distributed/distributed_pb.h" -#include "paddle/fluid/platform/port.h" - -namespace paddle { -namespace framework { -class Tensor; -class Variable; -} // namespace framework -namespace memory { -namespace allocation { -class Allocation; -} // namespace allocation -} // namespace memory -namespace platform { -class DeviceContext; -} // namespace platform -} // namespace paddle - -namespace paddle { -namespace operators { -namespace distributed { - -using VarMsg = sendrecv::VariableMessage; - -class TensorPayload final { - public: - explicit TensorPayload(const framework::Tensor& tensor); - explicit TensorPayload(std::shared_ptr allocation); - - TensorPayload(const TensorPayload& o) = default; - TensorPayload& operator=(const TensorPayload& o) = default; - - void* ptr() const; - size_t memory_size() const; - - private: - std::shared_ptr allocation_; - size_t offset_; - size_t memory_size_; -}; - -inline void SerializeDestroyCallback(void* payload) { - if (payload != nullptr) { - auto* shared_payload = reinterpret_cast(payload); - delete shared_payload; - } -} - -TensorPayload GetTensorPayload(framework::Variable* var, - const platform::DeviceContext& ctx, - VarMsg* request); - -TensorPayload GetSelectedRowsPayload(framework::Variable* var, - const platform::DeviceContext& ctx, - VarMsg* request); - -inline framework::proto::VarType::Type ToVarType( - sendrecv::VariableMessage::Type type) { - switch (type) { - case sendrecv::VariableMessage::FP32: - return framework::proto::VarType::FP32; // NOLINT - case sendrecv::VariableMessage::FP64: - return framework::proto::VarType::FP64; // NOLINT - case sendrecv::VariableMessage::INT32: - return framework::proto::VarType::INT32; // NOLINT - case sendrecv::VariableMessage::INT64: - return framework::proto::VarType::INT64; // NOLINT - case sendrecv::VariableMessage::BOOL: - return framework::proto::VarType::BOOL; // NOLINT - default: - PADDLE_THROW( - platform::errors::InvalidArgument("Not support type id: %d.", type)); - } -} - -template