# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest import paddle import paddle.nn.functional as F class ConvBNLayer(paddle.nn.Layer): def __init__( self, num_channels, num_filters, filter_size, stride=1, groups=1, act=None, ): super().__init__() self._conv = paddle.nn.Conv2D( in_channels=num_channels, out_channels=num_filters, kernel_size=filter_size, stride=stride, padding=(filter_size - 1) // 2, groups=groups, bias_attr=None, ) self._batch_norm = paddle.nn.BatchNorm(num_filters, act=act) def forward(self, inputs): y = self._conv(inputs) y = self._batch_norm(y) return y class Model(paddle.nn.Layer): def __init__( self, input_channel, hidden_size, fp16_conv=True, fp16_linear=True ): super().__init__() self.conv = ConvBNLayer(input_channel, 8, 3) self.linear = paddle.nn.Linear(8, hidden_size) self.layernorm = paddle.nn.Sequential( paddle.nn.LayerNorm(hidden_size), paddle.nn.LayerNorm(hidden_size), ) self.fp16_conv = fp16_conv self.fp16_linear = fp16_linear def forward(self, inputs): with paddle.amp.auto_cast(enable=self.fp16_conv): if not self.fp16_conv: inputs = inputs.astype('float32') x = self.conv(inputs) with paddle.amp.auto_cast(enable=self.fp16_linear): if not self.fp16_linear: x = x.astype('float32') x = self.linear(x) x = F.relu(x) x = self.layernorm(x) return x class TestAMPDecorate(unittest.TestCase): def check_results(self, fp32_layers=[], fp16_layers=[]): for idx in range(len(fp32_layers)): for layer in fp32_layers[idx].sublayers(include_self=False): self.assertEqual(layer.weight.dtype, paddle.float32) self.assertEqual(layer.bias.dtype, paddle.float32) for idx in range(len(fp16_layers)): for layer in fp16_layers[idx].sublayers(include_self=False): self.assertEqual(layer.weight.dtype, paddle.float16) self.assertEqual(layer.bias.dtype, paddle.float16) def test_excluded_layers(self): if not paddle.amp.is_float16_supported(): return model = Model(4, 8, fp16_conv=False) model = paddle.amp.decorate( models=model, level='O2', dtype='float16', excluded_layers=model.conv, ) with paddle.amp.auto_cast(level='O2'): out = model(paddle.rand(shape=[2, 4, 8, 8], dtype='float32')) self.check_results( fp32_layers=[model.conv, model.layernorm], fp16_layers=[model.linear], ) def test_excluded_layers_attr_list(self): if not paddle.amp.is_float16_supported(): return model = Model(4, 8, fp16_conv=False, fp16_linear=False) model = paddle.amp.decorate( models=model, level='O2', dtype='float16', excluded_layers=[model.conv, model.linear], ) with paddle.amp.auto_cast(level='O2'): out = model(paddle.rand(shape=[2, 4, 8, 8], dtype='float32')) self.check_results( fp32_layers=[model.conv, model.linear, model.layernorm] ) def test_excluded_layers_attr_types(self): if not paddle.amp.is_float16_supported(): return model = Model(4, 8) model = paddle.amp.decorate( models=model, level='O2', dtype='float16', excluded_layers=[paddle.nn.Conv2D, model.linear], ) with paddle.amp.auto_cast(level='O2'): out = model(paddle.rand(shape=[2, 4, 8, 8], dtype='float16')) self.check_results( fp32_layers=[model.conv, model.linear, model.layernorm] ) def test_excluded_layers_attr_none(self): if not paddle.amp.is_float16_supported(): return model = Model(4, 8) model = paddle.amp.decorate( models=model, level='O2', dtype='float16', excluded_layers=None, ) with paddle.amp.auto_cast(level='O2'): out = model(paddle.rand(shape=[2, 4, 8, 8], dtype='float16')) self.check_results( fp32_layers=[model.layernorm, model.conv._batch_norm], fp16_layers=[model.conv._conv, model.linear], ) if __name__ == '__main__': unittest.main()