未验证 提交 972581d8 编写于 作者: shaojie_wang's avatar shaojie_wang 提交者: GitHub

[AMP]Master grad in static graph (#53362)

* add master gradients on static graph

* add unit test for bf16 master grad static graph

* use float16 as v100 test dtype

* only skip GPU which do not support bf16

* use linear layer to test master grad

* 1.push master grad creation before all optimizer ops; 2.remove useless unittest; 3.use a function to create master grad states
上级 4f1bf199
...@@ -283,6 +283,8 @@ class AdamW(Optimizer): ...@@ -283,6 +283,8 @@ class AdamW(Optimizer):
self._auxiliary_vars = {} self._auxiliary_vars = {}
self._already_create_accumulater = set() self._already_create_accumulater = set()
self._create_master_grad_states()
def _set_auxiliary_var(self, key, val): def _set_auxiliary_var(self, key, val):
self._auxiliary_vars[key] = val self._auxiliary_vars[key] = val
......
...@@ -275,6 +275,14 @@ class Optimizer: ...@@ -275,6 +275,14 @@ class Optimizer:
self._auxiliary_vars = {} self._auxiliary_vars = {}
self._already_create_accumulater = set() self._already_create_accumulater = set()
# create master gradients' states
self._create_master_grad_states()
def _create_master_grad_states(self):
# master gradients states
self._master_grads = {}
self._master_grad = False
def _set_auxiliary_var(self, key, val): def _set_auxiliary_var(self, key, val):
self._auxiliary_vars[key] = val self._auxiliary_vars[key] = val
...@@ -669,6 +677,25 @@ class Optimizer: ...@@ -669,6 +677,25 @@ class Optimizer:
self._master_weights[param.name] = var self._master_weights[param.name] = var
return var return var
def _create_master_grad(self, grad):
assert self._is_dtype_fp16_or_bf16(grad.dtype)
if grad.name in self._master_grads:
var = self._master_grads[grad.name]
else:
var_name = grad.name + "_fp32_master"
var_name = unique_name.generate(var_name)
var = grad.block.create_var(
name=var_name,
shape=grad.shape,
value=0,
dtype='float32',
lod_level=grad.lod_level,
persistable=grad.persistable,
is_data=grad.is_data,
)
self._master_grads[grad.name] = var
return var
def _create_accumulators(self, block, parameters): def _create_accumulators(self, block, parameters):
"""Create all accumulators needed by the parameters """Create all accumulators needed by the parameters
...@@ -1168,7 +1195,6 @@ class Optimizer: ...@@ -1168,7 +1195,6 @@ class Optimizer:
if self._grad_clip is not None: if self._grad_clip is not None:
params_grads = self._grad_clip(params_grads) params_grads = self._grad_clip(params_grads)
else: else:
params_grads = paddle.nn.clip.append_gradient_clip_ops(params_grads) params_grads = paddle.nn.clip.append_gradient_clip_ops(params_grads)
# Add regularization if any # Add regularization if any
......
...@@ -80,6 +80,7 @@ class OptimizerWithMixedPrecision: ...@@ -80,6 +80,7 @@ class OptimizerWithMixedPrecision:
the loss scaling. the loss scaling.
use_amp_guard(bool): Whether to use `fp16_guard` when constructing the program. use_amp_guard(bool): Whether to use `fp16_guard` when constructing the program.
Default None, which means that its value is equal to `use_pure_fp16`. Default None, which means that its value is equal to `use_pure_fp16`.
use_master_grad(bool): Whether to use fp32 master gradients during optimizer. Default is False.
use_promote(bool): Whether to promotes to fp32 when op has any float32 inputs. Default is False. use_promote(bool): Whether to promotes to fp32 when op has any float32 inputs. Default is False.
""" """
...@@ -96,6 +97,7 @@ class OptimizerWithMixedPrecision: ...@@ -96,6 +97,7 @@ class OptimizerWithMixedPrecision:
incr_ratio, incr_ratio,
decr_ratio, decr_ratio,
use_amp_guard=None, use_amp_guard=None,
use_master_grad=False,
use_promote=False, use_promote=False,
): ):
self._optimizer = optimizer self._optimizer = optimizer
...@@ -104,6 +106,7 @@ class OptimizerWithMixedPrecision: ...@@ -104,6 +106,7 @@ class OptimizerWithMixedPrecision:
self._train_program = None self._train_program = None
self._is_distributed = False self._is_distributed = False
self._use_master_grad = False
self._scaled_loss = None self._scaled_loss = None
self._loss_scaling = None self._loss_scaling = None
self._init_loss_scaling = init_loss_scaling self._init_loss_scaling = init_loss_scaling
...@@ -122,6 +125,9 @@ class OptimizerWithMixedPrecision: ...@@ -122,6 +125,9 @@ class OptimizerWithMixedPrecision:
self._learning_rate = optimizer._learning_rate self._learning_rate = optimizer._learning_rate
self._learning_rate_map = optimizer._learning_rate_map self._learning_rate_map = optimizer._learning_rate_map
self._use_pure_fp16 = level == "O2" self._use_pure_fp16 = level == "O2"
if self._use_pure_fp16 and (dtype == "bfloat16" or dtype == "float16"):
self._use_master_grad = use_master_grad
self._optimizer._master_grad = use_master_grad
self._amp_level = level self._amp_level = level
self._use_fp16_guard = use_amp_guard self._use_fp16_guard = use_amp_guard
self._to_fp16_var_names = None self._to_fp16_var_names = None
...@@ -384,6 +390,51 @@ class OptimizerWithMixedPrecision: ...@@ -384,6 +390,51 @@ class OptimizerWithMixedPrecision:
use_promote=self.use_promote, use_promote=self.use_promote,
) )
def _append_cast_to_master_grad_op(self, param_grads):
"""
Create master gradient vars and add cast gradient to master gradient op in main program
Args:
param_grads(list(tuple(Tensor, Tensor))): A list of (parameter, gradient) pair to update.
Returns:
list: A list of (parameter, master_gradient) pair. In the following grad clip step and optimizer step, params can be updated by master gradient. main_prog will also append cast ops before grad clip ops.
"""
if not self._use_master_grad:
return param_grads
global_block = self._train_program.global_block()
target_block = global_block
current_block = self._train_program.current_block()
if current_block.idx != global_block.idx:
target_block = self._train_program.blocks[
current_block.backward_block_idx
]
params_master_grads = []
assert isinstance(target_block, paddle.fluid.framework.Block)
# create
for p, g in param_grads:
if g.name not in self._optimizer._master_grads.keys():
if self._optimizer._is_dtype_fp16_or_bf16(g.dtype):
master_g = self._optimizer._create_master_grad(g)
params_master_grads.append((p, master_g))
target_block.append_op(
type="cast",
inputs={"X": [g]},
outputs={"Out": [master_g]},
attrs={
"in_dtype": g.dtype,
"out_dtype": master_g.dtype,
},
)
else:
params_master_grads.append((p, g))
return params_master_grads
def apply_gradients(self, params_grads): def apply_gradients(self, params_grads):
""" """
Check scaled gradients to determine whether to update loss scaling and update Check scaled gradients to determine whether to update loss scaling and update
...@@ -400,6 +451,9 @@ class OptimizerWithMixedPrecision: ...@@ -400,6 +451,9 @@ class OptimizerWithMixedPrecision:
# transferred across GPUs can be FP16. # transferred across GPUs can be FP16.
update_role_var_grad(self._train_program, params_grads) update_role_var_grad(self._train_program, params_grads)
# Create master grad and add cast op into program
params_grads = self._append_cast_to_master_grad_op(params_grads)
# When not using dynamic loss scaling and the init loss scaling value is equal to 1.0, # When not using dynamic loss scaling and the init loss scaling value is equal to 1.0,
# the model can be optimized. # the model can be optimized.
if ( if (
...@@ -756,6 +810,7 @@ def decorate( ...@@ -756,6 +810,7 @@ def decorate(
level='O1', level='O1',
dtype='float16', dtype='float16',
master_weight=None, master_weight=None,
master_grad=False,
init_loss_scaling=2**15, init_loss_scaling=2**15,
incr_every_n_steps=1000, incr_every_n_steps=1000,
decr_every_n_nan_or_inf=2, decr_every_n_nan_or_inf=2,
...@@ -782,6 +837,9 @@ def decorate( ...@@ -782,6 +837,9 @@ def decorate(
master_weight(bool, optinal): For level='O2', whether to use multi-precision master_weight(bool, optinal): For level='O2', whether to use multi-precision
during weight updating. If master_weight is None, in O2 level optimizer during weight updating. If master_weight is None, in O2 level optimizer
will use multi-precision. Default is None. will use multi-precision. Default is None.
master_grad(bool, optinal): For level='O2', whether to use master_grad
during weight updating. If master_grad is False, in O2 level optimizer
will not use master grad. Default is False.
init_loss_scaling(float, optional): The initial loss scaling factor. init_loss_scaling(float, optional): The initial loss scaling factor.
Default is 32768. Default is 32768.
incr_every_n_steps(int, optional): Increases loss scaling every n incr_every_n_steps(int, optional): Increases loss scaling every n
...@@ -883,6 +941,7 @@ def decorate( ...@@ -883,6 +941,7 @@ def decorate(
decr_ratio=decr_ratio, decr_ratio=decr_ratio,
use_amp_guard=use_amp_guard, use_amp_guard=use_amp_guard,
use_promote=use_promote, use_promote=use_promote,
use_master_grad=master_grad,
) )
return mp_optimizer return mp_optimizer
...@@ -478,9 +478,9 @@ def op_need_keep_fp32(op, amp_lists, use_fp16_guard, params_list): ...@@ -478,9 +478,9 @@ def op_need_keep_fp32(op, amp_lists, use_fp16_guard, params_list):
need_keep_fp32 = True need_keep_fp32 = True
for in_name in op.input_names: for in_name in op.input_names:
for params in params_list: for params in params_list:
if op.input(in_name)[0] == params.name: if params.name in op.input(in_name):
fp16_varname_list_in_fp32_op = ( fp16_varname_list_in_fp32_op = (
fp16_varname_list_in_fp32_op.union(op.input(in_name)) fp16_varname_list_in_fp32_op.union([params.name])
) )
return need_keep_fp32, fp16_varname_list_in_fp32_op return need_keep_fp32, fp16_varname_list_in_fp32_op
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import copy import copy
import struct
import unittest import unittest
import numpy as np import numpy as np
...@@ -22,6 +23,34 @@ from paddle import nn ...@@ -22,6 +23,34 @@ from paddle import nn
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.framework import _non_static_mode from paddle.fluid.framework import _non_static_mode
def copy_bits_from_float_to_uint16(f):
return struct.unpack('<I', struct.pack('<f', f))[0] >> 16
def convert_float_to_uint16(in_list):
if in_list.dtype == np.float32:
new_output = []
for x in np.nditer(in_list):
new_output.append(np.uint16(copy_bits_from_float_to_uint16(x)))
new_output = np.reshape(new_output, in_list.shape).view(np.uint16)
return new_output
else:
return in_list
def convert_uint16_to_float(in_list):
if in_list.dtype == np.uint16:
in_list = np.asarray(in_list)
out = np.vectorize(
lambda x: struct.unpack('<f', struct.pack('<I', x << 16))[0],
otypes=[np.float32],
)(in_list.flat)
return np.reshape(out, in_list.shape)
else:
return in_list
_fixed_add_param = np.random.random(size=[16, 16]).astype("float32") _fixed_add_param = np.random.random(size=[16, 16]).astype("float32")
...@@ -32,6 +61,7 @@ def _build_optimizer( ...@@ -32,6 +61,7 @@ def _build_optimizer(
amp_lists=None, amp_lists=None,
use_grad_clip=False, use_grad_clip=False,
use_promote=False, use_promote=False,
use_master_grad=False,
model=None, model=None,
): ):
if use_grad_clip: if use_grad_clip:
...@@ -58,6 +88,7 @@ def _build_optimizer( ...@@ -58,6 +88,7 @@ def _build_optimizer(
amp_lists, amp_lists,
level=amp_level, level=amp_level,
dtype=amp_dtype, dtype=amp_dtype,
master_grad=use_master_grad,
use_promote=use_promote, use_promote=use_promote,
) )
return optimizer return optimizer
...@@ -80,6 +111,15 @@ class SimpleAddNet(nn.Layer): ...@@ -80,6 +111,15 @@ class SimpleAddNet(nn.Layer):
return x + self.weight return x + self.weight
def cast_add_param(amp_dtype):
global _fixed_add_param
if amp_dtype == "bfloat16":
_fixed_add_param_bf16 = convert_float_to_uint16(_fixed_add_param)
_fixed_add_param = convert_uint16_to_float(_fixed_add_param_bf16)
else:
pass
def build_add_model( def build_add_model(
use_amp, amp_dtype="float16", amp_level="O1", use_promote=False use_amp, amp_dtype="float16", amp_level="O1", use_promote=False
): ):
...@@ -93,6 +133,7 @@ def build_add_model( ...@@ -93,6 +133,7 @@ def build_add_model(
x_dtype = "uint16" x_dtype = "uint16"
elif amp_dtype == "float16": elif amp_dtype == "float16":
x_dtype = "float16" x_dtype = "float16"
cast_add_param(amp_dtype)
model = SimpleAddNet(x_dtype) model = SimpleAddNet(x_dtype)
x = paddle.static.data(name='input', shape=[16, 16], dtype=x_dtype) x = paddle.static.data(name='input', shape=[16, 16], dtype=x_dtype)
out = model(x) out = model(x)
...@@ -177,8 +218,6 @@ class SimpleEmbeddingNet(nn.Layer): ...@@ -177,8 +218,6 @@ class SimpleEmbeddingNet(nn.Layer):
super().__init__() super().__init__()
self.vocab_size = 128 self.vocab_size = 128
self.hidden_size = 16 self.hidden_size = 16
self.vocab_size = 128
self.hidden_size = 16
self.embedding = nn.Embedding(self.vocab_size, self.hidden_size) self.embedding = nn.Embedding(self.vocab_size, self.hidden_size)
self.linear = nn.Linear(in_features=16, out_features=10) self.linear = nn.Linear(in_features=16, out_features=10)
...@@ -192,7 +231,11 @@ class SimpleEmbeddingNet(nn.Layer): ...@@ -192,7 +231,11 @@ class SimpleEmbeddingNet(nn.Layer):
def build_embedding_model( def build_embedding_model(
use_amp, amp_dtype="float16", amp_level="O1", use_promote=False use_amp,
amp_dtype="float16",
amp_level="O1",
use_promote=False,
use_master_grad=False,
): ):
main_program = paddle.static.Program() main_program = paddle.static.Program()
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
...@@ -202,16 +245,90 @@ def build_embedding_model( ...@@ -202,16 +245,90 @@ def build_embedding_model(
x = paddle.static.data(name='x', shape=[None, 32], dtype='int64') x = paddle.static.data(name='x', shape=[None, 32], dtype='int64')
out = model(x) out = model(x)
loss = paddle.mean(out) loss = paddle.mean(out)
if use_amp:
amp_lists = paddle.static.amp.AutoMixedPrecisionLists(
custom_white_list=["elementwise_mul"],
custom_black_list=["reduce_mean"],
dtype=amp_dtype,
)
else:
amp_lists = None
optimizer = _build_optimizer( optimizer = _build_optimizer(
use_amp, use_amp,
amp_dtype, amp_dtype,
amp_level, amp_level,
None, amp_lists,
True, True,
use_promote=use_promote, use_promote=use_promote,
use_master_grad=use_master_grad,
) )
optimizer.minimize(loss) optimizer.minimize(loss)
return main_program, startup_program
feed_vars = [x]
fetch_vars = [loss]
return main_program, startup_program, optimizer, feed_vars, fetch_vars
class SimpleMLPNet(nn.Layer):
def __init__(self):
super().__init__()
self.linear0 = paddle.nn.Linear(16, 10)
self.linear1 = paddle.nn.Linear(10, 32)
def forward(self, x):
out = self.linear0(x)
out = nn.functional.relu(out)
out = self.linear1(out)
out = nn.functional.relu(out)
out = nn.functional.dropout(out, p=0.2)
return out
def build_MLP_model(
use_amp,
use_grad_clip=False,
amp_dtype="float16",
amp_level="O1",
use_promote=False,
use_master_grad=False,
):
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.utils.unique_name.guard():
with paddle.static.program_guard(main_program, startup_program):
model = SimpleMLPNet()
x_dtype = "float32"
if use_amp and amp_level == "O2":
if amp_dtype == "bfloat16":
x_dtype = "uint16"
elif amp_dtype == "float16":
x_dtype = "float16"
x = paddle.static.data(name='x', shape=[None, 16], dtype=x_dtype)
out = model(x)
loss = paddle.mean(out)
if use_amp:
amp_lists = paddle.static.amp.AutoMixedPrecisionLists(
custom_black_list=["reduce_mean"],
dtype=amp_dtype,
)
else:
amp_lists = None
optimizer = _build_optimizer(
use_amp,
amp_dtype,
amp_level,
amp_lists,
use_grad_clip=use_grad_clip,
use_promote=use_promote,
use_master_grad=use_master_grad,
)
optimizer.minimize(loss)
feed_vars = [x]
fetch_vars = [loss]
return main_program, startup_program, optimizer, feed_vars, fetch_vars
class SimpleWhileNet(nn.Layer): class SimpleWhileNet(nn.Layer):
......
...@@ -81,6 +81,7 @@ class TestStaticDecorate(AmpTestBase): ...@@ -81,6 +81,7 @@ class TestStaticDecorate(AmpTestBase):
exe, exe,
x_fp32, x_fp32,
max_iters, max_iters,
dtype,
level, level,
) )
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import unittest
import numpy as np
from amp_base_models import (
AmpTestBase,
build_embedding_model,
build_MLP_model,
convert_float_to_uint16,
convert_uint16_to_float,
)
import paddle
from paddle.static import amp
paddle.enable_static()
class TestStaticMasterGradProgramFP16(AmpTestBase):
def _check_optimizer(self, program, expected_num_mp):
optimizers = []
for block in program.blocks:
for op in block.ops:
if "Param" in op.input_names and "Grad" in op.input_names:
optimizers.append(op)
actual_num_mp = 0
for op in optimizers:
if op.has_attr("multi_precision") and op.attr("multi_precision"):
actual_num_mp += 1
self.assertEqual(
actual_num_mp,
expected_num_mp,
f"The number of optimizers with multi_precison = True is expected to be {expected_num_mp}, but recieved {actual_num_mp}.",
)
def amp_fp16_o2(self, use_master_grad):
main_program, _, _, _, _ = build_embedding_model(
True, "float16", "O2", use_master_grad=use_master_grad
)
self.assertEqual(main_program.num_blocks, 1)
amp.debugging.collect_operator_stats(main_program)
op_stats_list = amp.debugging._get_op_stats_list(main_program)
expected_fp32_calls = {"lookup_table_v2": 1}
if use_master_grad:
expected_fp16_calls = {
"matmul_v2": 1,
"elementwise_add": 1,
"dropout": 1,
"lookup_table_v2": 0,
"squared_l2_norm": 0,
"adamw": 3,
}
else:
expected_fp16_calls = {
"matmul_v2": 1,
"elementwise_add": 1,
"dropout": 1,
"lookup_table_v2": 0,
"squared_l2_norm": 3,
"adamw": 3,
}
self._check_optimizer(
main_program,
expected_fp16_calls["matmul_v2"]
+ expected_fp16_calls["elementwise_add"]
+ expected_fp32_calls["lookup_table_v2"],
)
self._check_op_calls(
op_stats_list[0], expected_fp16_calls=expected_fp16_calls
)
def test_amp_fp16_o2(self):
use_master_grad_list = [False, True]
for master_grad in use_master_grad_list:
self.amp_fp16_o2(master_grad)
class TestMasterGradAccuracy(AmpTestBase):
def _generate_feed_x(self, dtype="float16"):
seed = 0
paddle.seed(seed)
np.random.seed(seed)
random.seed(seed)
x = np.random.random(size=[64, 16]).astype("float32")
if dtype == "bfloat16":
x_f16 = convert_float_to_uint16(x)
x_f32 = convert_uint16_to_float(x_f16)
elif dtype == "float16":
x_f16 = x.astype(np.float16)
x_f32 = x_f16.astype(np.float32)
else:
raise AssertionError(f"unkown dtype:{dtype}")
return x_f32, x_f16
def test_compare_o1_and_o2_master_grad(self):
def _run(
place,
exe,
x_np,
max_iters,
level,
use_grad_clip,
dtype="float16",
use_master_grad=False,
):
(
main_program,
startup_program,
optimizer,
feed_vars,
fetch_vars,
) = build_MLP_model(
True,
use_grad_clip=use_grad_clip,
amp_dtype=dtype,
amp_level=level,
use_master_grad=use_master_grad,
)
seed = 0
paddle.seed(seed)
np.random.seed(seed)
random.seed(seed)
losses = self.run_program(
main_program,
startup_program,
optimizer,
feed_vars,
fetch_vars,
place,
exe,
x_np,
max_iters,
dtype,
level,
)
return losses
dtype = "float16"
max_iters = 25
x_f32, x_f16 = self._generate_feed_x(dtype)
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
use_grad_clip_list = [False, True]
for use_grad_clip in use_grad_clip_list:
losses_o1 = _run(
place, exe, x_f32, max_iters, 'O1', use_grad_clip, dtype=dtype
)
losses_o2_no_master_grad = _run(
place,
exe,
x_f16,
max_iters,
'O2',
use_grad_clip,
dtype=dtype,
use_master_grad=False,
)
losses_o2_master_grad = _run(
place,
exe,
x_f16,
max_iters,
'O2',
use_grad_clip,
dtype=dtype,
use_master_grad=True,
)
self.assertNotEqual(
losses_o1,
losses_o2_no_master_grad,
f"dtype: {dtype}, loss of o1 and o2-wo-master_grad should not be equal, but recieved loss o1: {losses_o1}, loss o2: {losses_o2_no_master_grad}",
)
self.assertEqual(
losses_o1,
losses_o2_master_grad,
f"dtype: {dtype}, loss of o1 and o2-w-master_grad should be equal, but recieved loss o1: {losses_o1}, loss o2: {losses_o2_master_grad}",
)
if __name__ == '__main__':
unittest.main()
...@@ -13,11 +13,16 @@ ...@@ -13,11 +13,16 @@
# limitations under the License. # limitations under the License.
import contextlib import contextlib
import struct
import unittest import unittest
import numpy as np import numpy as np
from amp_base_models import AmpTestBase, build_add_model, build_embedding_model from amp_base_models import (
AmpTestBase,
build_add_model,
build_embedding_model,
convert_float_to_uint16,
convert_uint16_to_float,
)
import paddle import paddle
from paddle import fluid from paddle import fluid
...@@ -26,34 +31,6 @@ from paddle.static import amp ...@@ -26,34 +31,6 @@ from paddle.static import amp
paddle.enable_static() paddle.enable_static()
def copy_bits_from_float_to_uint16(f):
return struct.unpack('<I', struct.pack('<f', f))[0] >> 16
def convert_float_to_uint16(in_list):
if in_list.dtype == np.float32:
new_output = []
for x in np.nditer(in_list):
new_output.append(np.uint16(copy_bits_from_float_to_uint16(x)))
new_output = np.reshape(new_output, in_list.shape).view(np.uint16)
return new_output
else:
return in_list
def convert_uint16_to_float(in_list):
if in_list.dtype == np.uint16:
in_list = np.asarray(in_list)
out = np.vectorize(
lambda x: struct.unpack('<f', struct.pack('<I', x << 16))[0],
otypes=[np.float32],
)(in_list.flat)
return np.reshape(out, in_list.shape)
else:
return in_list
cutf = convert_uint16_to_float cutf = convert_uint16_to_float
...@@ -239,7 +216,7 @@ class TestProgramBF16(AmpTestBase): ...@@ -239,7 +216,7 @@ class TestProgramBF16(AmpTestBase):
) )
def test_amp_bf16_o1(self): def test_amp_bf16_o1(self):
main_program, startup_program = build_embedding_model( main_program, startup_program, _, _, _ = build_embedding_model(
True, "bfloat16", "O1" True, "bfloat16", "O1"
) )
self.assertEqual(main_program.num_blocks, 1) self.assertEqual(main_program.num_blocks, 1)
...@@ -258,7 +235,7 @@ class TestProgramBF16(AmpTestBase): ...@@ -258,7 +235,7 @@ class TestProgramBF16(AmpTestBase):
self._check_op_calls(op_stats_list[0], expected_bf16_calls) self._check_op_calls(op_stats_list[0], expected_bf16_calls)
def test_amp_bf16_o2(self): def test_amp_bf16_o2(self):
main_program, startup_program = build_embedding_model( main_program, startup_program, _, _, _ = build_embedding_model(
True, "bfloat16", "O2" True, "bfloat16", "O2"
) )
self.assertEqual(main_program.num_blocks, 1) self.assertEqual(main_program.num_blocks, 1)
...@@ -322,6 +299,12 @@ class TestStaticBF16(AmpTestBase): ...@@ -322,6 +299,12 @@ class TestStaticBF16(AmpTestBase):
losses_o1 = _run(place, exe, x_fp32, max_iters, 'O1') losses_o1 = _run(place, exe, x_fp32, max_iters, 'O1')
losses_o2 = _run(place, exe, x_bf16, max_iters, 'O2') losses_o2 = _run(place, exe, x_bf16, max_iters, 'O2')
self.assertEqual(
losses_o1,
losses_o2,
f"loss of o1 and o2 should be equal, but recieved loss o1: {losses_o1}, loss o2: {losses_o2}",
)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册