diff --git a/.flake8 b/.flake8 index 0315113df588392f7d98e453e31f0cdf61ed5fca..eeee9c2329a8256e35b1bdf987b3f93a4a89992b 100644 --- a/.flake8 +++ b/.flake8 @@ -39,3 +39,4 @@ per-file-ignores = .cmake-format.py: F821 test/dygraph_to_static/test_loop.py: F821 test/dygraph_to_static/test_closure_analysis.py: F821 + python/paddle/static/amp/decorator.py: F811 diff --git a/python/paddle/static/amp/__init__.py b/python/paddle/static/amp/__init__.py index 48856cc1af916ba6fb067aade8a1459fe5cbfed2..843be1443a4fa09508a8966f0190b056aea11b0a 100644 --- a/python/paddle/static/amp/__init__.py +++ b/python/paddle/static/amp/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. from . import decorator -from .decorator import decorate, amp_decorate +from .decorator import decorate from . import fp16_lists from .fp16_lists import CustomOpLists, AutoMixedPrecisionLists from . import fp16_utils diff --git a/python/paddle/static/amp/decorator.py b/python/paddle/static/amp/decorator.py index ae3a98c37b3ce4cc97335223c6ff37ef75988d5b..b1dc122c3845263731cd330ccf0decbeaa7f5b35 100644 --- a/python/paddle/static/amp/decorator.py +++ b/python/paddle/static/amp/decorator.py @@ -32,6 +32,7 @@ from .fp16_utils import ( rewrite_program, update_role_var_grad, ) +from .function_overload import FunctionType, overload class OptimizerWithMixedPrecision: @@ -610,6 +611,7 @@ class OptimizerWithMixedPrecision: return optimize_ops, scaled_params_grads +@overload(key=FunctionType.FP16_ONLY) def decorate( optimizer, amp_lists=None, @@ -739,7 +741,8 @@ def decorate( return mp_optimizer -def amp_decorate( +@overload(key=FunctionType.COMMON) +def decorate( optimizer, amp_lists=None, level='O1', diff --git a/python/paddle/static/amp/fp16_utils.py b/python/paddle/static/amp/fp16_utils.py index 21b5268aa40c9b2c5ac415337e45bce4af9cb11b..4874f99e3eda6a245c207985a1463a8cf90702ce 100644 --- a/python/paddle/static/amp/fp16_utils.py +++ b/python/paddle/static/amp/fp16_utils.py @@ -626,11 +626,14 @@ def cast_parameters_to_fp16( for block in program.blocks: all_parameters.extend(block.all_parameters()) + dtype_str = get_low_precision_dtypestr(dest_type) fp16_var_names = to_fp16_var_names if to_fp16_var_names else set() var_scope = scope if scope else global_scope() for param in all_parameters: if param.name in fp16_var_names: - _logger.debug(f"---- cast {param.name} to fp16/bf16 dtype ----") + _logger.debug( + f"-- cast {param.name} to {dtype_str}, place is {place}" + ) if var_scope.find_var(param.name): param_t = var_scope.find_var(param.name).get_tensor() data = np.array(param_t) diff --git a/python/paddle/static/amp/function_overload.py b/python/paddle/static/amp/function_overload.py new file mode 100644 index 0000000000000000000000000000000000000000..8139401c21db1a269ac56597964ae5200abe68c3 --- /dev/null +++ b/python/paddle/static/amp/function_overload.py @@ -0,0 +1,142 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# The implementation refers to https://arpitbhayani.me/blogs/function-overloading. +# Note: it is customed for paddle.static.amp.decorate function. + +import inspect +import logging +from enum import Enum + +from paddle.fluid.log_helper import get_logger + +_logger = get_logger( + __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s' +) + + +class FunctionType(Enum): + FP16_ONLY = 0 + COMMON = 1 + + +class Function: + """ + Function is a wrap over standard python function + An instance of this Function class is also callable + just like the python function that it wrapped. + When the instance is "called" like a function it fetches + the function to be invoked from the virtual namespace and then + invokes the same. + """ + + def __init__(self, fn): + self.fn = fn + + def __call__(self, *args, **kwargs): + """ + Overriding the __call__ function which makes the + instance callable. + """ + # fetching the function to be invoked from the virtual namespace + # through the arguments. + fn = Namespace.get_instance().get(*args, **kwargs) + # invoking the wrapped function and returning the value. + return fn(*args, **kwargs) + + +class Namespace: + """ + Namespace is the singleton class that is responsible + for holding all the functions. + """ + + __instance = None + + def __init__(self): + if self.__instance is None: + self.function_map = {} + Namespace.__instance = self + else: + raise Exception("cannot instantiate Namespace again.") + + @staticmethod + def get_instance(): + if Namespace.__instance is None: + Namespace() + return Namespace.__instance + + def register(self, fn, key): + """ + Register the function in the virtual namespace and return + an instance of callable Function that wraps the function fn. + + Args: + fn (function): the native python function handle. + key (FunctionType): the specified type. + """ + assert isinstance( + key, FunctionType + ), f"The type of key is expected to be FunctionType, but recieved {type(key)}." + func = Function(fn) + self.function_map[key] = fn + return func + + def get(self, *args, **kwargs): + """ + Get the matching function from the virtual namespace according to the actual arguments. + Return None if it did not find any matching function. + """ + _logger.debug(f"get function: args={args}, kwargs={kwargs}") + satisfied_function_keys = set(self.function_map.keys()) + num_actual_args = len(args) + len(kwargs) + for func_key in self.function_map.keys(): + if func_key not in satisfied_function_keys: + continue + fn = self.function_map[func_key] + specs = inspect.getfullargspec(fn) + if len(specs) < len(args) + len(kwargs): + # Remove the not satisfied function according to the number of actual arguments. + _logger.debug( + f"fn={fn} (key={func_key}) is not satisfied and removed." + ) + satisfied_function_keys.remove(func_key) + continue + if len(kwargs) > 0: + # Remove the not satisfied function according to argument keys in kwargs. + for arg_name, value in kwargs.items(): + if arg_name not in specs.args: + _logger.debug( + f"fn={fn} (key={func_key}) is not satisfied and removed." + ) + satisfied_function_keys.remove(func_key) + break + if len(satisfied_function_keys) == 1: + key = list(satisfied_function_keys)[0] + elif len(args) >= 3 and isinstance(args[2], float): + key = FunctionType.FP16_ONLY + else: + key = FunctionType.COMMON + return self.function_map.get(key) + + +def overload(key): + """overload is the decorator that wraps the function + and returns a callable object of type Function. + """ + + def decorator(fn): + return Namespace.get_instance().register(fn, key) + + return decorator diff --git a/test/amp/amp_base_models.py b/test/amp/amp_base_models.py index 7f97a923f046f7bb6562db6a1af7e97fb7247cac..ce31bfa98dc10c09025fd6a6ce3e4cca41c8cf24 100644 --- a/test/amp/amp_base_models.py +++ b/test/amp/amp_base_models.py @@ -12,16 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +import unittest + import numpy as np import paddle from paddle import nn +from paddle.fluid import core _fixed_add_param = np.random.random(size=[16, 16]).astype("float32") def _build_optimizer( - use_amp, amp_dtype="float16", amp_level="O1", use_grad_clip=False + use_amp, + amp_dtype="float16", + amp_level="O1", + amp_lists=None, + use_grad_clip=False, ): if use_grad_clip: grad_clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0) @@ -37,13 +44,8 @@ def _build_optimizer( multi_precision=True, ) if use_amp: - amp_lists = paddle.static.amp.AutoMixedPrecisionLists( - custom_white_list=["elementwise_add"], - custom_black_list=["reduce_mean"], - dtype=amp_dtype, - ) - optimizer = paddle.static.amp.amp_decorate( - optimizer, amp_lists=amp_lists, level=amp_level, dtype=amp_dtype + optimizer = paddle.static.amp.decorate( + optimizer, amp_lists, level=amp_level, dtype=amp_dtype ) return optimizer @@ -80,7 +82,18 @@ def build_add_model(use_amp, amp_dtype="float16", amp_level="O1"): x = paddle.static.data(name='input', shape=[16, 16], dtype=x_dtype) out = model(x) loss = paddle.mean(out) - optimizer = _build_optimizer(use_amp, amp_dtype, amp_level) + + if use_amp: + amp_lists = paddle.static.amp.AutoMixedPrecisionLists( + custom_white_list=["elementwise_add"], + custom_black_list=["reduce_mean"], + dtype=amp_dtype, + ) + else: + amp_lists = None + optimizer = _build_optimizer( + use_amp, amp_dtype, amp_level, amp_lists + ) optimizer.minimize(loss) feed_vars = [x] fetch_vars = [loss] @@ -145,7 +158,9 @@ def build_embedding_model(use_amp, amp_dtype="float16", amp_level="O1"): x = paddle.static.data(name='x', shape=[None, 32], dtype='int64') out = model(x) loss = paddle.mean(out) - optimizer = _build_optimizer(use_amp, amp_dtype, amp_level, True) + optimizer = _build_optimizer( + use_amp, amp_dtype, amp_level, None, True + ) optimizer.minimize(loss) return main_program, startup_program @@ -186,3 +201,13 @@ def build_while_model(): out = model(x) loss = paddle.mean(out) return main_program, startup_program + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), + "core is not complied with CUDA and not support amp.", +) +class AmpTestBase(unittest.TestCase): + def setUp(self): + self.amp_dtype = None + self.amp_level = None diff --git a/test/amp/test_model_cast_to_bf16.py b/test/amp/test_model_cast_to_bf16.py index c09c15e37d2a4760e6b80f5f6786838afc6f205c..296db41d2f6958ba2e60c09537833075e31e31ca 100644 --- a/test/amp/test_model_cast_to_bf16.py +++ b/test/amp/test_model_cast_to_bf16.py @@ -17,7 +17,7 @@ import struct import unittest import numpy as np -from amp_base_models import build_add_model, build_embedding_model +from amp_base_models import AmpTestBase, build_add_model, build_embedding_model import paddle from paddle import fluid @@ -220,11 +220,7 @@ class TestModelCastBF16(unittest.TestCase): ) -@unittest.skipIf( - not core.is_compiled_with_cuda(), - "core is not complied with CUDA and not support the bfloat16", -) -class TestProgramBF16(unittest.TestCase): +class TestProgramBF16(AmpTestBase): def _check_bf16_calls(self, op_stats_dict, expected_bf16_calls): for op_type, value in expected_bf16_calls.items(): self.assertEqual( @@ -270,11 +266,7 @@ class TestProgramBF16(unittest.TestCase): self._check_bf16_calls(op_stats_list[0], expected_bf16_calls) -@unittest.skipIf( - not core.is_compiled_with_cuda(), - "core is not complied with CUDA and not support the bfloat16", -) -class TestStaticBF16(unittest.TestCase): +class TestStaticBF16(AmpTestBase): def _generate_feed_x(self): x = np.random.random(size=[16, 16]).astype("float32") x_bf16 = convert_float_to_uint16(x) @@ -282,7 +274,7 @@ class TestStaticBF16(unittest.TestCase): return x_fp32, x_bf16 def test_compare_o1_o2(self): - def _run_o1(exe, x_np, max_iters): + def _run_o1(place, exe, x_np, max_iters): ( main_program, startup_program, @@ -305,7 +297,7 @@ class TestStaticBF16(unittest.TestCase): losses.append(results[0]) return losses - def _run_o2(exe, x_np, max_iters): + def _run_o2(place, exe, x_np, max_iters): ( main_program, startup_program, @@ -334,8 +326,8 @@ class TestStaticBF16(unittest.TestCase): max_iters = 2 x_fp32, x_bf16 = self._generate_feed_x() - losses_o1 = _run_o1(exe, x_fp32, max_iters) - losses_o2 = _run_o2(exe, x_bf16, max_iters) + losses_o1 = _run_o1(place, exe, x_fp32, max_iters) + losses_o2 = _run_o2(place, exe, x_bf16, max_iters) if __name__ == '__main__':