未验证 提交 c4f279fe 编写于 作者: T Thunderbrook 提交者: GitHub

support multi node in heterps (#31102)

* push multi node

* multi node

* MultiThread

* remove log

* solve bug in 30829
上级 ae2be49f
......@@ -29,6 +29,7 @@ limitations under the License. */
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "glog/logging.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace framework {
......
......@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/dynload/nccl.h"
#include "paddle/fluid/platform/place.h"
#include "thrust/pair.h"
......@@ -68,7 +69,30 @@ class HeterComm {
void push_sparse(int num, KeyType* d_keys, GradType* d_grads, size_t len,
Sgd& sgd);
template <typename Sgd>
void push_sparse_multi_node(int num, KeyType* d_keys, GradType* d_grads,
size_t len, Sgd& sgd);
template <typename Sgd>
void update_one_table(int num, KeyType* d_keys, GradType* d_grads, size_t len,
Sgd& sgd);
int gather_one_node_grad(int num, KeyType* d_keys, GradType* d_grads,
int len);
int gather_multi_node_grad(int num, KeyType* d_keys, GradType* d_grads,
int len);
int log2i(int x);
void set_nccl_comm_and_size(const std::vector<ncclComm_t>& inner_comms,
const std::vector<ncclComm_t>& inter_comms,
int comm_size) {
nccl_inner_comms_ = inner_comms;
nccl_inter_comms_ = inter_comms;
node_size_ = comm_size;
}
bool need_transfer(int send_id, int receive_id) {
return ((send_id / 4 != receive_id / 4) && (send_id + 4) % 8 != receive_id);
}
......@@ -94,6 +118,44 @@ class HeterComm {
std::vector<Node> nodes_;
};
struct LocalStorage {
LocalStorage() {}
void init(int size, int dev_id) {
place_ = platform::CUDAPlace(dev_id);
alloc(size, true);
}
void alloc(int size, bool force = false) {
if (force || size > all_keys_mem->size()) {
all_keys_mem.reset();
all_grads_mem.reset();
all_keys_mem = memory::AllocShared(place_, size * sizeof(KeyType));
all_grads_mem = memory::AllocShared(place_, size * sizeof(GradType));
all_keys = reinterpret_cast<KeyType*>(all_keys_mem->ptr());
all_grads = reinterpret_cast<GradType*>(all_grads_mem->ptr());
}
if (force || size > local_keys_mem->size()) {
local_keys_mem.reset();
local_grads_mem.reset();
local_keys_mem = memory::AllocShared(place_, size * sizeof(KeyType));
local_grads_mem = memory::AllocShared(place_, size * sizeof(GradType));
local_keys = reinterpret_cast<KeyType*>(local_keys_mem->ptr());
local_grads = reinterpret_cast<GradType*>(local_grads_mem->ptr());
}
}
platform::CUDAPlace place_;
std::shared_ptr<memory::Allocation> all_keys_mem;
std::shared_ptr<memory::Allocation> all_grads_mem;
KeyType* all_keys;
GradType* all_grads;
std::shared_ptr<memory::Allocation> local_keys_mem;
std::shared_ptr<memory::Allocation> local_grads_mem;
KeyType* local_keys;
GradType* local_grads;
};
void init_path();
void create_storage(
int start_index, int end_index, int keylen, int vallen,
......@@ -111,6 +173,12 @@ class HeterComm {
CustomGradMerger merger_;
int topo_aware_{1};
std::vector<std::vector<Path>> path_;
std::vector<LocalStorage> storage_;
int feanum_{1800 * 2048};
int multi_node_{1};
std::vector<ncclComm_t> nccl_inner_comms_;
std::vector<ncclComm_t> nccl_inter_comms_;
int node_size_;
};
} // end namespace framework
......
......@@ -95,10 +95,14 @@ template <typename KeyType, typename ValType, typename GradType>
HeterComm<KeyType, ValType, GradType>::HeterComm(
size_t capacity, std::shared_ptr<HeterPsResource> resource) {
resource_ = resource;
storage_.resize(resource_->total_gpu());
for (int i = 0; i < resource_->total_gpu(); ++i) {
platform::CUDADeviceGuard guard(resource_->dev_id(i));
auto table = new Table(capacity / load_factor_);
tables_.push_back(table);
if (multi_node_) {
storage_[i].init(feanum_, resource_->dev_id(i));
}
}
init_path();
}
......@@ -595,6 +599,186 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int gpu_num,
}
}
template <typename KeyType, typename ValType, typename GradType>
template <typename Sgd>
void HeterComm<KeyType, ValType, GradType>::update_one_table(
int gpu_num, KeyType* d_keys, GradType* d_grads, size_t len, Sgd& sgd) {
if (len == 0) {
return;
}
int dev_id = resource_->dev_id(gpu_num);
platform::CUDADeviceGuard guard(dev_id);
tables_[gpu_num]->update(d_keys, d_grads, len, sgd,
resource_->remote_stream(gpu_num));
cudaStreamSynchronize(resource_->remote_stream(gpu_num));
}
template <typename KeyType, typename ValType, typename GradType>
template <typename Sgd>
void HeterComm<KeyType, ValType, GradType>::push_sparse_multi_node(
int gpu_num, KeyType* d_keys, GradType* d_grads, size_t len, Sgd& sgd) {
if (len == 0) {
return;
}
int uniq_len = len;
merge_grad(gpu_num, d_keys, d_grads, len, uniq_len);
uniq_len = gather_one_node_grad(gpu_num, d_keys, d_grads, uniq_len);
uniq_len = gather_multi_node_grad(gpu_num, storage_[gpu_num].local_keys,
storage_[gpu_num].local_grads, uniq_len);
update_one_table(gpu_num, storage_[gpu_num].local_keys,
storage_[gpu_num].local_grads, uniq_len, sgd);
}
template <typename KeyType, typename ValType, typename GradType>
int HeterComm<KeyType, ValType, GradType>::gather_one_node_grad(
int gpu_num, KeyType* d_keys, GradType* d_grads, int len) {
int total_gpu = resource_->total_gpu();
int dev_id = resource_->dev_id(gpu_num);
auto& storage = storage_[gpu_num];
platform::CUDAPlace place = platform::CUDAPlace(dev_id);
platform::CUDADeviceGuard guard(dev_id);
auto stream = resource_->local_stream(gpu_num, 0);
int max_size = 0;
ncclComm_t nccl_inner_comm = nccl_inner_comms_[gpu_num];
// alloc for size
int h_node_len[total_gpu];
auto d_node_len_mem = memory::AllocShared(place, total_gpu * sizeof(int));
int* d_node_len = reinterpret_cast<int*>(d_node_len_mem->ptr());
h_node_len[gpu_num] = len;
cudaMemcpy(d_node_len + gpu_num, h_node_len + gpu_num, sizeof(int),
cudaMemcpyHostToDevice);
// allgather grad len
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupStart());
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather(
(const void*)(d_node_len + gpu_num), (void*)d_node_len, 1, ncclInt,
nccl_inner_comm, stream));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd());
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
cudaMemcpy(h_node_len, d_node_len, sizeof(int) * total_gpu,
cudaMemcpyDeviceToHost);
for (int i = 0; i < total_gpu; ++i) {
if (h_node_len[i] > max_size) {
max_size = h_node_len[i];
}
}
storage.alloc(max_size * total_gpu);
// allgather keys and grads
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupStart());
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather(
d_keys, storage.all_keys, max_size, ncclUint64, nccl_inner_comm, stream));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather(
d_grads, storage.all_grads, max_size * sizeof(GradType), ncclUint8,
nccl_inner_comm, stream));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd());
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
int h_left[total_gpu];
int h_right[total_gpu];
auto d_left = memory::AllocShared(place, total_gpu * sizeof(int));
auto d_right = memory::AllocShared(place, total_gpu * sizeof(int));
int* d_left_ptr = reinterpret_cast<int*>(d_left->ptr());
int* d_right_ptr = reinterpret_cast<int*>(d_right->ptr());
int merge_num = 0;
for (int i = 0; i < total_gpu; ++i) {
int index = i * max_size;
auto d_idx = memory::AllocShared(place, h_node_len[i] * sizeof(int));
int* d_idx_ptr = reinterpret_cast<int*>(d_idx->ptr());
cudaMemset(d_left_ptr, -1, total_gpu * sizeof(int));
cudaMemset(d_right_ptr, -1, total_gpu * sizeof(int));
split_input_to_shard(storage.all_keys + index, d_idx_ptr, h_node_len[i],
d_left_ptr, d_right_ptr, gpu_num);
cudaMemcpy(h_left, d_left_ptr, total_gpu * sizeof(int),
cudaMemcpyDeviceToHost);
cudaMemcpy(h_right, d_right_ptr, total_gpu * sizeof(int),
cudaMemcpyDeviceToHost);
int grid_size = (h_node_len[i] - 1) / block_size_ + 1;
fill_shard_grads<<<grid_size, block_size_, 0, stream>>>(
storage.local_keys + merge_num, storage.all_keys + index,
storage.local_grads + merge_num, storage.all_grads + index,
d_idx_ptr + h_left[gpu_num], h_right[gpu_num] - h_left[gpu_num] + 1);
merge_num = merge_num + h_right[gpu_num] - h_left[gpu_num] + 1;
}
int ret = merge_num;
merge_grad(gpu_num, storage.local_keys, storage.local_grads, merge_num, ret);
return ret;
}
template <typename KeyType, typename ValType, typename GradType>
int HeterComm<KeyType, ValType, GradType>::gather_multi_node_grad(
int gpu_num, KeyType* d_keys, GradType* d_grads, int len) {
int dev_id = resource_->dev_id(gpu_num);
auto& storage = storage_[gpu_num];
platform::CUDAPlace place = platform::CUDAPlace(dev_id);
platform::CUDADeviceGuard guard(dev_id);
auto stream = resource_->local_stream(gpu_num, 0);
int max_size = 0;
ncclComm_t nccl_inter_comm = nccl_inter_comms_[gpu_num];
// alloc for size
int h_node_len[node_size_];
auto d_node_len_mem = memory::AllocShared(place, node_size_ * sizeof(int));
int* d_node_len = reinterpret_cast<int*>(d_node_len_mem->ptr());
h_node_len[0] = len;
cudaMemcpy(d_node_len, h_node_len, sizeof(int), cudaMemcpyHostToDevice);
// allgather grad len
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupStart());
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather(
d_node_len, d_node_len, 1, ncclInt, nccl_inter_comm, stream));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd());
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
cudaMemcpy(h_node_len, d_node_len, sizeof(int) * node_size_,
cudaMemcpyDeviceToHost);
for (int i = 0; i < node_size_; ++i) {
if (h_node_len[i] > max_size) {
max_size = h_node_len[i];
}
}
storage.alloc(max_size * node_size_);
// allgather keys and grads
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupStart());
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather(
d_keys, storage.all_keys, max_size, ncclUint64, nccl_inter_comm, stream));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather(
d_grads, storage.all_grads, max_size * sizeof(GradType), ncclUint8,
nccl_inter_comm, stream));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd());
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
int merge_num = 0;
for (int i = 0; i < node_size_; ++i) {
int index = i * max_size;
cudaMemcpyAsync(storage.local_keys + merge_num, storage.all_keys + index,
h_node_len[i], cudaMemcpyDefault, stream);
cudaMemcpyAsync(storage.local_grads + merge_num, storage.all_grads + index,
h_node_len[i], cudaMemcpyDefault, stream);
merge_num += h_node_len[i];
}
int ret = merge_num;
merge_grad(gpu_num, storage.local_keys, storage.local_grads, merge_num, ret);
return ret;
}
template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::end_pass() {
int total_gpu = resource_->total_gpu();
......
......@@ -54,7 +54,14 @@ void HeterPs::show_one_table(int gpu_num) { comm_->show_one_table(gpu_num); }
void HeterPs::push_sparse(int num, FeatureKey* d_keys,
FeaturePushValue* d_grads, size_t len) {
comm_->push_sparse(num, d_keys, d_grads, len, opt_);
// comm_->push_sparse(num, d_keys, d_grads, len, opt_);
comm_->push_sparse_multi_node(num, d_keys, d_grads, len, opt_);
}
void HeterPs::set_nccl_comm_and_size(const std::vector<ncclComm_t>& inner_comms,
const std::vector<ncclComm_t>& inter_comms,
int comm_size) {
comm_->set_nccl_comm_and_size(inner_comms, inter_comms, comm_size);
}
} // end namespace framework
......
......@@ -35,6 +35,9 @@ class HeterPs : public HeterPsBase {
size_t len) override;
virtual void build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals,
size_t len, size_t chunk_size, int stream_num) override;
virtual void set_nccl_comm_and_size(
const std::vector<ncclComm_t>& inner_comms,
const std::vector<ncclComm_t>& inter_comms, int comm_size) override;
virtual void end_pass() override;
virtual int get_index_by_devid(int devid) override;
virtual void show_one_table(int gpu_num) override;
......
......@@ -35,6 +35,9 @@ class HeterPsBase {
virtual void build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals,
size_t len, size_t chunk_size, int stream_num) = 0;
virtual int get_index_by_devid(int devid) = 0;
virtual void set_nccl_comm_and_size(
const std::vector<ncclComm_t>& inner_comms,
const std::vector<ncclComm_t>& inter_comms, int comm_size) = 0;
virtual void end_pass() = 0;
virtual void show_one_table(int gpu_num) = 0;
virtual void push_sparse(int num, FeatureKey* d_keys,
......
......@@ -28,6 +28,7 @@ limitations under the License. */
#include "paddle/fluid/framework/fleet/heter_wrapper.h"
#ifdef PADDLE_WITH_PSLIB
#include "paddle/fluid/framework/device_worker.h"
namespace paddle {
namespace framework {
......
......@@ -233,6 +233,7 @@ void PSGPUWrapper::BuildGPUPS(uint64_t table_id, int feature_dim) {
}
std::vector<std::thread> threads(device_num);
HeterPs_ = HeterPsBase::get_instance(size_max, resource_);
HeterPs_->set_nccl_comm_and_size(inner_comms_, inter_comms_, node_size_);
auto build_func = [this, &gpu_task, &feature_keys_count](int i) {
std::cout << "building table: " << i << std::endl;
this->HeterPs_->build_ps(i, gpu_task->device_keys_[i].data(),
......
......@@ -27,6 +27,10 @@ limitations under the License. */
#include <unordered_set>
#include <vector>
#ifdef PADDLE_WITH_GLOO
#include <gloo/broadcast.h>
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/fleet/heter_context.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h"
......@@ -34,6 +38,7 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/dynload/nccl.h"
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
#include "paddle/fluid/platform/place.h"
......@@ -80,11 +85,48 @@ class PSGPUWrapper {
void BuildTask(std::shared_ptr<HeterContext> gpu_task, uint64_t table_id,
int feature_dim);
void InitializeGPU(const std::vector<int>& dev_ids) {
if (s_instance_ != NULL) {
if (s_instance_ != NULL && is_initialized_ == false) {
VLOG(3) << "PSGPUWrapper Begin InitializeGPU";
is_initialized_ = true;
resource_ = std::make_shared<HeterPsResource>(dev_ids);
resource_->enable_p2p();
keys_tensor.resize(resource_->total_gpu());
if (multi_node_) {
int dev_size = dev_ids.size();
// init inner comm
inner_comms_.resize(dev_size);
inter_ncclids_.resize(dev_size);
platform::dynload::ncclCommInitAll(&(inner_comms_[0]), dev_size,
&dev_ids[0]);
// init inter comm
#ifdef PADDLE_WITH_GLOO
inter_comms_.resize(dev_size);
auto gloo = paddle::framework::GlooWrapper::GetInstance();
if (gloo->Rank() == 0) {
for (int i = 0; i < dev_size; ++i) {
platform::dynload::ncclGetUniqueId(&inter_ncclids_[i]);
}
}
PADDLE_ENFORCE_EQ(
gloo->IsInitialized(), true,
platform::errors::PreconditionNotMet(
"You must initialize the gloo environment first to use it."));
gloo::BroadcastOptions opts(gloo->GetContext());
opts.setOutput(&inter_ncclids_[0], dev_size);
opts.setRoot(0);
gloo::broadcast(opts);
for (int i = 0; i < dev_size; ++i) {
platform::dynload::ncclCommInitRank(&inter_comms_[i], gloo->Size(),
inter_ncclids_[i], gloo->Rank());
}
node_size_ = gloo->Size();
#else
PADDLE_THROW(
platform::errors::Unavailable("heter ps need compile with GLOO"));
#endif
}
heter_devices_ = dev_ids;
}
}
......@@ -177,6 +219,11 @@ class PSGPUWrapper {
std::shared_ptr<HeterPsResource> resource_;
int32_t sleep_seconds_before_fail_exit_;
std::vector<int> slot_vector_;
int multi_node_{1};
int node_size_;
std::vector<ncclComm_t> inner_comms_;
std::vector<ncclComm_t> inter_comms_;
std::vector<ncclUniqueId> inter_ncclids_;
std::vector<int> heter_devices_;
std::unordered_set<std::string> gpu_ps_config_keys_;
HeterObjectPool<HeterContext> gpu_task_pool_;
......
......@@ -12,6 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cstdlib>
#include <string>
#include <vector>
#include "io/fs.h"
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/trainer.h"
#if (defined PADDLE_WITH_CUDA || defined PADDLE_WITH_XPU) && \
(defined PADDLE_WITH_PSLIB)
......
......@@ -12,6 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/fleet/heter_wrapper.h"
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/string/string_helper.h"
#if (defined PADDLE_WITH_CUDA || defined PADDLE_WITH_XPU) && \
(defined PADDLE_WITH_PSLIB)
#include "paddle/fluid/platform/cuda_device_guard.h"
......
......@@ -12,6 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/fleet/heter_wrapper.h"
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/string/string_helper.h"
#ifdef PADDLE_WITH_PSLIB
#if defined _WIN32 || defined __APPLE__
......
......@@ -12,6 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cstdlib>
#include <ctime>
#include <string>
#include <vector>
#include "io/fs.h"
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/trainer.h"
#if (defined PADDLE_WITH_CUDA || defined PADDLE_WITH_XPU) && \
(defined PADDLE_WITH_PSLIB)
#ifdef PADDLE_WITH_CUDA
......
......@@ -599,6 +599,7 @@ class GeneralRoleMaker(RoleMakerBase):
self._init_timeout_seconds = kwargs.get("init_timeout_seconds", 3600)
self._run_timeout_seconds = kwargs.get("run_timeout_seconds", 9999999)
ip_port = kwargs.get("http_ip_port", "")
self._use_ps_gpu = kwargs.get("use_ps_gpu", False)
self._http_ip_port = []
self._http_server = None
# if ip_port is not empty, it will use http instead of hdfs
......@@ -666,6 +667,18 @@ class GeneralRoleMaker(RoleMakerBase):
self._hdfs_name, self._hdfs_ugi)
gloo.init()
self._node_type_comm = gloo
if self._use_ps_gpu:
Gloo_strategy = fluid.core.GlooParallelStrategy()
Gloo_strategy.rank = current_id
Gloo_strategy.rank_num = len(worker_endpoints)
Gloo_strategy.ip_address = self._http_ip_port[0]
Gloo_strategy.ip_port = int(self._http_ip_port[1])
Default_init_timeout_seconds = 3600
Default_run_timeout_seconds = 9999999
Gloo_strategy.init_seconds = Default_init_timeout_seconds
Gloo_strategy.run_seconds = Default_run_timeout_seconds
Gloo = fluid.core.GlooParallelContext(Gloo_strategy)
Gloo.init()
else:
self._all_comm = MockBarrier()
elif training_role == "PSERVER":
......
......@@ -386,3 +386,27 @@ class SingleProcessMultiThread(GradAllReduce):
def _transpile_startup_program(self):
block = self.startup_program.global_block()
block.append_op(type='c_comm_init_all', attrs={'ring_id': 0})
class MultiThread(GradAllReduce):
'''
'''
def __init__(self, nrings=1):
GradAllReduce.__init__(self, nrings)
self.mode = "box"
def _transpile_startup_program(self):
if len(self.endpoints) > 1:
print("begin to _transpile_startup_program for multi-node")
print("current_endpoint: ", self.current_endpoint)
print("total endpoints: ", self.endpoints)
print("rank: %d, ring_id: %d" % (self.rank, self.nrings))
for ring_id in range(self.nrings):
self._init_communicator(
self.startup_program, self.current_endpoint, self.endpoints,
self.rank, ring_id, self.wait_port, True)
else:
print("begin to _transpile_startup_program for single-node")
block = self.startup_program.global_block()
block.append_op(type='c_comm_init_all', attrs={'ring_id': 0})
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册