server.cc 3.9 KB
Newer Older
T
tangwei12 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/distributed/service/server.h"
T
tangwei12 已提交
16

T
tangwei12 已提交
17 18
#include "glog/logging.h"
#include "paddle/fluid/distributed/service/brpc_ps_server.h"
S
seemingwang 已提交
19
#include "paddle/fluid/distributed/service/graph_brpc_server.h"
T
Thunderbrook 已提交
20
#include "paddle/fluid/distributed/service/ps_local_server.h"
T
tangwei12 已提交
21 22 23 24 25
#include "paddle/fluid/distributed/table/table.h"

namespace paddle {
namespace distributed {

T
tangwei12 已提交
26
REGISTER_PSCORE_CLASS(PSServer, BrpcPsServer);
T
Thunderbrook 已提交
27
REGISTER_PSCORE_CLASS(PSServer, PsLocalServer);
T
tangwei12 已提交
28
REGISTER_PSCORE_CLASS(PsBaseService, BrpcPsService);
S
seemingwang 已提交
29 30
REGISTER_PSCORE_CLASS(PSServer, GraphBrpcServer);
REGISTER_PSCORE_CLASS(PsBaseService, GraphBrpcService);
T
tangwei12 已提交
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51

PSServer *PSServerFactory::create(const PSParameter &ps_config) {
  const auto &config = ps_config.server_param();

  if (!config.has_downpour_server_param()) {
    LOG(ERROR) << "miss downpour_server_param in ServerParameter";
    return NULL;
  }

  if (!config.downpour_server_param().has_service_param()) {
    LOG(ERROR) << "miss service_param in ServerParameter.downpour_server_param";
    return NULL;
  }

  if (!config.downpour_server_param().service_param().has_server_class()) {
    LOG(ERROR) << "miss server_class in "
                  "ServerParameter.downpour_server_param.service_param";
    return NULL;
  }

  const auto &service_param = config.downpour_server_param().service_param();
T
tangwei12 已提交
52 53
  PSServer *server =
      CREATE_PSCORE_CLASS(PSServer, service_param.server_class());
T
tangwei12 已提交
54 55 56 57 58 59 60 61 62
  if (server == NULL) {
    LOG(ERROR) << "server is not registered, server_name:"
               << service_param.server_class();
    return NULL;
  }
  TableManager::instance().initialize();
  return server;
}

63 64 65 66
int32_t PSServer::configure(
    const PSParameter &config, PSEnvironment &env, size_t server_rank,
    const std::vector<framework::ProgramDesc> &server_sub_program) {
  scope_.reset(new framework::Scope());
T
tangwei12 已提交
67 68 69 70 71
  _config = config.server_param();
  _rank = server_rank;
  _environment = &env;
  _shuffled_ins =
      paddle::framework::MakeChannel<std::pair<uint64_t, std::string>>();
T
tangwei12 已提交
72 73
  size_t shard_num = env.get_ps_servers().size();

T
tangwei12 已提交
74 75 76
  const auto &downpour_param = _config.downpour_server_param();

  uint32_t barrier_table = UINT32_MAX;
77
  uint32_t global_step_table = UINT32_MAX;
T
tangwei12 已提交
78 79

  for (size_t i = 0; i < downpour_param.downpour_table_param_size(); ++i) {
T
tangwei12 已提交
80
    auto *table = CREATE_PSCORE_CLASS(
T
tangwei12 已提交
81 82 83 84 85 86
        Table, downpour_param.downpour_table_param(i).table_class());

    if (downpour_param.downpour_table_param(i).table_class() ==
        "BarrierTable") {
      barrier_table = downpour_param.downpour_table_param(i).table_id();
    }
87 88 89 90 91 92
    if (downpour_param.downpour_table_param(i).table_class() ==
        "GlobalStepTable") {
      global_step_table = downpour_param.downpour_table_param(i).table_id();
    }

    table->set_program_env(scope_.get(), place_, &server_sub_program);
T
tangwei12 已提交
93
    table->set_shard(_rank, shard_num);
T
tangwei12 已提交
94 95 96 97 98 99 100 101
    table->initialize(downpour_param.downpour_table_param(i),
                      config.fs_client_param());
    _table_map[downpour_param.downpour_table_param(i).table_id()].reset(table);
  }

  if (barrier_table != UINT32_MAX) {
    _table_map[barrier_table]->set_table_map(&_table_map);
  }
102 103 104
  if (global_step_table != UINT32_MAX) {
    _table_map[global_step_table]->set_table_map(&_table_map);
  }
T
tangwei12 已提交
105 106 107 108 109

  return initialize();
}
}  // namespace distributed
}  // namespace paddle