未验证 提交 52edad6a 编写于 作者: 0 0x45f 提交者: GitHub

[Dy2stat]support pure fp16 for dy2stat (#36944)

* run dy2stat pure fp16 in Linear model

* no use self._pure_fp16_inputs

* add test and fix Adam error in dy2stat pure fp16 training

* use paddle.optimizer.Adam

* run test in gpu

* change test time for CI

* enlarge atol for test_resnet_pure_fp16

* refine code and enlarge atol

* make custom_white_list and custom_black_list take effect for AMP and pure fp16

* check tracer is not None

* use default atol

* change filter_size

* change atol and add some NOTE
上级 93aefceb
...@@ -261,6 +261,13 @@ NameVarBaseMap CastPureFp16Inputs(const std::string& op_type, ...@@ -261,6 +261,13 @@ NameVarBaseMap CastPureFp16Inputs(const std::string& op_type,
dst_type = framework::proto::VarType::FP32; dst_type = framework::proto::VarType::FP32;
} }
for (auto& pair : new_ins) { for (auto& pair : new_ins) {
// NOTE: The run_program OP only has FP32 kernel. In dy2stat pure fp16
// training, we have correctly cast the inputs of run_program OP before,
// so here should avoid casting for run_program OP.
if (op_type == "run_program") {
continue;
}
if ((op_type == "batch_norm" || op_type == "layer_norm" || if ((op_type == "batch_norm" || op_type == "layer_norm" ||
op_type == "sync_batch_norm") && op_type == "sync_batch_norm") &&
pair.first != "X") { pair.first != "X") {
......
...@@ -118,6 +118,11 @@ def _in_amp_guard(): ...@@ -118,6 +118,11 @@ def _in_amp_guard():
return False return False
def _in_pure_fp16_guard():
tracer = _dygraph_tracer()
return tracer and tracer._amp_level == core.AmpLevel.O2
@dygraph_only @dygraph_only
def pure_fp16_initialize(models): def pure_fp16_initialize(models):
for idx in range(len(models)): for idx in range(len(models)):
......
...@@ -27,8 +27,8 @@ from paddle.fluid.layers.utils import pack_sequence_as ...@@ -27,8 +27,8 @@ from paddle.fluid.layers.utils import pack_sequence_as
from paddle.fluid.layers.utils import _hash_with_id from paddle.fluid.layers.utils import _hash_with_id
from paddle.fluid.compiler import BuildStrategy from paddle.fluid.compiler import BuildStrategy
from paddle.fluid.contrib.mixed_precision.decorator import AutoMixedPrecisionLists from paddle.fluid.contrib.mixed_precision.decorator import AutoMixedPrecisionLists
from paddle.fluid.contrib.mixed_precision.fp16_utils import rewrite_program from paddle.fluid.contrib.mixed_precision.fp16_utils import rewrite_program, cast_model_to_fp16
from paddle.fluid.dygraph.amp.auto_cast import _in_amp_guard from paddle.fluid.dygraph.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard
import paddle.compat as cpt import paddle.compat as cpt
from paddle import _C_ops from paddle import _C_ops
...@@ -152,8 +152,14 @@ class PartialProgramLayer: ...@@ -152,8 +152,14 @@ class PartialProgramLayer:
self._double_grads = self._get_double_grads(self._origin_main_program) self._double_grads = self._get_double_grads(self._origin_main_program)
self.training = True self.training = True
custom_white_list, custom_black_list = None, None
tracer = framework._dygraph_tracer()
if tracer:
custom_white_list, custom_black_list = tracer._get_amp_op_list()
# For AMP training # For AMP training
self._amp_list = AutoMixedPrecisionLists() self._amp_list = AutoMixedPrecisionLists(
custom_white_list=custom_white_list,
custom_black_list=custom_black_list)
@LazyInitialized @LazyInitialized
def _infer_program(self): def _infer_program(self):
...@@ -193,6 +199,26 @@ class PartialProgramLayer: ...@@ -193,6 +199,26 @@ class PartialProgramLayer:
""" """
return self._append_backward_desc(self._infer_amp_program) return self._append_backward_desc(self._infer_amp_program)
@LazyInitialized
@switch_to_static_graph
def _infer_pure_fp16_program(self):
"""
Lazy initialized property of _infer_pure_fp16_program.
"""
infer_pure_fp16_program = self._origin_main_program.clone()
with program_guard(infer_pure_fp16_program):
cast_model_to_fp16(
infer_pure_fp16_program, self._amp_list, use_fp16_guard=False)
return infer_pure_fp16_program
@LazyInitialized
def _train_pure_fp16_program(self):
"""
Lazy initialized property of _train_pure_fp16_program.
"""
return self._append_backward_desc(self._infer_pure_fp16_program)
@LazyInitialized @LazyInitialized
def _infer_program_id(self): def _infer_program_id(self):
return _hash_with_id(self._infer_program, self) return _hash_with_id(self._infer_program, self)
...@@ -213,6 +239,14 @@ class PartialProgramLayer: ...@@ -213,6 +239,14 @@ class PartialProgramLayer:
return program_id return program_id
@LazyInitialized
def _train_pure_fp16_program_id(self):
program_id = _hash_with_id(self._train_pure_fp16_program, self)
core._set_cached_executor_build_strategy(program_id,
self._build_strategy)
return program_id
def _verify_program(self, main_program): def _verify_program(self, main_program):
""" """
Verify that the program parameter is initialized, prune some unused params, Verify that the program parameter is initialized, prune some unused params,
...@@ -275,8 +309,12 @@ class PartialProgramLayer: ...@@ -275,8 +309,12 @@ class PartialProgramLayer:
return self._valid_vars(double_grads) return self._valid_vars(double_grads)
def _get_end_op_index(self): def _get_end_op_index(self):
infer_program = self._infer_amp_program if _in_amp_guard( if _in_amp_guard():
) else self._infer_program infer_program = self._infer_amp_program
elif _in_pure_fp16_guard():
infer_program = self._infer_pure_fp16_program
else:
infer_program = self._infer_program
return infer_program.desc.block(0).op_size() return infer_program.desc.block(0).op_size()
def __call__(self, inputs): def __call__(self, inputs):
...@@ -285,6 +323,9 @@ class PartialProgramLayer: ...@@ -285,6 +323,9 @@ class PartialProgramLayer:
attrs = ('global_block', self.program.desc.block(0), 'start_op_index', attrs = ('global_block', self.program.desc.block(0), 'start_op_index',
0, 'end_op_index', self._get_end_op_index(), 'is_test', 0, 'end_op_index', self._get_end_op_index(), 'is_test',
not self.training, 'program_id', self.program_id) not self.training, 'program_id', self.program_id)
self._cast_fp16_if_pure_fp16(in_vars)
_C_ops.run_program( _C_ops.run_program(
self._valid_vars(in_vars), self._valid_vars(in_vars),
self._valid_vars(self._params), self._valid_vars(self._params),
...@@ -294,6 +335,16 @@ class PartialProgramLayer: ...@@ -294,6 +335,16 @@ class PartialProgramLayer:
restored_nest_out = self._restore_out(out_vars) restored_nest_out = self._restore_out(out_vars)
return self._remove_no_value(restored_nest_out) return self._remove_no_value(restored_nest_out)
def _cast_fp16_if_pure_fp16(self, in_vars):
if _in_pure_fp16_guard():
for i, var in enumerate(in_vars):
name = var.name
if (self.program.global_block().has_var(name) and
self.program.global_block().var(name).dtype ==
paddle.float16):
in_vars[i] = var.astype('float16')
in_vars[i].name = name
def drop_scope_if_no_grad(self): def drop_scope_if_no_grad(self):
tracer = framework._dygraph_tracer() tracer = framework._dygraph_tracer()
if self.training and not tracer._has_grad: if self.training and not tracer._has_grad:
...@@ -302,16 +353,24 @@ class PartialProgramLayer: ...@@ -302,16 +353,24 @@ class PartialProgramLayer:
@property @property
def program(self): def program(self):
if self.training: if self.training:
return self._train_amp_program if _in_amp_guard( if _in_amp_guard():
) else self._train_program return self._train_amp_program
elif _in_pure_fp16_guard():
return self._train_pure_fp16_program
else:
return self._train_program
else: else:
return self._infer_program return self._infer_program
@property @property
def program_id(self): def program_id(self):
if self.training: if self.training:
return self._train_amp_program_id if _in_amp_guard( if _in_amp_guard():
) else self._train_program_id return self._train_amp_program_id
elif _in_pure_fp16_guard():
return self._train_pure_fp16_program_id
else:
return self._train_program_id
else: else:
return self._infer_program_id return self._infer_program_id
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import unittest
import numpy as np
from time import time
from test_mnist import MNIST, TestMNIST, SEED, SimpleImgConvPool
from paddle.jit import ProgramTranslator
from paddle.fluid.optimizer import AdamOptimizer
if paddle.fluid.is_compiled_with_cuda():
paddle.fluid.set_flags({'FLAGS_cudnn_deterministic': True})
class TestPureFP16(TestMNIST):
def train_static(self):
return self.train(to_static=True)
def train_dygraph(self):
return self.train(to_static=False)
def test_mnist_to_static(self):
if paddle.fluid.is_compiled_with_cuda():
dygraph_loss = self.train_dygraph()
static_loss = self.train_static()
# NOTE: In pure fp16 training, loss is not stable, so we enlarge atol here.
self.assertTrue(
np.allclose(
dygraph_loss, static_loss, atol=1e-3),
msg='dygraph is {}\n static_res is \n{}'.format(dygraph_loss,
static_loss))
def train(self, to_static=False):
np.random.seed(SEED)
paddle.seed(SEED)
paddle.framework.random._manual_program_seed(SEED)
mnist = MNIST()
if to_static:
print("Successfully to apply @to_static.")
mnist = paddle.jit.to_static(mnist)
optimizer = paddle.optimizer.Adam(
learning_rate=0.001, parameters=mnist.parameters())
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
mnist, optimizer = paddle.amp.decorate(
models=mnist,
optimizers=optimizer,
level='O2',
save_dtype='float32')
loss_data = []
for epoch in range(self.epoch_num):
start = time()
for batch_id, data in enumerate(self.train_reader()):
dy_x_data = np.array([x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(-1, 1)
img = paddle.to_tensor(dy_x_data)
label = paddle.to_tensor(y_data)
label.stop_gradient = True
with paddle.amp.auto_cast(
enable=True,
custom_white_list=None,
custom_black_list=None,
level='O2'):
prediction, acc, avg_loss = mnist(img, label=label)
scaled = scaler.scale(avg_loss)
scaled.backward()
scaler.minimize(optimizer, scaled)
loss_data.append(avg_loss.numpy()[0])
# save checkpoint
mnist.clear_gradients()
if batch_id % 10 == 0:
print(
"Loss at epoch {} step {}: loss: {:}, acc: {}, cost: {}"
.format(epoch, batch_id,
avg_loss.numpy(), acc.numpy(), time() - start))
start = time()
if batch_id == 50:
break
return loss_data
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import math
import time
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph import declarative, ProgramTranslator
from paddle.fluid.dygraph.nn import BatchNorm, Conv2D, Linear, Pool2D
from test_resnet import ResNet, optimizer_setting, SEED
# NOTE: Reduce batch_size from 8 to 2 to avoid unittest timeout.
batch_size = 2
epoch_num = 1
program_translator = ProgramTranslator()
if fluid.is_compiled_with_cuda():
fluid.set_flags({'FLAGS_cudnn_deterministic': True})
def train(to_static, build_strategy=None):
"""
Tests model decorated by `dygraph_to_static_output` in static mode. For users, the model is defined in dygraph mode and trained in static mode.
"""
np.random.seed(SEED)
paddle.seed(SEED)
paddle.framework.random._manual_program_seed(SEED)
resnet = ResNet()
if to_static:
resnet = paddle.jit.to_static(resnet, build_strategy=build_strategy)
optimizer = optimizer_setting(parameter_list=resnet.parameters())
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
resnet, optimizer = paddle.amp.decorate(
models=resnet, optimizers=optimizer, level='O2', save_dtype='float32')
for epoch in range(epoch_num):
loss_data = []
total_loss = 0.0
total_acc1 = 0.0
total_acc5 = 0.0
total_sample = 0
for batch_id in range(100):
start_time = time.time()
img = paddle.to_tensor(
np.random.random([batch_size, 3, 224, 224]).astype('float32'))
label = paddle.to_tensor(
np.random.randint(
0, 100, [batch_size, 1], dtype='int64'))
img.stop_gradient = True
label.stop_gradient = True
with paddle.amp.auto_cast(
enable=True,
custom_white_list=None,
custom_black_list=None,
level='O2'):
pred = resnet(img)
loss = fluid.layers.cross_entropy(input=pred, label=label)
avg_loss = fluid.layers.mean(x=pred)
acc_top1 = fluid.layers.accuracy(input=pred, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=pred, label=label, k=5)
scaled = scaler.scale(avg_loss)
scaled.backward()
scaler.minimize(optimizer, scaled)
resnet.clear_gradients()
loss_data.append(avg_loss.numpy()[0])
total_loss += avg_loss
total_acc1 += acc_top1
total_acc5 += acc_top5
total_sample += 1
end_time = time.time()
if batch_id % 2 == 0:
print( "epoch %d | batch step %d, loss %0.3f, acc1 %0.3f, acc5 %0.3f, time %f" % \
( epoch, batch_id, total_loss.numpy() / total_sample, \
total_acc1.numpy() / total_sample, total_acc5.numpy() / total_sample, end_time-start_time))
if batch_id == 10:
break
return loss_data
class TestResnet(unittest.TestCase):
def train(self, to_static):
program_translator.enable(to_static)
return train(to_static)
def test_resnet(self):
if fluid.is_compiled_with_cuda():
static_loss = self.train(to_static=True)
dygraph_loss = self.train(to_static=False)
# NOTE: In pure fp16 training, loss is not stable, so we enlarge atol here.
self.assertTrue(
np.allclose(
static_loss, dygraph_loss, atol=1e-3),
msg="static_loss: {} \n dygraph_loss: {}".format(static_loss,
dygraph_loss))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册