未验证 提交 79a01d6c 编写于 作者: Y Yiqun Liu 提交者: GitHub

[AMP] Support overload of paddle.static.amp.decorate function. (#52918)

* Implement a common AmpTestBase.

* Support overload of decorate.

* Change the ignore list of flake and fix an error.
上级 a70d9db9
......@@ -39,3 +39,4 @@ per-file-ignores =
.cmake-format.py: F821
test/dygraph_to_static/test_loop.py: F821
test/dygraph_to_static/test_closure_analysis.py: F821
python/paddle/static/amp/decorator.py: F811
......@@ -13,7 +13,7 @@
# limitations under the License.
from . import decorator
from .decorator import decorate, amp_decorate
from .decorator import decorate
from . import fp16_lists
from .fp16_lists import CustomOpLists, AutoMixedPrecisionLists
from . import fp16_utils
......
......@@ -32,6 +32,7 @@ from .fp16_utils import (
rewrite_program,
update_role_var_grad,
)
from .function_overload import FunctionType, overload
class OptimizerWithMixedPrecision:
......@@ -610,6 +611,7 @@ class OptimizerWithMixedPrecision:
return optimize_ops, scaled_params_grads
@overload(key=FunctionType.FP16_ONLY)
def decorate(
optimizer,
amp_lists=None,
......@@ -739,7 +741,8 @@ def decorate(
return mp_optimizer
def amp_decorate(
@overload(key=FunctionType.COMMON)
def decorate(
optimizer,
amp_lists=None,
level='O1',
......
......@@ -626,11 +626,14 @@ def cast_parameters_to_fp16(
for block in program.blocks:
all_parameters.extend(block.all_parameters())
dtype_str = get_low_precision_dtypestr(dest_type)
fp16_var_names = to_fp16_var_names if to_fp16_var_names else set()
var_scope = scope if scope else global_scope()
for param in all_parameters:
if param.name in fp16_var_names:
_logger.debug(f"---- cast {param.name} to fp16/bf16 dtype ----")
_logger.debug(
f"-- cast {param.name} to {dtype_str}, place is {place}"
)
if var_scope.find_var(param.name):
param_t = var_scope.find_var(param.name).get_tensor()
data = np.array(param_t)
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The implementation refers to https://arpitbhayani.me/blogs/function-overloading.
# Note: it is customed for paddle.static.amp.decorate function.
import inspect
import logging
from enum import Enum
from paddle.fluid.log_helper import get_logger
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
)
class FunctionType(Enum):
FP16_ONLY = 0
COMMON = 1
class Function:
"""
Function is a wrap over standard python function
An instance of this Function class is also callable
just like the python function that it wrapped.
When the instance is "called" like a function it fetches
the function to be invoked from the virtual namespace and then
invokes the same.
"""
def __init__(self, fn):
self.fn = fn
def __call__(self, *args, **kwargs):
"""
Overriding the __call__ function which makes the
instance callable.
"""
# fetching the function to be invoked from the virtual namespace
# through the arguments.
fn = Namespace.get_instance().get(*args, **kwargs)
# invoking the wrapped function and returning the value.
return fn(*args, **kwargs)
class Namespace:
"""
Namespace is the singleton class that is responsible
for holding all the functions.
"""
__instance = None
def __init__(self):
if self.__instance is None:
self.function_map = {}
Namespace.__instance = self
else:
raise Exception("cannot instantiate Namespace again.")
@staticmethod
def get_instance():
if Namespace.__instance is None:
Namespace()
return Namespace.__instance
def register(self, fn, key):
"""
Register the function in the virtual namespace and return
an instance of callable Function that wraps the function fn.
Args:
fn (function): the native python function handle.
key (FunctionType): the specified type.
"""
assert isinstance(
key, FunctionType
), f"The type of key is expected to be FunctionType, but recieved {type(key)}."
func = Function(fn)
self.function_map[key] = fn
return func
def get(self, *args, **kwargs):
"""
Get the matching function from the virtual namespace according to the actual arguments.
Return None if it did not find any matching function.
"""
_logger.debug(f"get function: args={args}, kwargs={kwargs}")
satisfied_function_keys = set(self.function_map.keys())
num_actual_args = len(args) + len(kwargs)
for func_key in self.function_map.keys():
if func_key not in satisfied_function_keys:
continue
fn = self.function_map[func_key]
specs = inspect.getfullargspec(fn)
if len(specs) < len(args) + len(kwargs):
# Remove the not satisfied function according to the number of actual arguments.
_logger.debug(
f"fn={fn} (key={func_key}) is not satisfied and removed."
)
satisfied_function_keys.remove(func_key)
continue
if len(kwargs) > 0:
# Remove the not satisfied function according to argument keys in kwargs.
for arg_name, value in kwargs.items():
if arg_name not in specs.args:
_logger.debug(
f"fn={fn} (key={func_key}) is not satisfied and removed."
)
satisfied_function_keys.remove(func_key)
break
if len(satisfied_function_keys) == 1:
key = list(satisfied_function_keys)[0]
elif len(args) >= 3 and isinstance(args[2], float):
key = FunctionType.FP16_ONLY
else:
key = FunctionType.COMMON
return self.function_map.get(key)
def overload(key):
"""overload is the decorator that wraps the function
and returns a callable object of type Function.
"""
def decorator(fn):
return Namespace.get_instance().register(fn, key)
return decorator
......@@ -12,16 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle import nn
from paddle.fluid import core
_fixed_add_param = np.random.random(size=[16, 16]).astype("float32")
def _build_optimizer(
use_amp, amp_dtype="float16", amp_level="O1", use_grad_clip=False
use_amp,
amp_dtype="float16",
amp_level="O1",
amp_lists=None,
use_grad_clip=False,
):
if use_grad_clip:
grad_clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
......@@ -37,13 +44,8 @@ def _build_optimizer(
multi_precision=True,
)
if use_amp:
amp_lists = paddle.static.amp.AutoMixedPrecisionLists(
custom_white_list=["elementwise_add"],
custom_black_list=["reduce_mean"],
dtype=amp_dtype,
)
optimizer = paddle.static.amp.amp_decorate(
optimizer, amp_lists=amp_lists, level=amp_level, dtype=amp_dtype
optimizer = paddle.static.amp.decorate(
optimizer, amp_lists, level=amp_level, dtype=amp_dtype
)
return optimizer
......@@ -80,7 +82,18 @@ def build_add_model(use_amp, amp_dtype="float16", amp_level="O1"):
x = paddle.static.data(name='input', shape=[16, 16], dtype=x_dtype)
out = model(x)
loss = paddle.mean(out)
optimizer = _build_optimizer(use_amp, amp_dtype, amp_level)
if use_amp:
amp_lists = paddle.static.amp.AutoMixedPrecisionLists(
custom_white_list=["elementwise_add"],
custom_black_list=["reduce_mean"],
dtype=amp_dtype,
)
else:
amp_lists = None
optimizer = _build_optimizer(
use_amp, amp_dtype, amp_level, amp_lists
)
optimizer.minimize(loss)
feed_vars = [x]
fetch_vars = [loss]
......@@ -145,7 +158,9 @@ def build_embedding_model(use_amp, amp_dtype="float16", amp_level="O1"):
x = paddle.static.data(name='x', shape=[None, 32], dtype='int64')
out = model(x)
loss = paddle.mean(out)
optimizer = _build_optimizer(use_amp, amp_dtype, amp_level, True)
optimizer = _build_optimizer(
use_amp, amp_dtype, amp_level, None, True
)
optimizer.minimize(loss)
return main_program, startup_program
......@@ -186,3 +201,13 @@ def build_while_model():
out = model(x)
loss = paddle.mean(out)
return main_program, startup_program
@unittest.skipIf(
not core.is_compiled_with_cuda(),
"core is not complied with CUDA and not support amp.",
)
class AmpTestBase(unittest.TestCase):
def setUp(self):
self.amp_dtype = None
self.amp_level = None
......@@ -17,7 +17,7 @@ import struct
import unittest
import numpy as np
from amp_base_models import build_add_model, build_embedding_model
from amp_base_models import AmpTestBase, build_add_model, build_embedding_model
import paddle
from paddle import fluid
......@@ -220,11 +220,7 @@ class TestModelCastBF16(unittest.TestCase):
)
@unittest.skipIf(
not core.is_compiled_with_cuda(),
"core is not complied with CUDA and not support the bfloat16",
)
class TestProgramBF16(unittest.TestCase):
class TestProgramBF16(AmpTestBase):
def _check_bf16_calls(self, op_stats_dict, expected_bf16_calls):
for op_type, value in expected_bf16_calls.items():
self.assertEqual(
......@@ -270,11 +266,7 @@ class TestProgramBF16(unittest.TestCase):
self._check_bf16_calls(op_stats_list[0], expected_bf16_calls)
@unittest.skipIf(
not core.is_compiled_with_cuda(),
"core is not complied with CUDA and not support the bfloat16",
)
class TestStaticBF16(unittest.TestCase):
class TestStaticBF16(AmpTestBase):
def _generate_feed_x(self):
x = np.random.random(size=[16, 16]).astype("float32")
x_bf16 = convert_float_to_uint16(x)
......@@ -282,7 +274,7 @@ class TestStaticBF16(unittest.TestCase):
return x_fp32, x_bf16
def test_compare_o1_o2(self):
def _run_o1(exe, x_np, max_iters):
def _run_o1(place, exe, x_np, max_iters):
(
main_program,
startup_program,
......@@ -305,7 +297,7 @@ class TestStaticBF16(unittest.TestCase):
losses.append(results[0])
return losses
def _run_o2(exe, x_np, max_iters):
def _run_o2(place, exe, x_np, max_iters):
(
main_program,
startup_program,
......@@ -334,8 +326,8 @@ class TestStaticBF16(unittest.TestCase):
max_iters = 2
x_fp32, x_bf16 = self._generate_feed_x()
losses_o1 = _run_o1(exe, x_fp32, max_iters)
losses_o2 = _run_o2(exe, x_bf16, max_iters)
losses_o1 = _run_o1(place, exe, x_fp32, max_iters)
losses_o2 = _run_o2(place, exe, x_bf16, max_iters)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册