fleet_py.cc 15.2 KB
Newer Older
T
tangwei12 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <fcntl.h>

#ifdef _POSIX_C_SOURCE
#undef _POSIX_C_SOURCE
#endif

#ifdef _XOPEN_SOURCE
#undef _XOPEN_SOURCE
#endif

#include "paddle/fluid/pybind/fleet_py.h"

#include <map>
#include <memory>
#include <string>
#include <vector>

1
123malin 已提交
31 32
#include "paddle/fluid/distributed/index_dataset/index_sampler.h"
#include "paddle/fluid/distributed/index_dataset/index_wrapper.h"
33 34 35 36 37 38 39
#include "paddle/fluid/distributed/ps/service/communicator/communicator.h"
#include "paddle/fluid/distributed/ps/service/communicator/communicator_common.h"
#include "paddle/fluid/distributed/ps/service/env.h"
#include "paddle/fluid/distributed/ps/service/graph_brpc_client.h"
#include "paddle/fluid/distributed/ps/service/heter_client.h"
#include "paddle/fluid/distributed/ps/service/ps_service/graph_py_service.h"
#include "paddle/fluid/distributed/ps/wrapper/fleet.h"
40
#include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h"
T
tangwei12 已提交
41 42 43 44 45 46

namespace py = pybind11;
using paddle::distributed::CommContext;
using paddle::distributed::Communicator;
using paddle::distributed::FleetWrapper;
using paddle::distributed::HeterClient;
S
seemingwang 已提交
47 48 49 50 51
using paddle::distributed::GraphPyService;
using paddle::distributed::GraphNode;
using paddle::distributed::GraphPyServer;
using paddle::distributed::GraphPyClient;
using paddle::distributed::FeatureNode;
T
tangwei12 已提交
52 53 54 55 56 57 58 59

namespace paddle {
namespace pybind {
void BindDistFleetWrapper(py::module* m) {
  py::class_<FleetWrapper, std::shared_ptr<FleetWrapper>>(*m,
                                                          "DistFleetWrapper")
      .def(py::init([]() { return FleetWrapper::GetInstance(); }))
      .def("load_sparse", &FleetWrapper::LoadSparseOnServer)
T
Thunderbrook 已提交
60 61
      .def("load_model", &FleetWrapper::LoadModel)
      .def("load_one_table", &FleetWrapper::LoadModelOneTable)
T
tangwei12 已提交
62
      .def("init_server", &FleetWrapper::InitServer)
63
      .def("run_server", &FleetWrapper::RunServer)
T
tangwei12 已提交
64 65 66 67 68
      .def("init_worker", &FleetWrapper::InitWorker)
      .def("push_dense_params", &FleetWrapper::PushDenseParamSync)
      .def("pull_dense_params", &FleetWrapper::PullDenseVarsSync)
      .def("save_all_model", &FleetWrapper::SaveModel)
      .def("save_one_model", &FleetWrapper::SaveModelOneTable)
69
      .def("recv_and_save_model", &FleetWrapper::RecvAndSaveTable)
T
tangwei12 已提交
70 71 72
      .def("sparse_table_stat", &FleetWrapper::PrintTableStat)
      .def("stop_server", &FleetWrapper::StopServer)
      .def("stop_worker", &FleetWrapper::FinalizeWorker)
73
      .def("barrier", &FleetWrapper::BarrierWithTable)
74
      .def("shrink_sparse_table", &FleetWrapper::ShrinkSparseTable)
75 76
      .def("set_clients", &FleetWrapper::SetClients)
      .def("get_client_info", &FleetWrapper::GetClientsInfo)
77
      .def("create_client2client_connection",
Z
zhaocaibei123 已提交
78 79 80 81 82
           &FleetWrapper::CreateClient2ClientConnection)
      .def("client_flush", &FleetWrapper::ClientFlush)
      .def("get_cache_threshold", &FleetWrapper::GetCacheThreshold)
      .def("cache_shuffle", &FleetWrapper::CacheShuffle)
      .def("save_cache", &FleetWrapper::SaveCache);
83
}
T
tangwei12 已提交
84 85 86 87

void BindPSHost(py::module* m) {
  py::class_<distributed::PSHost>(*m, "PSHost")
      .def(py::init<const std::string&, uint32_t, uint32_t>())
Z
zhaocaibei123 已提交
88 89 90 91 92
      .def("serialize_to_string", &distributed::PSHost::SerializeToString)
      .def("parse_from_string", &distributed::PSHost::ParseFromString)
      .def("to_uint64", &distributed::PSHost::SerializeToUint64)
      .def("from_uint64", &distributed::PSHost::ParseFromUint64)
      .def("to_string", &distributed::PSHost::ToString);
T
tangwei12 已提交
93 94 95 96 97 98 99
}

void BindCommunicatorContext(py::module* m) {
  py::class_<CommContext>(*m, "CommContext")
      .def(
          py::init<const std::string&, const std::vector<std::string>&,
                   const std::vector<std::string>&, const std::vector<int64_t>&,
100
                   const std::vector<std::string>&, int, bool, bool, bool, int,
W
wangguanqun 已提交
101
                   bool, bool, int64_t>())
T
tangwei12 已提交
102 103 104 105
      .def("var_name", [](const CommContext& self) { return self.var_name; })
      .def("trainer_id",
           [](const CommContext& self) { return self.trainer_id; })
      .def("table_id", [](const CommContext& self) { return self.table_id; })
W
wangguanqun 已提交
106 107
      .def("program_id",
           [](const CommContext& self) { return self.program_id; })
T
tangwei12 已提交
108 109 110 111 112 113 114 115 116 117 118 119
      .def("split_varnames",
           [](const CommContext& self) { return self.splited_varnames; })
      .def("split_endpoints",
           [](const CommContext& self) { return self.epmap; })
      .def("sections",
           [](const CommContext& self) { return self.height_sections; })
      .def("aggregate", [](const CommContext& self) { return self.merge_add; })
      .def("is_sparse", [](const CommContext& self) { return self.is_sparse; })
      .def("is_distributed",
           [](const CommContext& self) { return self.is_distributed; })
      .def("origin_varnames",
           [](const CommContext& self) { return self.origin_varnames; })
120 121
      .def("is_tensor_table",
           [](const CommContext& self) { return self.is_tensor_table; })
W
wangguanqun 已提交
122 123
      .def("is_datanorm_table",
           [](const CommContext& self) { return self.is_datanorm_table; })
T
tangwei12 已提交
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
      .def("__str__", [](const CommContext& self) { return self.print(); });
}

using paddle::distributed::AsyncCommunicator;
using paddle::distributed::GeoCommunicator;
using paddle::distributed::RecvCtxMap;
using paddle::distributed::RpcCtxMap;
using paddle::distributed::SyncCommunicator;
using paddle::framework::Scope;

void BindDistCommunicator(py::module* m) {
  // Communicator is already used by nccl, change to DistCommunicator
  py::class_<Communicator, std::shared_ptr<Communicator>>(*m,
                                                          "DistCommunicator")
      .def(py::init([](const std::string& mode, const std::string& dist_desc,
                       const std::vector<std::string>& host_sign_list,
                       const RpcCtxMap& send_ctx, const RecvCtxMap& recv_ctx,
                       Scope* param_scope,
                       std::map<std::string, std::string>& envs) {
        if (mode == "ASYNC") {
          Communicator::InitInstance<AsyncCommunicator>(
              send_ctx, recv_ctx, dist_desc, host_sign_list, param_scope, envs);
        } else if (mode == "SYNC") {
          Communicator::InitInstance<SyncCommunicator>(
              send_ctx, recv_ctx, dist_desc, host_sign_list, param_scope, envs);
        } else if (mode == "GEO") {
          Communicator::InitInstance<GeoCommunicator>(
              send_ctx, recv_ctx, dist_desc, host_sign_list, param_scope, envs);
        } else {
          PADDLE_THROW(platform::errors::InvalidArgument(
              "unsuported communicator MODE"));
        }
        return Communicator::GetInstantcePtr();
      }))
      .def("stop", &Communicator::Stop)
      .def("start", &Communicator::Start)
      .def("push_sparse_param", &Communicator::RpcSendSparseParam)
      .def("is_running", &Communicator::IsRunning)
162
      .def("init_params", &Communicator::InitParams)
163 164 165 166 167
      .def("pull_dense", &Communicator::PullDense)
      .def("create_client_to_client_connection",
           &Communicator::CreateC2CConnection)
      .def("get_client_info", &Communicator::GetClientInfo)
      .def("set_clients", &Communicator::SetClients);
T
tangwei12 已提交
168 169 170 171
}

void BindHeterClient(py::module* m) {
  py::class_<HeterClient, std::shared_ptr<HeterClient>>(*m, "HeterClient")
172 173 174 175 176 177
      .def(py::init([](const std::vector<std::string>& endpoints,
                       const std::vector<std::string>& previous_endpoints,
                       const int& trainer_id) {
        return HeterClient::GetInstance(endpoints, previous_endpoints,
                                        trainer_id);
      }))
T
tangwei12 已提交
178 179 180
      .def("stop", &HeterClient::Stop);
}

S
seemingwang 已提交
181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
void BindGraphNode(py::module* m) {
  py::class_<GraphNode>(*m, "GraphNode")
      .def(py::init<>())
      .def("get_id", &GraphNode::get_id)
      .def("get_feature", &GraphNode::get_feature);
}
void BindGraphPyFeatureNode(py::module* m) {
  py::class_<FeatureNode>(*m, "FeatureNode")
      .def(py::init<>())
      .def("get_id", &GraphNode::get_id)
      .def("get_feature", &GraphNode::get_feature);
}

void BindGraphPyService(py::module* m) {
  py::class_<GraphPyService>(*m, "GraphPyService").def(py::init<>());
}

void BindGraphPyServer(py::module* m) {
  py::class_<GraphPyServer>(*m, "GraphPyServer")
      .def(py::init<>())
      .def("start_server", &GraphPyServer::start_server)
      .def("set_up", &GraphPyServer::set_up)
      .def("add_table_feat_conf", &GraphPyServer::add_table_feat_conf);
}
void BindGraphPyClient(py::module* m) {
  py::class_<GraphPyClient>(*m, "GraphPyClient")
      .def(py::init<>())
      .def("load_edge_file", &GraphPyClient::load_edge_file)
      .def("load_node_file", &GraphPyClient::load_node_file)
      .def("set_up", &GraphPyClient::set_up)
      .def("add_table_feat_conf", &GraphPyClient::add_table_feat_conf)
      .def("pull_graph_list", &GraphPyClient::pull_graph_list)
      .def("start_client", &GraphPyClient::start_client)
214 215
      .def("batch_sample_neighboors", &GraphPyClient::batch_sample_neighbors)
      .def("batch_sample_neighbors", &GraphPyClient::batch_sample_neighbors)
216 217
      // .def("use_neighbors_sample_cache",
      //      &GraphPyClient::use_neighbors_sample_cache)
S
seemingwang 已提交
218
      .def("remove_graph_node", &GraphPyClient::remove_graph_node)
S
seemingwang 已提交
219
      .def("random_sample_nodes", &GraphPyClient::random_sample_nodes)
Z
zhaocaibei123 已提交
220
      .def("stop_server", &GraphPyClient::StopServer)
S
seemingwang 已提交
221 222
      .def("get_node_feat",
           [](GraphPyClient& self, std::string node_type,
223
              std::vector<int64_t> node_ids,
S
seemingwang 已提交
224 225 226 227 228 229 230 231 232 233 234
              std::vector<std::string> feature_names) {
             auto feats =
                 self.get_node_feat(node_type, node_ids, feature_names);
             std::vector<std::vector<py::bytes>> bytes_feats(feats.size());
             for (int i = 0; i < feats.size(); ++i) {
               for (int j = 0; j < feats[i].size(); ++j) {
                 bytes_feats[i].push_back(py::bytes(feats[i][j]));
               }
             }
             return bytes_feats;
           })
S
seemingwang 已提交
235 236
      .def("set_node_feat",
           [](GraphPyClient& self, std::string node_type,
237
              std::vector<int64_t> node_ids,
S
seemingwang 已提交
238 239 240 241 242 243 244 245 246 247 248
              std::vector<std::string> feature_names,
              std::vector<std::vector<py::bytes>> bytes_feats) {
             std::vector<std::vector<std::string>> feats(bytes_feats.size());
             for (int i = 0; i < bytes_feats.size(); ++i) {
               for (int j = 0; j < bytes_feats[i].size(); ++j) {
                 feats[i].push_back(std::string(bytes_feats[i][j]));
               }
             }
             self.set_node_feat(node_type, node_ids, feature_names, feats);
             return;
           })
S
seemingwang 已提交
249 250 251
      .def("bind_local_server", &GraphPyClient::bind_local_server);
}

1
123malin 已提交
252 253 254
using paddle::distributed::TreeIndex;
using paddle::distributed::IndexWrapper;
using paddle::distributed::IndexNode;
255 256 257 258
#ifdef PADDLE_WITH_HETERPS
using paddle::framework::GraphGpuWrapper;
using paddle::framework::NeighborSampleResult;
#endif
1
123malin 已提交
259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308

void BindIndexNode(py::module* m) {
  py::class_<IndexNode>(*m, "IndexNode")
      .def(py::init<>())
      .def("id", [](IndexNode& self) { return self.id(); })
      .def("is_leaf", [](IndexNode& self) { return self.is_leaf(); })
      .def("probability", [](IndexNode& self) { return self.probability(); });
}

void BindTreeIndex(py::module* m) {
  py::class_<TreeIndex, std::shared_ptr<TreeIndex>>(*m, "TreeIndex")
      .def(py::init([](const std::string name, const std::string path) {
        auto index_wrapper = IndexWrapper::GetInstancePtr();
        index_wrapper->insert_tree_index(name, path);
        return index_wrapper->get_tree_index(name);
      }))
      .def("height", [](TreeIndex& self) { return self.Height(); })
      .def("branch", [](TreeIndex& self) { return self.Branch(); })
      .def("total_node_nums",
           [](TreeIndex& self) { return self.TotalNodeNums(); })
      .def("emb_size", [](TreeIndex& self) { return self.EmbSize(); })
      .def("get_all_leafs", [](TreeIndex& self) { return self.GetAllLeafs(); })
      .def("get_nodes",
           [](TreeIndex& self, const std::vector<uint64_t>& codes) {
             return self.GetNodes(codes);
           })
      .def("get_layer_codes",
           [](TreeIndex& self, int level) { return self.GetLayerCodes(level); })
      .def("get_ancestor_codes",
           [](TreeIndex& self, const std::vector<uint64_t>& ids, int level) {
             return self.GetAncestorCodes(ids, level);
           })
      .def("get_children_codes",
           [](TreeIndex& self, uint64_t ancestor, int level) {
             return self.GetChildrenCodes(ancestor, level);
           })
      .def("get_travel_codes",
           [](TreeIndex& self, uint64_t id, int start_level) {
             return self.GetTravelCodes(id, start_level);
           });
}

void BindIndexWrapper(py::module* m) {
  py::class_<IndexWrapper, std::shared_ptr<IndexWrapper>>(*m, "IndexWrapper")
      .def(py::init([]() { return IndexWrapper::GetInstancePtr(); }))
      .def("insert_tree_index", &IndexWrapper::insert_tree_index)
      .def("get_tree_index", &IndexWrapper::get_tree_index)
      .def("clear_tree", &IndexWrapper::clear_tree);
}

309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331
#ifdef PADDLE_WITH_HETERPS
void BindNeighborSampleResult(py::module* m) {
  py::class_<NeighborSampleResult>(*m, "NeighborSampleResult")
      .def(py::init<>())
      .def("initialize", &NeighborSampleResult::initialize);
}

void BindGraphGpuWrapper(py::module* m) {
  py::class_<GraphGpuWrapper>(*m, "GraphGpuWrapper")
      .def(py::init<>())
      .def("test", &GraphGpuWrapper::test)
      .def("initialize", &GraphGpuWrapper::initialize)
      .def("graph_neighbor_sample", &GraphGpuWrapper::graph_neighbor_sample)
      .def("set_device", &GraphGpuWrapper::set_device)
      .def("init_service", &GraphGpuWrapper::init_service)
      .def("set_up_types", &GraphGpuWrapper::set_up_types)
      .def("add_table_feat_conf", &GraphGpuWrapper::add_table_feat_conf)
      .def("load_edge_file", &GraphGpuWrapper::load_edge_file)
      .def("upload_batch", &GraphGpuWrapper::upload_batch)
      .def("load_node_file", &GraphGpuWrapper::load_node_file);
}
#endif

1
123malin 已提交
332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348
using paddle::distributed::IndexSampler;
using paddle::distributed::LayerWiseSampler;

void BindIndexSampler(py::module* m) {
  py::class_<IndexSampler, std::shared_ptr<IndexSampler>>(*m, "IndexSampler")
      .def(py::init([](const std::string& mode, const std::string& name) {
        if (mode == "by_layerwise") {
          return IndexSampler::Init<LayerWiseSampler>(name);
        } else {
          PADDLE_THROW(platform::errors::InvalidArgument(
              "Unsupported IndexSampler Type!"));
        }
      }))
      .def("init_layerwise_conf", &IndexSampler::init_layerwise_conf)
      .def("init_beamsearch_conf", &IndexSampler::init_beamsearch_conf)
      .def("sample", &IndexSampler::sample);
}
T
tangwei12 已提交
349 350
}  // end namespace pybind
}  // namespace paddle