server.cc 3.9 KB
Newer Older
T
tangwei12 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/fluid/distributed/ps/service/server.h"
T
tangwei12 已提交
16

T
tangwei12 已提交
17
#include "glog/logging.h"
18 19 20 21
#include "paddle/fluid/distributed/ps/service/brpc_ps_server.h"
#include "paddle/fluid/distributed/ps/service/graph_brpc_server.h"
#include "paddle/fluid/distributed/ps/service/ps_local_server.h"
#include "paddle/fluid/distributed/ps/table/table.h"
T
tangwei12 已提交
22 23 24 25

namespace paddle {
namespace distributed {

T
tangwei12 已提交
26
REGISTER_PSCORE_CLASS(PSServer, BrpcPsServer);
T
Thunderbrook 已提交
27
REGISTER_PSCORE_CLASS(PSServer, PsLocalServer);
T
tangwei12 已提交
28
REGISTER_PSCORE_CLASS(PsBaseService, BrpcPsService);
S
seemingwang 已提交
29 30
REGISTER_PSCORE_CLASS(PSServer, GraphBrpcServer);
REGISTER_PSCORE_CLASS(PsBaseService, GraphBrpcService);
T
tangwei12 已提交
31

Z
zhaocaibei123 已提交
32
PSServer *PSServerFactory::Create(const PSParameter &ps_config) {
T
tangwei12 已提交
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
  const auto &config = ps_config.server_param();

  if (!config.has_downpour_server_param()) {
    LOG(ERROR) << "miss downpour_server_param in ServerParameter";
    return NULL;
  }

  if (!config.downpour_server_param().has_service_param()) {
    LOG(ERROR) << "miss service_param in ServerParameter.downpour_server_param";
    return NULL;
  }

  if (!config.downpour_server_param().service_param().has_server_class()) {
    LOG(ERROR) << "miss server_class in "
                  "ServerParameter.downpour_server_param.service_param";
    return NULL;
  }

  const auto &service_param = config.downpour_server_param().service_param();
T
tangwei12 已提交
52 53
  PSServer *server =
      CREATE_PSCORE_CLASS(PSServer, service_param.server_class());
T
tangwei12 已提交
54 55 56 57 58
  if (server == NULL) {
    LOG(ERROR) << "server is not registered, server_name:"
               << service_param.server_class();
    return NULL;
  }
Z
zhaocaibei123 已提交
59
  TableManager::Instance().Initialize();
T
tangwei12 已提交
60 61 62
  return server;
}

Z
zhaocaibei123 已提交
63
int32_t PSServer::Configure(
64 65 66
    const PSParameter &config, PSEnvironment &env, size_t server_rank,
    const std::vector<framework::ProgramDesc> &server_sub_program) {
  scope_.reset(new framework::Scope());
T
tangwei12 已提交
67 68 69
  _config = config.server_param();
  _rank = server_rank;
  _environment = &env;
Z
zhaocaibei123 已提交
70 71
  _shuffled_ins =
      paddle::framework::MakeChannel<std::pair<uint64_t, std::string>>();
Z
zhaocaibei123 已提交
72
  size_t shard_num = env.GetPsServers().size();
T
tangwei12 已提交
73

T
tangwei12 已提交
74 75 76
  const auto &downpour_param = _config.downpour_server_param();

  uint32_t barrier_table = UINT32_MAX;
77
  uint32_t global_step_table = UINT32_MAX;
T
tangwei12 已提交
78 79

  for (size_t i = 0; i < downpour_param.downpour_table_param_size(); ++i) {
T
tangwei12 已提交
80
    auto *table = CREATE_PSCORE_CLASS(
T
tangwei12 已提交
81 82 83 84 85 86
        Table, downpour_param.downpour_table_param(i).table_class());

    if (downpour_param.downpour_table_param(i).table_class() ==
        "BarrierTable") {
      barrier_table = downpour_param.downpour_table_param(i).table_id();
    }
87 88 89 90 91
    if (downpour_param.downpour_table_param(i).table_class() ==
        "GlobalStepTable") {
      global_step_table = downpour_param.downpour_table_param(i).table_id();
    }

Z
zhaocaibei123 已提交
92 93 94
    table->SetProgramEnv(scope_.get(), place_, &server_sub_program);
    table->SetShard(_rank, shard_num);
    table->Initialize(downpour_param.downpour_table_param(i),
T
tangwei12 已提交
95 96 97 98 99
                      config.fs_client_param());
    _table_map[downpour_param.downpour_table_param(i).table_id()].reset(table);
  }

  if (barrier_table != UINT32_MAX) {
Z
zhaocaibei123 已提交
100
    _table_map[barrier_table]->SetTableMap(&_table_map);
T
tangwei12 已提交
101
  }
102
  if (global_step_table != UINT32_MAX) {
Z
zhaocaibei123 已提交
103
    _table_map[global_step_table]->SetTableMap(&_table_map);
104
  }
T
tangwei12 已提交
105

Z
zhaocaibei123 已提交
106
  return Initialize();
T
tangwei12 已提交
107 108 109
}
}  // namespace distributed
}  // namespace paddle