未验证 提交 195736cf 编写于 作者: N niuliling123 提交者: GitHub

Add FLAGS_low_precision_op_list to get amp list of current module (#48843)

上级 b51a752f
......@@ -100,6 +100,7 @@ inline paddle::experimental::DataType GetAmpDestDtype(
if (paddle::imperative::AmpOperators::Instance()
.GetMutableAllowOps()
->count(op_name)) {
paddle::imperative::AmpOperators::Instance().AddToAmpOpList(op_name);
return paddle::experimental::DataType::FLOAT16;
} else if (paddle::imperative::AmpOperators::Instance()
.GetMutableBlockOps()
......@@ -117,6 +118,8 @@ inline paddle::experimental::DataType GetAmpDestDtype(
.GetMutableUnsupportedFp16Ops()
->count(op_name)) {
dst_type = paddle::experimental::DataType::FLOAT32;
} else {
paddle::imperative::AmpOperators::Instance().AddToAmpOpList(op_name);
}
return dst_type;
}
......@@ -129,6 +132,8 @@ inline paddle::experimental::DataType GetAmpDestDtype(
.GetMutableBlockOps()
->count(op_name)) {
dst_type = paddle::experimental::DataType::FLOAT32;
} else {
paddle::imperative::AmpOperators::Instance().AddToAmpOpList(op_name);
}
return dst_type;
}
......@@ -137,6 +142,7 @@ inline paddle::experimental::DataType GetAmpDestDtype(
if (paddle::imperative::AmpOperators::Instance()
.GetMutableAllowOps()
->count(op_name)) {
paddle::imperative::AmpOperators::Instance().AddToAmpOpList(op_name);
return paddle::experimental::DataType::BFLOAT16;
} else if (paddle::imperative::AmpOperators::Instance()
.GetMutableBlockOps()
......@@ -152,6 +158,8 @@ inline paddle::experimental::DataType GetAmpDestDtype(
.GetMutableUnsupportedBf16Ops()
->count(op_name)) {
dst_type = paddle::experimental::DataType::FLOAT32;
} else {
paddle::imperative::AmpOperators::Instance().AddToAmpOpList(op_name);
}
return dst_type;
}
......@@ -164,6 +172,8 @@ inline paddle::experimental::DataType GetAmpDestDtype(
.GetMutableBlockOps()
->count(op_name)) {
dst_type = paddle::experimental::DataType::FLOAT32;
} else {
paddle::imperative::AmpOperators::Instance().AddToAmpOpList(op_name);
}
return dst_type;
}
......
......@@ -22,6 +22,7 @@
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/imperative/var_helper.h"
DECLARE_bool(low_precision_op_list);
namespace paddle {
namespace imperative {
......@@ -193,6 +194,16 @@ AmpOperators::GetMutableUnsupportedBf16Ops() {
return unsupported_bf16_ops_;
}
void AmpOperators::AddToAmpOpList(const std::string& op_name) {
if (FLAGS_low_precision_op_list) {
current_amp_ops_[op_name] += 1;
}
}
std::map<const std::string, int> AmpOperators::GetAmpOpList() {
return current_amp_ops_;
}
std::ostream& operator<<(std::ostream& os, AmpOperators& ops) {
os << "allow ops: ";
auto allow_ops = ops.GetMutableAllowOps();
......
......@@ -60,6 +60,10 @@ class AmpOperators {
std::shared_ptr<std::unordered_set<std::string>>
GetMutableUnsupportedBf16Ops();
void AddToAmpOpList(const std::string& op_name);
std::map<const std::string, int> GetAmpOpList();
private:
AmpOperators(); // forbid calling default constructor
......@@ -76,6 +80,9 @@ class AmpOperators {
// The set of ops that has no bf16 CUDA kennel.
std::shared_ptr<std::unordered_set<std::string>> unsupported_bf16_ops_;
// The amp op list of current module.
std::map<const std::string, int> current_amp_ops_;
};
std::ostream& operator<<(std::ostream& os, AmpOperators& ops);
......
......@@ -2545,6 +2545,10 @@ All parameter, weight, gradient are variables in Paddle.
m.def("update_autotune_status",
[] { return phi::autotune::AutoTuneStatus::Instance().Update(); });
m.def("get_low_precision_op_list", [] {
return paddle::imperative::AmpOperators::Instance().GetAmpOpList();
});
m.def("autotune_status", [] {
py::dict res;
phi::autotune::AutoTuneCache::Instance().UpdateStatus();
......
......@@ -52,6 +52,20 @@ PADDLE_DEFINE_EXPORTED_int32(paddle_num_threads,
1,
"Number of threads for each paddle instance.");
/**
* Low Precision Op related FLAG
* Name: FLAGS_low_precision_op_list
* Since Version: 0.13.0
* Value Range: bool, default=false
* Example:
* Note: Used to debug. Get the low precision op list of current module.
*/
PADDLE_DEFINE_EXPORTED_bool(low_precision_op_list,
false,
"Checking whether get the low precision op list of "
"current module. It will be "
"rerun the low precision list after module.");
/**
* Operator related FLAG
* Name: FLAGS_check_nan_inf
......
......@@ -110,6 +110,21 @@ PURE_BF16_BLACK_LIST = set()
_g_amp_state_ = None
def low_precision_op_list():
op_list = paddle.fluid.core.get_low_precision_op_list()
op_count = 0
print('<---------------- low precision op list ------------------->')
print('<---- op name ------|------- op count---------------------->')
for x in op_list:
print(' %-18s| %4d' % (x, op_list[x]))
op_count += 1
print(
'<------------- low precision op num:{:5d} ----------------->'.format(
op_count
)
)
def amp_state():
global _g_amp_state_
return _g_amp_state_
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
class TestAMPList(unittest.TestCase):
def test_main(self):
conv2d = paddle.nn.Conv2D(3, 2, 3, bias_attr=False)
data = paddle.rand([10, 3, 32, 32])
paddle.set_flags({'FLAGS_low_precision_op_list': 1})
a = paddle.rand([2, 3])
b = paddle.rand([2, 3])
# amp list conv2d, cast
with paddle.amp.auto_cast():
conv = conv2d(data)
c = a + b
paddle.fluid.dygraph.amp.auto_cast.low_precision_op_list()
op_list = paddle.fluid.core.get_low_precision_op_list()
print(conv.dtype)
if conv.dtype == paddle.float16:
self.assertTrue('elementwise_add' in op_list)
self.assertTrue('conv2d' in op_list)
self.assertTrue(2 == len(op_list))
else:
self.assertTrue(0 == len(op_list))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册