未验证 提交 c09b1d68 编写于 作者: A Allen Guo 提交者: GitHub

[IPU] add mixed-precission support for ipu (#41733) (#41906)

add mixed-precission support for ipu

cherry-pick from #41733
上级 fd9c7818
...@@ -30,9 +30,10 @@ std::set<std::string> ignored_ops = { ...@@ -30,9 +30,10 @@ std::set<std::string> ignored_ops = {
"elementwise_max", "elementwise_max",
"elementwise_div", "elementwise_div",
"elementwise_mul", "elementwise_mul",
"scale", // adamax "scale", // adamax
"assign", // adamw "assign", // adamw
"squared_l2_norm" // gradient_clip_norm "squared_l2_norm", // gradient_clip_norm
"cast", // mix-precision support
}; };
const bool startswith(const std::string& str, const std::string& pre) { const bool startswith(const std::string& str, const std::string& pre) {
......
...@@ -191,7 +191,8 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): ...@@ -191,7 +191,8 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
attrs={ attrs={
"in_dtype": in_var.dtype, "in_dtype": in_var.dtype,
"out_dtype": out_var.dtype, "out_dtype": out_var.dtype,
"op_device": op_device "op_device": op_device,
"op_role": op.attr("op_role"),
}) })
num_cast_ops += 1 num_cast_ops += 1
_rename_arg(op, in_var.name, out_var.name) _rename_arg(op, in_var.name, out_var.name)
...@@ -241,7 +242,8 @@ def _insert_cast_post_op(block, op, idx, src_dtype, dest_dtype, target_name, ...@@ -241,7 +242,8 @@ def _insert_cast_post_op(block, op, idx, src_dtype, dest_dtype, target_name,
attrs={ attrs={
"in_dtype": target_var.dtype, "in_dtype": target_var.dtype,
"out_dtype": cast_var.dtype, "out_dtype": cast_var.dtype,
"op_device": op.attr("op_device") "op_device": op.attr("op_device"),
"op_role": op.attr("op_role"),
}) })
num_cast_ops += 1 num_cast_ops += 1
op_var_rename_map[block.idx][target_var.name] = cast_var.name op_var_rename_map[block.idx][target_var.name] = cast_var.name
...@@ -415,7 +417,9 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True): ...@@ -415,7 +417,9 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
keep_fp32_ops.add(op) keep_fp32_ops.add(op)
continue # processed below continue # processed below
for in_name in op.input_names: for in_name in op.input_names:
if _keep_fp32_input(op, in_name): # for ipu, all inputs must be converted to fp16
if not core.is_compiled_with_ipu() and _keep_fp32_input(
op, in_name):
continue continue
for in_var_name in op.input(in_name): for in_var_name in op.input(in_name):
in_var = None in_var = None
...@@ -443,7 +447,9 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True): ...@@ -443,7 +447,9 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
format(op.type, in_var_name, in_var.dtype)) format(op.type, in_var_name, in_var.dtype))
for out_name in op.output_names: for out_name in op.output_names:
if _keep_fp32_output(op, out_name): # for ipu, all outputs must be converted to fp16
if not core.is_compiled_with_ipu() and _keep_fp32_output(
op, out_name):
continue continue
for out_var_name in op.output(out_name): for out_var_name in op.output(out_name):
out_var = None out_var = None
......
...@@ -16,7 +16,7 @@ import os ...@@ -16,7 +16,7 @@ import os
import random import random
import unittest import unittest
import numpy as np import numpy as np
from enum import Enum from enum import IntEnum
import paddle import paddle
import paddle.static import paddle.static
...@@ -33,17 +33,24 @@ map_np_dtype_to_fluid_dtype = { ...@@ -33,17 +33,24 @@ map_np_dtype_to_fluid_dtype = {
} }
class ExecutionMode(Enum): class ExecutionModeFull(IntEnum):
# Run fp32 model on cpu
CPU_FP32 = 1 CPU_FP32 = 1
# Run fp32 model on ipu
IPU_FP32 = 2 IPU_FP32 = 2
# enable_fp16 through ipu_strategy.enable_fp16 # Convert model to fp16 using popart transform
# All parameters will be converted to fp16
# TODO rename to IPU_FP16
IPU_POPART_FP16 = 3 IPU_POPART_FP16 = 3
# Mix-precision mode, using `paddle.static.amp.fp16_guard()` to control the
# precision of each operator
IPU_MIXED_PRECISION = 4
def __lt__(self, other):
return self.value < other.value
def __gt__(self, other): class ExecutionMode(IntEnum):
return self.value > other.value CPU_FP32 = ExecutionModeFull.CPU_FP32
IPU_FP32 = ExecutionModeFull.IPU_FP32
IPU_POPART_FP16 = ExecutionModeFull.IPU_POPART_FP16
def np_dtype_to_fluid_str(dtype: np.dtype) -> str: def np_dtype_to_fluid_str(dtype: np.dtype) -> str:
...@@ -61,6 +68,12 @@ class IPUOpTest(unittest.TestCase): ...@@ -61,6 +68,12 @@ class IPUOpTest(unittest.TestCase):
np.random.seed(cls.SEED) np.random.seed(cls.SEED)
random.seed(cls.SEED) random.seed(cls.SEED)
# For ipu, most ops support fp16
cls.amp_list = paddle.static.amp.CustomOpLists(
custom_black_list=[], custom_white_list=[])
cls.amp_list.unsupported_list = {}
cls.amp_list.black_list = {}
# Enable paddle static graph mode # Enable paddle static graph mode
paddle.enable_static() paddle.enable_static()
...@@ -114,3 +127,30 @@ class IPUOpTest(unittest.TestCase): ...@@ -114,3 +127,30 @@ class IPUOpTest(unittest.TestCase):
if check_shape: if check_shape:
self.assertTrue(ipu_popart_fp16.shape == cpu_fp32.shape) self.assertTrue(ipu_popart_fp16.shape == cpu_fp32.shape)
ipu_mixed_precision = None
if ExecutionModeFull.IPU_MIXED_PRECISION in outputs.keys():
ipu_mixed_precision = outputs[
ExecutionModeFull.IPU_MIXED_PRECISION]
max_diff = np.abs(
ipu_mixed_precision.astype(np.float32) - cpu_fp32).max()
fp16_flag = np.allclose(
ipu_mixed_precision.astype(np.float32),
cpu_fp32,
rtol=self.rtol_fp16,
atol=self.atol_fp16)
self.assertTrue(fp16_flag, "max diff is %f" % (max_diff))
if check_shape:
self.assertTrue(ipu_mixed_precision.shape == cpu_fp32.shape)
if ExecutionMode.IPU_POPART_FP16 in outputs.keys(
) and ExecutionModeFull.IPU_MIXED_PRECISION in outputs.keys():
max_diff = np.abs(ipu_popart_fp16 - ipu_mixed_precision).max()
self.assertEqual(ipu_popart_fp16.all(),
ipu_mixed_precision.all(),
"max diff is %f" % (max_diff))
if check_shape:
self.assertTrue(
ipu_popart_fp16.shape == ipu_mixed_precision.shape)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.static
import paddle.nn.functional as F
from paddle.fluid.tests.unittests.ipu.op_test_ipu import IPUOpTest, ExecutionModeFull
@unittest.skipIf(not paddle.is_compiled_with_ipu(),
"core is not compiled with IPU")
class TestBase(IPUOpTest):
def setUp(self):
self.set_atol()
self.set_data_feed()
self.set_feed_attr()
@property
def fp16_enabled(self):
return True
def set_atol(self):
self.atol = 1e-6
self.rtol = 1e-6
self.atol_fp16 = 1e-3
self.rtol_fp16 = 1e-3
def set_data_feed(self):
data = np.random.uniform(size=[1, 10, 27, 27])
self.feed_fp32 = {"in_0": data.astype(np.float32)}
self.feed_fp16 = {"in_0": data.astype(np.float16)}
def set_feed_attr(self):
self.feed_shape = [x.shape for x in self.feed_fp32.values()]
self.feed_list = list(self.feed_fp32.keys())
def dtype_check(self, program, to_fp16_var_names):
block = program.global_block()
assert len(to_fp16_var_names) > 0
for var_name in to_fp16_var_names:
assert (block.var(var_name).dtype, paddle.float16)
def _test_base(self, exec_mode):
generator = paddle.fluid.unique_name.UniqueNameGenerator()
scope = paddle.static.Scope()
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
main_prog.random_seed = self.SEED
startup_prog.random_seed = self.SEED
with paddle.fluid.unique_name.guard(generator):
with paddle.static.scope_guard(scope):
with paddle.static.program_guard(main_prog, startup_prog):
x = paddle.static.data(
name=self.feed_list[0],
shape=self.feed_shape[0],
dtype='float32')
# using fp32
x = paddle.static.nn.conv2d(
input=x, num_filters=3, filter_size=3)
x = paddle.static.nn.batch_norm(x, act='relu')
x = F.max_pool2d(x, kernel_size=2, stride=2)
# using fp16
with paddle.static.amp.fp16_guard():
x = paddle.static.nn.conv2d(
input=x, num_filters=6, filter_size=3)
x = paddle.static.nn.batch_norm(x, act='relu')
x = F.max_pool2d(x, kernel_size=2, stride=2)
# using fp32
x = paddle.static.nn.fc(x, size=10)
loss = paddle.mean(x)
fetch_list = [loss.name]
if exec_mode == ExecutionModeFull.CPU_FP32:
place = paddle.CPUPlace()
else:
place = paddle.IPUPlace()
# cast model to fp16
if exec_mode == ExecutionModeFull.IPU_MIXED_PRECISION:
to_fp16_var_names = paddle.static.amp.cast_model_to_fp16(
main_prog, self.amp_list)
self.dtype_check(main_prog, to_fp16_var_names)
exe = paddle.static.Executor(place)
exe.run(startup_prog)
# cast parameters to fp16
if exec_mode == ExecutionModeFull.IPU_MIXED_PRECISION:
paddle.static.amp.cast_parameters_to_fp16(
paddle.CPUPlace(),
main_prog,
to_fp16_var_names=to_fp16_var_names)
if exec_mode != ExecutionModeFull.CPU_FP32:
ipu_strategy = paddle.static.IpuStrategy()
ipu_strategy.set_graph_config(is_training=False)
if exec_mode == ExecutionModeFull.IPU_POPART_FP16:
ipu_strategy.set_precision_config(enable_fp16=True)
program = paddle.static.IpuCompiledProgram(
main_prog, ipu_strategy=ipu_strategy).compile(
self.feed_list, fetch_list)
else:
program = main_prog
feed = self.feed_fp32
result = exe.run(program, feed=feed, fetch_list=fetch_list)
return result[0]
def test(self):
output_dict = {}
for mode in ExecutionModeFull:
if mode == ExecutionModeFull.IPU_POPART_FP16:
continue
if mode > ExecutionModeFull.IPU_FP32 and not self.fp16_enabled:
break
output_dict[mode] = self._test_base(mode).flatten()
self.check(output_dict)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.static
import paddle.nn.functional as F
from paddle.fluid.tests.unittests.ipu.op_test_ipu import IPUOpTest, ExecutionModeFull
@unittest.skipIf(not paddle.is_compiled_with_ipu(),
"core is not compiled with IPU")
class TestBase(IPUOpTest):
def setUp(self):
self.set_atol()
self.set_training()
self.set_data_feed()
self.set_feed_attr()
@property
def fp16_enabled(self):
return True
def set_atol(self):
self.atol = 2e-6
self.rtol = 1e-5
self.atol_fp16 = 1e-2
self.rtol_fp16 = 1e-3
def set_training(self):
self.is_training = True
self.epoch = 20
def set_data_feed(self):
data = np.random.uniform(size=[1, 3, 28, 28])
self.feed_fp32 = {"in_0": data.astype(np.float32)}
self.feed_fp16 = {"in_0": data.astype(np.float16)}
def set_feed_attr(self):
self.feed_shape = [x.shape for x in self.feed_fp32.values()]
self.feed_list = list(self.feed_fp32.keys())
def dtype_check(self, program, to_fp16_var_names):
block = program.global_block()
assert len(to_fp16_var_names) > 0
for var_name in to_fp16_var_names:
assert (block.var(var_name).dtype, paddle.float16)
def _test_base(self, exec_mode):
generator = paddle.fluid.unique_name.UniqueNameGenerator()
scope = paddle.static.Scope()
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
main_prog.random_seed = self.SEED
startup_prog.random_seed = self.SEED
with paddle.fluid.unique_name.guard(generator):
with paddle.static.scope_guard(scope):
with paddle.static.program_guard(main_prog, startup_prog):
x = paddle.static.data(
name=self.feed_list[0],
shape=self.feed_shape[0],
dtype='float32')
# using fp32
x = paddle.static.nn.conv2d(
input=x, num_filters=3, filter_size=3)
x = paddle.static.nn.batch_norm(x, act='relu')
x = F.max_pool2d(x, kernel_size=2, stride=2)
# using fp16
with paddle.static.amp.fp16_guard():
x = paddle.static.nn.conv2d(
input=x, num_filters=6, filter_size=3)
x = paddle.static.nn.batch_norm(x, act='relu')
x = F.max_pool2d(x, kernel_size=2, stride=2)
# using fp32
x = paddle.static.nn.fc(x, size=10)
loss = paddle.mean(x)
# optimizer
optimizer = paddle.optimizer.Adam(learning_rate=1e-2)
optimizer.minimize(loss, startup_prog)
fetch_list = [loss.name]
# cast model to fp16
if exec_mode == ExecutionModeFull.IPU_MIXED_PRECISION:
to_fp16_var_names = paddle.static.amp.cast_model_to_fp16(
main_prog, self.amp_list)
self.dtype_check(main_prog, to_fp16_var_names)
if exec_mode == ExecutionModeFull.CPU_FP32:
place = paddle.CPUPlace()
else:
place = paddle.IPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)
# cast parameters to fp16
if exec_mode == ExecutionModeFull.IPU_MIXED_PRECISION:
paddle.static.amp.cast_parameters_to_fp16(
paddle.CPUPlace(),
main_prog,
to_fp16_var_names=to_fp16_var_names)
if exec_mode != ExecutionModeFull.CPU_FP32:
ipu_strategy = paddle.static.IpuStrategy()
ipu_strategy.set_graph_config(is_training=self.is_training)
if exec_mode == ExecutionModeFull.IPU_POPART_FP16:
ipu_strategy.set_precision_config(enable_fp16=True)
program = paddle.static.IpuCompiledProgram(
main_prog, ipu_strategy=ipu_strategy).compile(
self.feed_list, fetch_list)
else:
program = main_prog
feed = self.feed_fp32
result = []
for i in range(self.epoch):
out = exe.run(program, feed=feed, fetch_list=fetch_list)
result.append(out)
return np.array(result)
def test_base(self):
output_dict = {}
for mode in ExecutionModeFull:
if mode == ExecutionModeFull.IPU_POPART_FP16:
continue
if mode > ExecutionModeFull.IPU_FP32 and not self.fp16_enabled:
break
output_dict[mode] = self._test_base(mode).flatten()
self.check(output_dict)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.static
from paddle.fluid.tests.unittests.ipu.op_test_ipu import IPUOpTest
@unittest.skipIf(not paddle.is_compiled_with_ipu(),
"core is not compiled with IPU")
class TestBase(IPUOpTest):
def setUp(self):
self.set_atol()
self.set_training()
self.set_attrs()
self.set_data_feed()
def set_training(self):
self.is_training = False
self.epoch = 10
def set_attrs(self):
self.ipu_options = {
"batches_per_step": 1,
"enable_pipelining": False,
"enable_gradient_accumulation": False,
"accumulation_factor": 1,
"enable_replicated_graphs": False,
"replicated_graph_count": 1,
}
self.cpu_bs = 1
self.ipu_bs = 1
def set_data_feed(self):
np_image = np.random.rand(1, 3, 10, 10).astype(np.float32)
self.feed_cpu = {"image": np_image}
self.feed_ipu = {"image": np_image}
def _test_base(self, run_ipu=True):
scope = paddle.static.Scope()
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
main_prog.random_seed = self.SEED
startup_prog.random_seed = self.SEED
bs = self.ipu_bs if run_ipu else self.cpu_bs
with paddle.static.scope_guard(scope):
with paddle.static.program_guard(main_prog, startup_prog):
image = paddle.static.data(
name='image', shape=[bs, 3, 10, 10], dtype='float32')
with paddle.static.ipu_shard_guard(index=0):
conv1 = paddle.static.nn.conv2d(
image, num_filters=3, filter_size=3, bias_attr=False)
with paddle.static.ipu_shard_guard(index=1):
conv2 = paddle.static.nn.conv2d(
conv1, num_filters=3, filter_size=3, bias_attr=False)
# should consider influence of bs
loss = paddle.mean(conv2)
if self.is_training:
if self.optimizer == 'sgd':
opt = paddle.optimizer.SGD(learning_rate=1e-2)
elif self.optimizer == 'adam':
opt = paddle.optimizer.Adam(learning_rate=1e-2)
elif self.optimizer == 'lamb':
opt = paddle.optimizer.Lamb(learning_rate=1e-2)
else:
raise Exception('optimizer must be sgd, adam or lamb')
opt.minimize(loss)
if run_ipu:
place = paddle.IPUPlace()
else:
place = paddle.CPUPlace()
executor = paddle.static.Executor(place)
executor.run(startup_prog)
if run_ipu:
feed_list = [image.name]
fetch_list = [loss.name]
ipu_strategy = paddle.static.IpuStrategy()
ipu_strategy.set_graph_config(
num_ipus=2 * self.ipu_options['replicated_graph_count'],
is_training=self.is_training,
enable_manual_shard=True)
ipu_strategy.set_options(self.ipu_options)
program = paddle.static.IpuCompiledProgram(
main_prog,
ipu_strategy=ipu_strategy).compile(feed_list, fetch_list)
else:
program = main_prog
feed = self.feed_ipu if run_ipu else self.feed_cpu
epoch = self.epoch
if not run_ipu:
epoch *= self.ipu_options['replicated_graph_count']
epoch *= self.ipu_options['batches_per_step']
epoch *= self.ipu_options['accumulation_factor']
epoch = epoch / (self.cpu_bs / self.ipu_bs)
result = []
for i in range(int(epoch)):
loss_res = executor.run(program, feed=feed, fetch_list=[loss])
result.append(loss_res)
return np.array(result).flatten()
def test(self):
cpu_outputs = self._test_base(False)
ipu_outputs = self._test_base(True)
self.assertTrue(np.allclose(cpu_outputs, ipu_outputs, atol=self.atol))
class TestReplicaInference(TestBase):
def set_attrs(self):
self.ipu_options = {
"batches_per_step": 1,
"enable_pipelining": False,
"enable_gradient_accumulation": False,
"accumulation_factor": 1,
"enable_replicated_graphs": True,
"replicated_graph_count": 2,
}
self.cpu_bs = 1
self.ipu_bs = 1
def set_data_feed(self):
np_image = np.random.rand(1, 3, 10, 10).astype(np.float32)
self.feed_cpu = {"image": np_image}
self.feed_ipu = {
"image":
np.tile(np_image,
[self.ipu_options['replicated_graph_count'], 1, 1, 1])
}
class TestPipelineInference(TestBase):
def set_attrs(self):
self.ipu_options = {
"batches_per_step": 2,
"enable_pipelining": True,
"enable_gradient_accumulation": False,
"accumulation_factor": 1,
"enable_replicated_graphs": False,
"replicated_graph_count": 1,
}
self.cpu_bs = 1
self.ipu_bs = 1
def set_data_feed(self):
np_image = np.random.rand(1, 3, 10, 10).astype(np.float32)
self.feed_cpu = {"image": np_image}
self.feed_ipu = {
"image": np.tile(np_image,
[self.ipu_options['batches_per_step'], 1, 1, 1])
}
class TestTrainBase(TestBase):
def set_training(self):
self.is_training = True
self.epoch = 10
def set_attrs(self):
self.ipu_options = {
"batches_per_step": 1,
"enable_pipelining": False,
"enable_gradient_accumulation": False,
"accumulation_factor": 1,
"enable_replicated_graphs": False,
"replicated_graph_count": 1,
}
self.cpu_bs = 1
self.ipu_bs = 1
self.optimizer = 'sgd'
class TestReplicaTrain(TestTrainBase):
def set_attrs(self):
self.ipu_options = {
"batches_per_step": 1,
"enable_pipelining": False,
"enable_gradient_accumulation": False,
"accumulation_factor": 1,
"enable_replicated_graphs": True,
"replicated_graph_count": 2,
}
self.cpu_bs = 2
self.ipu_bs = 1
self.optimizer = 'sgd'
def set_data_feed(self):
np_image = np.random.rand(1, 3, 10, 10).astype(np.float32)
self.feed_cpu = {"image": np.tile(np_image, [self.cpu_bs, 1, 1, 1])}
self.feed_ipu = {
"image":
np.tile(np_image,
[self.ipu_options['replicated_graph_count'], 1, 1, 1])
}
def test(self):
cpu_outputs = self._test_base(False)
ipu_outputs = self._test_base(True)[::2]
self.assertTrue(np.allclose(cpu_outputs, ipu_outputs, atol=self.atol))
class TestPipelineTrain(TestTrainBase):
def set_attrs(self):
self.ipu_options = {
"batches_per_step": 3,
"enable_pipelining": True,
"enable_gradient_accumulation": True,
"accumulation_factor": 3,
"enable_replicated_graphs": False,
"replicated_graph_count": 1,
}
self.cpu_bs = 3
self.ipu_bs = 1
self.optimizer = 'sgd'
def set_data_feed(self):
np_image = np.random.rand(1, 3, 10, 10).astype(np.float32)
self.feed_cpu = {"image": np.tile(np_image, [self.cpu_bs, 1, 1, 1])}
bps_acc = self.ipu_options['batches_per_step'] * self.ipu_options[
'accumulation_factor']
self.feed_ipu = {"image": np.tile(np_image, [bps_acc, 1, 1, 1])}
def test(self):
cpu_outputs = self._test_base(False)
ipu_outputs = self._test_base(True)[::3]
self.assertTrue(np.allclose(cpu_outputs, ipu_outputs, atol=self.atol))
class TestAdamTrain(TestTrainBase):
def set_attrs(self):
self.ipu_options = {
"batches_per_step": 1,
"enable_pipelining": False,
"enable_gradient_accumulation": False,
"accumulation_factor": 1,
"enable_replicated_graphs": False,
"replicated_graph_count": 1,
}
self.cpu_bs = 1
self.ipu_bs = 1
self.optimizer = 'adam'
class TestAdamReplicaTrain(TestReplicaTrain):
def set_attrs(self):
self.ipu_options = {
"batches_per_step": 1,
"enable_pipelining": False,
"enable_gradient_accumulation": False,
"accumulation_factor": 1,
"enable_replicated_graphs": True,
"replicated_graph_count": 2,
}
self.cpu_bs = 2
self.ipu_bs = 1
self.optimizer = 'adam'
class TestAdamPipelineTrain(TestPipelineTrain):
def set_attrs(self):
self.ipu_options = {
"batches_per_step": 3,
"enable_pipelining": True,
"enable_gradient_accumulation": True,
"accumulation_factor": 3,
"enable_replicated_graphs": False,
"replicated_graph_count": 1,
}
self.cpu_bs = 3
self.ipu_bs = 1
self.optimizer = 'adam'
class TestAdamRecomputationTrain(TestPipelineTrain):
def set_attrs(self):
self.ipu_options = {
"batches_per_step": 3,
"enable_pipelining": True,
"enable_gradient_accumulation": True,
"accumulation_factor": 3,
"enable_replicated_graphs": False,
"replicated_graph_count": 1,
"auto_recomputation": 3,
}
self.cpu_bs = 3
self.ipu_bs = 1
self.optimizer = 'adam'
class TestLambTrain(TestAdamTrain):
def set_attrs(self):
self.ipu_options = {
"batches_per_step": 1,
"enable_pipelining": False,
"enable_gradient_accumulation": False,
"accumulation_factor": 1,
"enable_replicated_graphs": False,
"replicated_graph_count": 1,
}
self.cpu_bs = 1
self.ipu_bs = 1
self.optimizer = 'lamb'
class TestLambReplicaTrain(TestAdamReplicaTrain):
def set_attrs(self):
self.ipu_options = {
"batches_per_step": 1,
"enable_pipelining": False,
"enable_gradient_accumulation": False,
"accumulation_factor": 1,
"enable_replicated_graphs": True,
"replicated_graph_count": 2,
}
self.cpu_bs = 2
self.ipu_bs = 1
self.optimizer = 'lamb'
class TestLambPipelineTrain(TestAdamPipelineTrain):
def set_attrs(self):
self.ipu_options = {
"batches_per_step": 3,
"enable_pipelining": True,
"enable_gradient_accumulation": True,
"accumulation_factor": 3,
"enable_replicated_graphs": False,
"replicated_graph_count": 1,
}
self.cpu_bs = 3
self.ipu_bs = 1
self.optimizer = 'lamb'
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.static
from paddle.fluid.tests.unittests.ipu.op_test_ipu import IPUOpTest
@unittest.skipIf(not paddle.is_compiled_with_ipu(),
"core is not compiled with IPU")
@unittest.skipIf(IPUOpTest.use_ipumodel(), "skip for ipumodel")
class TestBase(IPUOpTest):
def setUp(self):
self.set_atol()
self.set_data_feed()
self.set_feed_attr()
self.set_attrs()
def set_atol(self):
self.atol = 1e-6
def set_data_feed(self):
self.feed = {
"image": np.random.uniform(size=[1, 3, 10, 10]).astype('float32'),
}
def set_feed_attr(self):
self.feed_shape = [x.shape for x in self.feed.values()]
self.feed_list = list(self.feed.keys())
self.feed_dtype = [x.dtype for x in self.feed.values()]
def set_attrs(self):
self.attrs = {
"weight_decay": 4.0,
"loss_scaling": 1.0,
}
def _test_optimizer(self, run_ipu=True):
def exclude_fn(param):
return param.name.endswith('.w_0')
scope = paddle.static.Scope()
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
main_prog.random_seed = self.SEED
startup_prog.random_seed = self.SEED
np.random.seed(self.SEED)
with paddle.static.scope_guard(scope):
with paddle.static.program_guard(main_prog, startup_prog):
image = paddle.static.data(
name='image', shape=[1, 3, 10, 10], dtype='float32')
bias = paddle.fluid.layers.create_parameter(
shape=[1, 3, 10, 10], is_bias=True, dtype='float32')
add1 = image + bias
conv1 = paddle.static.nn.conv2d(
add1, num_filters=3, filter_size=3, bias_attr=False)
loss = paddle.mean(conv1)
opt = paddle.optimizer.Lamb(
learning_rate=1e-1,
lamb_weight_decay=self.attrs['weight_decay'],
exclude_from_weight_decay_fn=exclude_fn)
opt.minimize(loss)
if run_ipu:
place = paddle.IPUPlace()
else:
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)
paddle.static.save(main_prog, "weight_decay")
if run_ipu:
feed_list = [image.name]
fetch_list = [loss.name]
ipu_strategy = paddle.static.IpuStrategy()
ipu_strategy.set_graph_config(is_training=True)
ipu_strategy.set_options({
'loss_scaling': self.attrs["loss_scaling"]
})
program = paddle.static.IpuCompiledProgram(
main_prog, ipu_strategy=ipu_strategy).compile(feed_list,
fetch_list)
else:
program = main_prog
result = []
for epoch in range(100):
loss_res = exe.run(program, feed=self.feed, fetch_list=[loss])
result.append(loss_res)
return np.array(result)
def test(self):
# cpu and ipu dimenstion mismatch, cpu:(100, 1, 1), ipu:(100, 1)
ipu_loss = self._test_optimizer(True).flatten()
cpu_loss = self._test_optimizer(False).flatten()
self.assertTrue(np.allclose(ipu_loss, cpu_loss, atol=self.atol))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册