gloo_wrapper.cc 12.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
  http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#include "paddle/fluid/framework/io/fs.h"
14
#include "paddle/fluid/string/string_helper.h"
15

16 17 18 19 20 21
namespace gloo {
namespace transport {
class Device;
}  // namespace transport
}  // namespace gloo

22 23 24
namespace gloo {
namespace rendezvous {

25 26 27
class HTTPStore;
class Store;

28 29
constexpr int kNodeSize = 136;

30 31
HdfsStore::HdfsStore(const std::string& path) {
  path_ = path;
32
  wait_sleep_ms_ = 10000;
33
  wait_timeout_ = std::chrono::seconds(999999999);
X
xujiaqi01 已提交
34
  retry_times_ = 100;
35 36 37 38 39 40 41 42 43 44 45 46
}

void HdfsStore::set(const std::string& key, const std::vector<char>& data) {
#ifdef PADDLE_WITH_GLOO
  auto tmp = TmpPath(key);
  auto path = ObjectPath(key);
  bool is_exists = paddle::framework::fs_exists(path);
  if (is_exists) {
    LOG(WARNING) << "path exists, will be removed: " << path;
    paddle::framework::fs_remove(path);
  }
  int err_no = 0;
X
xujiaqi01 已提交
47
  for (int i = 1; i <= retry_times_; ++i) {
48
    err_no = 0;
X
xujiaqi01 已提交
49 50 51 52 53 54
    std::shared_ptr<FILE> fp =
        paddle::framework::fs_open_write(tmp, &err_no, "");
    size_t write_count = fwrite_unlocked(data.data(), 1, data.size(), fp.get());
    if (write_count != data.size()) {
      VLOG(0) << "fwrite_unlocked failed, retry times " << i << " write_count "
              << write_count << " data.size() " << data.size();
55
      err_no = -1;
X
xujiaqi01 已提交
56 57
    }
    fp.reset();
58 59 60 61 62 63 64
    if (err_no != 0) {
      VLOG(0) << "fs_open_write failed, retry times " << i << " err no "
              << err_no;
      sleep(wait_sleep_ms_ / 1000);
      paddle::framework::fs_remove(tmp);
      if (i == retry_times_) {
        VLOG(0) << "fs_open_write failed, retry times reaches limit";
65 66 67
        PADDLE_THROW(paddle::platform::errors::PreconditionNotMet(
            "fs_open_write failed, retry times reaches %d limit.",
            retry_times_));
68 69 70 71
      }
    } else {
      break;
    }
X
xujiaqi01 已提交
72
  }
73
  paddle::framework::fs_mv(tmp, path);
74 75 76 77 78 79 80 81 82 83 84 85
  auto start = std::chrono::steady_clock::now();
  while (paddle::framework::fs_exists(path) == false) {
    VLOG(0) << "HdfsStore::set fs_mv retrying...";
    paddle::framework::fs_mv(tmp, path);
    auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
        std::chrono::steady_clock::now() - start);
    if (wait_timeout_ != gloo::kNoTimeout && elapsed > wait_timeout_) {
      PADDLE_THROW(paddle::platform::errors::ExecutionTimeout(
          "fs_mv failed, tmp: %s, path: %s", tmp, path));
    }
    std::this_thread::sleep_for(std::chrono::milliseconds(wait_sleep_ms_));
  }
86 87 88
#endif
}

89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
#ifdef PADDLE_WITH_GLOO
int retry_do_func(std::function<int(void)> func, uint32_t max_try_time,
                  uint32_t retry_interval_ms) {
  for (uint32_t i = 0; i < max_try_time; ++i) {
    if (func() == 0) {
      return 0;
    }
#ifdef _LINUX
    usleep(retry_interval_ms * 1000);
#endif
  }
  return -1;
}
#endif

104 105 106 107 108 109
std::vector<char> HdfsStore::get(const std::string& key) {
  auto path = ObjectPath(key);
  std::vector<char> result;
#ifdef PADDLE_WITH_GLOO
  // block until key is set
  wait({key});
110 111 112 113
  int ret = retry_do_func(
      [&path]() { return paddle::framework::fs_exists(path) ? 0 : -1; }, 5,
      wait_sleep_ms_);
  bool is_exists = (ret == 0);
114 115 116
  PADDLE_ENFORCE_EQ(is_exists, true,
                    paddle::platform::errors::NotFound(
                        "HdfsStore::get, path not exists: " + path));
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138

  int read_status = retry_do_func(
      [&path, &result]() {
        result.clear();
        int err_no = 0;
        {
          std::shared_ptr<FILE> fp =
              paddle::framework::fs_open_read(path, &err_no, "");
          char buffer = '\0';
          size_t read_count = 0;
          while (fread(&buffer, 1, 1, fp.get()) == 1) {
            ++read_count;
            result.push_back(buffer);
          }
          VLOG(3) << "HdfsStore::get read_count " << read_count;
        }
        return err_no;
      },
      5, wait_sleep_ms_);
  PADDLE_ENFORCE_EQ(read_status, 0,
                    paddle::platform::errors::Fatal(
                        "HdfsStore::get, path read faied: " + path));
139 140 141 142 143 144 145 146 147 148 149 150 151 152
#endif
  return result;
}

void HdfsStore::wait(const std::vector<std::string>& keys) {
#ifdef PADDLE_WITH_GLOO
  wait(keys, wait_timeout_);  // NOLINT
#endif
}

void HdfsStore::wait(const std::vector<std::string>& keys,
                     const std::chrono::milliseconds&) {  // NOLINT
#ifdef PADDLE_WITH_GLOO
  auto start = std::chrono::steady_clock::now();
153 154
  std::vector<bool> check_key_status(keys.size(), false);
  while (!Check(keys, &check_key_status)) {
155
    VLOG(0) << "HdfsStore::wait checking repeatedly...";
156 157 158
    auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
        std::chrono::steady_clock::now() - start);
    if (wait_timeout_ != gloo::kNoTimeout && elapsed > wait_timeout_) {
159 160 161 162 163 164 165
      int32_t last_check_rank = -1;
      for (size_t i = 0; i < check_key_status.size(); ++i) {
        if (!check_key_status[i]) {
          last_check_rank = i;
          break;
        }
      }
166 167 168
      PADDLE_THROW(paddle::platform::errors::ExecutionTimeout(
          "TIMEOUT self_rank = %d pair_rank = %d", self_rank_,
          last_check_rank));
169 170 171 172 173 174
    }
    std::this_thread::sleep_for(std::chrono::milliseconds(wait_sleep_ms_));
  }
#endif
}

175 176 177 178
void HdfsStore::SetTimeoutSeconds(int timeout_seconds) {
  wait_timeout_ = std::chrono::seconds(timeout_seconds);
}

179
std::string HdfsStore::EncodeName(const std::string& name) {
180
  return ::paddle::string::erase_spaces(name);
181 182 183 184 185 186 187 188 189 190
}

std::string HdfsStore::TmpPath(const std::string& name) {
  return path_ + "/" + EncodeName(name) + "_tmp";
}

std::string HdfsStore::ObjectPath(const std::string& name) {
  return path_ + "/" + EncodeName(name);
}

191 192
bool HdfsStore::Check(const std::vector<std::string>& keys,
                      std::vector<bool>* keys_check_status) {
193
#ifdef PADDLE_WITH_GLOO
194
  bool ret = true;
195 196 197 198
  std::vector<std::string> paths;
  for (const auto& key : keys) {
    paths.push_back(ObjectPath(key));
  }
199 200 201 202 203
  for (size_t i = 0; i < paths.size(); ++i) {
    if ((*keys_check_status)[i]) {
      continue;
    }
    const auto& path = paths[i];
204 205 206
    bool is_exists = paddle::framework::fs_exists(path);
    VLOG(3) << "HdfsStore::Check " << is_exists << " path " << path;
    if (!is_exists) {
207
      ret = false;
208
    }
209
    (*keys_check_status)[i] = is_exists;
210
  }
211 212 213
  return ret;
#else
  VLOG(0) << "HdfsStore::Check does nothing when no gloo";
214 215 216 217
#endif
  return true;
}

218 219 220 221 222 223 224
#ifdef PADDLE_WITH_GLOO
void ParallelConnectContext::connectFullMesh(
    Store& store, std::shared_ptr<transport::Device>& dev) {
  std::vector<char> allBytes;
  // Create pairs
  auto transportContext = dev->createContext(rank, size);
  transportContext->setTimeout(getTimeout());
225 226
  VLOG(0) << "transportContext timeout: " << getTimeout().count()
          << ", curr rank: " << rank;
227 228 229 230 231 232 233 234 235 236 237 238
  for (int i = 0; i < size; i++) {
    if (i == rank) {
      continue;
    }
    auto& pair = transportContext->createPair(i);
    auto addrBytes = pair->address().bytes();
    allBytes.insert(allBytes.end(), addrBytes.begin(), addrBytes.end());
  }
  std::ostringstream storeKey;
  storeKey << rank;
  store.set(storeKey.str(), allBytes);

239 240
  auto total_add_size = kNodeSize * (size - 1);

241 242
  std::vector<std::shared_ptr<std::thread>> connect_threads(thread_num_);
  // Connect every pair
243
  VLOG(0) << "connect_thread_num: " << thread_num_ << ", size: " << size;
244 245
  for (uint32_t i = 0; i < connect_threads.size(); ++i) {
    connect_threads[i].reset(new std::thread(
246 247
        [&store, &transportContext, total_add_size, this](
            size_t thread_idx, size_t thread_num) -> void {
248 249 250 251 252 253 254
          for (int i = thread_idx; i < size; i += thread_num) {
            if (i == rank) {
              continue;
            }
            // Wait for address of other side of this pair to become available
            std::string key = std::to_string(i);
            store.wait({key}, getTimeout());
255 256

            std::vector<char> allAddrs;
257
            auto max_retry_times = 10;
258
            // Connect to other side of this pair
259 260 261 262 263 264 265 266 267

            while (max_retry_times > 0) {
              allAddrs = store.get(key);
              VLOG(3) << "store get all address size: " << allAddrs.size()
                      << " except: " << total_add_size;
              if (allAddrs.size() == static_cast<size_t>(total_add_size)) {
                break;
              }

268
              sleep(5);
269 270
              --max_retry_times;
            }
271
            auto addr = extractAddress(allAddrs, i);
272 273 274 275 276 277 278 279 280
            if (addr.empty()) {
              VLOG(0) << "peer address is null";
            }
            Impl impl_;
            memcpy(&impl_, addr.data(), sizeof(impl_));
            struct sockaddr_in* sa = (struct sockaddr_in*)&(impl_.ss);
            std::string ip = getCharIpAddr(sa->sin_addr.s_addr);
            VLOG(0) << "peer " << i << " ip addr: " << ip
                    << ", port: " << sa->sin_port;
281 282
            transportContext->getPair(i)->connect(addr);
          }
283
          VLOG(0) << "peer connected success";
284 285 286 287 288 289 290 291
        },
        i, connect_threads.size()));
  }
  for (uint32_t i = 0; i < connect_threads.size(); ++i) {
    connect_threads[i]->join();
  }
  device_ = dev;
  transportContext_ = std::move(transportContext);
292
  VLOG(0) << "ParallelConnectContext::connectFullMesh() is over";
293 294
}
#endif
295 296 297 298 299 300
}  // namespace rendezvous
}  // namespace gloo

namespace paddle {
namespace framework {

301
void GlooWrapper::Init() {
302 303 304 305 306
  if (is_initialized_) {
    return;
  }
#ifdef PADDLE_WITH_GLOO
  gloo::transport::tcp::attr attr;
307 308 309
  attr.iface = iface_;
  std::shared_ptr<gloo::rendezvous::HdfsStore> file_store = nullptr;
  std::shared_ptr<gloo::rendezvous::HTTPStore> http_store = nullptr;
310
  auto dev = gloo::transport::tcp::CreateDevice(attr);
311

312 313
  switch (store_type_) {
    case GlooStoreType::HDFS: {
314 315 316
      auto context = std::make_shared<gloo::rendezvous::ParallelConnectContext>(
          rank_, size_);
      context->setTimeout(run_timeout_);
317 318 319 320 321 322 323 324 325
      std::string cmd = std::string("${HADOOP_HOME}/bin/hadoop fs");
      cmd += " -D fs.default.name=" + hdfs_name_;
      cmd += " -D hadoop.job.ugi=" + hdfs_ugi_;
      paddle::framework::hdfs_set_command(cmd);
      file_store = std::make_shared<gloo::rendezvous::HdfsStore>(hdfs_path_);
      file_store->SetTimeoutSeconds(init_timeout_.count());
      auto prefix_store =
          std::make_shared<gloo::rendezvous::PrefixStore>(prefix_, *file_store);
      context->connectFullMesh(*prefix_store, dev);
326
      context_ = std::move(context);
327 328 329
      break;
    }
    case GlooStoreType::HTTP: {
330 331
      auto context = std::make_shared<gloo::rendezvous::Context>(rank_, size_);
      context->setTimeout(run_timeout_);
332 333 334 335 336
      http_store = std::make_shared<gloo::rendezvous::HTTPStore>(
          http_ip_, http_port_, prefix_ + "_" + http_scope_, rank_);
      http_store->SetTimeoutSeconds(init_timeout_.count());
      context->connectFullMesh(*http_store, dev);
      http_store->Finalize();
L
lilong12 已提交
337
      VLOG(3) << "after calling http_store->Finalize.";
338
      context_ = std::move(context);
339 340 341 342 343 344
      break;
    }
    default:
      LOG(ERROR) << "unknown store type " << store_type_;
      exit(-1);
  }
345 346
#endif
  is_initialized_ = true;
L
lilong12 已提交
347
  VLOG(3) << "gloo initialized done.";
348 349
}

X
xujiaqi01 已提交
350
template std::vector<int64_t> GlooWrapper::AllReduce<int64_t>(
351 352
    std::vector<int64_t>& sendbuf,  // NOLINT
    const std::string& mode);
353 354 355
template std::vector<float> GlooWrapper::AllReduce<float>(
    std::vector<float>& sendbuf,  // NOLINT
    const std::string& mode);
X
xujiaqi01 已提交
356
template std::vector<double> GlooWrapper::AllReduce<double>(
357
    std::vector<double>& sendbuf,  // NOLINT
X
xujiaqi01 已提交
358 359 360
    const std::string& mode);
template std::vector<uint64_t> GlooWrapper::AllReduce<uint64_t>(
    std::vector<uint64_t>& sendbuf,  // NOLINT
361 362 363
    const std::string& mode);
template std::vector<int64_t> GlooWrapper::AllGather<int64_t>(
    int64_t& input);  // NOLINT
X
xujiaqi01 已提交
364 365
template std::vector<uint64_t> GlooWrapper::AllGather<uint64_t>(
    uint64_t& input);  // NOLINT
366 367
template std::vector<float> GlooWrapper::AllGather<float>(
    float& input);  // NOLINT
368 369 370 371 372
template std::vector<double> GlooWrapper::AllGather<double>(
    double& input);  // NOLINT

}  // namespace framework
}  // namespace paddle