From 195736cfe77d0e7e2dce546eacb3eb8bda6cf584 Mon Sep 17 00:00:00 2001 From: niuliling123 <51102941+niuliling123@users.noreply.github.com> Date: Wed, 21 Dec 2022 15:24:18 +0800 Subject: [PATCH] Add FLAGS_low_precision_op_list to get amp list of current module (#48843) --- paddle/fluid/eager/amp_utils.h | 10 +++++ paddle/fluid/imperative/amp_auto_cast.cc | 11 +++++ paddle/fluid/imperative/amp_auto_cast.h | 7 +++ paddle/fluid/pybind/pybind.cc | 4 ++ paddle/phi/core/flags.cc | 14 ++++++ python/paddle/fluid/dygraph/amp/auto_cast.py | 15 +++++++ .../unittests/test_low_precision_list.py | 43 +++++++++++++++++++ 7 files changed, 104 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/test_low_precision_list.py diff --git a/paddle/fluid/eager/amp_utils.h b/paddle/fluid/eager/amp_utils.h index 115811f6a3..c63912312b 100644 --- a/paddle/fluid/eager/amp_utils.h +++ b/paddle/fluid/eager/amp_utils.h @@ -100,6 +100,7 @@ inline paddle::experimental::DataType GetAmpDestDtype( if (paddle::imperative::AmpOperators::Instance() .GetMutableAllowOps() ->count(op_name)) { + paddle::imperative::AmpOperators::Instance().AddToAmpOpList(op_name); return paddle::experimental::DataType::FLOAT16; } else if (paddle::imperative::AmpOperators::Instance() .GetMutableBlockOps() @@ -117,6 +118,8 @@ inline paddle::experimental::DataType GetAmpDestDtype( .GetMutableUnsupportedFp16Ops() ->count(op_name)) { dst_type = paddle::experimental::DataType::FLOAT32; + } else { + paddle::imperative::AmpOperators::Instance().AddToAmpOpList(op_name); } return dst_type; } @@ -129,6 +132,8 @@ inline paddle::experimental::DataType GetAmpDestDtype( .GetMutableBlockOps() ->count(op_name)) { dst_type = paddle::experimental::DataType::FLOAT32; + } else { + paddle::imperative::AmpOperators::Instance().AddToAmpOpList(op_name); } return dst_type; } @@ -137,6 +142,7 @@ inline paddle::experimental::DataType GetAmpDestDtype( if (paddle::imperative::AmpOperators::Instance() .GetMutableAllowOps() ->count(op_name)) { + paddle::imperative::AmpOperators::Instance().AddToAmpOpList(op_name); return paddle::experimental::DataType::BFLOAT16; } else if (paddle::imperative::AmpOperators::Instance() .GetMutableBlockOps() @@ -152,6 +158,8 @@ inline paddle::experimental::DataType GetAmpDestDtype( .GetMutableUnsupportedBf16Ops() ->count(op_name)) { dst_type = paddle::experimental::DataType::FLOAT32; + } else { + paddle::imperative::AmpOperators::Instance().AddToAmpOpList(op_name); } return dst_type; } @@ -164,6 +172,8 @@ inline paddle::experimental::DataType GetAmpDestDtype( .GetMutableBlockOps() ->count(op_name)) { dst_type = paddle::experimental::DataType::FLOAT32; + } else { + paddle::imperative::AmpOperators::Instance().AddToAmpOpList(op_name); } return dst_type; } diff --git a/paddle/fluid/imperative/amp_auto_cast.cc b/paddle/fluid/imperative/amp_auto_cast.cc index 55c1520820..f8cea38ea6 100644 --- a/paddle/fluid/imperative/amp_auto_cast.cc +++ b/paddle/fluid/imperative/amp_auto_cast.cc @@ -22,6 +22,7 @@ #include "paddle/fluid/imperative/type_defs.h" #include "paddle/fluid/imperative/var_helper.h" +DECLARE_bool(low_precision_op_list); namespace paddle { namespace imperative { @@ -193,6 +194,16 @@ AmpOperators::GetMutableUnsupportedBf16Ops() { return unsupported_bf16_ops_; } +void AmpOperators::AddToAmpOpList(const std::string& op_name) { + if (FLAGS_low_precision_op_list) { + current_amp_ops_[op_name] += 1; + } +} + +std::map AmpOperators::GetAmpOpList() { + return current_amp_ops_; +} + std::ostream& operator<<(std::ostream& os, AmpOperators& ops) { os << "allow ops: "; auto allow_ops = ops.GetMutableAllowOps(); diff --git a/paddle/fluid/imperative/amp_auto_cast.h b/paddle/fluid/imperative/amp_auto_cast.h index 3bee230860..343b01dedb 100644 --- a/paddle/fluid/imperative/amp_auto_cast.h +++ b/paddle/fluid/imperative/amp_auto_cast.h @@ -60,6 +60,10 @@ class AmpOperators { std::shared_ptr> GetMutableUnsupportedBf16Ops(); + void AddToAmpOpList(const std::string& op_name); + + std::map GetAmpOpList(); + private: AmpOperators(); // forbid calling default constructor @@ -76,6 +80,9 @@ class AmpOperators { // The set of ops that has no bf16 CUDA kennel. std::shared_ptr> unsupported_bf16_ops_; + + // The amp op list of current module. + std::map current_amp_ops_; }; std::ostream& operator<<(std::ostream& os, AmpOperators& ops); diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 9ee0d3e473..2b07b0d9cd 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -2545,6 +2545,10 @@ All parameter, weight, gradient are variables in Paddle. m.def("update_autotune_status", [] { return phi::autotune::AutoTuneStatus::Instance().Update(); }); + m.def("get_low_precision_op_list", [] { + return paddle::imperative::AmpOperators::Instance().GetAmpOpList(); + }); + m.def("autotune_status", [] { py::dict res; phi::autotune::AutoTuneCache::Instance().UpdateStatus(); diff --git a/paddle/phi/core/flags.cc b/paddle/phi/core/flags.cc index ee3caeea36..cdcf67f245 100644 --- a/paddle/phi/core/flags.cc +++ b/paddle/phi/core/flags.cc @@ -52,6 +52,20 @@ PADDLE_DEFINE_EXPORTED_int32(paddle_num_threads, 1, "Number of threads for each paddle instance."); +/** + * Low Precision Op related FLAG + * Name: FLAGS_low_precision_op_list + * Since Version: 0.13.0 + * Value Range: bool, default=false + * Example: + * Note: Used to debug. Get the low precision op list of current module. + */ +PADDLE_DEFINE_EXPORTED_bool(low_precision_op_list, + false, + "Checking whether get the low precision op list of " + "current module. It will be " + "rerun the low precision list after module."); + /** * Operator related FLAG * Name: FLAGS_check_nan_inf diff --git a/python/paddle/fluid/dygraph/amp/auto_cast.py b/python/paddle/fluid/dygraph/amp/auto_cast.py index 1f644147a2..9c62f8edba 100644 --- a/python/paddle/fluid/dygraph/amp/auto_cast.py +++ b/python/paddle/fluid/dygraph/amp/auto_cast.py @@ -110,6 +110,21 @@ PURE_BF16_BLACK_LIST = set() _g_amp_state_ = None +def low_precision_op_list(): + op_list = paddle.fluid.core.get_low_precision_op_list() + op_count = 0 + print('<---------------- low precision op list ------------------->') + print('<---- op name ------|------- op count---------------------->') + for x in op_list: + print(' %-18s| %4d' % (x, op_list[x])) + op_count += 1 + print( + '<------------- low precision op num:{:5d} ----------------->'.format( + op_count + ) + ) + + def amp_state(): global _g_amp_state_ return _g_amp_state_ diff --git a/python/paddle/fluid/tests/unittests/test_low_precision_list.py b/python/paddle/fluid/tests/unittests/test_low_precision_list.py new file mode 100644 index 0000000000..7099fbe168 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_low_precision_list.py @@ -0,0 +1,43 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import paddle + + +class TestAMPList(unittest.TestCase): + def test_main(self): + conv2d = paddle.nn.Conv2D(3, 2, 3, bias_attr=False) + data = paddle.rand([10, 3, 32, 32]) + paddle.set_flags({'FLAGS_low_precision_op_list': 1}) + a = paddle.rand([2, 3]) + b = paddle.rand([2, 3]) + + # amp list conv2d, cast + with paddle.amp.auto_cast(): + conv = conv2d(data) + c = a + b + paddle.fluid.dygraph.amp.auto_cast.low_precision_op_list() + op_list = paddle.fluid.core.get_low_precision_op_list() + print(conv.dtype) + if conv.dtype == paddle.float16: + self.assertTrue('elementwise_add' in op_list) + self.assertTrue('conv2d' in op_list) + self.assertTrue(2 == len(op_list)) + else: + self.assertTrue(0 == len(op_list)) + + +if __name__ == "__main__": + unittest.main() -- GitLab