server.cc 3.5 KB
Newer Older
T
tangwei12 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/distributed/service/server.h"
#include "glog/logging.h"
#include "paddle/fluid/distributed/service/brpc_ps_server.h"
#include "paddle/fluid/distributed/table/table.h"

namespace paddle {
namespace distributed {

REGISTER_CLASS(PSServer, BrpcPsServer);
REGISTER_CLASS(PsBaseService, PsService);

PSServer *PSServerFactory::create(const PSParameter &ps_config) {
  const auto &config = ps_config.server_param();

  if (!config.has_downpour_server_param()) {
    LOG(ERROR) << "miss downpour_server_param in ServerParameter";
    return NULL;
  }

  if (!config.downpour_server_param().has_service_param()) {
    LOG(ERROR) << "miss service_param in ServerParameter.downpour_server_param";
    return NULL;
  }

  if (!config.downpour_server_param().service_param().has_server_class()) {
    LOG(ERROR) << "miss server_class in "
                  "ServerParameter.downpour_server_param.service_param";
    return NULL;
  }

  const auto &service_param = config.downpour_server_param().service_param();
  PSServer *server = CREATE_CLASS(PSServer, service_param.server_class());
  if (server == NULL) {
    LOG(ERROR) << "server is not registered, server_name:"
               << service_param.server_class();
    return NULL;
  }
  TableManager::instance().initialize();
  return server;
}

56 57 58 59
int32_t PSServer::configure(
    const PSParameter &config, PSEnvironment &env, size_t server_rank,
    const std::vector<framework::ProgramDesc> &server_sub_program) {
  scope_.reset(new framework::Scope());
T
tangwei12 已提交
60 61 62 63 64
  _config = config.server_param();
  _rank = server_rank;
  _environment = &env;
  _shuffled_ins =
      paddle::framework::MakeChannel<std::pair<uint64_t, std::string>>();
T
tangwei12 已提交
65 66
  size_t shard_num = env.get_ps_servers().size();

T
tangwei12 已提交
67 68 69
  const auto &downpour_param = _config.downpour_server_param();

  uint32_t barrier_table = UINT32_MAX;
70
  uint32_t global_step_table = UINT32_MAX;
T
tangwei12 已提交
71 72 73 74 75 76 77 78 79

  for (size_t i = 0; i < downpour_param.downpour_table_param_size(); ++i) {
    auto *table = CREATE_CLASS(
        Table, downpour_param.downpour_table_param(i).table_class());

    if (downpour_param.downpour_table_param(i).table_class() ==
        "BarrierTable") {
      barrier_table = downpour_param.downpour_table_param(i).table_id();
    }
80 81 82 83 84 85
    if (downpour_param.downpour_table_param(i).table_class() ==
        "GlobalStepTable") {
      global_step_table = downpour_param.downpour_table_param(i).table_id();
    }

    table->set_program_env(scope_.get(), place_, &server_sub_program);
T
tangwei12 已提交
86
    table->set_shard(_rank, shard_num);
T
tangwei12 已提交
87 88 89 90 91 92 93 94
    table->initialize(downpour_param.downpour_table_param(i),
                      config.fs_client_param());
    _table_map[downpour_param.downpour_table_param(i).table_id()].reset(table);
  }

  if (barrier_table != UINT32_MAX) {
    _table_map[barrier_table]->set_table_map(&_table_map);
  }
95 96 97
  if (global_step_table != UINT32_MAX) {
    _table_map[global_step_table]->set_table_map(&_table_map);
  }
T
tangwei12 已提交
98 99 100 101 102

  return initialize();
}
}  // namespace distributed
}  // namespace paddle