nccl_gpu_common.h 598 字节
Newer Older
D
dongzhihong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
#pragma once
#include <nccl.h>

#include "paddle/platform/device_context.h"

namespace paddle {
namespace platform {

class NCCLManager {
 public:
  static NCCLManager* Get() {
    static NCCLManager m;
    return &m;
  }

  NCCLManager() { _comms.resize(_gpu_worlds.size()); }
  ~NCCLManager() {}

 private:
  std::vector<ncclComm_t> _comms;
  std::vector<int> _gpu_worlds;
};

class NCCLContext : public DeviceContext {
 public:
  explicit NCCLContext(GPUPlace place);
  virtual ~NCCLContext();

 private:
  std::vector<int> _gpu_ids;
  std::vector<cudaStream_t> _streams;
  int root_gpu;
};
}
}