未验证 提交 4a8b8b45 编写于 作者: L liuyuhui 提交者: GitHub

[Kunlun] add gen_bkcl_id_op, support multi XPU cards training using multiprocess (#30858)

上级 39f41cb4
......@@ -11,7 +11,7 @@ foreach(src ${OPS})
set_source_files_properties(${src} PROPERTIES COMPILE_FLAGS ${COLLECTIVE_COMPILE_FLAGS})
endforeach()
register_operators(EXCLUDES c_gen_nccl_id_op gen_nccl_id_op DEPS ${COLLECTIVE_DEPS})
register_operators(EXCLUDES c_gen_bkcl_id_op gen_bkcl_id_op c_gen_nccl_id_op gen_nccl_id_op DEPS ${COLLECTIVE_DEPS})
if(WITH_NCCL)
set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} nccl_common collective_helper)
......@@ -19,13 +19,15 @@ if(WITH_NCCL)
op_library(gen_nccl_id_op DEPS ${COLLECTIVE_DEPS})
endif()
if(WITH_XPU_BKCL)
set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} collective_helper)
endif()
if(WITH_GLOO)
set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} gloo_wrapper)
endif()
if(WITH_XPU_BKCL)
set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} collective_helper)
op_library(c_gen_bkcl_id_op DEPS ${COLLECTIVE_DEPS})
op_library(gen_bkcl_id_op DEPS ${COLLECTIVE_DEPS})
endif()
set(OPERATOR_DEPS ${OPERATOR_DEPS} ${COLLECTIVE_DEPS} PARENT_SCOPE)
set(GLOB_COLLECTIVE_DEPS ${COLLECTIVE_DEPS} CACHE INTERNAL "collective dependency")
......@@ -14,6 +14,9 @@ limitations under the License. */
#if defined(PADDLE_WITH_NCCL)
#include <nccl.h>
#endif
#if defined(PADDLE_WITH_XPU_BKCL)
#include "xpu/bkcl.h"
#endif
#include <string>
#include "paddle/fluid/framework/op_registry.h"
......@@ -23,7 +26,7 @@ namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
#if defined(PADDLE_WITH_NCCL)
#if (defined PADDLE_WITH_NCCL) || (defined PADDLE_WITH_XPU_BKCL)
#include "paddle/fluid/platform/collective_helper.h"
#endif
......@@ -39,29 +42,56 @@ class CCommInitOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
PADDLE_ENFORCE_EQ(is_gpu_place(place), true,
PADDLE_ENFORCE_EQ(is_gpu_place(place) || is_xpu_place(place), true,
platform::errors::PreconditionNotMet(
"CCommInitOp can run on gpu place only."));
"CCommInitOp can run on gpu or xpu place only."));
auto var = scope.FindVar(Input("X"));
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::InvalidArgument("Input con not be empty."));
if (is_gpu_place(place)) {
#if defined(PADDLE_WITH_NCCL)
ncclUniqueId* nccl_id = var->GetMutable<ncclUniqueId>();
int nranks = Attr<int>("nranks");
int rank_id = Attr<int>("rank");
int rid = Attr<int>("ring_id");
int device_id = BOOST_GET_CONST(platform::CUDAPlace, place).device;
if (Attr<int>("device_id") >= 0) {
device_id = Attr<int>("device_id");
}
platform::NCCLCommContext::Instance().CreateNCCLComm(
nccl_id, nranks, rank_id, device_id, rid);
ncclUniqueId* nccl_id = var->GetMutable<ncclUniqueId>();
int nranks = Attr<int>("nranks");
int rank_id = Attr<int>("rank");
int rid = Attr<int>("ring_id");
int device_id = BOOST_GET_CONST(platform::CUDAPlace, place).device;
if (Attr<int>("device_id") >= 0) {
device_id = Attr<int>("device_id");
}
platform::NCCLCommContext::Instance().CreateNCCLComm(
nccl_id, nranks, rank_id, device_id, rid);
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU."));
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU."));
#endif
} else if (is_xpu_place(place)) {
#if defined(PADDLE_WITH_BKCL)
BKCLUniqueId* bkcl_id = var->GetMutable<BKCLUniqueId>();
int nranks = Attr<int>("nranks");
int rank_id = Attr<int>("rank");
int rid = Attr<int>("ring_id");
PADDLE_ENFORCE_EQ(
rid, 0,
platform::errors::OutOfRange(
"Ring id must equal 0 in multi Kunlun cards training, but got %d",
ring_id));
int device_id = BOOST_GET_CONST(platform::XPUPlace, place).device;
if (Attr<int>("device_id") >= 0) {
device_id = Attr<int>("device_id");
}
platform::BKCLCommContext::Instance().CreateBKCLComm(
bkcl_id, nranks, rank_id, device_id, rid);
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with XPU."));
#endif
} else {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"CCommInitOp can run on gpu or xpu place only."));
}
}
};
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h"
namespace paddle {
namespace operators {
static void GenBKCLID(std::vector<BKCLUniqueId>* bkcl_ids) {
for (size_t i = 0; i < bkcl_ids->size(); ++i) {
BKCLResult_t ret = bkcl_get_unique_id(&(*bkcl_ids)[i]);
PADDLE_ENFORCE_EQ(BKCL_SUCCESS, ret,
platform::errors::PreconditionNotMet(
"bkcl get unique id failed [%d]", ret));
}
}
static void CopyBKCLIDToVar(const std::vector<BKCLUniqueId>& bkcl_ids,
std::function<std::string(size_t)> func,
const framework::Scope& scope) {
for (size_t i = 0; i < bkcl_ids.size(); ++i) {
std::string var_name = func(i);
auto var = scope.FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::NotFound("Variable with name %s is not found",
var_name.c_str()));
auto bkcl_id = var->GetMutable<BKCLUniqueId>();
memcpy(bkcl_id, &bkcl_ids[i], sizeof(BKCLUniqueId));
}
}
class CGenBKCLIdOp : public framework::OperatorBase {
public:
CGenBKCLIdOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
int rank = Attr<int>("rank");
framework::Scope& local_scope = scope.NewScope();
std::function<std::string(size_t)> func = [&](size_t i) -> std::string {
return Output("Out");
};
std::vector<BKCLUniqueId> bkcl_ids;
bkcl_ids.resize(1);
if (rank == 0) {
GenBKCLID(&bkcl_ids);
std::vector<std::string> endpoint_list =
Attr<std::vector<std::string>>("other_endpoints");
platform::SendBroadCastCommID(endpoint_list, &bkcl_ids);
} else {
std::string endpoint = Attr<std::string>("endpoint");
platform::RecvBroadCastCommID(endpoint, &bkcl_ids);
}
CopyBKCLIDToVar(bkcl_ids, func, scope);
scope.DeleteScope(&local_scope);
}
};
class CGenBKCLIdOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddOutput("Out", "Raw variable contains a BKCL UniqueId instaces.");
AddComment(R"DOC(
CGenBKCLId operator
For trainer 0: generate a new UniqueId and send it to all the other trainers.
For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the server.
)DOC");
AddAttr<std::string>("endpoint",
"(string), e.g. 127.0.0.1:6175 "
"current listen endpoint");
AddAttr<std::vector<std::string>>(
"other_endpoints",
"['trainer1_ip:port', 'trainer2_ip:port', ...] "
"list of other trainer endpoints")
.SetDefault({});
AddAttr<int>("rank",
"(int default 0) "
"The rank of the trainer in distributed training.")
.SetDefault(0);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(c_gen_bkcl_id, ops::CGenBKCLIdOp, ops::CGenBKCLIdOpMaker);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <ostream>
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/fluid/platform/bkcl_helper.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/split.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h"
namespace paddle {
namespace operators {
static void GenBKCLID(std::vector<BKCLUniqueId>* bkcl_ids) {
for (size_t i = 0; i < bkcl_ids->size(); ++i) {
BKCLResult_t ret = bkcl_get_unique_id(&(*bkcl_ids)[i]);
PADDLE_ENFORCE_EQ(BKCL_SUCCESS, ret,
platform::errors::PreconditionNotMet(
"bkcl get unique id failed [%d]", ret));
}
}
static void CopyBKCLIDToVar(const std::vector<BKCLUniqueId>& bkcl_ids,
std::function<std::string(size_t)> func,
const framework::Scope& scope) {
for (size_t i = 0; i < bkcl_ids.size(); ++i) {
std::string var_name = func(i);
auto var = scope.FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::NotFound("Variable with name %s is not found",
var_name.c_str()));
auto bkcl_id = var->GetMutable<BKCLUniqueId>();
memcpy(bkcl_id, &bkcl_ids[i], sizeof(BKCLUniqueId));
}
}
class GenBKCLIdOp : public framework::OperatorBase {
public:
GenBKCLIdOp(const std::string& type, const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
std::vector<std::string> trainers =
Attr<std::vector<std::string>>("trainers");
int trainer_id = Attr<int>("trainer_id");
std::string endpoint = trainers[trainer_id];
PADDLE_ENFORCE_GE(trainer_id, 0, platform::errors::InvalidArgument(
"trainer_id %d is less than 0. Its "
"valid range is [0, trainer_size)"));
PADDLE_ENFORCE_LT(
trainer_id, static_cast<int>(trainers.size()),
platform::errors::OutOfRange("trainer_id %d is out of range. Its valid "
"range is [0, trainer_size)",
trainer_id));
int bkcl_comm_num = Attr<int>("bkcl_comm_num");
int use_hierarchical_allreduce = Attr<bool>("use_hierarchical_allreduce");
int inter_nranks = Attr<int>("hierarchical_allreduce_inter_nranks");
int inter_trainer_id = -1;
int exter_trainer_id = -1;
if (use_hierarchical_allreduce) {
PADDLE_ENFORCE_GT(
trainers.size(), 1,
platform::errors::PreconditionNotMet(
"The number of collective trainers %llu <= 1", trainers.size()));
PADDLE_ENFORCE_GT(
inter_nranks, 1,
platform::errors::PreconditionNotMet(
"inter_nranks %d <= 1 while in hierarchical allreduce mode",
inter_nranks));
PADDLE_ENFORCE_EQ(
trainers.size() % inter_nranks, 0,
platform::errors::PreconditionNotMet(
"The number of trainers %llu mod inter_nranks %d is not equal 0",
trainers.size(), inter_nranks));
inter_trainer_id = trainer_id % inter_nranks;
if (trainer_id % inter_nranks == 0) {
exter_trainer_id = trainer_id / inter_nranks;
}
}
std::ostringstream ss;
for (size_t i = 0; i < trainers.size(); i++) {
ss << trainers[i] << ",";
}
VLOG(1) << "trainer_id:" << trainer_id
<< ", use_hierarchical_allreduce:" << use_hierarchical_allreduce
<< ", bkcl_comm_num:" << bkcl_comm_num
<< ", inter_nranks:" << inter_nranks
<< ", inter_trainer_id:" << inter_trainer_id
<< ", exter_trainer_id:" << exter_trainer_id
<< ", trainers:" << ss.str();
int server_fd = -1;
std::vector<BKCLUniqueId> bkcl_ids;
bkcl_ids.resize(bkcl_comm_num);
/// 1. init flat
std::function<std::string(size_t)> func = platform::GetFlatBKCLVarName;
// broadcast unique id
if (trainer_id == 0) {
GenBKCLID(&bkcl_ids);
// server endpoints
std::vector<std::string> flat_endpoints;
flat_endpoints.insert(flat_endpoints.begin(), trainers.begin() + 1,
trainers.end());
platform::SendBroadCastCommID(flat_endpoints, &bkcl_ids);
} else {
server_fd = platform::CreateListenSocket(endpoint);
platform::RecvBroadCastCommID(server_fd, endpoint, &bkcl_ids);
}
CopyBKCLIDToVar(bkcl_ids, func, scope);
/*TODO(liuyuhui) Baidu Kunlun Communication Library(BKCL) don't support
hierarchical communication
as NVIDIA Collective Communications Library(NCCL) in multi Nvidia GPU cards,
and will support it later.
*/
// close socket server
if (trainer_id != 0) {
platform::CloseSocket(server_fd);
}
}
};
class GenBKCLIdOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddOutput("BKCLID", "Raw variable contains a BKCL UniqueId instaces.");
AddComment(R"DOC(
GenBKCLId operator
For trainer 0: generate a new UniqueId and send it to all the other trainers.
For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the server.
)DOC");
AddAttr<std::vector<std::string>>(
"trainers",
"['trainer0_ip:port', 'trainer1_ip:port', ...] "
"list of all trainer endpoints")
.SetDefault({});
AddAttr<int>("trainer_id",
"(int) "
"The index of the trainer in distributed training.");
AddAttr<int>("bkcl_comm_num",
"(int default 1) "
"The number of bkcl communicator num.")
.SetDefault(1);
AddAttr<bool>("use_hierarchical_allreduce",
"(bool default false) "
"Wheter to use hierarchical allreduce.")
.SetDefault(false);
AddAttr<int>("hierarchical_allreduce_inter_nranks",
"(int default 1) "
"Wheter to use hierarchical allreduce.")
.SetDefault(-1);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(gen_bkcl_id, ops::GenBKCLIdOp, ops::GenBKCLIdOpMaker);
......@@ -74,30 +74,60 @@ class CollectiveHelper(object):
wait_server_ready(other_endpoints)
block = program.global_block()
nccl_id_var = block.create_var(
name=unique_name.generate('nccl_id'),
persistable=True,
type=core.VarDesc.VarType.RAW)
block.append_op(
type='c_gen_nccl_id',
inputs={},
outputs={'Out': nccl_id_var},
attrs={
'rank': rank,
'endpoint': current_endpoint,
'other_endpoints': other_endpoints,
OP_ROLE_KEY: OpRole.Forward
})
block.append_op(
type='c_comm_init',
inputs={'X': nccl_id_var},
outputs={},
attrs={
'nranks': nranks,
'rank': rank,
'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Forward
})
if core.is_compiled_with_cuda():
comm_id_var = block.create_var(
name=unique_name.generate('nccl_id'),
persistable=True,
type=core.VarDesc.VarType.RAW)
block.append_op(
type='c_gen_nccl_id',
inputs={},
outputs={'Out': comm_id_var},
attrs={
'rank': rank,
'endpoint': current_endpoint,
'other_endpoints': other_endpoints,
OP_ROLE_KEY: OpRole.Forward
})
block.append_op(
type='c_comm_init',
inputs={'X': comm_id_var},
outputs={},
attrs={
'nranks': nranks,
'rank': rank,
'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Forward
})
elif core.is_compiled_with_xpu():
comm_id_var = block.create_var(
name=unique_name.generate('bkcl_id'),
persistable=True,
type=core.VarDesc.VarType.RAW)
block.append_op(
type='c_gen_bkcl_id',
inputs={},
outputs={'Out': comm_id_var},
attrs={
'rank': rank,
'endpoint': current_endpoint,
'other_endpoints': other_endpoints,
OP_ROLE_KEY: OpRole.Forward
})
block.append_op(
type='c_comm_init',
inputs={'X': comm_id_var},
outputs={},
attrs={
'nranks': nranks,
'rank': rank,
'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Forward
})
else:
raise ValueError(
"comm_id must be generated in paddlepaddle-xpu or paddlepaddle-xpu."
)
def _wait(self, current_endpoint, endpoints):
assert (self.wait_port)
......
......@@ -64,39 +64,70 @@ class GraphExecutionOptimizer(MetaOptimizerBase):
if trainer_id == 0:
wait_server_ready(other_trainers)
nccl_id_var = startup_program.global_block().create_var(
name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW)
if core.is_compiled_with_cuda():
comm_id_var = startup_program.global_block().create_var(
name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW)
for i in range(1, build_strategy.nccl_comm_num):
startup_program.global_block().create_var(
name="NCCLID_{}".format(i),
persistable=True,
type=core.VarDesc.VarType.RAW)
if build_strategy.use_hierarchical_allreduce:
for i in range(0, build_strategy.nccl_comm_num):
for i in range(1, build_strategy.nccl_comm_num):
startup_program.global_block().create_var(
name="Hierarchical_inter_NCCLID_{}".format(i),
name="NCCLID_{}".format(i),
persistable=True,
type=core.VarDesc.VarType.RAW)
if build_strategy.use_hierarchical_allreduce:
for i in range(0, build_strategy.nccl_comm_num):
startup_program.global_block().create_var(
name="Hierarchical_inter_NCCLID_{}".format(i),
persistable=True,
type=core.VarDesc.VarType.RAW)
startup_program.global_block().create_var(
name="Hierarchical_exter_NCCLID_{}".format(i),
persistable=True,
type=core.VarDesc.VarType.RAW)
startup_program.global_block().append_op(
type="gen_nccl_id",
inputs={},
outputs={"NCCLID": comm_id_var},
attrs={
"trainers": trainer_endpoints,
"trainer_id": trainer_id,
"nccl_comm_num": build_strategy.nccl_comm_num,
"use_hierarchical_allreduce":
build_strategy.use_hierarchical_allreduce,
"hierarchical_allreduce_inter_ranks":
build_strategy.hierarchical_allreduce_inter_nranks
})
elif core.is_compiled_with_xpu():
comm_id_var = startup_program.global_block().create_var(
name="BKCLID", persistable=True, type=core.VarDesc.VarType.RAW)
#NOTE(liuyuhui) Baidu Kunlun Communication Library(BKCL) currently do not support multi machines.
assert build_strategy.bkcl_comm_num == 1, \
"Baidu Kunlun Communication Library(BKCL) currently do not support multi machines."
for i in range(1, build_strategy.bkcl_comm_num):
startup_program.global_block().create_var(
name="Hierarchical_exter_NCCLID_{}".format(i),
name="BKCLID_{}".format(i),
persistable=True,
type=core.VarDesc.VarType.RAW)
startup_program.global_block().append_op(
type="gen_nccl_id",
inputs={},
outputs={"NCCLID": nccl_id_var},
attrs={
"trainers": trainer_endpoints,
"trainer_id": trainer_id,
"nccl_comm_num": build_strategy.nccl_comm_num,
"use_hierarchical_allreduce":
build_strategy.use_hierarchical_allreduce,
"hierarchical_allreduce_inter_ranks":
build_strategy.hierarchical_allreduce_inter_nranks
})
startup_program.global_block().append_op(
type="gen_bkcl_id",
inputs={},
outputs={"BKCLID": comm_id_var},
attrs={
"trainers": trainer_endpoints,
"trainer_id": trainer_id,
"nccl_comm_num": build_strategy.nccl_comm_num,
"use_hierarchical_allreduce":
build_strategy.use_hierarchical_allreduce,
"hierarchical_allreduce_inter_ranks":
build_strategy.hierarchical_allreduce_inter_nranks
})
else:
raise ValueError(
"comm_id must be generated in paddlepaddle-xpu or paddlepaddle-gpu."
)
def _try_to_compile(self, startup_program, main_program, loss):
dist_strategy = self.user_defined_strategy
......
......@@ -2057,9 +2057,9 @@ class Operator(object):
'feed', 'fetch', 'recurrent', 'go', 'rnn_memory_helper_grad',
'conditional_block', 'while', 'send', 'recv', 'listen_and_serv',
'fl_listen_and_serv', 'ncclInit', 'select', 'checkpoint_notify',
'gen_nccl_id', 'c_gen_nccl_id', 'c_comm_init', 'c_sync_calc_stream',
'c_sync_comm_stream', 'queue_generator', 'dequeue', 'enqueue',
'heter_listen_and_serv'
'gen_bkcl_id', 'c_gen_bkcl_id', 'gen_nccl_id', 'c_gen_nccl_id',
'c_comm_init', 'c_sync_calc_stream', 'c_sync_comm_stream',
'queue_generator', 'dequeue', 'enqueue', 'heter_listen_and_serv'
}
def __init__(self,
......
......@@ -186,8 +186,8 @@ class TestDistRunnerBase(object):
fleet.save_inference_model(exe, infer_save_dir_fleet,
feeded_var_names, [avg_cost])
def run_gpu_fleet_api_trainer(self, args):
assert args.update_method == "nccl2"
def run_use_fleet_api_trainer(self, args):
assert args.update_method == "nccl2" or "bkcl"
self.lr = args.lr
......@@ -207,7 +207,7 @@ class TestDistRunnerBase(object):
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role)
print_to_err("gpu_fleet", "fleet.node_num:")
print_to_err("use_fleet", "fleet.node_num:")
# "fleet.node_id:", fleet.node_id(),
# "fleet.trainer_num:", fleet.worker_num())
......@@ -217,8 +217,16 @@ class TestDistRunnerBase(object):
trainer_prog = fleet._origin_program
dist_prog = fleet.main_program
device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
place = fluid.CUDAPlace(device_id)
if fluid.core.is_compiled_with_cuda():
device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
place = fluid.CUDAPlace(device_id)
elif fluid.core.is_compiled_with_xpu():
device_id = int(os.getenv("FLAGS_selected_xpus", "0"))
place = fluid.XPUPlace(device_id)
else:
raise ValueError(
"fleet dygraph api must in paddlepaddle-xpu or paddlepaddle-gpu."
)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
......@@ -550,7 +558,7 @@ class TestParallelDyGraphRunnerBase(object):
model.clear_gradients()
return out_losses
def run_gpu_fleet_api_trainer(self, args):
def run_use_fleet_api_trainer(self, args):
import paddle.distributed.fleet as fleet
import paddle.distributed.fleet.base.role_maker as role_maker
# 1. enable dygraph
......@@ -566,12 +574,12 @@ class TestParallelDyGraphRunnerBase(object):
args.trainer_id = paddle.distributed.get_rank()
# 3. init parallel env
if args.update_method == "nccl2":
if args.update_method == "nccl2" or "bkcl":
fleet.init(is_collective=True)
# 4. train model
model, train_reader, opt = self.get_model()
if args.update_method == "nccl2":
if args.update_method == "nccl2" or "bkcl":
opt = fleet.distributed_optimizer(opt)
model = fleet.distributed_model(model)
......@@ -606,7 +614,7 @@ def runtime_main(test_class):
parser.add_argument('--enable_backward_deps', action='store_true')
parser.add_argument('--use_hallreduce', action='store_true')
parser.add_argument('--use_pipeline', action='store_true')
parser.add_argument('--gpu_fleet_api', action='store_true')
parser.add_argument('--use_fleet_api', action='store_true')
parser.add_argument('--use_local_sgd', action='store_true')
parser.add_argument('--ut4grad_allreduce', action='store_true')
parser.add_argument(
......@@ -644,8 +652,8 @@ def runtime_main(test_class):
model = test_class()
if args.role == "pserver" and args.update_method == "pserver":
model.run_pserver(args)
elif args.gpu_fleet_api:
model.run_gpu_fleet_api_trainer(args)
elif args.use_fleet_api:
model.run_use_fleet_api_trainer(args)
elif args.use_pipeline:
model.run_pipeline_trainer(args)
else:
......@@ -708,7 +716,7 @@ class TestDistBase(unittest.TestCase):
self._dygraph = False
self._nccl_comm_num = 1
self._enable_backward_deps = False
self._gpu_fleet_api = False
self._use_fleet_api = False
self._use_local_sgd = False
self._ut4grad_allreduce = False
self._use_hallreduce = False
......@@ -1020,8 +1028,8 @@ class TestDistBase(unittest.TestCase):
if self._fuse_all_reduce is not None:
tr_cmd += " --fuse_all_reduce {}".format(self._fuse_all_reduce)
if self._gpu_fleet_api:
tr_cmd += " --gpu_fleet_api"
if self._use_fleet_api:
tr_cmd += " --use_fleet_api"
if self._use_local_sgd:
tr_cmd += " --use_local_sgd"
if self._ut4grad_allreduce:
......
......@@ -28,7 +28,7 @@ class TestDistMnistFleetSave(TestDistBase):
self._use_reduce = False
self._use_reader_alloc = False
self._nccl2_mode = True
self._gpu_fleet_api = True
self._use_fleet_api = True
self._save_model = True
def _rm_temp_files(self, dirname):
......
......@@ -26,7 +26,7 @@ class TestDistMnistNCCL2FleetApi(TestDistBase):
self._use_reduce = False
self._use_reader_alloc = False
self._nccl2_mode = True
self._gpu_fleet_api = True
self._use_fleet_api = True
self._sync_batch_norm = True
def test_dist_train(self):
......
......@@ -26,7 +26,7 @@ class TestDistMnistLocalSGDFleetApi(TestDistBase):
self._use_reduce = False
self._use_reader_alloc = False
self._nccl2_mode = True
self._gpu_fleet_api = True
self._use_fleet_api = True
self._use_local_sgd = True
def test_dist_train(self):
......@@ -41,7 +41,7 @@ class TestDistMnistGradAllReduceFleetApi(TestDistBase):
self._use_reduce = False
self._use_reader_alloc = False
self._nccl2_mode = True
self._gpu_fleet_api = True
self._use_fleet_api = True
self._ut4grad_allreduce = True
def test_dist_train(self):
......
......@@ -28,7 +28,7 @@ class TestDistMnistFleetSave(TestDistBase):
self._use_reduce = False
self._use_reader_alloc = False
self._nccl2_mode = True
self._gpu_fleet_api = True
self._use_fleet_api = True
self._sharding_save = True
self._enforce_place = "GPU"
......
......@@ -71,7 +71,7 @@ class TestFleetDygraphMnist(TestDistBase):
self._sync_mode = False
self._nccl2_mode = True
self._dygraph = True
self._gpu_fleet_api = True
self._use_fleet_api = True
def test_mnist(self):
if fluid.core.is_compiled_with_cuda():
......@@ -82,5 +82,22 @@ class TestFleetDygraphMnist(TestDistBase):
log_name=flag_name)
class TestFleetDygraphMnistXPU(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._bkcl_mode = True
self._dygraph = True
self._enforce_place = "XPU"
self._use_fleet_api = True
def test_mnist(self):
if fluid.core.is_compiled_with_xpu():
self.check_with_place(
"parallel_dygraph_mnist.py",
delta=1e-1,
check_error_log=True,
log_name=flag_name)
if __name__ == "__main__":
unittest.main()
......@@ -53,7 +53,7 @@ class TestFleetDygraphMnist(TestDistBase):
self._sync_mode = False
self._nccl2_mode = True
self._dygraph = True
self._gpu_fleet_api = True
self._use_fleet_api = True
def test_mnist(self):
if fluid.core.is_compiled_with_cuda():
......
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
if (WITH_XPU_BKCL)
list(REMOVE_ITEM TEST_OPS "test_gen_bkcl_id_op")
endif()
file(GLOB DIST_TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_dist_*.py")
if (WITH_XPU_BKCL)
list(APPEND DIST_TEST_OPS test_gen_bkcl_id_op)
endif()
list(REMOVE_ITEM TEST_OPS test_concat_op_xpu)
list(REMOVE_ITEM TEST_OPS test_mean_op_xpu)
......@@ -8,5 +17,9 @@ foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP})
endforeach(TEST_OP)
foreach(TEST_OP ${DIST_TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP})
endforeach(TEST_OP)
set_tests_properties(test_mul_op_xpu PROPERTIES TIMEOUT 120)
set_tests_properties(test_conv2d_op_xpu PROPERTIES TIMEOUT 120)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import os
import copy
import sys
sys.path.append("..")
from launch_function_helper import wait, _find_free_port
from multiprocessing import Pool, Process
from threading import Thread
os.environ['GLOG_vmodule'] = str("gen_bkcl_id_op*=10,gen_comm_id*=10")
import paddle
from paddle.fluid import core
paddle.enable_static()
def run_gen_bkc_id(attr):
bkcl_comm_num = attr['bkcl_comm_num']
use_hallreduce = attr['use_hierarchical_allreduce']
startup_program = paddle.static.default_startup_program()
main_program = paddle.static.default_main_program()
with paddle.static.program_guard(main_program, startup_program):
bkcl_id_var = startup_program.global_block().create_var(
name="BKCLID", persistable=True, type=core.VarDesc.VarType.RAW)
for i in range(1, bkcl_comm_num):
startup_program.global_block().create_var(
name="BKCLID_{}".format(i),
persistable=True,
type=core.VarDesc.VarType.RAW)
if use_hallreduce:
for i in range(0, bkcl_comm_num):
startup_program.global_block().create_var(
name="Hierarchical_inter_BKCLID_{}".format(i),
persistable=True,
type=core.VarDesc.VarType.RAW)
startup_program.global_block().create_var(
name="Hierarchical_exter_BKCLID_{}".format(i),
persistable=True,
type=core.VarDesc.VarType.RAW)
startup_program.global_block().append_op(
type="gen_bkcl_id",
inputs={},
outputs={"BKCLID": bkcl_id_var},
attrs=attr)
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_program)
class TestGenBKCLIdOp(unittest.TestCase):
def setUp(self):
try:
self._dist_ut_port_0 = int(os.environ["PADDLE_DIST_UT_PORT"])
except Exception as e:
self._dist_ut_port_0 = _find_free_port(set())
def gen_bkcl_id(self, nranks=2):
bkcl_comm_num = 1
if nranks == 2:
use_hallreduce = False
hallreduce_inter_nranks = -1
elif nranks == 4:
use_hallreduce = True
hallreduce_inter_nranks = 2
port = self._dist_ut_port_0
trainers = []
for i in range(nranks):
trainers.append('127.0.0.1:{}'.format(port + i))
attr = {
"trainers": trainers,
"trainer_id": 0,
"bkcl_comm_num": bkcl_comm_num,
"use_hierarchical_allreduce": use_hallreduce,
"hierarchical_allreduce_inter_nranks": hallreduce_inter_nranks,
}
procs = []
for i in range(nranks):
attr['trainer_id'] = i
# NOTE: multiprocessing cannot be covered by coverage
p = Process(target=run_gen_bkc_id, args=(attr, ))
p.start()
procs.append(p)
wait(procs, timeout=120)
def test_flat(self):
print(">>> test gen flat bkcl id")
self.gen_bkcl_id(2)
print("<<< end test gen flat bkcl id")
print()
def test_hierarchical(self):
print(">>> test gen hierarchical bkcl id")
self.gen_bkcl_id(4)
print("<<< end test gen hierarchical bkcl id")
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册