fleet_py.cc 18.8 KB
Newer Older
T
tangwei12 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <fcntl.h>

#ifdef _POSIX_C_SOURCE
#undef _POSIX_C_SOURCE
#endif

#ifdef _XOPEN_SOURCE
#undef _XOPEN_SOURCE
#endif

#include <map>
#include <memory>
#include <string>
#include <vector>

1
123malin 已提交
26 27
#include "paddle/fluid/distributed/index_dataset/index_sampler.h"
#include "paddle/fluid/distributed/index_dataset/index_wrapper.h"
28 29 30 31 32 33 34
#include "paddle/fluid/distributed/ps/service/communicator/communicator.h"
#include "paddle/fluid/distributed/ps/service/communicator/communicator_common.h"
#include "paddle/fluid/distributed/ps/service/env.h"
#include "paddle/fluid/distributed/ps/service/graph_brpc_client.h"
#include "paddle/fluid/distributed/ps/service/heter_client.h"
#include "paddle/fluid/distributed/ps/service/ps_service/graph_py_service.h"
#include "paddle/fluid/distributed/ps/wrapper/fleet.h"
35
#include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h"
36
#include "paddle/fluid/pybind/fleet_py.h"
T
tangwei12 已提交
37 38 39 40

namespace py = pybind11;
using paddle::distributed::CommContext;
using paddle::distributed::Communicator;
41
using paddle::distributed::FeatureNode;
T
tangwei12 已提交
42
using paddle::distributed::FleetWrapper;
S
seemingwang 已提交
43 44
using paddle::distributed::GraphNode;
using paddle::distributed::GraphPyClient;
45 46 47
using paddle::distributed::GraphPyServer;
using paddle::distributed::GraphPyService;
using paddle::distributed::HeterClient;
T
tangwei12 已提交
48 49 50 51 52 53 54 55

namespace paddle {
namespace pybind {
void BindDistFleetWrapper(py::module* m) {
  py::class_<FleetWrapper, std::shared_ptr<FleetWrapper>>(*m,
                                                          "DistFleetWrapper")
      .def(py::init([]() { return FleetWrapper::GetInstance(); }))
      .def("load_sparse", &FleetWrapper::LoadSparseOnServer)
T
Thunderbrook 已提交
56 57
      .def("load_model", &FleetWrapper::LoadModel)
      .def("load_one_table", &FleetWrapper::LoadModelOneTable)
T
tangwei12 已提交
58
      .def("init_server", &FleetWrapper::InitServer)
59
      .def("run_server", &FleetWrapper::RunServer)
T
tangwei12 已提交
60 61 62 63 64
      .def("init_worker", &FleetWrapper::InitWorker)
      .def("push_dense_params", &FleetWrapper::PushDenseParamSync)
      .def("pull_dense_params", &FleetWrapper::PullDenseVarsSync)
      .def("save_all_model", &FleetWrapper::SaveModel)
      .def("save_one_model", &FleetWrapper::SaveModelOneTable)
65
      .def("recv_and_save_model", &FleetWrapper::RecvAndSaveTable)
T
tangwei12 已提交
66 67 68
      .def("sparse_table_stat", &FleetWrapper::PrintTableStat)
      .def("stop_server", &FleetWrapper::StopServer)
      .def("stop_worker", &FleetWrapper::FinalizeWorker)
69
      .def("barrier", &FleetWrapper::BarrierWithTable)
70
      .def("shrink_sparse_table", &FleetWrapper::ShrinkSparseTable)
71 72
      .def("set_clients", &FleetWrapper::SetClients)
      .def("get_client_info", &FleetWrapper::GetClientsInfo)
73
      .def("create_client2client_connection",
Z
zhaocaibei123 已提交
74 75 76 77
           &FleetWrapper::CreateClient2ClientConnection)
      .def("client_flush", &FleetWrapper::ClientFlush)
      .def("get_cache_threshold", &FleetWrapper::GetCacheThreshold)
      .def("cache_shuffle", &FleetWrapper::CacheShuffle)
Z
zhaocaibei123 已提交
78
      .def("save_cache", &FleetWrapper::SaveCache)
79 80 81
      .def("init_fl_worker", &FleetWrapper::InitFlWorker)
      .def("push_fl_client_info_sync", &FleetWrapper::PushFLClientInfoSync)
      .def("pull_fl_strategy", &FleetWrapper::PullFlStrategy)
Z
zhaocaibei123 已提交
82 83
      .def("revert", &FleetWrapper::Revert)
      .def("check_save_pre_patch_done", &FleetWrapper::CheckSavePrePatchDone);
84
}
T
tangwei12 已提交
85 86 87 88

void BindPSHost(py::module* m) {
  py::class_<distributed::PSHost>(*m, "PSHost")
      .def(py::init<const std::string&, uint32_t, uint32_t>())
Z
zhaocaibei123 已提交
89 90 91 92 93
      .def("serialize_to_string", &distributed::PSHost::SerializeToString)
      .def("parse_from_string", &distributed::PSHost::ParseFromString)
      .def("to_uint64", &distributed::PSHost::SerializeToUint64)
      .def("from_uint64", &distributed::PSHost::ParseFromUint64)
      .def("to_string", &distributed::PSHost::ToString);
T
tangwei12 已提交
94 95 96 97
}

void BindCommunicatorContext(py::module* m) {
  py::class_<CommContext>(*m, "CommContext")
98 99 100 101 102 103 104 105 106 107 108 109
      .def(py::init<const std::string&,
                    const std::vector<std::string>&,
                    const std::vector<std::string>&,
                    const std::vector<int64_t>&,
                    const std::vector<std::string>&,
                    int,
                    bool,
                    bool,
                    bool,
                    int,
                    bool,
                    bool,
110 111
                    int64_t,
                    const std::vector<int32_t>&>())
T
tangwei12 已提交
112
      .def("var_name", [](const CommContext& self) { return self.var_name; })
113 114
      .def("remote_sparse_ids",
           [](const CommContext& self) { return self.remote_sparse_ids; })
T
tangwei12 已提交
115 116 117
      .def("trainer_id",
           [](const CommContext& self) { return self.trainer_id; })
      .def("table_id", [](const CommContext& self) { return self.table_id; })
W
wangguanqun 已提交
118 119
      .def("program_id",
           [](const CommContext& self) { return self.program_id; })
T
tangwei12 已提交
120 121 122 123 124 125 126 127 128 129 130 131
      .def("split_varnames",
           [](const CommContext& self) { return self.splited_varnames; })
      .def("split_endpoints",
           [](const CommContext& self) { return self.epmap; })
      .def("sections",
           [](const CommContext& self) { return self.height_sections; })
      .def("aggregate", [](const CommContext& self) { return self.merge_add; })
      .def("is_sparse", [](const CommContext& self) { return self.is_sparse; })
      .def("is_distributed",
           [](const CommContext& self) { return self.is_distributed; })
      .def("origin_varnames",
           [](const CommContext& self) { return self.origin_varnames; })
132 133
      .def("is_tensor_table",
           [](const CommContext& self) { return self.is_tensor_table; })
W
wangguanqun 已提交
134 135
      .def("is_datanorm_table",
           [](const CommContext& self) { return self.is_datanorm_table; })
T
tangwei12 已提交
136 137 138 139
      .def("__str__", [](const CommContext& self) { return self.print(); });
}

using paddle::distributed::AsyncCommunicator;
140
using paddle::distributed::FLCommunicator;
T
tangwei12 已提交
141 142 143 144 145 146 147 148 149 150
using paddle::distributed::GeoCommunicator;
using paddle::distributed::RecvCtxMap;
using paddle::distributed::RpcCtxMap;
using paddle::distributed::SyncCommunicator;
using paddle::framework::Scope;

void BindDistCommunicator(py::module* m) {
  // Communicator is already used by nccl, change to DistCommunicator
  py::class_<Communicator, std::shared_ptr<Communicator>>(*m,
                                                          "DistCommunicator")
151 152
      .def(py::init([](const std::string& mode,
                       const std::string& dist_desc,
T
tangwei12 已提交
153
                       const std::vector<std::string>& host_sign_list,
154 155
                       const RpcCtxMap& send_ctx,
                       const RecvCtxMap& recv_ctx,
T
tangwei12 已提交
156 157 158 159 160 161 162 163 164 165 166
                       Scope* param_scope,
                       std::map<std::string, std::string>& envs) {
        if (mode == "ASYNC") {
          Communicator::InitInstance<AsyncCommunicator>(
              send_ctx, recv_ctx, dist_desc, host_sign_list, param_scope, envs);
        } else if (mode == "SYNC") {
          Communicator::InitInstance<SyncCommunicator>(
              send_ctx, recv_ctx, dist_desc, host_sign_list, param_scope, envs);
        } else if (mode == "GEO") {
          Communicator::InitInstance<GeoCommunicator>(
              send_ctx, recv_ctx, dist_desc, host_sign_list, param_scope, envs);
167 168 169
        } else if (mode == "WITH_COORDINATOR") {
          Communicator::InitInstance<FLCommunicator>(
              send_ctx, recv_ctx, dist_desc, host_sign_list, param_scope, envs);
T
tangwei12 已提交
170 171 172 173 174 175 176 177 178 179
        } else {
          PADDLE_THROW(platform::errors::InvalidArgument(
              "unsuported communicator MODE"));
        }
        return Communicator::GetInstantcePtr();
      }))
      .def("stop", &Communicator::Stop)
      .def("start", &Communicator::Start)
      .def("push_sparse_param", &Communicator::RpcSendSparseParam)
      .def("is_running", &Communicator::IsRunning)
180
      .def("init_params", &Communicator::InitParams)
181 182 183 184
      .def("pull_dense", &Communicator::PullDense)
      .def("create_client_to_client_connection",
           &Communicator::CreateC2CConnection)
      .def("get_client_info", &Communicator::GetClientInfo)
185 186 187 188
      .def("set_clients", &Communicator::SetClients)
      .def("start_coordinator", &Communicator::StartCoordinator)
      .def("query_fl_clients_info", &Communicator::QueryFLClientsInfo)
      .def("save_fl_strategy", &Communicator::SaveFLStrategy);
T
tangwei12 已提交
189 190 191 192
}

void BindHeterClient(py::module* m) {
  py::class_<HeterClient, std::shared_ptr<HeterClient>>(*m, "HeterClient")
193 194 195
      .def(py::init([](const std::vector<std::string>& endpoints,
                       const std::vector<std::string>& previous_endpoints,
                       const int& trainer_id) {
196 197
        return HeterClient::GetInstance(
            endpoints, previous_endpoints, trainer_id);
198
      }))
T
tangwei12 已提交
199 200 201
      .def("stop", &HeterClient::Stop);
}

S
seemingwang 已提交
202 203 204
void BindGraphNode(py::module* m) {
  py::class_<GraphNode>(*m, "GraphNode")
      .def(py::init<>())
D
danleifeng 已提交
205
      .def("get_id", &GraphNode::get_py_id)
S
seemingwang 已提交
206 207 208 209 210
      .def("get_feature", &GraphNode::get_feature);
}
void BindGraphPyFeatureNode(py::module* m) {
  py::class_<FeatureNode>(*m, "FeatureNode")
      .def(py::init<>())
D
danleifeng 已提交
211
      .def("get_id", &GraphNode::get_py_id)
S
seemingwang 已提交
212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234
      .def("get_feature", &GraphNode::get_feature);
}

void BindGraphPyService(py::module* m) {
  py::class_<GraphPyService>(*m, "GraphPyService").def(py::init<>());
}

void BindGraphPyServer(py::module* m) {
  py::class_<GraphPyServer>(*m, "GraphPyServer")
      .def(py::init<>())
      .def("start_server", &GraphPyServer::start_server)
      .def("set_up", &GraphPyServer::set_up)
      .def("add_table_feat_conf", &GraphPyServer::add_table_feat_conf);
}
void BindGraphPyClient(py::module* m) {
  py::class_<GraphPyClient>(*m, "GraphPyClient")
      .def(py::init<>())
      .def("load_edge_file", &GraphPyClient::load_edge_file)
      .def("load_node_file", &GraphPyClient::load_node_file)
      .def("set_up", &GraphPyClient::set_up)
      .def("add_table_feat_conf", &GraphPyClient::add_table_feat_conf)
      .def("pull_graph_list", &GraphPyClient::pull_graph_list)
      .def("start_client", &GraphPyClient::start_client)
235 236
      .def("batch_sample_neighboors", &GraphPyClient::batch_sample_neighbors)
      .def("batch_sample_neighbors", &GraphPyClient::batch_sample_neighbors)
237 238
      // .def("use_neighbors_sample_cache",
      //      &GraphPyClient::use_neighbors_sample_cache)
S
seemingwang 已提交
239
      .def("remove_graph_node", &GraphPyClient::remove_graph_node)
S
seemingwang 已提交
240
      .def("random_sample_nodes", &GraphPyClient::random_sample_nodes)
Z
zhaocaibei123 已提交
241
      .def("stop_server", &GraphPyClient::StopServer)
S
seemingwang 已提交
242
      .def("get_node_feat",
243 244
           [](GraphPyClient& self,
              std::string node_type,
245
              std::vector<int64_t> node_ids,
S
seemingwang 已提交
246 247 248 249
              std::vector<std::string> feature_names) {
             auto feats =
                 self.get_node_feat(node_type, node_ids, feature_names);
             std::vector<std::vector<py::bytes>> bytes_feats(feats.size());
Z
zhangchunle 已提交
250 251
             for (size_t i = 0; i < feats.size(); ++i) {
               for (size_t j = 0; j < feats[i].size(); ++j) {
S
seemingwang 已提交
252 253 254 255 256
                 bytes_feats[i].push_back(py::bytes(feats[i][j]));
               }
             }
             return bytes_feats;
           })
S
seemingwang 已提交
257
      .def("set_node_feat",
258 259
           [](GraphPyClient& self,
              std::string node_type,
260
              std::vector<int64_t> node_ids,
S
seemingwang 已提交
261 262 263
              std::vector<std::string> feature_names,
              std::vector<std::vector<py::bytes>> bytes_feats) {
             std::vector<std::vector<std::string>> feats(bytes_feats.size());
Z
zhangchunle 已提交
264 265
             for (size_t i = 0; i < bytes_feats.size(); ++i) {
               for (size_t j = 0; j < bytes_feats[i].size(); ++j) {
S
seemingwang 已提交
266 267 268 269 270 271
                 feats[i].push_back(std::string(bytes_feats[i][j]));
               }
             }
             self.set_node_feat(node_type, node_ids, feature_names, feats);
             return;
           })
S
seemingwang 已提交
272 273 274
      .def("bind_local_server", &GraphPyClient::bind_local_server);
}

1
123malin 已提交
275
using paddle::distributed::IndexNode;
276 277
using paddle::distributed::IndexWrapper;
using paddle::distributed::TreeIndex;
278 279
#ifdef PADDLE_WITH_HETERPS
using paddle::framework::GraphGpuWrapper;
280
using paddle::framework::NeighborSampleQuery;
281
using paddle::framework::NeighborSampleResult;
282
using paddle::framework::NodeQueryResult;
283
#endif
1
123malin 已提交
284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333

void BindIndexNode(py::module* m) {
  py::class_<IndexNode>(*m, "IndexNode")
      .def(py::init<>())
      .def("id", [](IndexNode& self) { return self.id(); })
      .def("is_leaf", [](IndexNode& self) { return self.is_leaf(); })
      .def("probability", [](IndexNode& self) { return self.probability(); });
}

void BindTreeIndex(py::module* m) {
  py::class_<TreeIndex, std::shared_ptr<TreeIndex>>(*m, "TreeIndex")
      .def(py::init([](const std::string name, const std::string path) {
        auto index_wrapper = IndexWrapper::GetInstancePtr();
        index_wrapper->insert_tree_index(name, path);
        return index_wrapper->get_tree_index(name);
      }))
      .def("height", [](TreeIndex& self) { return self.Height(); })
      .def("branch", [](TreeIndex& self) { return self.Branch(); })
      .def("total_node_nums",
           [](TreeIndex& self) { return self.TotalNodeNums(); })
      .def("emb_size", [](TreeIndex& self) { return self.EmbSize(); })
      .def("get_all_leafs", [](TreeIndex& self) { return self.GetAllLeafs(); })
      .def("get_nodes",
           [](TreeIndex& self, const std::vector<uint64_t>& codes) {
             return self.GetNodes(codes);
           })
      .def("get_layer_codes",
           [](TreeIndex& self, int level) { return self.GetLayerCodes(level); })
      .def("get_ancestor_codes",
           [](TreeIndex& self, const std::vector<uint64_t>& ids, int level) {
             return self.GetAncestorCodes(ids, level);
           })
      .def("get_children_codes",
           [](TreeIndex& self, uint64_t ancestor, int level) {
             return self.GetChildrenCodes(ancestor, level);
           })
      .def("get_travel_codes",
           [](TreeIndex& self, uint64_t id, int start_level) {
             return self.GetTravelCodes(id, start_level);
           });
}

void BindIndexWrapper(py::module* m) {
  py::class_<IndexWrapper, std::shared_ptr<IndexWrapper>>(*m, "IndexWrapper")
      .def(py::init([]() { return IndexWrapper::GetInstancePtr(); }))
      .def("insert_tree_index", &IndexWrapper::insert_tree_index)
      .def("get_tree_index", &IndexWrapper::get_tree_index)
      .def("clear_tree", &IndexWrapper::clear_tree);
}

334
#ifdef PADDLE_WITH_HETERPS
335 336 337 338 339 340 341 342 343 344 345 346 347 348 349
void BindNodeQueryResult(py::module* m) {
  py::class_<NodeQueryResult>(*m, "NodeQueryResult")
      .def(py::init<>())
      .def("initialize", &NodeQueryResult::initialize)
      .def("display", &NodeQueryResult::display)
      .def("get_val", &NodeQueryResult::get_val)
      .def("get_len", &NodeQueryResult::get_len);
}
void BindNeighborSampleQuery(py::module* m) {
  py::class_<NeighborSampleQuery>(*m, "NeighborSampleQuery")
      .def(py::init<>())
      .def("initialize", &NeighborSampleQuery::initialize)
      .def("display", &NeighborSampleQuery::display);
}

350 351 352
void BindNeighborSampleResult(py::module* m) {
  py::class_<NeighborSampleResult>(*m, "NeighborSampleResult")
      .def(py::init<>())
353
      .def("initialize", &NeighborSampleResult::initialize)
354 355
      .def("get_len", &NeighborSampleResult::get_len)
      .def("get_val", &NeighborSampleResult::get_actual_val)
S
seemingwang 已提交
356
      .def("get_sampled_graph", &NeighborSampleResult::get_sampled_graph)
357
      .def("display", &NeighborSampleResult::display);
358 359 360
}

void BindGraphGpuWrapper(py::module* m) {
S
seemingwang 已提交
361 362 363
  py::class_<GraphGpuWrapper, std::shared_ptr<GraphGpuWrapper>>(
      *m, "GraphGpuWrapper")
      .def(py::init([]() { return GraphGpuWrapper::GetInstance(); }))
364
      .def("neighbor_sample", &GraphGpuWrapper::graph_neighbor_sample_v3)
D
danleifeng 已提交
365 366 367 368 369 370
      .def("graph_neighbor_sample",
           py::overload_cast<int, uint64_t*, int, int>(
               &GraphGpuWrapper::graph_neighbor_sample))
      .def("graph_neighbor_sample",
           py::overload_cast<int, int, std::vector<uint64_t>&, int>(
               &GraphGpuWrapper::graph_neighbor_sample))
371
      .def("set_device", &GraphGpuWrapper::set_device)
D
danleifeng 已提交
372
      .def("set_feature_separator", &GraphGpuWrapper::set_feature_separator)
373 374
      .def("init_service", &GraphGpuWrapper::init_service)
      .def("set_up_types", &GraphGpuWrapper::set_up_types)
375
      .def("query_node_list", &GraphGpuWrapper::query_node_list)
376 377
      .def("add_table_feat_conf", &GraphGpuWrapper::add_table_feat_conf)
      .def("load_edge_file", &GraphGpuWrapper::load_edge_file)
D
danleifeng 已提交
378 379 380 381 382 383 384 385 386 387 388 389 390
      .def("load_node_and_edge", &GraphGpuWrapper::load_node_and_edge)
      .def("upload_batch",
           py::overload_cast<int, int, int, const std::string&>(
               &GraphGpuWrapper::upload_batch))
      .def("upload_batch",
           py::overload_cast<int, int, int>(&GraphGpuWrapper::upload_batch))
      .def(
          "get_all_id",
          py::overload_cast<int, int, int, std::vector<std::vector<uint64_t>>*>(
              &GraphGpuWrapper::get_all_id))
      .def("get_all_id",
           py::overload_cast<int, int, std::vector<std::vector<uint64_t>>*>(
               &GraphGpuWrapper::get_all_id))
391 392 393 394 395 396 397 398
      .def("load_next_partition", &GraphGpuWrapper::load_next_partition)
      .def("make_partitions", &GraphGpuWrapper::make_partitions)
      .def("make_complementary_graph",
           &GraphGpuWrapper::make_complementary_graph)
      .def("set_search_level", &GraphGpuWrapper::set_search_level)
      .def("init_search_level", &GraphGpuWrapper::init_search_level)
      .def("get_partition_num", &GraphGpuWrapper::get_partition_num)
      .def("get_partition", &GraphGpuWrapper::get_partition)
399 400
      .def("load_node_weight", &GraphGpuWrapper::load_node_weight)
      .def("export_partition_files", &GraphGpuWrapper::export_partition_files)
D
danleifeng 已提交
401 402
      .def("load_node_file", &GraphGpuWrapper::load_node_file)
      .def("finalize", &GraphGpuWrapper::finalize);
403 404 405
}
#endif

1
123malin 已提交
406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422
using paddle::distributed::IndexSampler;
using paddle::distributed::LayerWiseSampler;

void BindIndexSampler(py::module* m) {
  py::class_<IndexSampler, std::shared_ptr<IndexSampler>>(*m, "IndexSampler")
      .def(py::init([](const std::string& mode, const std::string& name) {
        if (mode == "by_layerwise") {
          return IndexSampler::Init<LayerWiseSampler>(name);
        } else {
          PADDLE_THROW(platform::errors::InvalidArgument(
              "Unsupported IndexSampler Type!"));
        }
      }))
      .def("init_layerwise_conf", &IndexSampler::init_layerwise_conf)
      .def("init_beamsearch_conf", &IndexSampler::init_beamsearch_conf)
      .def("sample", &IndexSampler::sample);
}
T
tangwei12 已提交
423 424
}  // end namespace pybind
}  // namespace paddle