未验证 提交 7e2a20d5 编写于 作者: H houj04 提交者: GitHub

[XPU] add some collective ops. (#45049)

* [XPU] add some collective ops. test=kunlun

* use XPUOpTestWrapper. test=kunlun

* skip kl1 for collective ops. fix typo: deivce -> device. test=kunlun
上级 566bbf0c
......@@ -25,7 +25,7 @@ else()
endif()
set(XPU_XCCL_BASE_URL
"https://klx-sdk-release-public.su.bcebos.com/xccl/release/1.0.0")
"https://klx-sdk-release-public.su.bcebos.com/xccl/release/1.0.4")
if(WITH_AARCH64)
set(XPU_XRE_DIR_NAME "xre-kylin_aarch64")
......@@ -52,7 +52,7 @@ elseif(WITH_BDCENTOS)
elseif(WITH_UBUNTU)
set(XPU_XRE_DIR_NAME "xre-ubuntu_x86_64")
set(XPU_XDNN_DIR_NAME "xdnn-ubuntu_x86_64")
set(XPU_XCCL_DIR_NAME "xccl-bdcentos_x86_64")
set(XPU_XCCL_DIR_NAME "xccl-ubuntu_x86_64")
# ubuntu and centos: use output by XDNN API team
set(XPU_XDNN_URL
"${XPU_XDNN_BASE_URL}/${XPU_XDNN_DIR_NAME}.tar.gz"
......@@ -68,7 +68,7 @@ elseif(WITH_CENTOS)
else()
set(XPU_XRE_DIR_NAME "xre-ubuntu_x86_64")
set(XPU_XDNN_DIR_NAME "xdnn-ubuntu_x86_64")
set(XPU_XCCL_DIR_NAME "xccl-bdcentos_x86_64")
set(XPU_XCCL_DIR_NAME "xccl-ubuntu_x86_64")
# default: use output by XDNN API team
set(XPU_XDNN_URL
"${XPU_XDNN_BASE_URL}/${XPU_XDNN_DIR_NAME}.tar.gz"
......
......@@ -301,6 +301,29 @@ void DistModel::InsertCommOp(std::string tmp_var_name,
comm_init_op->SetAttr("op_role",
static_cast<int>(framework::OpRole::kForward));
comm_init_op->CheckAttrs();
} else if (config_.place == "XPU") {
framework::VarDesc *new_var = block->Var(tmp_var_name);
new_var->SetType(framework::proto::VarType::RAW);
new_var->SetPersistable(true);
framework::OpDesc *gen_bkcl_id_op = block->AppendOp();
gen_bkcl_id_op->SetType("c_gen_bkcl_id");
gen_bkcl_id_op->SetOutput("Out", {tmp_var_name});
gen_bkcl_id_op->SetAttr("rank", rank);
gen_bkcl_id_op->SetAttr("endpoint", config_.current_endpoint);
gen_bkcl_id_op->SetAttr("other_endpoints", peer_endpoints);
gen_bkcl_id_op->SetAttr("ring_id", ring_id);
gen_bkcl_id_op->SetAttr("op_role",
static_cast<int>(framework::OpRole::kForward));
gen_bkcl_id_op->CheckAttrs();
framework::OpDesc *comm_init_op = block->AppendOp();
comm_init_op->SetType("c_comm_init");
comm_init_op->SetInput("X", {tmp_var_name});
comm_init_op->SetAttr("rank", rank);
comm_init_op->SetAttr("nranks", nranks);
comm_init_op->SetAttr("ring_id", ring_id);
comm_init_op->SetAttr("op_role",
static_cast<int>(framework::OpRole::kForward));
comm_init_op->CheckAttrs();
} else {
LOG(WARNING) << "DistModelInf doesn't init comm.";
// TODO(fleet exe dev): comm init for more devices
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_allgather_op.h"
#ifdef PADDLE_WITH_XPU_BKCL
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/xpu/bkcl_helper.h"
#endif
namespace paddle {
namespace operators {
template <typename T>
class CAllGatherOpXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_XPU_BKCL)
auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
BKCLDataType dtype =
platform::ToBKCLDataType(framework::TransToProtoVarType(in->dtype()));
int nranks = ctx.Attr<int>("nranks");
int rid = ctx.Attr<int>("ring_id");
auto place = ctx.GetPlace();
auto comm = platform::BKCLCommContext::Instance().Get(rid, place);
PADDLE_ENFORCE_EQ(
nranks,
comm->nranks(),
platform::errors::InvalidArgument(
"nranks: %s should equal to %s", nranks, comm->nranks()));
framework::DDim out_dims = in->dims();
out_dims[0] *= nranks;
size_t numel = in->numel();
const void* sendbuff = in->data<T>();
void* recvbuff = out->mutable_data<T>(out_dims, place);
XPUStream stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::XPUDeviceContext*>(dev_ctx)
->x_context()
->xpu_stream;
} else {
stream = comm->stream();
}
// BKCLResult_t bkcl_all_gather(const BKCLContext_t ctx, const void*
// sendbuf, size_t sendcnt, void* recvbuf, BKCLDataType datatype, XPUStream
// stream);
PADDLE_ENFORCE_EQ(
bkcl_all_gather(comm->comm(), sendbuff, numel, recvbuff, dtype, stream),
BKCL_SUCCESS,
platform::errors::PreconditionNotMet("BKCL all gather failed"));
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should be compiled with XPU and bkcl."));
#endif
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(c_allgather,
ops::CAllGatherOpXPUKernel<float>,
ops::CAllGatherOpXPUKernel<double>,
ops::CAllGatherOpXPUKernel<int>,
ops::CAllGatherOpXPUKernel<int64_t>,
ops::CAllGatherOpXPUKernel<plat::float16>);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_allreduce_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(c_allreduce_sum,
ops::CAllReduceOpXPUKernel<ops::kRedSum, float>,
ops::CAllReduceOpXPUKernel<ops::kRedSum, plat::float16>)
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_identity_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(c_identity,
ops::CIdentityOpKernel<float>,
ops::CIdentityOpKernel<double>,
ops::CIdentityOpKernel<int>,
ops::CIdentityOpKernel<int64_t>,
ops::CIdentityOpKernel<plat::float16>);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_sync_comm_stream_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(c_sync_comm_stream, ops::CSyncCommStreamKernel<float>);
......@@ -42,10 +42,18 @@ namespace platform {
inline BKCLDataType ToBKCLDataType(framework::proto::VarType::Type type) {
if (type == framework::proto::VarType::FP32) {
return BKCL_FLOAT;
} else if (type == framework::proto::VarType::INT64) {
return BKCL_INT64;
} else if (type == framework::proto::VarType::INT32) {
return BKCL_INT32;
} else if (type == framework::proto::VarType::FP64) {
return BKCL_FLOAT64;
} else if (type == framework::proto::VarType::FP16) {
return BKCL_FLOAT16;
} else {
PADDLE_THROW(
platform::errors::Unimplemented("BKCL currently only support FP32, "
"other data types are not supported."));
PADDLE_THROW(platform::errors::Unimplemented(
"BKCL currently only support FP32, INT64, INT32, FP64 and FP16, other "
"data types are not supported."));
}
}
......
......@@ -66,6 +66,23 @@ XPUOpMap& get_kl2_ops() {
{"bilinear_interp_v2_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"broadcast", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"c_allgather",
XPUKernelSet({pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace())})},
{"c_allreduce_sum",
XPUKernelSet({pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"c_identity",
XPUKernelSet({pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace())})},
{"c_sync_comm_stream",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"cast",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
import paddle.fluid.layers as layers
from test_collective_base_xpu import TestCollectiveRunnerBase, runtime_main
paddle.enable_static()
class TestCollectiveAllGather(TestCollectiveRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program):
ring_id = 0
nranks = 2
with fluid.program_guard(main_prog, startup_program):
tindata = layers.data(name="tindata",
shape=[10, 1000],
dtype='float32')
toutdata = main_prog.current_block().create_var(
name="outofgather",
dtype='float32',
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False)
main_prog.global_block().append_op(type="c_allgather",
inputs={'X': tindata},
attrs={
'ring_id': ring_id,
'nranks': nranks
},
outputs={'Out': toutdata})
main_prog.global_block().append_op(type="c_sync_comm_stream",
inputs={'X': toutdata},
outputs={'Out': toutdata},
attrs={'ring_id': ring_id})
return toutdata
if __name__ == "__main__":
runtime_main(TestCollectiveAllGather, "allgather", 0)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
import paddle.fluid.layers as layers
from test_collective_base_xpu import TestCollectiveRunnerBase, runtime_main
paddle.enable_static()
class TestCollectiveAllReduce(TestCollectiveRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program):
ring_id = 0
with fluid.program_guard(main_prog, startup_program):
tindata = layers.data(name="tindata",
shape=[10, 1000],
dtype='float32')
toutdata = main_prog.current_block().create_var(
name="outofreduce",
dtype='float32',
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False)
main_prog.global_block().append_op(type="c_allreduce_sum",
inputs={'X': tindata},
attrs={
'ring_id': ring_id,
},
outputs={'Out': toutdata})
main_prog.global_block().append_op(type="c_sync_comm_stream",
inputs={'X': toutdata},
outputs={'Out': toutdata},
attrs={'ring_id': ring_id})
return toutdata
if __name__ == "__main__":
os.environ["BKCL_PCIE_RING"] = "1"
runtime_main(TestCollectiveAllReduce, "allreduce", 0)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
import paddle.fluid.layers as layers
from test_collective_base_xpu import TestCollectiveRunnerBase, runtime_main
paddle.enable_static()
class TestCollectiveIdentity(TestCollectiveRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program):
ring_id = 0
nranks = 2
with fluid.program_guard(main_prog, startup_program):
tindata = layers.data(name="tindata",
shape=[10, 1000],
dtype='float32')
toutdata = main_prog.current_block().create_var(
name="outofgather",
dtype='float32',
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False)
main_prog.global_block().append_op(type="c_identity",
inputs={'X': tindata},
outputs={'Out': toutdata},
attrs={
'ring_id': ring_id,
'nranks': nranks
})
return toutdata
if __name__ == "__main__":
runtime_main(TestCollectiveIdentity, "identity", 0)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle
from paddle.fluid import core
from test_collective_base_xpu import TestDistBase
import sys
sys.path.append("..")
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
paddle.enable_static()
class XPUTestCAllgatherOP(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'c_allgather'
self.use_dynamic_create_class = False
class TestCAllgatherOp(TestDistBase):
def _setup_config(self):
pass
def test_allgather(self):
self.check_with_place("collective_allgather_op_xpu.py", "allgather",
self.in_type_str)
support_types = get_xpu_op_support_types('c_allgather')
for stype in support_types:
create_test_class(globals(),
XPUTestCAllgatherOP,
stype,
ignore_device_version=[core.XPUVersion.XPU1])
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle
from paddle.fluid import core
from test_collective_base_xpu import TestDistBase
import sys
sys.path.append("..")
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
paddle.enable_static()
class XPUTestCAllreduceOP(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'c_allreduce_sum'
self.use_dynamic_create_class = False
class TestCAllreduceOp(TestDistBase):
def _setup_config(self):
pass
def test_allreduce(self):
self.check_with_place("collective_allreduce_op_xpu.py", "allreduce",
self.in_type_str)
support_types = get_xpu_op_support_types('c_allreduce_sum')
for stype in support_types:
create_test_class(globals(),
XPUTestCAllreduceOP,
stype,
ignore_device_version=[core.XPUVersion.XPU1])
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import unittest
import time
import os
import sys
import subprocess
import pickle
import tempfile
from contextlib import closing
import paddle.fluid as fluid
import paddle.fluid.unique_name as nameGen
from paddle.fluid import core
def DataTypeCast(date_type):
np_data_type = None
if date_type == "float16":
np_data_type = np.float16
elif date_type == "float32":
np_data_type = np.float32
elif date_type == "float64":
np_data_type = np.float64
elif date_type == "int8":
np_data_type = np.int8
elif date_type == "int16":
np_data_type = np.int16
elif date_type == "int32":
np_data_type = np.int32
elif date_type == "int64":
np_data_type = np.int64
else:
raise ValueError("This data type is not support!")
return np_data_type
class TestCollectiveRunnerBase(object):
def get_model(self, train_prog, startup_prog):
raise NotImplementedError(
"get model should be implemented by child class.")
def wait_server_ready(self, endpoints):
while True:
all_ok = True
not_ready_endpoints = []
for ep in endpoints:
ip_port = ep.split(":")
with closing(socket.socket(socket.AF_INET,
socket.SOCK_STREAM)) as sock:
sock.settimeout(2)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if hasattr(socket, 'SO_REUSEPORT'):
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT,
1)
result = sock.connect_ex((ip_port[0], int(ip_port[1])))
if result != 0:
all_ok = False
not_ready_endpoints.append(ep)
if not all_ok:
sys.stderr.write("server not ready, wait 3 sec to retry...\n")
sys.stderr.write("not ready endpoints:" +
str(not_ready_endpoints) + "\n")
sys.stderr.flush()
time.sleep(3)
else:
break
#endpoints should be ["ip1:port1","ip2:port2"]
def initCommunicator(self, program, rank, nranks, wait_port,
current_endpoint, endpoints):
other_endpoints = endpoints[:]
other_endpoints.remove(current_endpoint)
if rank == 0 and wait_port:
self.wait_server_ready(other_endpoints)
block = program.global_block()
bkcl_id_var = block.create_var(name=nameGen.generate('bkcl_id'),
persistable=True,
type=core.VarDesc.VarType.RAW)
block.append_op(type='c_gen_bkcl_id',
inputs={},
outputs={'Out': bkcl_id_var},
attrs={
'rank': rank,
'endpoint': current_endpoint,
'other_endpoints': other_endpoints
})
block.append_op(type='c_comm_init',
inputs={'X': bkcl_id_var},
outputs={},
attrs={
'nranks': nranks,
'rank': rank,
'ring_id': self.global_ring_id
})
def run_trainer(self, args):
train_prog = fluid.Program()
startup_prog = fluid.Program()
endpoints = args["endpoints"].split(",")
rank = args["trainerid"]
current_endpoint = args["currentendpoint"]
nranks = 2
self.initCommunicator(startup_prog, rank, nranks, True,
current_endpoint, endpoints)
self.rank = rank
result = self.get_model(train_prog, startup_prog)
device_id = int(os.getenv("FLAGS_selected_xpus", "0"))
place = fluid.XPUPlace(device_id)
exe = fluid.Executor(place)
exe.run(startup_prog)
np.random.seed(os.getpid())
np_data_type = DataTypeCast(args["data_type"])
indata = np.random.uniform(low=-10.0, high=10.0,
size=(10, 1000)).astype(np_data_type)
out = exe.run(train_prog,
feed={'tindata': indata},
fetch_list=[result.name])
sys.stdout.buffer.write(pickle.dumps(out[0]))
def runtime_main(test_class, col_type, sub_type):
args = {}
model = test_class()
args["deviceid"] = os.getenv("FLAGS_selected_xpus")
args["trainerid"] = int(os.getenv("PADDLE_TRAINER_ID"))
args["trainernum"] = int(os.getenv("PADDLE_TRAINERS_NUM"))
args["endpoints"] = os.getenv('PADDLE_TRAINER_ENDPOINTS')
args["currentendpoint"] = os.getenv("PADDLE_CURRENT_ENDPOINT")
args["col_type"] = col_type
args["data_type"] = os.getenv("DATA_TYPE")
model.run_trainer(args)
import paddle.compat as cpt
import socket
from contextlib import closing
class TestDistBase(unittest.TestCase):
def setUp(self):
self._port_set = set()
self._trainers = 2
self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
self._find_free_port(), self._find_free_port())
self._python_interp = sys.executable
self.temp_dir = tempfile.TemporaryDirectory()
def tearDown(self):
self.temp_dir.cleanup()
def _find_free_port(self):
def __free_port():
with closing(socket.socket(socket.AF_INET,
socket.SOCK_STREAM)) as s:
s.bind(('', 0))
return s.getsockname()[1]
while True:
port = __free_port()
if port not in self._port_set:
self._port_set.add(port)
return port
def _run_cluster(self, model_file, envs):
worker_endpoints = self._ps_endpoints.split(",")
w0_ep, w1_ep = worker_endpoints
env0 = {
"FLAGS_selected_xpus": "0",
"PADDLE_TRAINER_ID": "0",
"PADDLE_TRAINERS_NUM": "2",
"PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints,
"PADDLE_CURRENT_ENDPOINT": w0_ep
}
env1 = {
"FLAGS_selected_xpus": "1",
"PADDLE_TRAINER_ID": "1",
"PADDLE_TRAINERS_NUM": "2",
"PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints,
"PADDLE_CURRENT_ENDPOINT": w1_ep
}
#update environment
env0.update(envs)
env1.update(envs)
tr_cmd = "%s %s"
tr0_cmd = tr_cmd % (self._python_interp, model_file)
tr1_cmd = tr_cmd % (self._python_interp, model_file)
path0 = os.path.join(self.temp_dir.name, "/tmp/tr0_err.log")
path1 = os.path.join(self.temp_dir.name, "/tmp/tr1_err.log")
tr0_pipe = open(path0, "wb")
tr1_pipe = open(path1, "wb")
tr0_proc = subprocess.Popen(
tr0_cmd.strip().split(),
stdout=subprocess.PIPE,
#stderr=tr0_pipe,
env=env0)
tr1_proc = subprocess.Popen(
tr0_cmd.strip().split(),
stdout=subprocess.PIPE,
#stderr=tr1_pipe,
env=env1)
tr0_out, tr0_err = tr0_proc.communicate()
tr1_out, tr1_err = tr1_proc.communicate()
sys.stderr.write('trainer 0 stderr: %s\n' % tr0_err)
sys.stderr.write('trainer 1 stderr: %s\n' % tr1_err)
# close trainer file
tr0_pipe.close()
tr1_pipe.close()
return pickle.loads(tr0_out), pickle.loads(
tr1_out), tr0_proc.pid, tr1_proc.pid
def check_with_place(self,
model_file,
col_type,
data_type,
check_error_log=False,
need_envs={}):
required_envs = {
"FLAGS_eager_delete_tensor_gb": "0.0",
"PATH": os.getenv("PATH"),
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"LD_PRELOAD": os.getenv("LD_PRELOAD", ""),
"GLOG_v": "3",
"DATA_TYPE": data_type,
}
required_envs.update(need_envs)
if check_error_log:
required_envs["GLOG_v"] = "3"
required_envs["GLOG_logtostderr"] = "1"
tr0_out, tr1_out, pid0, pid1 = self._run_cluster(
model_file, required_envs)
np_data_type = DataTypeCast(data_type)
np.random.seed(pid0)
input1 = np.random.uniform(low=-10.0, high=10.0,
size=(10, 1000)).astype(np_data_type)
np.random.seed(pid1)
input2 = np.random.uniform(low=-10.0, high=10.0,
size=(10, 1000)).astype(np_data_type)
if col_type == "allgather":
need_result = np.vstack((input1, input2))
np.testing.assert_allclose(tr0_out, need_result)
np.testing.assert_allclose(tr1_out, need_result)
elif col_type == "broadcast":
need_result = input2
np.testing.assert_allclose(tr0_out, need_result)
np.testing.assert_allclose(tr1_out, need_result)
elif col_type == "reduce":
need_result = input1 + input2
np.testing.assert_allclose(tr1_out, need_result)
elif col_type == "scatter":
need_result = input2
need_result1 = need_result[0:need_result.shape[0] // 2]
need_result2 = need_result[need_result.shape[0] // 2:]
np.testing.assert_allclose(tr0_out, need_result1)
np.testing.assert_allclose(tr1_out, need_result2)
elif col_type == "allreduce":
need_result = input1 + input2
np.testing.assert_allclose(tr0_out,
need_result,
rtol=1e-05,
atol=1e-05)
np.testing.assert_allclose(tr1_out,
need_result,
rtol=1e-05,
atol=1e-05)
elif col_type == "reduce_scatter":
tmp = input1 + input2
need_result1 = tmp[0:tmp.shape[0] // 2]
need_result2 = tmp[tmp.shape[0] // 2:]
np.testing.assert_allclose(tr0_out,
need_result1,
rtol=1e-05,
atol=1e-05)
np.testing.assert_allclose(tr1_out,
need_result2,
rtol=1e-05,
atol=1e-05)
elif col_type == "sendrecv":
need_result = input1
np.testing.assert_allclose(tr1_out,
need_result,
rtol=1e-05,
atol=1e-05)
elif col_type == "identity":
need_result1 = input1
need_result2 = input2
np.testing.assert_allclose(tr0_out, need_result1, rtol=0, atol=0)
np.testing.assert_allclose(tr1_out, need_result2, rtol=0, atol=0)
elif col_type == "reduce_slicegather":
slicesize = input1.shape[0] // 2
tmp10 = input1[0:slicesize]
tmp11 = input2[0:slicesize]
need_result1 = np.concatenate((tmp10, tmp11), axis=1)
tmp20 = input1[slicesize:]
tmp21 = input2[slicesize:]
need_result2 = np.concatenate((tmp20, tmp21), axis=1)
np.testing.assert_allclose(tr0_out, need_result1)
np.testing.assert_allclose(tr1_out, need_result2)
elif col_type == "concat":
need_result = np.concatenate((input1, input2), axis=1)
np.testing.assert_allclose(tr0_out,
need_result,
rtol=1e-05,
atol=1e-05)
np.testing.assert_allclose(tr1_out,
need_result,
rtol=1e-05,
atol=1e-05)
elif col_type == "split":
need_result1 = np.split(input1, 2, axis=1)[0]
need_result2 = np.split(input2, 2, axis=1)[1]
np.testing.assert_allclose(tr0_out,
need_result1,
rtol=1e-05,
atol=1e-05)
np.testing.assert_allclose(tr1_out,
need_result2,
rtol=1e-05,
atol=1e-05)
elif col_type == "sendrecv_array":
need_result1 = np.array([[0, 1, 2]])
need_result2 = np.array([[3, 4, 5]])
np.testing.assert_allclose(tr1_out[0][0],
need_result1,
rtol=1e-05,
atol=1e-05)
np.testing.assert_allclose(tr1_out[0][1],
need_result2,
rtol=1e-05,
atol=1e-05)
else:
pass
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle
from paddle.fluid import core
from test_collective_base_xpu import TestDistBase
import sys
sys.path.append("..")
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
paddle.enable_static()
class XPUTestCIdentityOP(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'c_identity'
self.use_dynamic_create_class = False
class TestCIdentityOp(TestDistBase):
def _setup_config(self):
pass
def test_identity(self):
self.check_with_place("collective_identity_op_xpu.py", "identity",
self.in_type_str)
support_types = get_xpu_op_support_types('c_identity')
for stype in support_types:
create_test_class(globals(),
XPUTestCIdentityOP,
stype,
ignore_device_version=[core.XPUVersion.XPU1])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册