未验证 提交 a842828a 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat]Support Mixed Precision training in @to_static (#34562)

* Support Mixed Precision training in @to_static

* fix block.vars logic

* fix GPU training loss diff

* remove unused code
上级 bb7b4c0c
......@@ -90,6 +90,17 @@ def _update_list(custom_white_list, custom_black_list):
return _white_list, _black_list
def _in_amp_guard():
"""
Judge whether current code block is in `amp_guard` context.
"""
tracer = _dygraph_tracer()
if tracer:
return tracer._enable_autocast
else:
return False
@signature_safe_contextmanager
@dygraph_only
def amp_guard(enable=True, custom_white_list=None, custom_black_list=None):
......
......@@ -17,7 +17,7 @@ import numpy as np
import six
import paddle
from paddle.fluid import framework, backward, core
from paddle.fluid import framework, backward, core, program_guard
from paddle.fluid.dygraph import layers
from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static import logging_utils
......@@ -26,6 +26,9 @@ from paddle.fluid.layers.utils import flatten
from paddle.fluid.layers.utils import pack_sequence_as
from paddle.fluid.layers.utils import _hash_with_id
from paddle.fluid.compiler import BuildStrategy
from paddle.fluid.contrib.mixed_precision.decorator import AutoMixedPrecisionLists
from paddle.fluid.contrib.mixed_precision.fp16_utils import rewrite_program
from paddle.fluid.dygraph.amp.auto_cast import _in_amp_guard
import paddle.compat as cpt
from paddle import _C_ops
......@@ -149,6 +152,9 @@ class PartialProgramLayer:
self._double_grads = self._get_double_grads(self._origin_main_program)
self.training = True
# For AMP training
self._amp_list = AutoMixedPrecisionLists()
@LazyInitialized
def _infer_program(self):
"""
......@@ -168,6 +174,25 @@ class PartialProgramLayer:
return train_program
@LazyInitialized
@switch_to_static_graph
def _infer_amp_program(self):
"""
Lazy initialized property of infer_amp_program.
"""
infer_amp_program = self._origin_main_program.clone()
with program_guard(infer_amp_program):
rewrite_program(infer_amp_program, self._amp_list)
return infer_amp_program
@LazyInitialized
def _train_amp_program(self):
"""
Lazy initialized property of train_amp_program.
"""
return self._append_backward_desc(self._infer_amp_program)
@LazyInitialized
def _infer_program_id(self):
return _hash_with_id(self._infer_program, self)
......@@ -180,6 +205,14 @@ class PartialProgramLayer:
return program_id
@LazyInitialized
def _train_amp_program_id(self):
program_id = _hash_with_id(self._train_amp_program, self)
core._set_cached_executor_build_strategy(program_id,
self._build_strategy)
return program_id
def _verify_program(self, main_program):
"""
Verify that the program parameter is initialized, prune some unused params,
......@@ -241,12 +274,17 @@ class PartialProgramLayer:
double_grads.append(var_base)
return self._valid_vars(double_grads)
def _get_end_op_index(self):
infer_program = self._infer_amp_program if _in_amp_guard(
) else self._infer_program
return infer_program.desc.block(0).op_size()
def __call__(self, inputs):
in_vars, out_vars = self._prepare(inputs)
attrs = ('global_block', self.program.desc.block(0), 'start_op_index',
0, 'end_op_index', self._infer_program.desc.block(0).op_size(),
'is_test', not self.training, 'program_id', self.program_id)
0, 'end_op_index', self._get_end_op_index(), 'is_test',
not self.training, 'program_id', self.program_id)
_C_ops.run_program(
self._valid_vars(in_vars),
self._valid_vars(self._params),
......@@ -258,11 +296,19 @@ class PartialProgramLayer:
@property
def program(self):
return self._train_program if self.training else self._infer_program
if self.training:
return self._train_amp_program if _in_amp_guard(
) else self._train_program
else:
return self._infer_program
@property
def program_id(self):
return self._train_program_id if self.training else self._infer_program_id
if self.training:
return self._train_amp_program_id if _in_amp_guard(
) else self._train_program_id
else:
return self._infer_program_id
def _prepare(self, inputs):
"""
......
......@@ -2035,6 +2035,11 @@ class Operator(object):
del op_attrs[role_var_name]
if len(self.desc.type()) != 0:
# NOTE(Aurelius84): prog.clone() will lead that var.op is always None,
# we add this to fix the problem.
for arg in self.desc.output_arg_names():
if block.has_var(arg) and block.var(arg).op is None:
block.var(arg).op = self
return
if type is None:
raise ValueError(
......
......@@ -32,6 +32,9 @@ from predictor_utils import PredictorTools
SEED = 2020
if paddle.fluid.is_compiled_with_cuda():
paddle.fluid.set_flags({'FLAGS_cudnn_deterministic': True})
class SimpleImgConvPool(fluid.dygraph.Layer):
def __init__(self,
......@@ -48,7 +51,7 @@ class SimpleImgConvPool(fluid.dygraph.Layer):
conv_dilation=1,
conv_groups=1,
act=None,
use_cudnn=False,
use_cudnn=True,
param_attr=None,
bias_attr=None):
super(SimpleImgConvPool, self).__init__()
......@@ -101,7 +104,6 @@ class MNIST(fluid.dygraph.Layer):
loc=0.0, scale=scale)),
act="softmax")
@paddle.jit.to_static
def forward(self, inputs, label=None):
x = self.inference(inputs)
if label is not None:
......@@ -167,14 +169,14 @@ class TestMNISTWithToStatic(TestMNIST):
dygraph_loss_cpu, dygraph_loss_mkldnn))
def train(self, to_static=False):
prog_trans = ProgramTranslator()
prog_trans.enable(to_static)
loss_data = []
with fluid.dygraph.guard(self.place):
fluid.default_main_program().random_seed = SEED
fluid.default_startup_program().random_seed = SEED
mnist = MNIST()
if to_static:
mnist = paddle.jit.to_static(mnist)
adam = AdamOptimizer(
learning_rate=0.001, parameter_list=mnist.parameters())
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import unittest
import numpy as np
from time import time
from test_mnist import MNIST, TestMNIST, SEED
from paddle.jit import ProgramTranslator
from paddle.fluid.optimizer import AdamOptimizer
if paddle.fluid.is_compiled_with_cuda():
paddle.fluid.set_flags({'FLAGS_cudnn_deterministic': True})
class TestAMP(TestMNIST):
def train_static(self):
return self.train(to_static=True)
def train_dygraph(self):
return self.train(to_static=False)
def test_mnist_to_static(self):
dygraph_loss = self.train_dygraph()
static_loss = self.train_static()
# NOTE(Aurelius84): In static AMP training, there is a grep_list but
# dygraph AMP don't. It will bring the numbers of cast_op is different
# and leads to loss has a bit diff.
self.assertTrue(
np.allclose(
dygraph_loss, static_loss, atol=1e-3),
msg='dygraph is {}\n static_res is \n{}'.format(dygraph_loss,
static_loss))
def train(self, to_static=False):
paddle.seed(SEED)
mnist = MNIST()
if to_static:
print("Successfully to apply @to_static.")
mnist = paddle.jit.to_static(mnist)
adam = AdamOptimizer(
learning_rate=0.001, parameter_list=mnist.parameters())
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
loss_data = []
for epoch in range(self.epoch_num):
start = time()
for batch_id, data in enumerate(self.train_reader()):
dy_x_data = np.array([x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(-1, 1)
img = paddle.to_tensor(dy_x_data)
label = paddle.to_tensor(y_data)
label.stop_gradient = True
with paddle.amp.auto_cast():
prediction, acc, avg_loss = mnist(img, label=label)
scaled = scaler.scale(avg_loss)
scaled.backward()
scaler.minimize(adam, scaled)
loss_data.append(avg_loss.numpy()[0])
# save checkpoint
mnist.clear_gradients()
if batch_id % 10 == 0:
print(
"Loss at epoch {} step {}: loss: {:}, acc: {}, cost: {}"
.format(epoch, batch_id,
avg_loss.numpy(), acc.numpy(), time() - start))
start = time()
if batch_id == 50:
break
return loss_data
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册