From d2bd05b9372026cd6e76a16dd48aedfffee60cad Mon Sep 17 00:00:00 2001 From: zn <96479180+kangna-qi@users.noreply.github.com> Date: Fri, 18 Feb 2022 09:56:06 +0800 Subject: [PATCH] [MLU]add sync stream ops and broadcast pytest (#39518) * [MLU]add sync stream ops and broadcast pytest * [MLU]fix broadcast pytest to add data type --- .../collective/c_sync_calc_stream_op.cc | 12 + .../collective/c_sync_comm_stream_op.cc | 16 ++ .../fluid/tests/unittests/mlu/CMakeLists.txt | 2 +- .../unittests/mlu/collective_broadcast_op.py | 72 +++++ .../unittests/mlu/test_collective_base_mlu.py | 266 ++++++++++++++++++ .../mlu/test_collective_broadcast.py | 55 ++++ 6 files changed, 422 insertions(+), 1 deletion(-) create mode 100644 python/paddle/fluid/tests/unittests/mlu/collective_broadcast_op.py create mode 100644 python/paddle/fluid/tests/unittests/mlu/test_collective_base_mlu.py create mode 100644 python/paddle/fluid/tests/unittests/mlu/test_collective_broadcast.py diff --git a/paddle/fluid/operators/collective/c_sync_calc_stream_op.cc b/paddle/fluid/operators/collective/c_sync_calc_stream_op.cc index 8a4c1979adb..42584948e06 100644 --- a/paddle/fluid/operators/collective/c_sync_calc_stream_op.cc +++ b/paddle/fluid/operators/collective/c_sync_calc_stream_op.cc @@ -67,6 +67,16 @@ class CSyncCalcStreamKernel : public framework::OpKernel { platform::DeviceContextPool::Instance().Get(place)); platform::NPUStreamSync(dev_ctx->stream()); +#elif defined(PADDLE_WITH_CNCL) + auto place = ctx.GetPlace(); + PADDLE_ENFORCE_EQ(platform::is_mlu_place(place), true, + platform::errors::PreconditionNotMet( + "Sync stream op can run on mlu place only for now.")); + + auto dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(place)); + platform::MLUStreamSync(dev_ctx->stream()); + #else PADDLE_THROW(platform::errors::PreconditionNotMet( "PaddlePaddle should compile with GPU.")); @@ -85,3 +95,5 @@ REGISTER_OP_WITHOUT_GRADIENT(c_sync_calc_stream, ops::CSyncCalcStreamOp, REGISTER_OP_CUDA_KERNEL(c_sync_calc_stream, ops::CSyncCalcStreamKernel); REGISTER_OP_NPU_KERNEL(c_sync_calc_stream, ops::CSyncCalcStreamKernel); + +REGISTER_OP_MLU_KERNEL(c_sync_calc_stream, ops::CSyncCalcStreamKernel); diff --git a/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc b/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc index 893cc90762f..37ce4ef7ee2 100644 --- a/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc +++ b/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc @@ -24,6 +24,10 @@ limitations under the License. */ #include "paddle/fluid/platform/device/npu/hccl_helper.h" #endif +#if defined(PADDLE_WITH_CNCL) +#include "paddle/fluid/platform/device/mlu/cncl_helper.h" +#endif + namespace paddle { namespace operators { @@ -81,6 +85,16 @@ class CSyncCommStreamKernel : public framework::OpKernel { platform::HCCLCommContext::Instance().Get(ring_id, place)->stream(); platform::NPUStreamSync(stream); +#elif defined(PADDLE_WITH_CNCL) + auto place = ctx.GetPlace(); + PADDLE_ENFORCE_EQ(platform::is_mlu_place(place), true, + platform::errors::PreconditionNotMet( + "Sync stream op can run on mlu place only for now.")); + int ring_id = ctx.Attr("ring_id"); + auto stream = + platform::CNCLCommContext::Instance().Get(ring_id, place)->stream(); + platform::MLUStreamSync(stream); + #else PADDLE_THROW(platform::errors::PreconditionNotMet( "PaddlePaddle should compile with GPU.")); @@ -99,3 +113,5 @@ REGISTER_OP_WITHOUT_GRADIENT(c_sync_comm_stream, ops::CSyncCommStreamOp, REGISTER_OP_CUDA_KERNEL(c_sync_comm_stream, ops::CSyncCommStreamKernel); REGISTER_OP_NPU_KERNEL(c_sync_comm_stream, ops::CSyncCommStreamKernel); + +REGISTER_OP_MLU_KERNEL(c_sync_comm_stream, ops::CSyncCommStreamKernel); diff --git a/python/paddle/fluid/tests/unittests/mlu/CMakeLists.txt b/python/paddle/fluid/tests/unittests/mlu/CMakeLists.txt index 8fcd3f196dc..2e588355ce7 100644 --- a/python/paddle/fluid/tests/unittests/mlu/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/mlu/CMakeLists.txt @@ -5,5 +5,5 @@ if (WITH_MLU) foreach(TEST_OP ${TEST_OPS}) py_test_modules(${TEST_OP} MODULES ${TEST_OP}) endforeach(TEST_OP) - + set_tests_properties(test_collective_broadcast PROPERTIES TIMEOUT 120) endif() diff --git a/python/paddle/fluid/tests/unittests/mlu/collective_broadcast_op.py b/python/paddle/fluid/tests/unittests/mlu/collective_broadcast_op.py new file mode 100644 index 00000000000..d4f32b5f524 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/collective_broadcast_op.py @@ -0,0 +1,72 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import argparse +import os +import sys +import signal +import time +import socket +from contextlib import closing +from six import string_types +import math +import paddle +import paddle.fluid as fluid +import paddle.fluid.profiler as profiler +import paddle.fluid.unique_name as nameGen +from paddle.fluid import core +import unittest +from multiprocessing import Process +import paddle.fluid.layers as layers +from functools import reduce +from test_collective_base_mlu import TestCollectiveRunnerBase, runtime_main + +paddle.enable_static() + + +class TestCollectiveBroadcast(TestCollectiveRunnerBase): + def __init__(self): + self.global_ring_id = 0 + + def get_model(self, main_prog, startup_program): + ring_id = 0 + rootid = 1 + with fluid.program_guard(main_prog, startup_program): + tindata = layers.data( + name="tindata", shape=[10, 1000], dtype='float32') + toutdata = main_prog.current_block().create_var( + name="outofbroadcast", + dtype='float32', + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=False) + main_prog.global_block().append_op( + type="c_broadcast", + inputs={'X': tindata}, + attrs={'ring_id': ring_id, + 'root': rootid}, + outputs={'Out': toutdata}) + main_prog.global_block().append_op( + type="c_sync_comm_stream", + inputs={'X': toutdata}, + outputs={'Out': toutdata}, + attrs={'ring_id': ring_id}) + return toutdata + + +if __name__ == "__main__": + runtime_main(TestCollectiveBroadcast, "broadcast", 0) diff --git a/python/paddle/fluid/tests/unittests/mlu/test_collective_base_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_collective_base_mlu.py new file mode 100644 index 00000000000..2a7c64fe489 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_collective_base_mlu.py @@ -0,0 +1,266 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import numpy as np +import unittest +import time +import argparse +import os +import sys +import subprocess +import traceback +import functools +import pickle +from contextlib import closing +import paddle.fluid as fluid +import paddle.fluid.unique_name as nameGen +from paddle.fluid import core + + +def DataTypeCast(date_type): + np_data_type = None + + if date_type == "float16": + np_data_type = np.float16 + elif date_type == "float32": + np_data_type = np.float32 + elif date_type == "float64": + np_data_type = np.float64 + elif date_type == "int8": + np_data_type = np.int8 + elif date_type == "int16": + np_data_type = np.int16 + elif date_type == "int32": + np_data_type = np.int32 + elif date_type == "uint8": + np_data_type = np.uint8 + else: + raise ValueError("This data type is not support!") + + return np_data_type + + +class TestCollectiveRunnerBase(object): + def get_model(self, train_prog, startup_prog): + raise NotImplementedError( + "get model should be implemented by child class.") + + def wait_server_ready(self, endpoints): + while True: + all_ok = True + not_ready_endpoints = [] + for ep in endpoints: + ip_port = ep.split(":") + with closing( + socket.socket(socket.AF_INET, + socket.SOCK_STREAM)) as sock: + sock.settimeout(2) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + if hasattr(socket, 'SO_REUSEPORT'): + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, + 1) + + result = sock.connect_ex((ip_port[0], int(ip_port[1]))) + if result != 0: + all_ok = False + not_ready_endpoints.append(ep) + if not all_ok: + sys.stderr.write("server not ready, wait 3 sec to retry...\n") + sys.stderr.write("not ready endpoints:" + str( + not_ready_endpoints) + "\n") + sys.stderr.flush() + time.sleep(3) + else: + break + +#endpoints should be ["ip1:port1","ip2:port2"] + + def initCommunicator(self, program, rank, nranks, wait_port, + current_endpoint, endpoints): + other_endpoints = endpoints[:] + other_endpoints.remove(current_endpoint) + if rank == 0 and wait_port: + self.wait_server_ready(other_endpoints) + block = program.global_block() + cncl_id_var = block.create_var( + name=nameGen.generate('cncl_id'), + persistable=True, + type=core.VarDesc.VarType.RAW) + + block.append_op( + type='c_gen_cncl_id', + inputs={}, + outputs={'Out': cncl_id_var}, + attrs={ + 'rank': rank, + 'endpoint': current_endpoint, + 'other_endpoints': other_endpoints + }) + + block.append_op( + type='c_comm_init', + inputs={'X': cncl_id_var}, + outputs={}, + attrs={ + 'nranks': nranks, + 'rank': rank, + 'ring_id': self.global_ring_id + }) + + def run_trainer(self, args): + train_prog = fluid.Program() + startup_prog = fluid.Program() + endpoints = args["endpoints"].split(",") + rank = args["trainerid"] + current_endpoint = args["currentendpoint"] + nranks = 2 + self.initCommunicator(startup_prog, rank, nranks, True, + current_endpoint, endpoints) + self.rank = rank + result = self.get_model(train_prog, startup_prog) + device_id = int(os.getenv("FLAGS_selected_mlus", "0")) + place = fluid.MLUPlace(device_id) + exe = fluid.Executor(place) + exe.run(startup_prog) + np.random.seed(os.getpid()) + np_data_type = DataTypeCast(args["data_type"]) + indata = np.random.random((10, 1000)).astype(np_data_type) + out = exe.run(train_prog, + feed={'tindata': indata}, + fetch_list=[result.name]) + sys.stdout.buffer.write(pickle.dumps(out)) + + +def runtime_main(test_class, col_type, sub_type): + args = {} + model = test_class() + args["deviceid"] = os.getenv("FLAGS_selected_mlus") + args["trainerid"] = int(os.getenv("PADDLE_TRAINER_ID")) + args["trainernum"] = int(os.getenv("PADDLE_TRAINERS_NUM")) + args["endpoints"] = os.getenv('PADDLE_TRAINER_ENDPOINTS') + args["currentendpoint"] = os.getenv("PADDLE_CURRENT_ENDPOINT") + args["col_type"] = col_type + args["data_type"] = os.getenv("DATA_TYPE") + model.run_trainer(args) + + +import paddle.compat as cpt +import socket +from contextlib import closing + + +class TestDistBase(unittest.TestCase): + def setUp(self): + self._port_set = set() + self._trainers = 2 + self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % ( + self._find_free_port(), self._find_free_port()) + self._python_interp = sys.executable + + def _find_free_port(self): + def __free_port(): + with closing(socket.socket(socket.AF_INET, + socket.SOCK_STREAM)) as s: + s.bind(('', 0)) + return s.getsockname()[1] + + while True: + port = __free_port() + if port not in self._port_set: + self._port_set.add(port) + return port + + def _run_cluster(self, model_file, envs): + worker_endpoints = self._ps_endpoints.split(",") + w0_ep, w1_ep = worker_endpoints + #print("w0_ep:",w0_ep," w1_ep:",w1_ep) + env0 = { + "FLAGS_selected_mlus": "0", + "PADDLE_TRAINER_ID": "0", + "PADDLE_TRAINERS_NUM": "2", + "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints, + "PADDLE_CURRENT_ENDPOINT": w0_ep + } + + env1 = { + "FLAGS_selected_mlus": "1", + "PADDLE_TRAINER_ID": "1", + "PADDLE_TRAINERS_NUM": "2", + "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints, + "PADDLE_CURRENT_ENDPOINT": w1_ep + } + #update environment + env0.update(envs) + env1.update(envs) + tr_cmd = "%s %s" + tr0_cmd = tr_cmd % (self._python_interp, model_file) + tr1_cmd = tr_cmd % (self._python_interp, model_file) + tr0_pipe = open("/tmp/tr0_err.log", "wb") + tr1_pipe = open("/tmp/tr1_err.log", "wb") + #print(tr0_cmd) + tr0_proc = subprocess.Popen( + tr0_cmd.strip().split(), + stdout=subprocess.PIPE, + stderr=tr0_pipe, + env=env0) + + tr1_proc = subprocess.Popen( + tr0_cmd.strip().split(), + stdout=subprocess.PIPE, + stderr=tr1_pipe, + env=env1) + + tr0_out, tr0_err = tr0_proc.communicate() + tr1_out, tr1_err = tr1_proc.communicate() + sys.stderr.write('trainer 0 stderr: %s\n' % tr0_err) + sys.stderr.write('trainer 1 stderr: %s\n' % tr1_err) + # close trainer file + tr0_pipe.close() + tr1_pipe.close() + return pickle.loads(tr0_out), pickle.loads( + tr1_out), tr0_proc.pid, tr1_proc.pid + + def check_with_place(self, + model_file, + col_type, + data_type, + check_error_log=False, + need_envs={}): + required_envs = { + "FLAGS_eager_delete_tensor_gb": "0.0", + "PATH": os.getenv("PATH"), + "PYTHONPATH": os.getenv("PYTHONPATH", ""), + "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""), + "LD_PRELOAD": os.getenv("LD_PRELOAD", ""), + "GLOG_v": "3", + "DATA_TYPE": data_type, + } + required_envs.update(need_envs) + if check_error_log: + required_envs["GLOG_v"] = "3" + required_envs["GLOG_logtostderr"] = "1" + tr0_out, tr1_out, pid0, pid1 = self._run_cluster(model_file, + required_envs) + np_data_type = DataTypeCast(data_type) + np.random.seed(pid0) + input1 = np.random.random((10, 1000)).astype(np_data_type) + np.random.seed(pid1) + input2 = np.random.random((10, 1000)).astype(np_data_type) + if col_type == "broadcast": + need_result = input2 + self.assertTrue(np.allclose(tr0_out, need_result)) + self.assertTrue(np.allclose(tr1_out, need_result)) + else: + pass diff --git a/python/paddle/fluid/tests/unittests/mlu/test_collective_broadcast.py b/python/paddle/fluid/tests/unittests/mlu/test_collective_broadcast.py new file mode 100644 index 00000000000..d9f3aca0314 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_collective_broadcast.py @@ -0,0 +1,55 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import sys +import unittest +import numpy as np +import paddle + +from test_collective_base_mlu import TestDistBase + +paddle.enable_static() + + +class TestCBroadcastOp(TestDistBase): + def _setup_config(self): + pass + + def test_broadcast_fp32(self): + self.check_with_place("collective_broadcast_op.py", "broadcast", + "float32") + + def test_broadcast_fp16(self): + self.check_with_place("collective_broadcast_op.py", "broadcast", + "float16") + + def test_broadcast_int32(self): + self.check_with_place("collective_broadcast_op.py", "broadcast", + "int32") + + def test_broadcast_int16(self): + self.check_with_place("collective_broadcast_op.py", "broadcast", + "int16") + + def test_broadcast_int8(self): + self.check_with_place("collective_broadcast_op.py", "broadcast", "int8") + + def test_broadcast_uint8(self): + self.check_with_place("collective_broadcast_op.py", "broadcast", + "uint8") + + +if __name__ == '__main__': + unittest.main() -- GitLab