未验证 提交 79a01d6c 编写于 作者: Y Yiqun Liu 提交者: GitHub

[AMP] Support overload of paddle.static.amp.decorate function. (#52918)

* Implement a common AmpTestBase.

* Support overload of decorate.

* Change the ignore list of flake and fix an error.
上级 a70d9db9
...@@ -39,3 +39,4 @@ per-file-ignores = ...@@ -39,3 +39,4 @@ per-file-ignores =
.cmake-format.py: F821 .cmake-format.py: F821
test/dygraph_to_static/test_loop.py: F821 test/dygraph_to_static/test_loop.py: F821
test/dygraph_to_static/test_closure_analysis.py: F821 test/dygraph_to_static/test_closure_analysis.py: F821
python/paddle/static/amp/decorator.py: F811
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from . import decorator from . import decorator
from .decorator import decorate, amp_decorate from .decorator import decorate
from . import fp16_lists from . import fp16_lists
from .fp16_lists import CustomOpLists, AutoMixedPrecisionLists from .fp16_lists import CustomOpLists, AutoMixedPrecisionLists
from . import fp16_utils from . import fp16_utils
......
...@@ -32,6 +32,7 @@ from .fp16_utils import ( ...@@ -32,6 +32,7 @@ from .fp16_utils import (
rewrite_program, rewrite_program,
update_role_var_grad, update_role_var_grad,
) )
from .function_overload import FunctionType, overload
class OptimizerWithMixedPrecision: class OptimizerWithMixedPrecision:
...@@ -610,6 +611,7 @@ class OptimizerWithMixedPrecision: ...@@ -610,6 +611,7 @@ class OptimizerWithMixedPrecision:
return optimize_ops, scaled_params_grads return optimize_ops, scaled_params_grads
@overload(key=FunctionType.FP16_ONLY)
def decorate( def decorate(
optimizer, optimizer,
amp_lists=None, amp_lists=None,
...@@ -739,7 +741,8 @@ def decorate( ...@@ -739,7 +741,8 @@ def decorate(
return mp_optimizer return mp_optimizer
def amp_decorate( @overload(key=FunctionType.COMMON)
def decorate(
optimizer, optimizer,
amp_lists=None, amp_lists=None,
level='O1', level='O1',
......
...@@ -626,11 +626,14 @@ def cast_parameters_to_fp16( ...@@ -626,11 +626,14 @@ def cast_parameters_to_fp16(
for block in program.blocks: for block in program.blocks:
all_parameters.extend(block.all_parameters()) all_parameters.extend(block.all_parameters())
dtype_str = get_low_precision_dtypestr(dest_type)
fp16_var_names = to_fp16_var_names if to_fp16_var_names else set() fp16_var_names = to_fp16_var_names if to_fp16_var_names else set()
var_scope = scope if scope else global_scope() var_scope = scope if scope else global_scope()
for param in all_parameters: for param in all_parameters:
if param.name in fp16_var_names: if param.name in fp16_var_names:
_logger.debug(f"---- cast {param.name} to fp16/bf16 dtype ----") _logger.debug(
f"-- cast {param.name} to {dtype_str}, place is {place}"
)
if var_scope.find_var(param.name): if var_scope.find_var(param.name):
param_t = var_scope.find_var(param.name).get_tensor() param_t = var_scope.find_var(param.name).get_tensor()
data = np.array(param_t) data = np.array(param_t)
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The implementation refers to https://arpitbhayani.me/blogs/function-overloading.
# Note: it is customed for paddle.static.amp.decorate function.
import inspect
import logging
from enum import Enum
from paddle.fluid.log_helper import get_logger
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
)
class FunctionType(Enum):
FP16_ONLY = 0
COMMON = 1
class Function:
"""
Function is a wrap over standard python function
An instance of this Function class is also callable
just like the python function that it wrapped.
When the instance is "called" like a function it fetches
the function to be invoked from the virtual namespace and then
invokes the same.
"""
def __init__(self, fn):
self.fn = fn
def __call__(self, *args, **kwargs):
"""
Overriding the __call__ function which makes the
instance callable.
"""
# fetching the function to be invoked from the virtual namespace
# through the arguments.
fn = Namespace.get_instance().get(*args, **kwargs)
# invoking the wrapped function and returning the value.
return fn(*args, **kwargs)
class Namespace:
"""
Namespace is the singleton class that is responsible
for holding all the functions.
"""
__instance = None
def __init__(self):
if self.__instance is None:
self.function_map = {}
Namespace.__instance = self
else:
raise Exception("cannot instantiate Namespace again.")
@staticmethod
def get_instance():
if Namespace.__instance is None:
Namespace()
return Namespace.__instance
def register(self, fn, key):
"""
Register the function in the virtual namespace and return
an instance of callable Function that wraps the function fn.
Args:
fn (function): the native python function handle.
key (FunctionType): the specified type.
"""
assert isinstance(
key, FunctionType
), f"The type of key is expected to be FunctionType, but recieved {type(key)}."
func = Function(fn)
self.function_map[key] = fn
return func
def get(self, *args, **kwargs):
"""
Get the matching function from the virtual namespace according to the actual arguments.
Return None if it did not find any matching function.
"""
_logger.debug(f"get function: args={args}, kwargs={kwargs}")
satisfied_function_keys = set(self.function_map.keys())
num_actual_args = len(args) + len(kwargs)
for func_key in self.function_map.keys():
if func_key not in satisfied_function_keys:
continue
fn = self.function_map[func_key]
specs = inspect.getfullargspec(fn)
if len(specs) < len(args) + len(kwargs):
# Remove the not satisfied function according to the number of actual arguments.
_logger.debug(
f"fn={fn} (key={func_key}) is not satisfied and removed."
)
satisfied_function_keys.remove(func_key)
continue
if len(kwargs) > 0:
# Remove the not satisfied function according to argument keys in kwargs.
for arg_name, value in kwargs.items():
if arg_name not in specs.args:
_logger.debug(
f"fn={fn} (key={func_key}) is not satisfied and removed."
)
satisfied_function_keys.remove(func_key)
break
if len(satisfied_function_keys) == 1:
key = list(satisfied_function_keys)[0]
elif len(args) >= 3 and isinstance(args[2], float):
key = FunctionType.FP16_ONLY
else:
key = FunctionType.COMMON
return self.function_map.get(key)
def overload(key):
"""overload is the decorator that wraps the function
and returns a callable object of type Function.
"""
def decorator(fn):
return Namespace.get_instance().register(fn, key)
return decorator
...@@ -12,16 +12,23 @@ ...@@ -12,16 +12,23 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import unittest
import numpy as np import numpy as np
import paddle import paddle
from paddle import nn from paddle import nn
from paddle.fluid import core
_fixed_add_param = np.random.random(size=[16, 16]).astype("float32") _fixed_add_param = np.random.random(size=[16, 16]).astype("float32")
def _build_optimizer( def _build_optimizer(
use_amp, amp_dtype="float16", amp_level="O1", use_grad_clip=False use_amp,
amp_dtype="float16",
amp_level="O1",
amp_lists=None,
use_grad_clip=False,
): ):
if use_grad_clip: if use_grad_clip:
grad_clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0) grad_clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
...@@ -37,13 +44,8 @@ def _build_optimizer( ...@@ -37,13 +44,8 @@ def _build_optimizer(
multi_precision=True, multi_precision=True,
) )
if use_amp: if use_amp:
amp_lists = paddle.static.amp.AutoMixedPrecisionLists( optimizer = paddle.static.amp.decorate(
custom_white_list=["elementwise_add"], optimizer, amp_lists, level=amp_level, dtype=amp_dtype
custom_black_list=["reduce_mean"],
dtype=amp_dtype,
)
optimizer = paddle.static.amp.amp_decorate(
optimizer, amp_lists=amp_lists, level=amp_level, dtype=amp_dtype
) )
return optimizer return optimizer
...@@ -80,7 +82,18 @@ def build_add_model(use_amp, amp_dtype="float16", amp_level="O1"): ...@@ -80,7 +82,18 @@ def build_add_model(use_amp, amp_dtype="float16", amp_level="O1"):
x = paddle.static.data(name='input', shape=[16, 16], dtype=x_dtype) x = paddle.static.data(name='input', shape=[16, 16], dtype=x_dtype)
out = model(x) out = model(x)
loss = paddle.mean(out) loss = paddle.mean(out)
optimizer = _build_optimizer(use_amp, amp_dtype, amp_level)
if use_amp:
amp_lists = paddle.static.amp.AutoMixedPrecisionLists(
custom_white_list=["elementwise_add"],
custom_black_list=["reduce_mean"],
dtype=amp_dtype,
)
else:
amp_lists = None
optimizer = _build_optimizer(
use_amp, amp_dtype, amp_level, amp_lists
)
optimizer.minimize(loss) optimizer.minimize(loss)
feed_vars = [x] feed_vars = [x]
fetch_vars = [loss] fetch_vars = [loss]
...@@ -145,7 +158,9 @@ def build_embedding_model(use_amp, amp_dtype="float16", amp_level="O1"): ...@@ -145,7 +158,9 @@ def build_embedding_model(use_amp, amp_dtype="float16", amp_level="O1"):
x = paddle.static.data(name='x', shape=[None, 32], dtype='int64') x = paddle.static.data(name='x', shape=[None, 32], dtype='int64')
out = model(x) out = model(x)
loss = paddle.mean(out) loss = paddle.mean(out)
optimizer = _build_optimizer(use_amp, amp_dtype, amp_level, True) optimizer = _build_optimizer(
use_amp, amp_dtype, amp_level, None, True
)
optimizer.minimize(loss) optimizer.minimize(loss)
return main_program, startup_program return main_program, startup_program
...@@ -186,3 +201,13 @@ def build_while_model(): ...@@ -186,3 +201,13 @@ def build_while_model():
out = model(x) out = model(x)
loss = paddle.mean(out) loss = paddle.mean(out)
return main_program, startup_program return main_program, startup_program
@unittest.skipIf(
not core.is_compiled_with_cuda(),
"core is not complied with CUDA and not support amp.",
)
class AmpTestBase(unittest.TestCase):
def setUp(self):
self.amp_dtype = None
self.amp_level = None
...@@ -17,7 +17,7 @@ import struct ...@@ -17,7 +17,7 @@ import struct
import unittest import unittest
import numpy as np import numpy as np
from amp_base_models import build_add_model, build_embedding_model from amp_base_models import AmpTestBase, build_add_model, build_embedding_model
import paddle import paddle
from paddle import fluid from paddle import fluid
...@@ -220,11 +220,7 @@ class TestModelCastBF16(unittest.TestCase): ...@@ -220,11 +220,7 @@ class TestModelCastBF16(unittest.TestCase):
) )
@unittest.skipIf( class TestProgramBF16(AmpTestBase):
not core.is_compiled_with_cuda(),
"core is not complied with CUDA and not support the bfloat16",
)
class TestProgramBF16(unittest.TestCase):
def _check_bf16_calls(self, op_stats_dict, expected_bf16_calls): def _check_bf16_calls(self, op_stats_dict, expected_bf16_calls):
for op_type, value in expected_bf16_calls.items(): for op_type, value in expected_bf16_calls.items():
self.assertEqual( self.assertEqual(
...@@ -270,11 +266,7 @@ class TestProgramBF16(unittest.TestCase): ...@@ -270,11 +266,7 @@ class TestProgramBF16(unittest.TestCase):
self._check_bf16_calls(op_stats_list[0], expected_bf16_calls) self._check_bf16_calls(op_stats_list[0], expected_bf16_calls)
@unittest.skipIf( class TestStaticBF16(AmpTestBase):
not core.is_compiled_with_cuda(),
"core is not complied with CUDA and not support the bfloat16",
)
class TestStaticBF16(unittest.TestCase):
def _generate_feed_x(self): def _generate_feed_x(self):
x = np.random.random(size=[16, 16]).astype("float32") x = np.random.random(size=[16, 16]).astype("float32")
x_bf16 = convert_float_to_uint16(x) x_bf16 = convert_float_to_uint16(x)
...@@ -282,7 +274,7 @@ class TestStaticBF16(unittest.TestCase): ...@@ -282,7 +274,7 @@ class TestStaticBF16(unittest.TestCase):
return x_fp32, x_bf16 return x_fp32, x_bf16
def test_compare_o1_o2(self): def test_compare_o1_o2(self):
def _run_o1(exe, x_np, max_iters): def _run_o1(place, exe, x_np, max_iters):
( (
main_program, main_program,
startup_program, startup_program,
...@@ -305,7 +297,7 @@ class TestStaticBF16(unittest.TestCase): ...@@ -305,7 +297,7 @@ class TestStaticBF16(unittest.TestCase):
losses.append(results[0]) losses.append(results[0])
return losses return losses
def _run_o2(exe, x_np, max_iters): def _run_o2(place, exe, x_np, max_iters):
( (
main_program, main_program,
startup_program, startup_program,
...@@ -334,8 +326,8 @@ class TestStaticBF16(unittest.TestCase): ...@@ -334,8 +326,8 @@ class TestStaticBF16(unittest.TestCase):
max_iters = 2 max_iters = 2
x_fp32, x_bf16 = self._generate_feed_x() x_fp32, x_bf16 = self._generate_feed_x()
losses_o1 = _run_o1(exe, x_fp32, max_iters) losses_o1 = _run_o1(place, exe, x_fp32, max_iters)
losses_o2 = _run_o2(exe, x_bf16, max_iters) losses_o2 = _run_o2(place, exe, x_bf16, max_iters)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册