From e680d581c4ff906e84ae273d2c2b3dbee96ee9db Mon Sep 17 00:00:00 2001 From: yaozhixin Date: Tue, 31 May 2022 16:25:40 +0800 Subject: [PATCH] [IPU] support paddle.distributed.launch with IPUs (#43087) * [IPU] support paddle.distributed.launch with IPUs * add device_num to env_args_mapping --- .../distributed/launch/context/args_envs.py | 7 + .../distributed/launch/context/device.py | 12 +- .../launch/controllers/collective.py | 6 +- python/paddle/distributed/launch/main.py | 4 +- .../distributed/launch/plugins/__init__.py | 18 +- .../distributed/launch/utils/ipu_launch.py | 167 ++++++++++++++++ .../unittests/ipu/distributed/run_dist_ipu.sh | 80 ++++++++ .../test_dist_data_parallel_ipu.py | 184 ++++++++++++++++++ .../distributed/test_dist_pod128_sample.py | 111 +++++++++++ .../ipu/distributed/test_dist_sample.py | 177 +++++++++++++++++ 10 files changed, 761 insertions(+), 5 deletions(-) create mode 100644 python/paddle/distributed/launch/utils/ipu_launch.py create mode 100644 python/paddle/fluid/tests/unittests/ipu/distributed/run_dist_ipu.sh create mode 100644 python/paddle/fluid/tests/unittests/ipu/distributed/test_dist_data_parallel_ipu.py create mode 100644 python/paddle/fluid/tests/unittests/ipu/distributed/test_dist_pod128_sample.py create mode 100644 python/paddle/fluid/tests/unittests/ipu/distributed/test_dist_sample.py diff --git a/python/paddle/distributed/launch/context/args_envs.py b/python/paddle/distributed/launch/context/args_envs.py index ea8bf3d597..b70dd7d3f7 100644 --- a/python/paddle/distributed/launch/context/args_envs.py +++ b/python/paddle/distributed/launch/context/args_envs.py @@ -35,6 +35,7 @@ env_args_mapping = { 'PADDLE_TRAINERS_ENDPOINTS': 'trainers', 'PADDLE_GLOO_PORT': 'gloo_port', 'PADDLE_WITH_GLOO': 'with_gloo', + 'PADDLE_DEVICE_NUM': 'device_num' } @@ -100,6 +101,12 @@ def parse_args(): default=None, help="accelerate devices. as --gpus,npus,xps") + base_group.add_argument( + "--device_num", + type=int, + default=None, + help="the number of accelerate devices.") + base_group.add_argument("--host", type=str, default=None, help="host ip") base_group.add_argument( diff --git a/python/paddle/distributed/launch/context/device.py b/python/paddle/distributed/launch/context/device.py index 30b8cc1538..61ffe8e809 100644 --- a/python/paddle/distributed/launch/context/device.py +++ b/python/paddle/distributed/launch/context/device.py @@ -21,6 +21,7 @@ class DeviceType: XPU = 'xpu' NPU = 'npu' MLU = 'mlu' + IPU = 'ipu' class Device(object): @@ -68,12 +69,18 @@ class Device(object): return 'FLAGS_selected_xpus' if self._dtype == DeviceType.MLU: return 'FLAGS_selected_mlus' + if self._dtype == DeviceType.IPU: + return 'FLAGS_selected_ipus' return 'FLAGS_selected_devices' - def get_selected_devices(self, devices=''): + def get_selected_devices(self, devices='', device_num=None): ''' return the device label/id relative to the visible devices ''' + if self._dtype == DeviceType.IPU: + if not device_num: + raise RuntimeError("The \'device_num\' is required by IPUs.") + return [str(device_num)] if not devices: return [str(x) for x in range(0, len(self._labels))] else: @@ -129,6 +136,9 @@ class Device(object): dev._dtype = DeviceType.MLU num = fluid.core.get_mlu_device_count() visible_devices = os.getenv("MLU_VISIBLE_DEVICES") + elif fluid.core.is_compiled_with_ipu(): + dev._dtype = DeviceType.IPU + num = fluid.core.get_ipu_device_count() if num == 0: dev._dtype = DeviceType.CPU diff --git a/python/paddle/distributed/launch/controllers/collective.py b/python/paddle/distributed/launch/controllers/collective.py index 5225fd6e81..166eb3a4f9 100644 --- a/python/paddle/distributed/launch/controllers/collective.py +++ b/python/paddle/distributed/launch/controllers/collective.py @@ -79,7 +79,8 @@ class CollectiveController(Controller): self.pod.reset() selected_dev_key = self.ctx.node.device.get_selected_device_key() selected_dev_list = self.ctx.node.device.get_selected_devices( - self.ctx.args.devices) + self.ctx.args.devices, self.ctx.args.device_num) + for i in range(self.pod.replicas): e = { "PADDLE_MASTER": collective_master, @@ -95,7 +96,8 @@ class CollectiveController(Controller): "PADDLE_TRAINERS_NUM": "{}".format(global_size), "PADDLE_RANK_IN_NODE": str(i), } - if self.pod.replicas == 1: + + if self.pod.replicas == 1 or self.ctx.node.device.dtype == "ipu": e.update({selected_dev_key: ",".join(selected_dev_list)}) else: e.update({selected_dev_key: selected_dev_list[i]}) diff --git a/python/paddle/distributed/launch/main.py b/python/paddle/distributed/launch/main.py index b2c87e737c..92585c9e76 100644 --- a/python/paddle/distributed/launch/main.py +++ b/python/paddle/distributed/launch/main.py @@ -52,7 +52,9 @@ def launch(): - ``--job_id``: The job unique id, it affects the log files' name. e.g., ``--job_id=job1``. Default ``--job_id=default``. - - ``--devices``: The selected accelerate devices on nodes, can be gpu/xpu/npu/mlu etc.. e.g., ``--devices=0,1,2,3`` will launch four training processes each bound to one device. + - ``--devices``: The selected accelerate devices on nodes, can be gpu/xpu/npu/mlu/ipu etc.. e.g., ``--devices=0,1,2,3`` will launch four training processes each bound to one device. + + - ``--device_num``: The number of selected accelerate devices on nodes, can be gpu/xpu/npu/mlu/ipu etc.. e.g., ``--device_num=4`` will require four devices per node. - ``training_script``: The full path to the single GPU training program/script to be launched in parallel, followed by all the arguments for the training script. e.g., ``training.py`` diff --git a/python/paddle/distributed/launch/plugins/__init__.py b/python/paddle/distributed/launch/plugins/__init__.py index 13c09b4c27..faa8f28237 100644 --- a/python/paddle/distributed/launch/plugins/__init__.py +++ b/python/paddle/distributed/launch/plugins/__init__.py @@ -25,6 +25,20 @@ def log(ctx): ctx.logger.info("--------------------------------------------------") +def rewrite_ipu_script(ctx): + import paddle.fluid as fluid + if fluid.core.is_compiled_with_ipu(): + import os + if ctx.args.training_script != "ipu": + raise RuntimeError( + "Only support to run the script \'ipu\' for IPU distributed computing." + ) + ctx.args.training_script = os.path.abspath( + os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "utils/ipu_launch.py")) + + def process_args(ctx): # reset device by args #argdev = ctx.args.gpus or ctx.args.xpus or ctx.args.npus @@ -60,4 +74,6 @@ def rewrite_host_ip(ctx): ctx.node.ip = ctx.args.host -enabled_plugins = [collective_compatible, rewrite_host_ip, process_args] +enabled_plugins = [ + collective_compatible, rewrite_host_ip, process_args, rewrite_ipu_script +] diff --git a/python/paddle/distributed/launch/utils/ipu_launch.py b/python/paddle/distributed/launch/utils/ipu_launch.py new file mode 100644 index 0000000000..595243cdf9 --- /dev/null +++ b/python/paddle/distributed/launch/utils/ipu_launch.py @@ -0,0 +1,167 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid + +import subprocess +import argparse +import os +import logging +import sys + + +class IPULaunch(object): + def __init__(self, hosts, ipus_per_replica, nproc_per_host, ipu_partition, + vipu_server, training_script, training_script_args): + if not fluid.core.is_compiled_with_ipu(): + raise RuntimeError( + "Can not call ipu_launch.py in non IPU compiled environment, please re-compile with WITH_IPU=ON." + ) + self._hosts = hosts + self._ipus_per_replica = ipus_per_replica + self._nproc_per_host = nproc_per_host + self._ipu_partition = ipu_partition + self._vipu_server = vipu_server + self._training_script = training_script + self._training_script_args = training_script_args + + self._num_ipus = int(os.getenv("FLAGS_selected_ipus")) + self.logger = self.get_logger() + + @classmethod + def parse_ipu_args(self): + parser = argparse.ArgumentParser() + parser.add_argument( + "--hosts", + type=str, + help="The hosts for IPU PopRun distributd computing.") + parser.add_argument( + "--ipus_per_replica", + type=int, + help="The number of IPUs per replica.") + parser.add_argument( + "--nproc_per_host", + type=int, + help="The number of processes per host.") + parser.add_argument( + "--ipu_partition", type=str, help="The partition name of IPU.") + parser.add_argument( + "--vipu_server", + type=str, + help="The vipu server host to enable vipu.") + parser.add_argument( + "training_script", + type=str, + help="The full path to the single IPU replica training program/script to be launched in parallel." + ) + parser.add_argument('training_script_args', nargs=argparse.REMAINDER) + args = parser.parse_args() + + ipu_launch = IPULaunch( + hosts=args.hosts, + ipus_per_replica=args.ipus_per_replica, + nproc_per_host=args.nproc_per_host, + ipu_partition=args.ipu_partition, + vipu_server=args.vipu_server, + training_script=args.training_script, + training_script_args=args.training_script_args, ) + + return ipu_launch + + def get_logger(self, level=logging.INFO): + logger = logging.getLogger("LAUNCH") + logger.setLevel(level) + formatter = logging.Formatter( + fmt='%(name)s %(levelname)s %(asctime)s %(message)s') + ch = logging.StreamHandler() + ch.setFormatter(formatter) + logger.addHandler(ch) + return logger + + def launch(self): + # The number of replicas for data parallel + assert (self._num_ipus % self._ipus_per_replica) == 0, \ + "The number of IPUs:{} mod the number of IPUs per replica:{} must == 0".format(self._num_ipus, self._ipus_per_replica) + num_replicas = self._num_ipus // self._ipus_per_replica + self.logger.info("The number of total replicas is {}.".format( + num_replicas)) + + # The number of processes + num_nodes = len(self._hosts.split(',')) + num_procs = num_nodes * self._nproc_per_host + self.logger.info("The number of total processes is {}.".format( + num_procs)) + assert (num_replicas % num_procs) == 0, \ + "The number of replicas:{} mod the number of processes:{} must == 0".format(num_replicas, num_procs) + + # hosts and endpoints + hosts = self._hosts.replace(' ', '').split(',') + endpoints = [x + ":8090" for x in hosts] + + # args for poprun + poprun_command = ['poprun'] + + poprun_command.append('--num-instances={}'.format(num_procs)) + poprun_command.append('--num-replicas={}'.format(num_replicas)) + poprun_command.append('--ipus-per-replica={}'.format( + self._ipus_per_replica)) + poprun_command.append('--host={}'.format(','.join(hosts))) + poprun_command.append('--vipu-partition={}'.format(self._ipu_partition)) + poprun_command.append('--vipu-server-host={}'.format(self._vipu_server)) + + poprun_command.extend([ + '--update-partition=no', '--vipu-server-timeout=120', + '--print-topology=yes', '--numa-aware=yes' + ]) + + # global envs + global_envs = '--mpi-local-args=\'' + log_level = os.getenv('POPART_LOG_LEVEL', None) + if log_level: + global_envs += '-x POPART_LOG_LEVEL={} '.format(log_level) + global_envs += '-x PADDLE_TRAINERS_NUM={} -x PADDLE_TRAINER_ENDPOINTS={}'.format( + num_procs, ','.join(endpoints)) + global_envs += '\'' + poprun_command.append(global_envs) + + # local envs + for idx in range(num_procs): + cur_endpoint = endpoints[idx // self._nproc_per_host] + rank_in_node = idx % self._nproc_per_host + poprun_command.append( + '--instance-mpi-local-args={}:\"-x PADDLE_TRAINER_ID={} -x PADDLE_CURRENT_ENDPOINT={} -x PADDLE_RANK_IN_NODE={}\"'. + format(idx, idx, cur_endpoint, rank_in_node)) + + # executor + poprun_command.append(sys.executable) + + # script and script args + poprun_command.append(self._training_script) + for arg in self._training_script_args: + poprun_command.append(arg) + + # for debug + print("----------- PopRun Command -----------") + for i in range(len(poprun_command) - 1): + print("%s \\" % (poprun_command[i])) + print("%s" % (poprun_command[len(poprun_command) - 1])) + print("---------------------------------------") + + # Launch + subprocess.run(" ".join(poprun_command), shell=True) + + +if __name__ == '__main__': + ipu_launch = IPULaunch.parse_ipu_args() + ipu_launch.launch() diff --git a/python/paddle/fluid/tests/unittests/ipu/distributed/run_dist_ipu.sh b/python/paddle/fluid/tests/unittests/ipu/distributed/run_dist_ipu.sh new file mode 100644 index 0000000000..6f491ef107 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ipu/distributed/run_dist_ipu.sh @@ -0,0 +1,80 @@ +#!/bin/bash + +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -e + +partition_name=pod64 +vipu_server=10.137.96.62 +allclose_script=" +import sys +import numpy as np +data1 = np.loadtxt(\"ipu_res.txt\") +data2 = np.loadtxt(\"cpu_res.txt\") +if np.allclose(data1[::16], data2, atol=1e-6): + sys.exit(0) +else: + sys.exit(1) +" + +for opt in lamb sgd adam ; +do + for onchip in False True ; + do + for rts in False True ; + do + echo "Testcase: opt: ${opt}, onchip: ${onchip}, rts: ${rts}" + echo "paddle.distributed.fleet.launch test with IPUs..." + python3.7 -m paddle.distributed.launch \ + --device_num=8 \ + ipu \ + --hosts=localhost \ + --nproc_per_host=2 \ + --ipus_per_replica=2 \ + --ipu_partition=${partition_name} \ + --vipu_server=${vipu_server} \ + test_dist_data_parallel_ipu.py ${opt} ipu_res.txt ${onchip} ${rts} > ipu.log + echo "paddle.distributed.fleet.launch test with IPUs...Done" + + echo "paddle normal test with CPU..." + export POPLAR_IPUMODEL=1 + python3.7 test_dist_data_parallel_ipu.py ${opt} cpu_res.txt > cpu.log + unset POPLAR_IPUMODEL + echo "paddle normal test with CPU...Done" + + echo "Compare results..." + python3.7 -c """${allclose_script}""" + if [ $? -eq 0 ];then + echo "Compare results...Done" + else + echo "Error occurs. Please check ipu.log, cpu.log, ipu_res.txt and cpu_res.txt" + exit 0 + fi + done + done +done + +if [ -f "ipu.log" ]; then + rm "ipu.log" +fi +if [ -f "cpu.log" ]; then + rm "cpu.log" +fi +if [ -f "ipu_res.txt" ]; then + rm "ipu_res.txt" +fi +if [ -f "cpu_res.txt" ]; then + rm "cpu_res.txt" +fi diff --git a/python/paddle/fluid/tests/unittests/ipu/distributed/test_dist_data_parallel_ipu.py b/python/paddle/fluid/tests/unittests/ipu/distributed/test_dist_data_parallel_ipu.py new file mode 100644 index 0000000000..6054f2be75 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ipu/distributed/test_dist_data_parallel_ipu.py @@ -0,0 +1,184 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import sys +import os +import random +import numpy as np +import paddle +import paddle.static +from paddle.fluid.tests.unittests.ipu.op_test_ipu import IPUOpTest + +mpi_comm = None + + +@unittest.skip('Disable distributed tests on auto CI.') +class TestBase(IPUOpTest): + def set_attrs(self, enable_ipu, optimizer, log, onchip=False, rts=False): + self.ipu_options = { + "enable_pipelining": True, + "batches_per_step": 1, + "enable_gradient_accumulation": True, + "accumulation_factor": 4, + "enable_replicated_graphs": True, + "replicated_graph_count": 2, + "location_optimizer": { + "on_chip": onchip, + "use_replicated_tensor_sharding": rts + } + } + + self.cpu_bs = 16 + self.ipu_bs = 1 + self.optimizer = optimizer + self.log = log + self.enable_ipu = enable_ipu + + def test(self): + seed = 2021 + np.random.seed(seed) + random.seed(seed) + scope = paddle.static.Scope() + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + main_prog.random_seed = seed + startup_prog.random_seed = seed + + bs = self.ipu_bs if self.enable_ipu else self.cpu_bs + data = np.random.rand(1, 3, 10, 10).astype(np.float32) + + with paddle.static.scope_guard(scope): + with paddle.static.program_guard(main_prog, startup_prog): + image = paddle.static.data( + name='image', shape=[bs, 3, 10, 10], dtype='float32') + with paddle.static.ipu_shard_guard(index=0, stage=0): + conv1 = paddle.static.nn.conv2d( + image, num_filters=3, filter_size=3, bias_attr=False) + with paddle.static.ipu_shard_guard(index=1, stage=1): + conv2 = paddle.static.nn.conv2d( + conv1, num_filters=3, filter_size=3, bias_attr=False) + # should consider influence of bs + loss = paddle.mean(conv2) + + if self.optimizer == 'sgd': + opt = paddle.optimizer.SGD(learning_rate=1e-2) + elif self.optimizer == 'adam': + opt = paddle.optimizer.Adam(learning_rate=1e-2) + elif self.optimizer == 'lamb': + opt = paddle.optimizer.Lamb(learning_rate=1e-2) + else: + raise Exception('optimizer must be sgd, adam or lamb') + + opt.minimize(loss) + + if self.enable_ipu: + place = paddle.IPUPlace() + else: + place = paddle.CPUPlace() + executor = paddle.static.Executor(place) + executor.run(startup_prog) + + if self.enable_ipu: + feed_list = [image.name] + fetch_list = [loss.name] + ipu_strategy = paddle.static.IpuStrategy() + ipu_strategy.set_graph_config( + num_ipus=2 * self.ipu_options['replicated_graph_count'], + is_training=True, + enable_manual_shard=True) + ipu_strategy.set_options(self.ipu_options) + ipu_strategy.set_options({ + "enable_distribution": True, + "enable_distributed_replicated_graphs": True, + "global_replica_offset": + int(os.environ.get("PADDLE_TRAINER_ID")) * 2, + "global_replication_factor": 4 + }) + program = paddle.static.IpuCompiledProgram( + main_prog, ipu_strategy=ipu_strategy).compile( + feed_list, fetch_list) + feed = { + "image": np.tile(data, [ + self.ipu_options['replicated_graph_count'] * + self.ipu_options['batches_per_step'] * + self.ipu_options['accumulation_factor'], 1, 1, 1 + ]) + } + + else: + program = main_prog + feed = {"image": np.tile(data, [self.cpu_bs, 1, 1, 1])} + + epoch = 10 + if not self.enable_ipu: + # global replication factor + epoch *= 4 + epoch *= self.ipu_options['batches_per_step'] + epoch *= self.ipu_options['accumulation_factor'] + epoch = epoch / (self.cpu_bs / self.ipu_bs) + + results = [] + for i in range(int(epoch)): + res = executor.run(program, feed=feed, fetch_list=[loss]) + if self.enable_ipu: + res = mpi_comm.gather(res, root=0) + results.append(res) + if self.enable_ipu: + if int(os.environ.get("PADDLE_TRAINER_ID")) == 0: + np.savetxt(self.log, np.array(results).flatten()) + else: + np.savetxt(self.log, np.array(results).flatten()) + + +if __name__ == "__main__": + paddle.enable_static() + # Run distributed tests + if len(sys.argv) == 5: + from mpi4py import MPI + + DISTRIBUTED_COMM = MPI.COMM_WORLD + + def _get_comm(): + global DISTRIBUTED_COMM + if DISTRIBUTED_COMM is None: + raise RuntimeError( + "Distributed Commumication not setup. Please run setup_comm(MPI.COMM_WORLD) first." + ) + return DISTRIBUTED_COMM + + mpi_comm = _get_comm() + + optimizer = sys.argv[1] + log = sys.argv[2] + onchip = True if sys.argv[3] == "True" else False + rts = True if sys.argv[4] == "True" else False + test = TestBase() + test.set_attrs( + enable_ipu=True, + optimizer=optimizer, + log=log, + onchip=onchip, + rts=rts) + test.test() + # Run cpu tests for compare + elif len(sys.argv) == 3: + test = TestBase() + test.set_attrs(enable_ipu=False, optimizer=sys.argv[1], log=sys.argv[2]) + test.test() + else: + raise ValueError( + "Only support 3 or 5 args. 3 for cpu test, 5 for ipu distributed test" + ) diff --git a/python/paddle/fluid/tests/unittests/ipu/distributed/test_dist_pod128_sample.py b/python/paddle/fluid/tests/unittests/ipu/distributed/test_dist_pod128_sample.py new file mode 100644 index 0000000000..44c26d123b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ipu/distributed/test_dist_pod128_sample.py @@ -0,0 +1,111 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +''' +python3.7 -m paddle.distributed.launch \ +--device_num=128 \ +ipu \ +--hosts=host1,host2 \ +--ipus_per_host=2 \ +--nproc_per_host=1 \ +--ipu_partition=pod128 \ +--vipu_server=lr17-1-ctrl \ +python/paddle/fluid/tests/unittests/ipu/disabled/test_dist_pod128_ipu.py + +Equal to: + +poprun \ +--host=localhost,host2 \ +--num-instances=2 \ +--num-replicas=64 \ +--ipus-per-replica=2 \ +--print-topology=yes \ +--vipu-partition=pod128_bert \ +--vipu-server-host=lr17-1-ctrl \ +--update-partition=yes \ +python3.7 python/paddle/fluid/tests/unittests/ipu/disabled/test_dist_pod128_ipu.py +''' + +import os +import numpy as np +import paddle + + +def TestDistTraining(): + paddle.enable_static() + + attrs = {"size": [128, 16], "padding_idx": -1, "dtype": 'float32'} + + scope = paddle.fluid.core.Scope() + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + main_prog.random_seed = 42 + startup_prog.random_seed = 42 + + np.random.seed(42) + input_data = np.random.uniform(0, 127, size=[128, 3, 2, 1]).astype(np.int32) + + with paddle.fluid.scope_guard(scope): + with paddle.static.program_guard(main_prog, startup_prog): + x = paddle.static.data(name="x", shape=[3, 2, 1], dtype='int64') + with paddle.static.ipu_shard_guard(index=0, stage=0): + out = paddle.fluid.layers.embedding(x, **attrs) + with paddle.static.ipu_shard_guard(index=1, stage=1): + loss = paddle.mean(out) + opt = paddle.optimizer.Adam(learning_rate=1e-1) + opt.minimize(loss) + + feed_list = ["x"] + fetch_list = [loss.name] + + place = paddle.IPUPlace() + exe = paddle.static.Executor(place) + exe.run(startup_prog) + + ipu_strategy = paddle.static.IpuStrategy() + ipu_strategy.set_graph_config( + num_ipus=64, is_training=True, enable_manual_shard=True) + ipu_strategy.set_pipelining_config( + enable_pipelining=True, + batches_per_step=1, + enable_gradient_accumulation=True, + accumulation_factor=4) + ipu_strategy.set_options({ + "enable_distribution": True, + "enable_replicated_graphs": True, + "replicated_graph_count": 32, + "enable_distributed_replicated_graphs": True, + "global_replica_offset": + # Paddle : int(os.environ.get("PADDLE_TRAINER_ID")) * 32 + # PopRun : int(os.environ.get("POPDIST_REPLICA_INDEX_OFFSET")) + int(os.environ.get("PADDLE_TRAINER_ID")) * 32, + "global_replication_factor": 64, + "location_optimizer": { + "on_chip": False, + "use_replicated_tensor_sharding": True + } + }) + + ipu_program = paddle.static.IpuCompiledProgram( + main_prog, ipu_strategy=ipu_strategy) + program = ipu_program.compile(feed_list, fetch_list) + + for i in range(10): + res = exe.run(program, + feed={"x": input_data}, + fetch_list=fetch_list) + print("index: {}, result: {}".format(i, res)) + + +if __name__ == "__main__": + TestDistTraining() diff --git a/python/paddle/fluid/tests/unittests/ipu/distributed/test_dist_sample.py b/python/paddle/fluid/tests/unittests/ipu/distributed/test_dist_sample.py new file mode 100644 index 0000000000..6ca9222d91 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ipu/distributed/test_dist_sample.py @@ -0,0 +1,177 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +''' +Single host: + +python3.7 -m paddle.distributed.launch \ +--device_num=4 \ +ipu \ +--hosts=localhost \ +--nproc_per_host=2 \ +--ipus_per_replica=1 \ +--ipu_partition=pod64 \ +--vipu_server=10.137.96.62 \ +python/paddle/fluid/tests/unittests/ipu/disabled/test_dist_sample.py + +Equal to: + +poprun \ +--host=localhost \ +--num-instances=2 \ +--num-replicas=4 \ +--ipus-per-replica=1 \ +--print-topology=yes \ +python3.7 python/paddle/fluid/tests/unittests/ipu/disabled/test_dist_sample.py +''' +''' +Multi hosts: + +python3.7 -m paddle.distributed.launch \ +--device_num=4 \ +ipu \ +--hosts=host1,host2 \ +--nproc_per_host=1 \ +--ipus_per_replica=1 \ +--ipu_partition=pod64 \ +--vipu_server=10.137.96.62 \ +python/paddle/fluid/tests/unittests/ipu/disabled/test_dist_sample.py + +Equal to: + +poprun \ +--host=host1,host2 \ +--num-instances=2 \ +--num-replicas=4 \ +--ipus-per-replica=1 \ +--print-topology=yes \ +python3.7 python/paddle/fluid/tests/unittests/ipu/disabled/test_dist_sample.py +''' + +import os +import sys +import paddle +import numpy as np + +mpi_comm = None + + +def Test(use_dist, file_name): + paddle.enable_static() + + attrs = {"size": [128, 16], "padding_idx": -1, "dtype": 'float32'} + + scope = paddle.fluid.core.Scope() + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + main_prog.random_seed = 42 + startup_prog.random_seed = 42 + + with paddle.fluid.scope_guard(scope): + with paddle.static.program_guard(main_prog, startup_prog): + x = paddle.static.data(name="x", shape=[3, 2, 1], dtype='int64') + + out = paddle.fluid.layers.embedding(x, **attrs) + loss = paddle.mean(out) + opt = paddle.optimizer.Adam(learning_rate=1e-1) + opt.minimize(loss) + + feed_list = ["x"] + fetch_list = [loss.name] + + place = paddle.IPUPlace() + exe = paddle.static.Executor(place) + exe.run(startup_prog) + + ipu_strategy = paddle.static.IpuStrategy() + if use_dist: + ipu_strategy.set_graph_config(num_ipus=2, is_training=True) + # Set distributed envs + ipu_strategy.set_options({ + "enable_distribution": True, + "enable_replicated_graphs": True, + "replicated_graph_count": 2, + "enable_distributed_replicated_graphs": True, + "global_replica_offset": + int(os.environ.get("PADDLE_TRAINER_ID")) * 2, + "global_replication_factor": 4 + }) + else: + ipu_strategy.set_graph_config(num_ipus=4, is_training=True) + ipu_strategy.set_options({ + "enable_replicated_graphs": True, + "replicated_graph_count": 4, + }) + + ipu_program = paddle.static.IpuCompiledProgram( + main_prog, ipu_strategy=ipu_strategy) + program = ipu_program.compile(feed_list, fetch_list) + + if use_dist: + if os.environ.get("PADDLE_TRAINER_ID") == "0": + input_data = np.concatenate([ + np.array([[[1], [3]], [[2], [4]], [[4], [127]]]) + .astype(np.int32), np.array( + [[[1], [3]], [[2], [4]], [[4], [127]]]).astype( + np.int32) + ]) + else: + input_data = np.concatenate([ + np.array([[[8], [60]], [[50], [77]], + [[90], [13]]]).astype(np.int32), + np.array([[[8], [60]], [[50], [77]], + [[90], [13]]]).astype(np.int32) + ]) + else: + input_data = np.concatenate([ + np.array([[[1], [3]], [[2], [4]], [[4], [127]]]).astype( + np.int32), np.array([[[1], [3]], [[2], [4]], + [[4], [127]]]).astype(np.int32), + np.array([[[8], [60]], [[50], [77]], [[90], [13]]]).astype( + np.int32), np.array([[[8], [60]], [[50], [77]], + [[90], [13]]]).astype(np.int32) + ]) + feed_data = {"x": input_data} + + for step in range(10): + res = exe.run(program, feed=feed_data, fetch_list=fetch_list) + + if use_dist: + if os.getenv("PADDLE_TRAINER_ID") == "0": + res = mpi_comm.gather(res, root=0) + np.savetxt(file_name, res) + else: + np.savetxt(file_name, res) + + +if __name__ == "__main__": + file_name = sys.argv[1] + + use_dist = False + if 'PADDLE_TRAINER_ID' in os.environ: + from mpi4py import MPI + + DISTRIBUTED_COMM = MPI.COMM_WORLD + + def _get_comm(): + global DISTRIBUTED_COMM + if DISTRIBUTED_COMM is None: + raise RuntimeError( + "Distributed Commumication not setup. Please run setup_comm(MPI.COMM_WORLD) first." + ) + return DISTRIBUTED_COMM + + mpi_comm = _get_comm() + use_dist = True + + Test(use_dist, file_name) -- GitLab