From c09b1d68b9afe78435ae7af76f86ace6da9ee9db Mon Sep 17 00:00:00 2001 From: Allen Guo Date: Fri, 22 Apr 2022 11:24:10 +0800 Subject: [PATCH] [IPU] add mixed-precission support for ipu (#41733) (#41906) add mixed-precission support for ipu cherry-pick from #41733 --- .../ir/ipu/optimizer_extract_pass.cc | 7 +- .../contrib/mixed_precision/fp16_utils.py | 14 +- .../fluid/tests/unittests/ipu/op_test_ipu.py | 54 ++- .../ipu/test_mixed_precision_inference_ipu.py | 140 +++++++ .../ipu/test_mixed_precision_training_ipu.py | 151 ++++++++ .../unittests/ipu/test_model_parallel_ipu.py | 357 ++++++++++++++++++ .../unittests/ipu/test_weight_decay_ipu.py | 118 ++++++ 7 files changed, 827 insertions(+), 14 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/ipu/test_mixed_precision_inference_ipu.py create mode 100644 python/paddle/fluid/tests/unittests/ipu/test_mixed_precision_training_ipu.py create mode 100644 python/paddle/fluid/tests/unittests/ipu/test_model_parallel_ipu.py create mode 100644 python/paddle/fluid/tests/unittests/ipu/test_weight_decay_ipu.py diff --git a/paddle/fluid/framework/ir/ipu/optimizer_extract_pass.cc b/paddle/fluid/framework/ir/ipu/optimizer_extract_pass.cc index 7cdb7a8854a..7c517a50e9a 100644 --- a/paddle/fluid/framework/ir/ipu/optimizer_extract_pass.cc +++ b/paddle/fluid/framework/ir/ipu/optimizer_extract_pass.cc @@ -30,9 +30,10 @@ std::set ignored_ops = { "elementwise_max", "elementwise_div", "elementwise_mul", - "scale", // adamax - "assign", // adamw - "squared_l2_norm" // gradient_clip_norm + "scale", // adamax + "assign", // adamw + "squared_l2_norm", // gradient_clip_norm + "cast", // mix-precision support }; const bool startswith(const std::string& str, const std::string& pre) { diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py index e3e5bc4f327..760e9ceb9ea 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py @@ -191,7 +191,8 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): attrs={ "in_dtype": in_var.dtype, "out_dtype": out_var.dtype, - "op_device": op_device + "op_device": op_device, + "op_role": op.attr("op_role"), }) num_cast_ops += 1 _rename_arg(op, in_var.name, out_var.name) @@ -241,7 +242,8 @@ def _insert_cast_post_op(block, op, idx, src_dtype, dest_dtype, target_name, attrs={ "in_dtype": target_var.dtype, "out_dtype": cast_var.dtype, - "op_device": op.attr("op_device") + "op_device": op.attr("op_device"), + "op_role": op.attr("op_role"), }) num_cast_ops += 1 op_var_rename_map[block.idx][target_var.name] = cast_var.name @@ -415,7 +417,9 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True): keep_fp32_ops.add(op) continue # processed below for in_name in op.input_names: - if _keep_fp32_input(op, in_name): + # for ipu, all inputs must be converted to fp16 + if not core.is_compiled_with_ipu() and _keep_fp32_input( + op, in_name): continue for in_var_name in op.input(in_name): in_var = None @@ -443,7 +447,9 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True): format(op.type, in_var_name, in_var.dtype)) for out_name in op.output_names: - if _keep_fp32_output(op, out_name): + # for ipu, all outputs must be converted to fp16 + if not core.is_compiled_with_ipu() and _keep_fp32_output( + op, out_name): continue for out_var_name in op.output(out_name): out_var = None diff --git a/python/paddle/fluid/tests/unittests/ipu/op_test_ipu.py b/python/paddle/fluid/tests/unittests/ipu/op_test_ipu.py index 790388f30ea..26fd42be6cd 100644 --- a/python/paddle/fluid/tests/unittests/ipu/op_test_ipu.py +++ b/python/paddle/fluid/tests/unittests/ipu/op_test_ipu.py @@ -16,7 +16,7 @@ import os import random import unittest import numpy as np -from enum import Enum +from enum import IntEnum import paddle import paddle.static @@ -33,17 +33,24 @@ map_np_dtype_to_fluid_dtype = { } -class ExecutionMode(Enum): +class ExecutionModeFull(IntEnum): + # Run fp32 model on cpu CPU_FP32 = 1 + # Run fp32 model on ipu IPU_FP32 = 2 - # enable_fp16 through ipu_strategy.enable_fp16 + # Convert model to fp16 using popart transform + # All parameters will be converted to fp16 + # TODO rename to IPU_FP16 IPU_POPART_FP16 = 3 + # Mix-precision mode, using `paddle.static.amp.fp16_guard()` to control the + # precision of each operator + IPU_MIXED_PRECISION = 4 - def __lt__(self, other): - return self.value < other.value - def __gt__(self, other): - return self.value > other.value +class ExecutionMode(IntEnum): + CPU_FP32 = ExecutionModeFull.CPU_FP32 + IPU_FP32 = ExecutionModeFull.IPU_FP32 + IPU_POPART_FP16 = ExecutionModeFull.IPU_POPART_FP16 def np_dtype_to_fluid_str(dtype: np.dtype) -> str: @@ -61,6 +68,12 @@ class IPUOpTest(unittest.TestCase): np.random.seed(cls.SEED) random.seed(cls.SEED) + # For ipu, most ops support fp16 + cls.amp_list = paddle.static.amp.CustomOpLists( + custom_black_list=[], custom_white_list=[]) + cls.amp_list.unsupported_list = {} + cls.amp_list.black_list = {} + # Enable paddle static graph mode paddle.enable_static() @@ -114,3 +127,30 @@ class IPUOpTest(unittest.TestCase): if check_shape: self.assertTrue(ipu_popart_fp16.shape == cpu_fp32.shape) + + ipu_mixed_precision = None + if ExecutionModeFull.IPU_MIXED_PRECISION in outputs.keys(): + ipu_mixed_precision = outputs[ + ExecutionModeFull.IPU_MIXED_PRECISION] + max_diff = np.abs( + ipu_mixed_precision.astype(np.float32) - cpu_fp32).max() + fp16_flag = np.allclose( + ipu_mixed_precision.astype(np.float32), + cpu_fp32, + rtol=self.rtol_fp16, + atol=self.atol_fp16) + self.assertTrue(fp16_flag, "max diff is %f" % (max_diff)) + + if check_shape: + self.assertTrue(ipu_mixed_precision.shape == cpu_fp32.shape) + + if ExecutionMode.IPU_POPART_FP16 in outputs.keys( + ) and ExecutionModeFull.IPU_MIXED_PRECISION in outputs.keys(): + max_diff = np.abs(ipu_popart_fp16 - ipu_mixed_precision).max() + self.assertEqual(ipu_popart_fp16.all(), + ipu_mixed_precision.all(), + "max diff is %f" % (max_diff)) + + if check_shape: + self.assertTrue( + ipu_popart_fp16.shape == ipu_mixed_precision.shape) diff --git a/python/paddle/fluid/tests/unittests/ipu/test_mixed_precision_inference_ipu.py b/python/paddle/fluid/tests/unittests/ipu/test_mixed_precision_inference_ipu.py new file mode 100644 index 00000000000..a70550c1df7 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ipu/test_mixed_precision_inference_ipu.py @@ -0,0 +1,140 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import paddle +import paddle.static +import paddle.nn.functional as F +from paddle.fluid.tests.unittests.ipu.op_test_ipu import IPUOpTest, ExecutionModeFull + + +@unittest.skipIf(not paddle.is_compiled_with_ipu(), + "core is not compiled with IPU") +class TestBase(IPUOpTest): + def setUp(self): + self.set_atol() + self.set_data_feed() + self.set_feed_attr() + + @property + def fp16_enabled(self): + return True + + def set_atol(self): + self.atol = 1e-6 + self.rtol = 1e-6 + self.atol_fp16 = 1e-3 + self.rtol_fp16 = 1e-3 + + def set_data_feed(self): + data = np.random.uniform(size=[1, 10, 27, 27]) + self.feed_fp32 = {"in_0": data.astype(np.float32)} + self.feed_fp16 = {"in_0": data.astype(np.float16)} + + def set_feed_attr(self): + self.feed_shape = [x.shape for x in self.feed_fp32.values()] + self.feed_list = list(self.feed_fp32.keys()) + + def dtype_check(self, program, to_fp16_var_names): + block = program.global_block() + assert len(to_fp16_var_names) > 0 + for var_name in to_fp16_var_names: + assert (block.var(var_name).dtype, paddle.float16) + + def _test_base(self, exec_mode): + generator = paddle.fluid.unique_name.UniqueNameGenerator() + scope = paddle.static.Scope() + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + main_prog.random_seed = self.SEED + startup_prog.random_seed = self.SEED + + with paddle.fluid.unique_name.guard(generator): + with paddle.static.scope_guard(scope): + with paddle.static.program_guard(main_prog, startup_prog): + x = paddle.static.data( + name=self.feed_list[0], + shape=self.feed_shape[0], + dtype='float32') + + # using fp32 + x = paddle.static.nn.conv2d( + input=x, num_filters=3, filter_size=3) + x = paddle.static.nn.batch_norm(x, act='relu') + x = F.max_pool2d(x, kernel_size=2, stride=2) + + # using fp16 + with paddle.static.amp.fp16_guard(): + x = paddle.static.nn.conv2d( + input=x, num_filters=6, filter_size=3) + x = paddle.static.nn.batch_norm(x, act='relu') + x = F.max_pool2d(x, kernel_size=2, stride=2) + + # using fp32 + x = paddle.static.nn.fc(x, size=10) + loss = paddle.mean(x) + fetch_list = [loss.name] + + if exec_mode == ExecutionModeFull.CPU_FP32: + place = paddle.CPUPlace() + else: + place = paddle.IPUPlace() + + # cast model to fp16 + if exec_mode == ExecutionModeFull.IPU_MIXED_PRECISION: + to_fp16_var_names = paddle.static.amp.cast_model_to_fp16( + main_prog, self.amp_list) + self.dtype_check(main_prog, to_fp16_var_names) + + exe = paddle.static.Executor(place) + exe.run(startup_prog) + + # cast parameters to fp16 + if exec_mode == ExecutionModeFull.IPU_MIXED_PRECISION: + paddle.static.amp.cast_parameters_to_fp16( + paddle.CPUPlace(), + main_prog, + to_fp16_var_names=to_fp16_var_names) + + if exec_mode != ExecutionModeFull.CPU_FP32: + ipu_strategy = paddle.static.IpuStrategy() + ipu_strategy.set_graph_config(is_training=False) + if exec_mode == ExecutionModeFull.IPU_POPART_FP16: + ipu_strategy.set_precision_config(enable_fp16=True) + program = paddle.static.IpuCompiledProgram( + main_prog, ipu_strategy=ipu_strategy).compile( + self.feed_list, fetch_list) + else: + program = main_prog + + feed = self.feed_fp32 + result = exe.run(program, feed=feed, fetch_list=fetch_list) + return result[0] + + def test(self): + output_dict = {} + for mode in ExecutionModeFull: + if mode == ExecutionModeFull.IPU_POPART_FP16: + continue + if mode > ExecutionModeFull.IPU_FP32 and not self.fp16_enabled: + break + output_dict[mode] = self._test_base(mode).flatten() + + self.check(output_dict) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ipu/test_mixed_precision_training_ipu.py b/python/paddle/fluid/tests/unittests/ipu/test_mixed_precision_training_ipu.py new file mode 100644 index 00000000000..224c0bddc22 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ipu/test_mixed_precision_training_ipu.py @@ -0,0 +1,151 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import paddle +import paddle.static +import paddle.nn.functional as F +from paddle.fluid.tests.unittests.ipu.op_test_ipu import IPUOpTest, ExecutionModeFull + + +@unittest.skipIf(not paddle.is_compiled_with_ipu(), + "core is not compiled with IPU") +class TestBase(IPUOpTest): + def setUp(self): + self.set_atol() + self.set_training() + self.set_data_feed() + self.set_feed_attr() + + @property + def fp16_enabled(self): + return True + + def set_atol(self): + self.atol = 2e-6 + self.rtol = 1e-5 + self.atol_fp16 = 1e-2 + self.rtol_fp16 = 1e-3 + + def set_training(self): + self.is_training = True + self.epoch = 20 + + def set_data_feed(self): + data = np.random.uniform(size=[1, 3, 28, 28]) + self.feed_fp32 = {"in_0": data.astype(np.float32)} + self.feed_fp16 = {"in_0": data.astype(np.float16)} + + def set_feed_attr(self): + self.feed_shape = [x.shape for x in self.feed_fp32.values()] + self.feed_list = list(self.feed_fp32.keys()) + + def dtype_check(self, program, to_fp16_var_names): + block = program.global_block() + assert len(to_fp16_var_names) > 0 + for var_name in to_fp16_var_names: + assert (block.var(var_name).dtype, paddle.float16) + + def _test_base(self, exec_mode): + generator = paddle.fluid.unique_name.UniqueNameGenerator() + scope = paddle.static.Scope() + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + main_prog.random_seed = self.SEED + startup_prog.random_seed = self.SEED + + with paddle.fluid.unique_name.guard(generator): + with paddle.static.scope_guard(scope): + with paddle.static.program_guard(main_prog, startup_prog): + x = paddle.static.data( + name=self.feed_list[0], + shape=self.feed_shape[0], + dtype='float32') + + # using fp32 + x = paddle.static.nn.conv2d( + input=x, num_filters=3, filter_size=3) + x = paddle.static.nn.batch_norm(x, act='relu') + x = F.max_pool2d(x, kernel_size=2, stride=2) + + # using fp16 + with paddle.static.amp.fp16_guard(): + x = paddle.static.nn.conv2d( + input=x, num_filters=6, filter_size=3) + x = paddle.static.nn.batch_norm(x, act='relu') + x = F.max_pool2d(x, kernel_size=2, stride=2) + + # using fp32 + x = paddle.static.nn.fc(x, size=10) + loss = paddle.mean(x) + + # optimizer + optimizer = paddle.optimizer.Adam(learning_rate=1e-2) + optimizer.minimize(loss, startup_prog) + fetch_list = [loss.name] + + # cast model to fp16 + if exec_mode == ExecutionModeFull.IPU_MIXED_PRECISION: + to_fp16_var_names = paddle.static.amp.cast_model_to_fp16( + main_prog, self.amp_list) + self.dtype_check(main_prog, to_fp16_var_names) + + if exec_mode == ExecutionModeFull.CPU_FP32: + place = paddle.CPUPlace() + else: + place = paddle.IPUPlace() + exe = paddle.static.Executor(place) + exe.run(startup_prog) + + # cast parameters to fp16 + if exec_mode == ExecutionModeFull.IPU_MIXED_PRECISION: + paddle.static.amp.cast_parameters_to_fp16( + paddle.CPUPlace(), + main_prog, + to_fp16_var_names=to_fp16_var_names) + + if exec_mode != ExecutionModeFull.CPU_FP32: + ipu_strategy = paddle.static.IpuStrategy() + ipu_strategy.set_graph_config(is_training=self.is_training) + if exec_mode == ExecutionModeFull.IPU_POPART_FP16: + ipu_strategy.set_precision_config(enable_fp16=True) + program = paddle.static.IpuCompiledProgram( + main_prog, ipu_strategy=ipu_strategy).compile( + self.feed_list, fetch_list) + else: + program = main_prog + + feed = self.feed_fp32 + result = [] + for i in range(self.epoch): + out = exe.run(program, feed=feed, fetch_list=fetch_list) + result.append(out) + return np.array(result) + + def test_base(self): + output_dict = {} + for mode in ExecutionModeFull: + if mode == ExecutionModeFull.IPU_POPART_FP16: + continue + if mode > ExecutionModeFull.IPU_FP32 and not self.fp16_enabled: + break + output_dict[mode] = self._test_base(mode).flatten() + + self.check(output_dict) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ipu/test_model_parallel_ipu.py b/python/paddle/fluid/tests/unittests/ipu/test_model_parallel_ipu.py new file mode 100644 index 00000000000..792b88849fa --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ipu/test_model_parallel_ipu.py @@ -0,0 +1,357 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import paddle +import paddle.static +from paddle.fluid.tests.unittests.ipu.op_test_ipu import IPUOpTest + + +@unittest.skipIf(not paddle.is_compiled_with_ipu(), + "core is not compiled with IPU") +class TestBase(IPUOpTest): + def setUp(self): + self.set_atol() + self.set_training() + self.set_attrs() + self.set_data_feed() + + def set_training(self): + self.is_training = False + self.epoch = 10 + + def set_attrs(self): + self.ipu_options = { + "batches_per_step": 1, + "enable_pipelining": False, + "enable_gradient_accumulation": False, + "accumulation_factor": 1, + "enable_replicated_graphs": False, + "replicated_graph_count": 1, + } + self.cpu_bs = 1 + self.ipu_bs = 1 + + def set_data_feed(self): + np_image = np.random.rand(1, 3, 10, 10).astype(np.float32) + self.feed_cpu = {"image": np_image} + self.feed_ipu = {"image": np_image} + + def _test_base(self, run_ipu=True): + scope = paddle.static.Scope() + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + main_prog.random_seed = self.SEED + startup_prog.random_seed = self.SEED + + bs = self.ipu_bs if run_ipu else self.cpu_bs + with paddle.static.scope_guard(scope): + with paddle.static.program_guard(main_prog, startup_prog): + image = paddle.static.data( + name='image', shape=[bs, 3, 10, 10], dtype='float32') + with paddle.static.ipu_shard_guard(index=0): + conv1 = paddle.static.nn.conv2d( + image, num_filters=3, filter_size=3, bias_attr=False) + with paddle.static.ipu_shard_guard(index=1): + conv2 = paddle.static.nn.conv2d( + conv1, num_filters=3, filter_size=3, bias_attr=False) + # should consider influence of bs + loss = paddle.mean(conv2) + + if self.is_training: + if self.optimizer == 'sgd': + opt = paddle.optimizer.SGD(learning_rate=1e-2) + elif self.optimizer == 'adam': + opt = paddle.optimizer.Adam(learning_rate=1e-2) + elif self.optimizer == 'lamb': + opt = paddle.optimizer.Lamb(learning_rate=1e-2) + else: + raise Exception('optimizer must be sgd, adam or lamb') + + opt.minimize(loss) + + if run_ipu: + place = paddle.IPUPlace() + else: + place = paddle.CPUPlace() + executor = paddle.static.Executor(place) + executor.run(startup_prog) + + if run_ipu: + feed_list = [image.name] + fetch_list = [loss.name] + ipu_strategy = paddle.static.IpuStrategy() + ipu_strategy.set_graph_config( + num_ipus=2 * self.ipu_options['replicated_graph_count'], + is_training=self.is_training, + enable_manual_shard=True) + ipu_strategy.set_options(self.ipu_options) + program = paddle.static.IpuCompiledProgram( + main_prog, + ipu_strategy=ipu_strategy).compile(feed_list, fetch_list) + else: + program = main_prog + + feed = self.feed_ipu if run_ipu else self.feed_cpu + epoch = self.epoch + if not run_ipu: + epoch *= self.ipu_options['replicated_graph_count'] + epoch *= self.ipu_options['batches_per_step'] + epoch *= self.ipu_options['accumulation_factor'] + epoch = epoch / (self.cpu_bs / self.ipu_bs) + result = [] + for i in range(int(epoch)): + loss_res = executor.run(program, feed=feed, fetch_list=[loss]) + result.append(loss_res) + return np.array(result).flatten() + + def test(self): + cpu_outputs = self._test_base(False) + ipu_outputs = self._test_base(True) + + self.assertTrue(np.allclose(cpu_outputs, ipu_outputs, atol=self.atol)) + + +class TestReplicaInference(TestBase): + def set_attrs(self): + self.ipu_options = { + "batches_per_step": 1, + "enable_pipelining": False, + "enable_gradient_accumulation": False, + "accumulation_factor": 1, + "enable_replicated_graphs": True, + "replicated_graph_count": 2, + } + self.cpu_bs = 1 + self.ipu_bs = 1 + + def set_data_feed(self): + np_image = np.random.rand(1, 3, 10, 10).astype(np.float32) + self.feed_cpu = {"image": np_image} + self.feed_ipu = { + "image": + np.tile(np_image, + [self.ipu_options['replicated_graph_count'], 1, 1, 1]) + } + + +class TestPipelineInference(TestBase): + def set_attrs(self): + self.ipu_options = { + "batches_per_step": 2, + "enable_pipelining": True, + "enable_gradient_accumulation": False, + "accumulation_factor": 1, + "enable_replicated_graphs": False, + "replicated_graph_count": 1, + } + self.cpu_bs = 1 + self.ipu_bs = 1 + + def set_data_feed(self): + np_image = np.random.rand(1, 3, 10, 10).astype(np.float32) + self.feed_cpu = {"image": np_image} + self.feed_ipu = { + "image": np.tile(np_image, + [self.ipu_options['batches_per_step'], 1, 1, 1]) + } + + +class TestTrainBase(TestBase): + def set_training(self): + self.is_training = True + self.epoch = 10 + + def set_attrs(self): + self.ipu_options = { + "batches_per_step": 1, + "enable_pipelining": False, + "enable_gradient_accumulation": False, + "accumulation_factor": 1, + "enable_replicated_graphs": False, + "replicated_graph_count": 1, + } + self.cpu_bs = 1 + self.ipu_bs = 1 + self.optimizer = 'sgd' + + +class TestReplicaTrain(TestTrainBase): + def set_attrs(self): + self.ipu_options = { + "batches_per_step": 1, + "enable_pipelining": False, + "enable_gradient_accumulation": False, + "accumulation_factor": 1, + "enable_replicated_graphs": True, + "replicated_graph_count": 2, + } + self.cpu_bs = 2 + self.ipu_bs = 1 + self.optimizer = 'sgd' + + def set_data_feed(self): + np_image = np.random.rand(1, 3, 10, 10).astype(np.float32) + self.feed_cpu = {"image": np.tile(np_image, [self.cpu_bs, 1, 1, 1])} + self.feed_ipu = { + "image": + np.tile(np_image, + [self.ipu_options['replicated_graph_count'], 1, 1, 1]) + } + + def test(self): + cpu_outputs = self._test_base(False) + ipu_outputs = self._test_base(True)[::2] + + self.assertTrue(np.allclose(cpu_outputs, ipu_outputs, atol=self.atol)) + + +class TestPipelineTrain(TestTrainBase): + def set_attrs(self): + self.ipu_options = { + "batches_per_step": 3, + "enable_pipelining": True, + "enable_gradient_accumulation": True, + "accumulation_factor": 3, + "enable_replicated_graphs": False, + "replicated_graph_count": 1, + } + self.cpu_bs = 3 + self.ipu_bs = 1 + self.optimizer = 'sgd' + + def set_data_feed(self): + np_image = np.random.rand(1, 3, 10, 10).astype(np.float32) + self.feed_cpu = {"image": np.tile(np_image, [self.cpu_bs, 1, 1, 1])} + bps_acc = self.ipu_options['batches_per_step'] * self.ipu_options[ + 'accumulation_factor'] + self.feed_ipu = {"image": np.tile(np_image, [bps_acc, 1, 1, 1])} + + def test(self): + cpu_outputs = self._test_base(False) + ipu_outputs = self._test_base(True)[::3] + + self.assertTrue(np.allclose(cpu_outputs, ipu_outputs, atol=self.atol)) + + +class TestAdamTrain(TestTrainBase): + def set_attrs(self): + self.ipu_options = { + "batches_per_step": 1, + "enable_pipelining": False, + "enable_gradient_accumulation": False, + "accumulation_factor": 1, + "enable_replicated_graphs": False, + "replicated_graph_count": 1, + } + self.cpu_bs = 1 + self.ipu_bs = 1 + self.optimizer = 'adam' + + +class TestAdamReplicaTrain(TestReplicaTrain): + def set_attrs(self): + self.ipu_options = { + "batches_per_step": 1, + "enable_pipelining": False, + "enable_gradient_accumulation": False, + "accumulation_factor": 1, + "enable_replicated_graphs": True, + "replicated_graph_count": 2, + } + self.cpu_bs = 2 + self.ipu_bs = 1 + self.optimizer = 'adam' + + +class TestAdamPipelineTrain(TestPipelineTrain): + def set_attrs(self): + self.ipu_options = { + "batches_per_step": 3, + "enable_pipelining": True, + "enable_gradient_accumulation": True, + "accumulation_factor": 3, + "enable_replicated_graphs": False, + "replicated_graph_count": 1, + } + self.cpu_bs = 3 + self.ipu_bs = 1 + self.optimizer = 'adam' + + +class TestAdamRecomputationTrain(TestPipelineTrain): + def set_attrs(self): + self.ipu_options = { + "batches_per_step": 3, + "enable_pipelining": True, + "enable_gradient_accumulation": True, + "accumulation_factor": 3, + "enable_replicated_graphs": False, + "replicated_graph_count": 1, + "auto_recomputation": 3, + } + self.cpu_bs = 3 + self.ipu_bs = 1 + self.optimizer = 'adam' + + +class TestLambTrain(TestAdamTrain): + def set_attrs(self): + self.ipu_options = { + "batches_per_step": 1, + "enable_pipelining": False, + "enable_gradient_accumulation": False, + "accumulation_factor": 1, + "enable_replicated_graphs": False, + "replicated_graph_count": 1, + } + self.cpu_bs = 1 + self.ipu_bs = 1 + self.optimizer = 'lamb' + + +class TestLambReplicaTrain(TestAdamReplicaTrain): + def set_attrs(self): + self.ipu_options = { + "batches_per_step": 1, + "enable_pipelining": False, + "enable_gradient_accumulation": False, + "accumulation_factor": 1, + "enable_replicated_graphs": True, + "replicated_graph_count": 2, + } + self.cpu_bs = 2 + self.ipu_bs = 1 + self.optimizer = 'lamb' + + +class TestLambPipelineTrain(TestAdamPipelineTrain): + def set_attrs(self): + self.ipu_options = { + "batches_per_step": 3, + "enable_pipelining": True, + "enable_gradient_accumulation": True, + "accumulation_factor": 3, + "enable_replicated_graphs": False, + "replicated_graph_count": 1, + } + self.cpu_bs = 3 + self.ipu_bs = 1 + self.optimizer = 'lamb' + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ipu/test_weight_decay_ipu.py b/python/paddle/fluid/tests/unittests/ipu/test_weight_decay_ipu.py new file mode 100644 index 00000000000..5e652ce4833 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ipu/test_weight_decay_ipu.py @@ -0,0 +1,118 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import paddle +import paddle.static +from paddle.fluid.tests.unittests.ipu.op_test_ipu import IPUOpTest + + +@unittest.skipIf(not paddle.is_compiled_with_ipu(), + "core is not compiled with IPU") +@unittest.skipIf(IPUOpTest.use_ipumodel(), "skip for ipumodel") +class TestBase(IPUOpTest): + def setUp(self): + self.set_atol() + self.set_data_feed() + self.set_feed_attr() + self.set_attrs() + + def set_atol(self): + self.atol = 1e-6 + + def set_data_feed(self): + self.feed = { + "image": np.random.uniform(size=[1, 3, 10, 10]).astype('float32'), + } + + def set_feed_attr(self): + self.feed_shape = [x.shape for x in self.feed.values()] + self.feed_list = list(self.feed.keys()) + self.feed_dtype = [x.dtype for x in self.feed.values()] + + def set_attrs(self): + self.attrs = { + "weight_decay": 4.0, + "loss_scaling": 1.0, + } + + def _test_optimizer(self, run_ipu=True): + def exclude_fn(param): + return param.name.endswith('.w_0') + + scope = paddle.static.Scope() + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + main_prog.random_seed = self.SEED + startup_prog.random_seed = self.SEED + np.random.seed(self.SEED) + + with paddle.static.scope_guard(scope): + with paddle.static.program_guard(main_prog, startup_prog): + image = paddle.static.data( + name='image', shape=[1, 3, 10, 10], dtype='float32') + bias = paddle.fluid.layers.create_parameter( + shape=[1, 3, 10, 10], is_bias=True, dtype='float32') + add1 = image + bias + conv1 = paddle.static.nn.conv2d( + add1, num_filters=3, filter_size=3, bias_attr=False) + + loss = paddle.mean(conv1) + opt = paddle.optimizer.Lamb( + learning_rate=1e-1, + lamb_weight_decay=self.attrs['weight_decay'], + exclude_from_weight_decay_fn=exclude_fn) + opt.minimize(loss) + + if run_ipu: + place = paddle.IPUPlace() + else: + place = paddle.CPUPlace() + exe = paddle.static.Executor(place) + exe.run(startup_prog) + paddle.static.save(main_prog, "weight_decay") + + if run_ipu: + feed_list = [image.name] + fetch_list = [loss.name] + ipu_strategy = paddle.static.IpuStrategy() + ipu_strategy.set_graph_config(is_training=True) + ipu_strategy.set_options({ + 'loss_scaling': self.attrs["loss_scaling"] + }) + program = paddle.static.IpuCompiledProgram( + main_prog, ipu_strategy=ipu_strategy).compile(feed_list, + fetch_list) + else: + program = main_prog + + result = [] + for epoch in range(100): + loss_res = exe.run(program, feed=self.feed, fetch_list=[loss]) + result.append(loss_res) + + return np.array(result) + + def test(self): + # cpu and ipu dimenstion mismatch, cpu:(100, 1, 1), ipu:(100, 1) + ipu_loss = self._test_optimizer(True).flatten() + cpu_loss = self._test_optimizer(False).flatten() + + self.assertTrue(np.allclose(ipu_loss, cpu_loss, atol=self.atol)) + + +if __name__ == "__main__": + unittest.main() -- GitLab