未验证 提交 73544322 编写于 作者: Y Yiqun Liu 提交者: GitHub

[AMP] Add python API for collecting operator stats. (#52215)

* [AMP] Add python API for collecting operator stats.

* Fix import and polish codes.

* Add more unittest.

* Add doc for the new APIs.
上级 28927209
......@@ -2808,6 +2808,23 @@ All parameter, weight, gradient are variables in Paddle.
.def("is_pattern_enabled", &platform::ipu::IpuStrategy::IsPatternEnabled);
#endif
m.def("get_low_precision_op_list", [] {
py::dict op_list;
auto list_op = phi::KernelFactory::Instance().GetLowPrecisionKernelList();
for (auto iter = list_op.begin(); iter != list_op.end(); iter++) {
auto op_name = (iter->first).c_str();
auto counts = iter->second;
op_list[op_name] = std::to_string(counts.fp16_called_) + "," +
std::to_string(counts.bf16_called_) + "," +
std::to_string(counts.fp32_called_) + "," +
std::to_string(counts.other_called_);
}
return op_list;
});
m.def("clear_low_precision_op_list",
[] { phi::KernelFactory::Instance().ClearLowPrecisionKernelList(); });
m.def("enable_autotune", [] {
return phi::autotune::AutoTuneStatus::Instance().EnableAutoTune();
});
......@@ -2824,20 +2841,6 @@ All parameter, weight, gradient are variables in Paddle.
m.def("update_autotune_status",
[] { return phi::autotune::AutoTuneStatus::Instance().Update(); });
m.def("get_low_precision_op_list", [] {
py::dict op_list;
auto list_op = phi::KernelFactory::Instance().GetLowPrecisionKernelList();
for (auto iter = list_op.begin(); iter != list_op.end(); iter++) {
auto op_name = (iter->first).c_str();
auto counts = iter->second;
op_list[op_name] = std::to_string(counts.fp16_called_) + "," +
std::to_string(counts.bf16_called_) + "," +
std::to_string(counts.fp32_called_) + "," +
std::to_string(counts.other_called_);
}
return op_list;
});
m.def("autotune_status", [] {
py::dict res;
phi::autotune::AutoTuneCache::Instance().UpdateStatus();
......
......@@ -340,6 +340,8 @@ class KernelFactory {
std::map<const std::string, OpCount> GetLowPrecisionKernelList();
void ClearLowPrecisionKernelList() { low_precision_kernels_.clear(); }
private:
KernelFactory() = default;
......
......@@ -16,7 +16,6 @@ from .auto_cast import auto_cast # noqa: F401
from .auto_cast import decorate # noqa: F401
from .auto_cast import amp_guard # noqa: F401
from .auto_cast import amp_decorate # noqa: F401
from .auto_cast import low_precision_op_list # noqa: F401
from .auto_cast import FP16_WHITE_LIST # noqa: F401
from .auto_cast import FP16_BLACK_LIST # noqa: F401
from .auto_cast import PURE_FP16_WHITE_LIST # noqa: F401
......@@ -27,4 +26,6 @@ from .grad_scaler import GradScaler # noqa: F401
from .grad_scaler import AmpScaler # noqa: F401
from .grad_scaler import OptimizerState # noqa: F401
from . import debugging # noqa: F401
__all__ = ['auto_cast', 'GradScaler', 'decorate']
......@@ -13,7 +13,6 @@
# limitations under the License.
import copy
import os
import warnings
import paddle
......@@ -97,34 +96,6 @@ PURE_BF16_BLACK_LIST = set()
_g_amp_state_ = None
def low_precision_op_list():
if os.getenv("FLAGS_low_precision_op_list") is not None:
level = int(os.getenv("FLAGS_low_precision_op_list"))
print('<{:-^120}>'.format(" op list "))
op_list = paddle.fluid.core.get_low_precision_op_list()
op_count = 0
print(
'<{:-^40}'.format(" Op Name "),
'|',
'{:-^17}'.format("FP16 Calls"),
'|',
'{:-^17}'.format("BF16 Calls"),
'|',
'{:-^17}'.format('FP32 Calls'),
'|',
'{:-^17}>'.format('Other Calls'),
)
for x in op_list:
# fp16, bf16, fp32, other
called = op_list[x].split(",")
print(
' %-40s| %-17s| %-17s| %-17s| %-17s'
% (x, called[0], called[1], called[2], called[3])
)
op_count += 1
print('<{:-^120}>'.format(" op count: " + str(op_count) + " "))
def amp_state():
global _g_amp_state_
return _g_amp_state_
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import paddle
from paddle.fluid.framework import dygraph_only
__all__ = [
"enable_operator_stats_collection",
"disable_operator_stats_collection",
"collect_operator_stats",
]
def _get_operator_stats_flag():
flags = paddle.get_flags(["FLAGS_low_precision_op_list"])
return flags["FLAGS_low_precision_op_list"]
def _print_operator_stats(op_count_dict):
"""
Parse and print the stats of operators, mainly including the calls of
dtypes such as different fp32, fp16, bf16 and others.
Args:
op_count_dict(dict): a dict to record the number of calls for different
operator and dtype. An example is
{'conv2d': '1,0,0,0', 'elementwise_add': '1,0,0,0'} or
{'conv2d': [1, 0, 0, 0], 'elementwise_add': [1, 0, 0, 0]}.
"""
print("<{:-^120}>".format(" op list "))
total_ops = 0
print(
"<{:-^40}".format(" Op Name "),
"|",
"{:-^17}".format(" FP16 Calls "),
"|",
"{:-^17}".format(" BF16 Calls "),
"|",
"{:-^17}".format(" FP32 Calls"),
"|",
"{:-^17}>".format(" Other Calls "),
)
if op_count_dict is not None and isinstance(op_count_dict, dict):
for op_type in op_count_dict:
# fp16, bf16, fp32, other
value = op_count_dict[op_type]
if isinstance(value, list):
called = value
elif isinstance(value, str):
called = value.split(",")
else:
raise ValueError(
"Input {} is expected to be a list of str, but recieved {}.".format(
value, type(value)
)
)
print(
" %-40s| %-17s| %-17s| %-17s| %-17s"
% (op_type, called[0], called[1], called[2], called[3])
)
total_ops += 1
print("<{:-^120}>\n".format(" op count: " + str(total_ops) + " "))
@dygraph_only
def enable_operator_stats_collection():
"""
Enable to collect the number of operators for different data types.
The statistical data are categorized according to four data types, namely
float32, float16, bfloat16 and others. This funciton is used in pair with
the corresponding disable function.
Examples:
.. code-block:: python
import paddle
conv = paddle.nn.Conv2D(3, 2, 3)
x = paddle.rand([10, 3, 32, 32])
paddle.amp.debugging.enable_operator_stats_collection()
# AMP list including conv2d, elementwise_add, reshape2, cast (transfer_dtype)
with paddle.amp.auto_cast(enable=True, level='O2'):
out = conv(x)
# Print to the standard output.
paddle.amp.debugging.disable_operator_stats_collection()
# <------------------------------------------------------- op list -------------------------------------------------------->
# <--------------- Op Name ---------------- | -- FP16 Calls --- | -- BF16 Calls --- | --- FP32 Calls--- | -- Other Calls -->
# conv2d | 1 | 0 | 0 | 0
# elementwise_add | 1 | 0 | 0 | 0
# reshape2 | 1 | 0 | 0 | 0
# transfer_dtype | 0 | 0 | 3 | 0
# <----------------------------------------------------- op count: 4 ------------------------------------------------------>
"""
# Clear the previous stats.
paddle.fluid.core.clear_low_precision_op_list()
paddle.set_flags({'FLAGS_low_precision_op_list': 1})
@dygraph_only
def disable_operator_stats_collection():
"""
Disable the collection the number of operators for different data types.
This funciton is used in pair with the corresponding enable function.
The statistical data are categorized according to four data types, namely
float32, float16, bfloat16 and others, and will be printed after the
function call.
Examples:
.. code-block:: python
import paddle
conv = paddle.nn.Conv2D(3, 2, 3)
x = paddle.rand([10, 3, 32, 32])
paddle.amp.debugging.enable_operator_stats_collection()
# AMP list including conv2d, elementwise_add, reshape2, cast (transfer_dtype)
with paddle.amp.auto_cast(enable=True, level='O2'):
out = conv(x)
# Print to the standard output.
paddle.amp.debugging.disable_operator_stats_collection()
# <------------------------------------------------------- op list -------------------------------------------------------->
# <--------------- Op Name ---------------- | -- FP16 Calls --- | -- BF16 Calls --- | --- FP32 Calls--- | -- Other Calls -->
# conv2d | 1 | 0 | 0 | 0
# elementwise_add | 1 | 0 | 0 | 0
# reshape2 | 1 | 0 | 0 | 0
# transfer_dtype | 0 | 0 | 3 | 0
# <----------------------------------------------------- op count: 4 ------------------------------------------------------>
"""
if not _get_operator_stats_flag():
return
op_count_dict = paddle.fluid.core.get_low_precision_op_list()
_print_operator_stats(op_count_dict)
paddle.set_flags({'FLAGS_low_precision_op_list': 0})
@dygraph_only
@contextlib.contextmanager
def collect_operator_stats():
"""
The context switcher to enable to collect the number of operators for
different data types. The statistical data are categorized according
to four data types, namely float32, float16, bfloat16 and others, and
will be printed when exiting the context.
Examples:
.. code-block:: python
import paddle
conv = paddle.nn.Conv2D(3, 2, 3)
x = paddle.rand([10, 3, 32, 32])
with paddle.amp.debugging.collect_operator_stats():
# AMP list including conv2d, elementwise_add, reshape2, cast (transfer_dtype)
with paddle.amp.auto_cast(enable=True, level='O2'):
out = conv(x)
# Print to the standard output.
# <------------------------------------------------------- op list -------------------------------------------------------->
# <--------------- Op Name ---------------- | -- FP16 Calls --- | -- BF16 Calls --- | --- FP32 Calls--- | -- Other Calls -->
# conv2d | 1 | 0 | 0 | 0
# elementwise_add | 1 | 0 | 0 | 0
# reshape2 | 1 | 0 | 0 | 0
# transfer_dtype | 0 | 0 | 3 | 0
# <----------------------------------------------------- op count: 4 ------------------------------------------------------>
"""
enable_operator_stats_collection()
yield
disable_operator_stats_collection()
......@@ -53,6 +53,7 @@ function(bash_test_modules TARGET_NAME)
endfunction()
if(WITH_TESTING)
add_subdirectory(amp)
add_subdirectory(asp)
# add_subdirectory(auto_parallel)
add_subdirectory(autograd)
......
file(
GLOB TEST_OPS
RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}"
"test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
function(py_test_modules TARGET_NAME)
if(WITH_TESTING)
set(options SERIAL)
set(oneValueArgs "")
set(multiValueArgs MODULES DEPS ENVS)
cmake_parse_arguments(py_test_modules "${options}" "${oneValueArgs}"
"${multiValueArgs}" ${ARGN})
if(WITH_COVERAGE AND NOT (WITH_INCREMENTAL_COVERAGE
AND "$ENV{PADDLE_GIT_DIFF_PY_FILE}" STREQUAL ""))
if(WITH_ASCEND_CL)
add_test(
NAME ${TARGET_NAME}
COMMAND
${CMAKE_COMMAND} -E env
PYTHONPATH=${PADDLE_BINARY_DIR}/python:$ENV{PYTHONPATH}
${py_test_modules_ENVS}
COVERAGE_FILE=${PADDLE_BINARY_DIR}/python-coverage.data
${PYTHON_EXECUTABLE} -m coverage run --branch -p
${PADDLE_SOURCE_DIR}/tools/test_runner.py ${py_test_modules_MODULES}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
else()
add_test(
NAME ${TARGET_NAME}
COMMAND
${CMAKE_COMMAND} -E env PYTHONPATH=${PADDLE_BINARY_DIR}/python
${py_test_modules_ENVS}
COVERAGE_FILE=${PADDLE_BINARY_DIR}/python-coverage.data
${PYTHON_EXECUTABLE} -m coverage run --branch -p
${PADDLE_SOURCE_DIR}/tools/test_runner.py ${py_test_modules_MODULES}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
endif()
else()
if(WITH_ASCEND_CL)
add_test(
NAME ${TARGET_NAME}
COMMAND
${CMAKE_COMMAND} -E env
PYTHONPATH=${PADDLE_BINARY_DIR}/python:$ENV{PYTHONPATH}
${py_test_modules_ENVS} ${PYTHON_EXECUTABLE}
${PADDLE_SOURCE_DIR}/tools/test_runner.py ${py_test_modules_MODULES}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
else()
add_test(
NAME ${TARGET_NAME}
COMMAND
${CMAKE_COMMAND} -E env PYTHONPATH=${PADDLE_BINARY_DIR}/python
${py_test_modules_ENVS} ${PYTHON_EXECUTABLE}
${PADDLE_SOURCE_DIR}/tools/test_runner.py ${py_test_modules_MODULES}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
endif()
endif()
if(py_test_modules_SERIAL)
set_property(TEST ${TARGET_NAME} PROPERTY RUN_SERIAL 1)
endif()
if(WIN32)
set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 150)
endif()
endif()
endfunction()
foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP})
endforeach()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,24 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
class TestAMPList(unittest.TestCase):
def test_main(self):
conv2d = paddle.nn.Conv2D(3, 2, 3, bias_attr=False)
data = paddle.rand([10, 3, 32, 32])
paddle.set_flags({'FLAGS_low_precision_op_list': 1})
a = paddle.rand([2, 3])
b = paddle.rand([2, 3])
# amp list conv2d, cast
with paddle.amp.auto_cast(enable=True, level='O2'):
conv = conv2d(data)
c = a + b
paddle.amp.low_precision_op_list()
def _check_result(self, dtype):
# Returned the dict.
op_list = paddle.fluid.core.get_low_precision_op_list()
self.assertTrue('elementwise_add' in op_list)
......@@ -45,10 +36,34 @@ class TestAMPList(unittest.TestCase):
self.assertTrue(conv_num == 1)
self.assertTrue(add_num == 1)
if conv.dtype == "float16":
if dtype == "float16":
self.assertTrue(int(conv2d_called[0]) == 1)
self.assertTrue(int(add_called[0]) == 1)
def test_enable_disable(self):
conv = paddle.nn.Conv2D(3, 2, 3)
x = paddle.rand([10, 3, 32, 32])
paddle.amp.debugging.enable_operator_stats_collection()
# amp list conv2d, elementwise_add, cast (transfer_dtype)
with paddle.amp.auto_cast(enable=True, level='O2'):
out = conv(x)
# Print to the standard output.
paddle.amp.debugging.disable_operator_stats_collection()
self._check_result(dtype=out.dtype)
def test_context(self):
conv = paddle.nn.Conv2D(3, 2, 3)
x = paddle.rand([10, 3, 32, 32])
with paddle.amp.debugging.collect_operator_stats():
# amp list conv2d, elementwise_add, cast (transfer_dtype)
with paddle.amp.auto_cast(enable=True, level='O2'):
out = conv(x)
self._check_result(dtype=out.dtype)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册