From a842828af291b772110c3128c618ce0a5887a29e Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Thu, 5 Aug 2021 16:05:54 +0800 Subject: [PATCH] [Dy2Stat]Support Mixed Precision training in @to_static (#34562) * Support Mixed Precision training in @to_static * fix block.vars logic * fix GPU training loss diff * remove unused code --- python/paddle/fluid/dygraph/amp/auto_cast.py | 11 +++ .../dygraph_to_static/partial_program.py | 56 ++++++++++- python/paddle/fluid/framework.py | 5 + .../unittests/dygraph_to_static/test_mnist.py | 10 +- .../dygraph_to_static/test_mnist_amp.py | 94 +++++++++++++++++++ 5 files changed, 167 insertions(+), 9 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/dygraph_to_static/test_mnist_amp.py diff --git a/python/paddle/fluid/dygraph/amp/auto_cast.py b/python/paddle/fluid/dygraph/amp/auto_cast.py index bd464450aef..a7eb0d31b7f 100644 --- a/python/paddle/fluid/dygraph/amp/auto_cast.py +++ b/python/paddle/fluid/dygraph/amp/auto_cast.py @@ -90,6 +90,17 @@ def _update_list(custom_white_list, custom_black_list): return _white_list, _black_list +def _in_amp_guard(): + """ + Judge whether current code block is in `amp_guard` context. + """ + tracer = _dygraph_tracer() + if tracer: + return tracer._enable_autocast + else: + return False + + @signature_safe_contextmanager @dygraph_only def amp_guard(enable=True, custom_white_list=None, custom_black_list=None): diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py b/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py index a99a5d50813..e275ee04858 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py @@ -17,7 +17,7 @@ import numpy as np import six import paddle -from paddle.fluid import framework, backward, core +from paddle.fluid import framework, backward, core, program_guard from paddle.fluid.dygraph import layers from paddle.fluid.dygraph.base import switch_to_static_graph from paddle.fluid.dygraph.dygraph_to_static import logging_utils @@ -26,6 +26,9 @@ from paddle.fluid.layers.utils import flatten from paddle.fluid.layers.utils import pack_sequence_as from paddle.fluid.layers.utils import _hash_with_id from paddle.fluid.compiler import BuildStrategy +from paddle.fluid.contrib.mixed_precision.decorator import AutoMixedPrecisionLists +from paddle.fluid.contrib.mixed_precision.fp16_utils import rewrite_program +from paddle.fluid.dygraph.amp.auto_cast import _in_amp_guard import paddle.compat as cpt from paddle import _C_ops @@ -149,6 +152,9 @@ class PartialProgramLayer: self._double_grads = self._get_double_grads(self._origin_main_program) self.training = True + # For AMP training + self._amp_list = AutoMixedPrecisionLists() + @LazyInitialized def _infer_program(self): """ @@ -168,6 +174,25 @@ class PartialProgramLayer: return train_program + @LazyInitialized + @switch_to_static_graph + def _infer_amp_program(self): + """ + Lazy initialized property of infer_amp_program. + """ + infer_amp_program = self._origin_main_program.clone() + with program_guard(infer_amp_program): + rewrite_program(infer_amp_program, self._amp_list) + + return infer_amp_program + + @LazyInitialized + def _train_amp_program(self): + """ + Lazy initialized property of train_amp_program. + """ + return self._append_backward_desc(self._infer_amp_program) + @LazyInitialized def _infer_program_id(self): return _hash_with_id(self._infer_program, self) @@ -180,6 +205,14 @@ class PartialProgramLayer: return program_id + @LazyInitialized + def _train_amp_program_id(self): + program_id = _hash_with_id(self._train_amp_program, self) + core._set_cached_executor_build_strategy(program_id, + self._build_strategy) + + return program_id + def _verify_program(self, main_program): """ Verify that the program parameter is initialized, prune some unused params, @@ -241,12 +274,17 @@ class PartialProgramLayer: double_grads.append(var_base) return self._valid_vars(double_grads) + def _get_end_op_index(self): + infer_program = self._infer_amp_program if _in_amp_guard( + ) else self._infer_program + return infer_program.desc.block(0).op_size() + def __call__(self, inputs): in_vars, out_vars = self._prepare(inputs) attrs = ('global_block', self.program.desc.block(0), 'start_op_index', - 0, 'end_op_index', self._infer_program.desc.block(0).op_size(), - 'is_test', not self.training, 'program_id', self.program_id) + 0, 'end_op_index', self._get_end_op_index(), 'is_test', + not self.training, 'program_id', self.program_id) _C_ops.run_program( self._valid_vars(in_vars), self._valid_vars(self._params), @@ -258,11 +296,19 @@ class PartialProgramLayer: @property def program(self): - return self._train_program if self.training else self._infer_program + if self.training: + return self._train_amp_program if _in_amp_guard( + ) else self._train_program + else: + return self._infer_program @property def program_id(self): - return self._train_program_id if self.training else self._infer_program_id + if self.training: + return self._train_amp_program_id if _in_amp_guard( + ) else self._train_program_id + else: + return self._infer_program_id def _prepare(self, inputs): """ diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 2247d494830..02f9fd1a95e 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -2035,6 +2035,11 @@ class Operator(object): del op_attrs[role_var_name] if len(self.desc.type()) != 0: + # NOTE(Aurelius84): prog.clone() will lead that var.op is always None, + # we add this to fix the problem. + for arg in self.desc.output_arg_names(): + if block.has_var(arg) and block.var(arg).op is None: + block.var(arg).op = self return if type is None: raise ValueError( diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_mnist.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_mnist.py index 8a21c4cfd0e..cac64c73913 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_mnist.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_mnist.py @@ -32,6 +32,9 @@ from predictor_utils import PredictorTools SEED = 2020 +if paddle.fluid.is_compiled_with_cuda(): + paddle.fluid.set_flags({'FLAGS_cudnn_deterministic': True}) + class SimpleImgConvPool(fluid.dygraph.Layer): def __init__(self, @@ -48,7 +51,7 @@ class SimpleImgConvPool(fluid.dygraph.Layer): conv_dilation=1, conv_groups=1, act=None, - use_cudnn=False, + use_cudnn=True, param_attr=None, bias_attr=None): super(SimpleImgConvPool, self).__init__() @@ -101,7 +104,6 @@ class MNIST(fluid.dygraph.Layer): loc=0.0, scale=scale)), act="softmax") - @paddle.jit.to_static def forward(self, inputs, label=None): x = self.inference(inputs) if label is not None: @@ -167,14 +169,14 @@ class TestMNISTWithToStatic(TestMNIST): dygraph_loss_cpu, dygraph_loss_mkldnn)) def train(self, to_static=False): - prog_trans = ProgramTranslator() - prog_trans.enable(to_static) loss_data = [] with fluid.dygraph.guard(self.place): fluid.default_main_program().random_seed = SEED fluid.default_startup_program().random_seed = SEED mnist = MNIST() + if to_static: + mnist = paddle.jit.to_static(mnist) adam = AdamOptimizer( learning_rate=0.001, parameter_list=mnist.parameters()) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_mnist_amp.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_mnist_amp.py new file mode 100644 index 00000000000..d2160ca6416 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_mnist_amp.py @@ -0,0 +1,94 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import unittest +import numpy as np +from time import time +from test_mnist import MNIST, TestMNIST, SEED +from paddle.jit import ProgramTranslator +from paddle.fluid.optimizer import AdamOptimizer + +if paddle.fluid.is_compiled_with_cuda(): + paddle.fluid.set_flags({'FLAGS_cudnn_deterministic': True}) + + +class TestAMP(TestMNIST): + def train_static(self): + return self.train(to_static=True) + + def train_dygraph(self): + return self.train(to_static=False) + + def test_mnist_to_static(self): + dygraph_loss = self.train_dygraph() + static_loss = self.train_static() + # NOTE(Aurelius84): In static AMP training, there is a grep_list but + # dygraph AMP don't. It will bring the numbers of cast_op is different + # and leads to loss has a bit diff. + self.assertTrue( + np.allclose( + dygraph_loss, static_loss, atol=1e-3), + msg='dygraph is {}\n static_res is \n{}'.format(dygraph_loss, + static_loss)) + + def train(self, to_static=False): + paddle.seed(SEED) + mnist = MNIST() + + if to_static: + print("Successfully to apply @to_static.") + mnist = paddle.jit.to_static(mnist) + + adam = AdamOptimizer( + learning_rate=0.001, parameter_list=mnist.parameters()) + + scaler = paddle.amp.GradScaler(init_loss_scaling=1024) + + loss_data = [] + for epoch in range(self.epoch_num): + start = time() + for batch_id, data in enumerate(self.train_reader()): + dy_x_data = np.array([x[0].reshape(1, 28, 28) + for x in data]).astype('float32') + y_data = np.array( + [x[1] for x in data]).astype('int64').reshape(-1, 1) + + img = paddle.to_tensor(dy_x_data) + label = paddle.to_tensor(y_data) + label.stop_gradient = True + + with paddle.amp.auto_cast(): + prediction, acc, avg_loss = mnist(img, label=label) + + scaled = scaler.scale(avg_loss) + scaled.backward() + scaler.minimize(adam, scaled) + + loss_data.append(avg_loss.numpy()[0]) + # save checkpoint + mnist.clear_gradients() + if batch_id % 10 == 0: + print( + "Loss at epoch {} step {}: loss: {:}, acc: {}, cost: {}" + .format(epoch, batch_id, + avg_loss.numpy(), acc.numpy(), time() - start)) + start = time() + if batch_id == 50: + break + return loss_data + + +if __name__ == '__main__': + unittest.main() -- GitLab