提交 0293d58a 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

feat(mge): add bfloat16 support

GitOrigin-RevId: a942ce67915dac6b1203c974daee7ad8d0902e0a
上级 c3d5b61f
......@@ -29,6 +29,7 @@
#define MEGDNN_FLOAT16_SELECT(_x, _y) _y
#else
#include "megdnn/dtype/half.hpp"
#include "megdnn/dtype/bfloat16.hpp"
#define MEGDNN_INC_FLOAT16(_x) _x
#define MEGDNN_FLOAT16_SELECT(_x, _y) _x
#endif
......@@ -49,6 +50,7 @@ namespace megdnn {
cb(IntB4) \
cb(Byte) \
MEGDNN_INC_FLOAT16(cb(Float16)) \
MEGDNN_INC_FLOAT16(cb(BFloat16)) \
cb(UintB4) \
/*!
......@@ -62,6 +64,7 @@ namespace megdnn {
cb(Int32) \
cb(Byte) \
MEGDNN_INC_FLOAT16(cb(Float16)) \
MEGDNN_INC_FLOAT16(cb(BFloat16)) \
/*!
* \brief iterate through each fractional byte dtype
......@@ -101,6 +104,7 @@ namespace megdnn {
#define MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) \
cb(::megdnn::dtype::Float32) \
MEGDNN_INC_FLOAT16(cb(::megdnn::dtype::Float16)) \
MEGDNN_INC_FLOAT16(cb(::megdnn::dtype::BFloat16)) \
/*!
* \brief iterate through each dtype object that can be involved in integer
......@@ -345,6 +349,7 @@ typedef int16_t dt_int16;
typedef int8_t dt_int8;
typedef uint8_t dt_uint8;
MEGDNN_INC_FLOAT16(typedef half_float::half dt_float16;)
MEGDNN_INC_FLOAT16(typedef half_bfloat16::bfloat16 dt_bfloat16;)
#define MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE 100000
#if MEGDNN_CC_HOST
......@@ -367,6 +372,9 @@ MEGDNN_INC_FLOAT16(typedef half_float::half dt_float16;)
Float16,
#endif
UintB4 = 10,
#if !MEGDNN_DISABLE_FLOAT16
BFloat16 = 11,
#endif
#define FST(_name) _name = MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE,
#define D(_name) _name,
......@@ -702,6 +710,9 @@ MEGDNN_DEF_DT(Uint8, dt_uint8, INT, UNSIGNED, 0, UINT8_MAX);
MEGDNN_INC_FLOAT16(MEGDNN_DEF_DT(Float16, dt_float16, FLOAT, SIGNED,
std::numeric_limits<dt_float16>::lowest(),
std::numeric_limits<dt_float16>::max()));
MEGDNN_INC_FLOAT16(MEGDNN_DEF_DT(BFloat16, dt_bfloat16, FLOAT, SIGNED,
std::numeric_limits<dt_bfloat16>::lowest(),
std::numeric_limits<dt_bfloat16>::max()));
template <>
struct DTypeTrait<dtype::Byte> {
......
此差异已折叠。
......@@ -50,167 +50,7 @@
#include <hip/hip_fp16.h>
#endif
/// Combined gcc version number.
#define HALF_GNUC_VERSION (__GNUC__*100+__GNUC_MINOR__)
//check C++11 language features
#if defined(__clang__) //clang
#if __has_feature(cxx_static_assert) && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT)
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1
#endif
#if __has_feature(cxx_constexpr) && !defined(HALF_ENABLE_CPP11_CONSTEXPR)
#define HALF_ENABLE_CPP11_CONSTEXPR 1
#endif
#if __has_feature(cxx_noexcept) && !defined(HALF_ENABLE_CPP11_NOEXCEPT)
#define HALF_ENABLE_CPP11_NOEXCEPT 1
#endif
#if __has_feature(cxx_user_literals) && !defined(HALF_ENABLE_CPP11_USER_LITERALS)
#define HALF_ENABLE_CPP11_USER_LITERALS 1
#endif
#if (defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L) && !defined(HALF_ENABLE_CPP11_LONG_LONG)
#define HALF_ENABLE_CPP11_LONG_LONG 1
#endif
/*#elif defined(__INTEL_COMPILER) //Intel C++
#if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) ????????
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1
#endif
#if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) ????????
#define HALF_ENABLE_CPP11_CONSTEXPR 1
#endif
#if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) ????????
#define HALF_ENABLE_CPP11_NOEXCEPT 1
#endif
#if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_LONG_LONG) ????????
#define HALF_ENABLE_CPP11_LONG_LONG 1
#endif*/
#elif defined(__GNUC__) //gcc
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT)
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1
#endif
#if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_CONSTEXPR)
#define HALF_ENABLE_CPP11_CONSTEXPR 1
#endif
#if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_NOEXCEPT)
#define HALF_ENABLE_CPP11_NOEXCEPT 1
#endif
#if HALF_GNUC_VERSION >= 407 && !defined(HALF_ENABLE_CPP11_USER_LITERALS)
#define HALF_ENABLE_CPP11_USER_LITERALS 1
#endif
#if !defined(HALF_ENABLE_CPP11_LONG_LONG)
#define HALF_ENABLE_CPP11_LONG_LONG 1
#endif
#endif
#elif defined(_MSC_VER) //Visual C++
#if _MSC_VER >= 1600 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT)
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1
#endif
#if _MSC_VER >= 1310 && !defined(HALF_ENABLE_CPP11_LONG_LONG)
#define HALF_ENABLE_CPP11_LONG_LONG 1
#endif
#define HALF_POP_WARNINGS 1
#pragma warning(push)
//! 4521 and 4522 is multiple copy/assigment operator specified
#pragma warning(disable : 4099 4127 4146 4521 4522) //struct vs class, constant in if, negative unsigned
#endif
//check C++11 library features
#include <utility>
#if defined(_LIBCPP_VERSION) //libc++
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103
#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1
#endif
#ifndef HALF_ENABLE_CPP11_CSTDINT
#define HALF_ENABLE_CPP11_CSTDINT 1
#endif
#ifndef HALF_ENABLE_CPP11_CMATH
#define HALF_ENABLE_CPP11_CMATH 1
#endif
#ifndef HALF_ENABLE_CPP11_HASH
#define HALF_ENABLE_CPP11_HASH 1
#endif
#endif
#elif defined(__GLIBCXX__) //libstdc++
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103
#ifdef __clang__
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS)
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1
#endif
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CSTDINT)
#define HALF_ENABLE_CPP11_CSTDINT 1
#endif
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CMATH)
#define HALF_ENABLE_CPP11_CMATH 1
#endif
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_HASH)
#define HALF_ENABLE_CPP11_HASH 1
#endif
#else
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CSTDINT)
#define HALF_ENABLE_CPP11_CSTDINT 1
#endif
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CMATH)
#define HALF_ENABLE_CPP11_CMATH 1
#endif
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_HASH)
#define HALF_ENABLE_CPP11_HASH 1
#endif
#endif
#endif
#elif defined(_CPPLIB_VER) //Dinkumware/Visual C++
#if _CPPLIB_VER >= 520
#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1
#endif
#ifndef HALF_ENABLE_CPP11_CSTDINT
#define HALF_ENABLE_CPP11_CSTDINT 1
#endif
#ifndef HALF_ENABLE_CPP11_HASH
#define HALF_ENABLE_CPP11_HASH 1
#endif
#endif
#if _CPPLIB_VER >= 610
#ifndef HALF_ENABLE_CPP11_CMATH
#define HALF_ENABLE_CPP11_CMATH 1
#endif
#endif
#endif
#undef HALF_GNUC_VERSION
//support constexpr
#if HALF_ENABLE_CPP11_CONSTEXPR
#define HALF_CONSTEXPR constexpr
#define HALF_CONSTEXPR_CONST constexpr
#else
#define HALF_CONSTEXPR
#define HALF_CONSTEXPR_CONST const
#endif
//support noexcept
#if HALF_ENABLE_CPP11_NOEXCEPT
#define HALF_NOEXCEPT noexcept
#define HALF_NOTHROW noexcept
#else
#define HALF_NOEXCEPT
#define HALF_NOTHROW throw()
#endif
#include <algorithm>
#include <limits>
#include <climits>
#include <cmath>
#include <cstring>
#if HALF_ENABLE_CPP11_TYPE_TRAITS
#include <type_traits>
#endif
#if HALF_ENABLE_CPP11_CSTDINT
#include <cstdint>
#endif
#if HALF_ENABLE_CPP11_HASH
#include <functional>
#endif
#include "megdnn/dtype/half_common_prologue.h"
/// Default rounding mode.
/// This specifies the rounding mode used for all conversions between [half](\ref half_float::half)s and `float`s as well as
......@@ -3141,16 +2981,7 @@ namespace std
#endif
}
#undef HALF_CONSTEXPR
#undef HALF_CONSTEXPR_CONST
#undef HALF_NOEXCEPT
#undef HALF_NOTHROW
#ifdef HALF_POP_WARNINGS
#pragma warning(pop)
#undef HALF_POP_WARNINGS
#endif
#include "megdnn/dtype/half_common_epilogue.h"
#endif
// vim: syntax=cpp.doxygen
/**
* half - IEEE 754-based half-precision floating point library.
*
* Copyright (c) 2012-2013 Christian Rau <rauy@users.sourceforge.net>
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation
* files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy,
* modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
* WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
* ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*
* Version 1.11.0
* \file
* Main header file for half precision functionality.
*
* --------------------------------------------------------------------------
* \file include/megdnn/dtype/half_common_epilogue.h
*
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*
* This file has been modified by Megvii ("Megvii Modifications").
* All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved.
*
* --------------------------------------------------------------------------
*/
#undef HALF_CONSTEXPR
#undef HALF_CONSTEXPR_CONST
#undef HALF_NOEXCEPT
#undef HALF_NOTHROW
#ifdef HALF_POP_WARNINGS
#pragma warning(pop)
#undef HALF_POP_WARNINGS
#endif
// vim: syntax=cpp.doxygen
/**
* half - IEEE 754-based half-precision floating point library.
*
* Copyright (c) 2012-2013 Christian Rau <rauy@users.sourceforge.net>
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation
* files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy,
* modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
* WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
* ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*
* Version 1.11.0
* \file
* Main header file for half precision functionality.
*
* --------------------------------------------------------------------------
* \file dnn/include/megdnn/dtype/half_common_prologue.h
*
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*
* This file has been modified by Megvii ("Megvii Modifications").
* All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved.
*
* --------------------------------------------------------------------------
*/
#include "megdnn/arch.h"
/// Combined gcc version number.
#define HALF_GNUC_VERSION (__GNUC__*100+__GNUC_MINOR__)
//check C++11 language features
#if defined(__clang__) //clang
#if __has_feature(cxx_static_assert) && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT)
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1
#endif
#if __has_feature(cxx_constexpr) && !defined(HALF_ENABLE_CPP11_CONSTEXPR)
#define HALF_ENABLE_CPP11_CONSTEXPR 1
#endif
#if __has_feature(cxx_noexcept) && !defined(HALF_ENABLE_CPP11_NOEXCEPT)
#define HALF_ENABLE_CPP11_NOEXCEPT 1
#endif
#if __has_feature(cxx_user_literals) && !defined(HALF_ENABLE_CPP11_USER_LITERALS)
#define HALF_ENABLE_CPP11_USER_LITERALS 1
#endif
#if (defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L) && !defined(HALF_ENABLE_CPP11_LONG_LONG)
#define HALF_ENABLE_CPP11_LONG_LONG 1
#endif
/*#elif defined(__INTEL_COMPILER) //Intel C++
#if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) ????????
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1
#endif
#if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) ????????
#define HALF_ENABLE_CPP11_CONSTEXPR 1
#endif
#if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) ????????
#define HALF_ENABLE_CPP11_NOEXCEPT 1
#endif
#if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_LONG_LONG) ????????
#define HALF_ENABLE_CPP11_LONG_LONG 1
#endif*/
#elif defined(__GNUC__) //gcc
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT)
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1
#endif
#if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_CONSTEXPR)
#define HALF_ENABLE_CPP11_CONSTEXPR 1
#endif
#if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_NOEXCEPT)
#define HALF_ENABLE_CPP11_NOEXCEPT 1
#endif
#if HALF_GNUC_VERSION >= 407 && !defined(HALF_ENABLE_CPP11_USER_LITERALS)
#define HALF_ENABLE_CPP11_USER_LITERALS 1
#endif
#if !defined(HALF_ENABLE_CPP11_LONG_LONG)
#define HALF_ENABLE_CPP11_LONG_LONG 1
#endif
#endif
#elif defined(_MSC_VER) //Visual C++
#if _MSC_VER >= 1600 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT)
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1
#endif
#if _MSC_VER >= 1310 && !defined(HALF_ENABLE_CPP11_LONG_LONG)
#define HALF_ENABLE_CPP11_LONG_LONG 1
#endif
#define HALF_POP_WARNINGS 1
#pragma warning(push)
//! 4521 and 4522 is multiple copy/assigment operator specified
#pragma warning(disable : 4099 4127 4146 4521 4522) //struct vs class, constant in if, negative unsigned
#endif
//check C++11 library features
#include <utility>
#if defined(_LIBCPP_VERSION) //libc++
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103
#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1
#endif
#ifndef HALF_ENABLE_CPP11_CSTDINT
#define HALF_ENABLE_CPP11_CSTDINT 1
#endif
#ifndef HALF_ENABLE_CPP11_CMATH
#define HALF_ENABLE_CPP11_CMATH 1
#endif
#ifndef HALF_ENABLE_CPP11_HASH
#define HALF_ENABLE_CPP11_HASH 1
#endif
#endif
#elif defined(__GLIBCXX__) //libstdc++
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103
#ifdef __clang__
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS)
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1
#endif
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CSTDINT)
#define HALF_ENABLE_CPP11_CSTDINT 1
#endif
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CMATH)
#define HALF_ENABLE_CPP11_CMATH 1
#endif
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_HASH)
#define HALF_ENABLE_CPP11_HASH 1
#endif
#else
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CSTDINT)
#define HALF_ENABLE_CPP11_CSTDINT 1
#endif
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CMATH)
#define HALF_ENABLE_CPP11_CMATH 1
#endif
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_HASH)
#define HALF_ENABLE_CPP11_HASH 1
#endif
#endif
#endif
#elif defined(_CPPLIB_VER) //Dinkumware/Visual C++
#if _CPPLIB_VER >= 520
#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1
#endif
#ifndef HALF_ENABLE_CPP11_CSTDINT
#define HALF_ENABLE_CPP11_CSTDINT 1
#endif
#ifndef HALF_ENABLE_CPP11_HASH
#define HALF_ENABLE_CPP11_HASH 1
#endif
#endif
#if _CPPLIB_VER >= 610
#ifndef HALF_ENABLE_CPP11_CMATH
#define HALF_ENABLE_CPP11_CMATH 1
#endif
#endif
#endif
#undef HALF_GNUC_VERSION
//support constexpr
#if HALF_ENABLE_CPP11_CONSTEXPR
#define HALF_CONSTEXPR constexpr
#define HALF_CONSTEXPR_CONST constexpr
#else
#define HALF_CONSTEXPR
#define HALF_CONSTEXPR_CONST const
#endif
//support noexcept
#if HALF_ENABLE_CPP11_NOEXCEPT
#define HALF_NOEXCEPT noexcept
#define HALF_NOTHROW noexcept
#else
#define HALF_NOEXCEPT
#define HALF_NOTHROW throw()
#endif
#include <algorithm>
#include <limits>
#include <climits>
#include <cmath>
#include <cstring>
#if HALF_ENABLE_CPP11_TYPE_TRAITS
#include <type_traits>
#endif
#if HALF_ENABLE_CPP11_CSTDINT
#include <cstdint>
#endif
#if HALF_ENABLE_CPP11_HASH
#include <functional>
#endif
// vim: syntax=cpp.doxygen
......@@ -30,7 +30,7 @@ def main():
w('// generated by gen_cond_take_kern_impls.py')
w('#include "../kern.inl"')
w('')
if dtype == 'dt_float16':
if dtype == 'dt_float16' or dtype == 'dt_bfloat16':
w('#if !MEGDNN_DISABLE_FLOAT16')
w('namespace megdnn {')
w('namespace cuda {')
......@@ -48,7 +48,7 @@ def main():
w('} // cond_take')
w('} // cuda')
w('} // megdnn')
if dtype == 'dt_float16':
if dtype == 'dt_float16' or dtype == 'dt_bfloat16':
w('#endif')
print('generated {}'.format(fname))
......
......@@ -34,7 +34,7 @@ def main():
w = lambda s: print(s, file=fout)
w('// generated by gen_elemwise_kern_impls.py')
if ctype == 'dt_float16':
if ctype == 'dt_float16' or ctype == 'dt_bfloat16':
w('#if !MEGDNN_DISABLE_FLOAT16')
w('#define KERN_IMPL_MODE(cb) {}'.format(formode))
......@@ -42,7 +42,7 @@ def main():
w('#define KERN_IMPL_CTYPE {}'.format(ctype))
w('#include "../kern_impl.inl"')
if ctype == 'dt_float16':
if ctype == 'dt_float16' or ctype == 'dt_bfloat16':
w('#endif')
print('generated {}'.format(fname))
......
......@@ -30,14 +30,14 @@ def main():
w = lambda s: print(s, file=fout)
w('// generated by gen_elemwise_special_kern_impls.py')
if dtype == 'dt_float16':
if dtype == 'dt_float16' or dtype == 'dt_bfloat16':
w('#if !MEGDNN_DISABLE_FLOAT16')
w('#include "../special_kerns.inl"')
w('INST(::megdnn::dtype::{})'.format(DTYPES[dtype][0]))
w('#undef INST')
w('}')
w('}')
if dtype == 'dt_float16':
if dtype == 'dt_float16' or dtype == 'dt_bfloat16':
w('#endif')
print('generated {}'.format(fname))
......
......@@ -6,7 +6,8 @@ DTYPES = {'dt_int32': ('Int32', 'INT'),
'dt_int8': ('Int8', 'INT'),
'dt_int16': ('Int16', 'INT'),
'dt_float32': ('Float32', 'FLOAT'),
'dt_float16': ('Float16', 'FLOAT')
'dt_float16': ('Float16', 'FLOAT'),
'dt_bfloat16': ('BFloat16', 'FLOAT')
}
MODES = {
......
......@@ -618,9 +618,10 @@ void ConvolutionBase<Parameter>::check_or_deduce_dtype_fwd(DType src,
megdnn_assert(param().compute_mode != Param::ComputeMode::FLOAT32
#if !MEGDNN_DISABLE_FLOAT16
|| src.enumv() == DTypeEnum::Float16
|| src.enumv() == DTypeEnum::BFloat16
#endif
,
"ComputeMode::FLOAT32 is only available for Float16 "
,
"ComputeMode::FLOAT32 is only available for Float16/BFloat16 "
"input / output.");
}
......@@ -1036,9 +1037,10 @@ void ConvolutionBackwardData::deduce_dtype(DType filter, DType diff,
megdnn_assert(param().compute_mode != Param::ComputeMode::FLOAT32
#if !MEGDNN_DISABLE_FLOAT16
|| filter.enumv() == DTypeEnum::Float16
|| filter.enumv() == DTypeEnum::BFloat16
#endif
,
"ComputeMode::FLOAT32 is only available for Float16 "
,
"ComputeMode::FLOAT32 is only available for Float16/BFloat16 "
"input / output.");
}
......
......@@ -87,7 +87,8 @@ namespace megdnn {
//! define kernel for all float types
#define DEF_KERN_FLOAT(_mode, _imp) \
DEF_KERN(dt_float32, _mode, _imp); \
MEGDNN_INC_FLOAT16(DEF_KERN(dt_float16, _mode, _imp);)
MEGDNN_INC_FLOAT16(DEF_KERN(dt_float16, _mode, _imp);) \
MEGDNN_INC_FLOAT16(DEF_KERN(dt_bfloat16, _mode, _imp);)
//! define kernel for all int types
#define DEF_KERN_INT(_mode, _imp) \
......
......@@ -69,11 +69,11 @@ void MatrixMulForward::deduce_layout(const TensorLayout& A,
C = TensorLayout(TensorShape({A0, B1}), C.dtype);
} else {
auto do_deduce = [&](size_t pack_size) {
megdnn_assert(
A.ndim == 4 && B.ndim == 3,
"matmul requires input dimension to be A(4), B(3); get: %s %s",
A.TensorShape::to_string().c_str(),
B.TensorShape::to_string().c_str());
megdnn_assert(A.ndim == 4 && B.ndim == 3,
"matmul requires input dimension to be A(4), B(3); "
"get: %s %s",
A.TensorShape::to_string().c_str(),
B.TensorShape::to_string().c_str());
A0 = A.shape[0];
A1 = A.shape[1];
B0 = B.shape[0];
......@@ -82,11 +82,11 @@ void MatrixMulForward::deduce_layout(const TensorLayout& A,
std::swap(A0, A1);
if (m_param.transposeB)
std::swap(B0, B1);
megdnn_assert(
A1 == B0,
"shape mismatch in matmal: (transposed) A is (%zu,%zu,4,4), "
"(transposed) B is (%zu,%zu,4)",
A0, A1, B0, B1);
megdnn_assert(A1 == B0,
"shape mismatch in matmal: (transposed) A is "
"(%zu,%zu,4,4), "
"(transposed) B is (%zu,%zu,4)",
A0, A1, B0, B1);
C = TensorLayout(TensorShape({A0, B1, pack_size}), C.dtype);
};
do_deduce(pack_size(param().format));
......@@ -172,8 +172,9 @@ void MatrixMulForward::check_exec(const TensorLayout& A, const TensorLayout& B,
}
megdnn_assert(param().compute_mode !=
Param::ComputeMode::FLOAT32 MEGDNN_INC_FLOAT16(
|| A.dtype == dtype::Float16()),
"ComputeMode::FLOAT32 is only available for Float16 "
|| A.dtype == dtype::Float16() ||
A.dtype == dtype::BFloat16()),
"ComputeMode::FLOAT32 is only available for Float16/BFloat16 "
"input / output.");
auto required_workspace_in_bytes = get_workspace_in_bytes(A, B, C);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
......
......@@ -46,6 +46,14 @@ struct RoundingConverter<half_float::half> {
}
};
template <>
struct RoundingConverter<half_bfloat16::bfloat16> {
__host__ __device__ __forceinline__ half_bfloat16::bfloat16 operator()(
float x) const {
return static_cast<half_bfloat16::bfloat16>(x);
}
};
#endif // #ifdef MEGDNN_DISABLE_FLOAT16
template <>
......
......@@ -16,6 +16,7 @@
#include "megdnn/dtype.h"
#include "megdnn/handle.h"
#include "megdnn/thin/small_vector.h"
#include "megdnn/oprs/general.h"
#include "src/common/hash_ct.h"
#include "src/common/utils.cuh"
......@@ -548,6 +549,59 @@ public:
std::string to_string() const;
};
/**!
* \brief helpers for oprs using typecvt between comp_type and dst_type
* \tparam SrcType src type
* \tparam CompType compute type, such as fp32 for conv
* \tparam DstType dst type
*/
template <typename SrcType, typename CompType, typename DstType = SrcType>
struct CompTypeCvter {
std::unique_ptr<TypeCvt> m_cvt_opr;
WorkspaceBundle* m_workspace_bundle;
size_t m_workspace_idx;
CompTypeCvter(Handle* handle, WorkspaceBundle* bundle)
: m_workspace_bundle(bundle), m_workspace_idx(0) {
megdnn_assert(
(DTypeTrait<SrcType>::enumv != DTypeTrait<CompType>::enumv &&
DTypeTrait<DstType>::enumv != DTypeTrait<CompType>::enumv),
"SrcType(%s) == CompType(%s) or DstType(%s) == CompType(%s) is "
"not "
"supportted.",
SrcType().name(), CompType().name(), DstType().name(),
CompType().name());
m_cvt_opr = handle->create_operator<TypeCvt>();
}
//! Convert tensor dtype from SrcType to CompType.
CompTypeCvter& src_to_comp_type(const TensorND& src, TensorND& comp) {
if (src.layout.dtype.enumv() == DTypeTrait<SrcType>::enumv) {
if (!comp.layout.dtype.valid() ||
comp.layout.dtype.enumv() != DTypeTrait<CompType>::enumv) {
comp.layout.dtype = CompType();
comp.layout.init_contiguous_stride();
comp.raw_ptr = m_workspace_bundle->get(m_workspace_idx++);
if (src.layout.ndim) {
m_cvt_opr->exec(src, comp);
}
}
}
return *this;
}
//! Convert tensor dtype from CompType to DstType.
CompTypeCvter& comp_to_dst_type(const TensorND& comp, const TensorND& dst) {
megdnn_assert(comp.layout.dtype.enumv() == DTypeTrait<CompType>::enumv);
if (dst.layout.dtype.enumv() == DTypeTrait<DstType>::enumv) {
m_cvt_opr->exec(comp, dst);
}
return *this;
}
Workspace workspace() {
return m_workspace_bundle->get_workspace(m_workspace_idx);
}
};
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -55,17 +55,19 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout &src,
megdnn_assert(mat.shape[2] == 3_z, "%s", errmsg().c_str());
if (param().format == param::WarpPerspective::Format::NCHW) {
megdnn_assert(src.dtype.enumv() == DTypeEnum::Float32 ||
MEGDNN_FLOAT16_SELECT(
src.dtype.enumv() == DTypeEnum::Float16,
false) ||
src.dtype.enumv() == DTypeEnum::Int8 ||
src.dtype.enumv() == DTypeEnum::Uint8 ||
(src.dtype.enumv() == DTypeEnum::QuantizedS8 ||
src.dtype.enumv() == DTypeEnum::Quantized8Asymm),
"WarpPerspective NCHW input dtype should be "
"Float32/Int8/Uint8/QInt8/QUint8" MEGDNN_FLOAT16_SELECT(
"/Float16", "") ".");
megdnn_assert(
src.dtype.enumv() == DTypeEnum::Float32 ||
MEGDNN_FLOAT16_SELECT(
(src.dtype.enumv() == DTypeEnum::Float16 ||
src.dtype.enumv() == DTypeEnum::BFloat16),
false) ||
src.dtype.enumv() == DTypeEnum::Int8 ||
src.dtype.enumv() == DTypeEnum::Uint8 ||
(src.dtype.enumv() == DTypeEnum::QuantizedS8 ||
src.dtype.enumv() == DTypeEnum::Quantized8Asymm),
"WarpPerspective NCHW input dtype should be "
"Float32/Int8/Uint8/QInt8/QUint8" MEGDNN_FLOAT16_SELECT(
"/Float16/BFloat16", "") ".");
megdnn_assert(
(src.dtype.category() == DTypeCategory::FLOAT &&
(src.dtype == mat.dtype ||
......@@ -107,14 +109,17 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout &src,
param::WarpPerspective::BorderMode::ISOLATED);
} else {
megdnn_assert(param().format == param::WarpPerspective::Format::NHWCD4);
megdnn_assert(src.dtype == dtype::Float32() ||
MEGDNN_FLOAT16_SELECT(
src.dtype == dtype::Float16(), false) ||
src.dtype.enumv() == DTypeEnum::QuantizedS8 ||
src.dtype.enumv() == DTypeEnum::Quantized8Asymm,
"WarpPerspective NHWCD4 input dtype should be "
"Float32" MEGDNN_FLOAT16_SELECT(
"/Float16", "") ",QunatizedS8, Quantized8Asymm.");
megdnn_assert(
src.dtype == dtype::Float32() ||
MEGDNN_FLOAT16_SELECT((src.dtype == dtype::Float16() ||
src.dtype == dtype::BFloat16()),
false) ||
src.dtype.enumv() == DTypeEnum::QuantizedS8 ||
src.dtype.enumv() == DTypeEnum::Quantized8Asymm,
"WarpPerspective NHWCD4 input dtype should be "
"Float32" MEGDNN_FLOAT16_SELECT(
"/Float16/BFloat16",
"") ",QunatizedS8, Quantized8Asymm.");
megdnn_assert(
(src.dtype == mat.dtype || mat.dtype == dtype::Float32()),
"The input to WarpPerspective is in NHWCD4 format, in this "
......@@ -253,30 +258,30 @@ void WarpPerspectiveForward::check_exec_allow_nhwc_mat_idx(
}
}
void WarpPerspectiveBackwardData::check_exec(const TensorLayout &mat,
const TensorLayout &diff,
const TensorLayout &grad,
size_t workspace_in_bytes)
{
void WarpPerspectiveBackwardData::check_exec(const TensorLayout& mat,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_in_bytes) {
check_layout_fwd(grad, mat, diff);
megdnn_assert(grad.dtype == dtype::Float32(),
"Backward WarpPerspective only supports Float32.");
megdnn_assert(grad.dtype == dtype::Float32() MEGDNN_INC_FLOAT16(
|| grad.dtype == dtype::BFloat16()),
"Backward WarpPerspective only supports Float32/BFloat16.");
auto required_workspace_in_bytes = get_workspace_in_bytes(mat, diff, grad);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}
void WarpPerspectiveBackwardMat::check_exec(const TensorLayout &src,
const TensorLayout &mat,
const TensorLayout &diff,
const TensorLayout &grad,
size_t workspace_in_bytes)
{
void WarpPerspectiveBackwardMat::check_exec(const TensorLayout& src,
const TensorLayout& mat,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_in_bytes) {
check_layout_fwd(src, mat, diff);
megdnn_assert_eq_layout(mat, grad);
megdnn_assert(grad.dtype == dtype::Float32(),
"Backward WarpPerspective only supports Float32.");
auto required_workspace_in_bytes = get_workspace_in_bytes(src,
mat, diff, grad);
megdnn_assert(grad.dtype == dtype::Float32() MEGDNN_INC_FLOAT16(
|| grad.dtype == dtype::BFloat16()),
"Backward WarpPerspective only supports Float32/BFloat16.");
auto required_workspace_in_bytes =
get_workspace_in_bytes(src, mat, diff, grad);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}
......
/**
* \file dnn/src/cuda/cond_take/kimpl/dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_cond_take_kern_impls.py
#include "../kern.inl"
#if !MEGDNN_DISABLE_FLOAT16
namespace megdnn {
namespace cuda {
namespace cond_take {
inst_genidx(::megdnn::dtype::BFloat16)
#undef inst_genidx
inst_copy(::megdnn::dtype::BFloat16)
#undef inst_copy
#undef inst_copy_
} // cond_take
} // cuda
} // megdnn
#endif
......@@ -62,6 +62,13 @@ ConvBiasForwardImpl::AlgoPack::AlgoPack() {
non_cudnn_algos.push_back(all_algos.rbegin()[1]); // group batched_matmul
non_cudnn_algos.push_back(all_algos.rbegin()[0]); // group 1x1
algo_size = all_algos.size();
for (size_t i = 0; i < algo_size; ++i) {
bfloat16_refhold.emplace_back(new AlgoBFloat16(all_algos[i]));
all_algos.push_back(bfloat16_refhold.back().get());
bfloat16_algos.push_back(bfloat16_refhold.back().get());
}
size_t all_algo_size = all_algos.size();
#if CUDA_VERSION >= 10000
fill_imma_algos();
......
......@@ -499,6 +499,28 @@ private:
};
#endif
class ConvBiasForwardImpl::AlgoBFloat16 final : public AlgoBase {
public:
AlgoBFloat16(AlgoBase* impl);
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
const char* name() const override { return m_name.c_str(); }
bool is_reproducible() const override { return m_impl->is_reproducible(); }
private:
SizeArgs float_args(const SizeArgs& args, ConvBiasForwardImpl* opr,
TensorLayout& fsrc, TensorLayout& ffilter,
TensorLayout& fbias, TensorLayout& fz,
TensorLayout& fdst) const;
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
AlgoBase* m_impl;
std::string m_name;
};
class ConvBiasForwardImpl::AlgoPack {
AlgoPack(const AlgoPack&) = delete;
AlgoPack& operator=(const AlgoPack&) = delete;
......@@ -508,7 +530,8 @@ public:
std::vector<AlgoBase*> all_algos,
//! non-cudnn algos, used for heuristic if cudnn is not supported
non_cudnn_algos;
non_cudnn_algos,
bfloat16_algos;
std::vector<AlgoCUDNNConvBiasActivation> cudnn_conv_bias_activations;
std::vector<AlgoCUDNNConv> cudnn_convs;
AlgoChanwise chanwise;
......@@ -531,6 +554,7 @@ public:
int8_chwn4_imma_unroll_width;
#endif
std::vector<std::unique_ptr<AlgoGroupConvGeneral>> gconv_refhold;
std::vector<std::unique_ptr<AlgoBFloat16>> bfloat16_refhold;
std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv;
AlgoBase* cudnn_conv_bias_act_from_enum(cudnnConvolutionFwdAlgo_t algo);
......
/**
* \file dnn/src/cuda/conv_bias/bfloat16.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/cuda/conv_bias/algo.h"
#include "src/cuda/handle.h"
#include "src/cuda/utils.cuh"
#include "src/cuda/utils.h"
using namespace megdnn;
using namespace cuda;
using namespace conv_bias;
ConvBiasForwardImpl::AlgoBFloat16::AlgoBFloat16(
ConvBiasForwardImpl::AlgoBase* algorithm)
: m_impl(algorithm) {
megdnn_assert_internal(algorithm);
m_name = ssprintf("BFLOAT16:%s", m_impl->name());
}
ConvBiasForwardImpl::AlgoBase::SizeArgs
ConvBiasForwardImpl::AlgoBFloat16::float_args(
const SizeArgs& args, ConvBiasForwardImpl* opr, TensorLayout& fsrc,
TensorLayout& ffilter, TensorLayout& fbias, TensorLayout& fz,
TensorLayout& fdst) const {
fsrc = *args.src_layout;
ffilter = *args.filter_layout;
fbias = *args.bias_layout;
fz = *args.z_layout;
fdst = *args.dst_layout;
auto change_dtype = [](TensorLayout& layout) {
if (layout.dtype == dtype::BFloat16()) {
layout.dtype = dtype::Float32();
}
};
change_dtype(fsrc);
change_dtype(ffilter);
change_dtype(fbias);
change_dtype(fz);
change_dtype(fdst);
opr->param() = args.opr->param();
opr->param().compute_mode = Param::ComputeMode::DEFAULT;
opr->execution_policy() = {m_impl};
return SizeArgs(opr, fsrc, ffilter, fbias, fz, fdst);
}
bool ConvBiasForwardImpl::AlgoBFloat16::is_available(
const SizeArgs& args) const {
TensorLayout fsrc, ffilter, fbias, fz, fdst;
auto convbias_opr = args.handle->create_operator<ConvBias>();
SizeArgs fargs = float_args(
args, static_cast<ConvBiasForwardImpl*>(convbias_opr.get()), fsrc,
ffilter, fbias, fz, fdst);
return args.src_layout->dtype == args.filter_layout->dtype &&
args.src_layout->dtype == dtype::BFloat16() &&
m_impl->is_available(fargs);
}
WorkspaceBundle ConvBiasForwardImpl::AlgoBFloat16::get_workspace_bundle(
void* ptr, const SizeArgs& args) const {
TensorLayout fsrc, ffilter, fbias, fz, fdst;
auto convbias_opr = args.handle->create_operator<ConvBias>();
SizeArgs fargs = float_args(
args, static_cast<ConvBiasForwardImpl*>(convbias_opr.get()), fsrc,
ffilter, fbias, fz, fdst);
SmallVector<size_t> sizes;
auto get_workspace = [&sizes](const TensorLayout& src,
const TensorLayout& dst) {
if (src.dtype != dst.dtype) {
sizes.push_back(dst.span().dist_byte());
}
};
get_workspace(*args.src_layout, fsrc);
get_workspace(*args.filter_layout, ffilter);
get_workspace(*args.bias_layout, fbias);
get_workspace(*args.z_layout, fz);
get_workspace(*args.dst_layout, fdst);
sizes.push_back(m_impl->get_workspace_in_bytes(fargs));
return {ptr, std::move(sizes)};
}
size_t ConvBiasForwardImpl::AlgoBFloat16::get_workspace_in_bytes(
const SizeArgs& args) const {
return get_workspace_bundle(nullptr, args).total_size_in_bytes();
}
void ConvBiasForwardImpl::AlgoBFloat16::exec(const ExecArgs& args) const {
TensorND fsrc_tensor = *args.src_tensor;
TensorND ffilter_tensor = *args.filter_tensor;
TensorND fbias_tensor = *args.bias_tensor;
TensorND fz_tensor = *args.z_tensor;
TensorND fdst_tensor = *args.dst_tensor;
auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args);
CompTypeCvter<dtype::BFloat16, dtype::Float32> cvter(args.handle, &bundle);
{
cvter.src_to_comp_type(*args.src_tensor, fsrc_tensor)
.src_to_comp_type(*args.filter_tensor, ffilter_tensor)
.src_to_comp_type(*args.bias_tensor, fbias_tensor)
.src_to_comp_type(*args.z_tensor, fz_tensor)
.src_to_comp_type(*args.dst_tensor, fdst_tensor);
}
{
auto convbias_opr = args.handle->create_operator<ConvBias>();
convbias_opr->param() = args.opr->param();
convbias_opr->param().compute_mode = Param::ComputeMode::DEFAULT;
convbias_opr->execution_policy() = {m_impl};
convbias_opr->exec(fsrc_tensor, ffilter_tensor, fbias_tensor, fz_tensor,
fdst_tensor, cvter.workspace());
}
{ cvter.comp_to_dst_type(fdst_tensor, *args.dst_tensor); }
}
// vim: syntax=cpp.doxygen
......@@ -20,6 +20,10 @@ using namespace conv_bias;
bool ConvBiasForwardImpl::AlgoChanwise::is_available(
const SizeArgs& args) const {
if (args.src_layout->dtype == args.filter_layout->dtype &&
args.src_layout->dtype == dtype::BFloat16()) {
return false;
}
if (args.z_layout->ndim > 0)
return false;
......
......@@ -30,6 +30,10 @@ inline bool is_available_small(const chanwise::Param& param) {
bool ConvBiasForwardImpl::AlgoChanwiseSmall::is_available(
const SizeArgs& args) const {
if (args.src_layout->dtype == args.filter_layout->dtype &&
args.src_layout->dtype == dtype::BFloat16()) {
return false;
}
if (args.z_layout->ndim > 0)
return false;
#if CUDA_VERSION < 9000
......
......@@ -23,6 +23,10 @@ using namespace conv_bias;
bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available(
const SizeArgs& args) const {
if (args.src_layout->dtype == args.filter_layout->dtype &&
args.src_layout->dtype == dtype::BFloat16()) {
return false;
}
if (args.bias_layout->ndim == 0 ||
args.bias_layout->eq_shape(*args.dst_layout))
return false;
......
......@@ -50,6 +50,10 @@ ConvBiasForwardImpl::AlgoGroupConvGeneral::AlgoGroupConvGeneral(AlgoBase* impl)
bool ConvBiasForwardImpl::AlgoGroupConvGeneral::is_available(
const SizeArgs& args) const {
if (args.src_layout->dtype == args.filter_layout->dtype &&
args.src_layout->dtype == dtype::BFloat16()) {
return false;
}
if (args.z_layout->ndim > 0 || args.filter_meta.group <= 1)
return false;
auto&& param = args.opr->param();
......
......@@ -136,6 +136,11 @@ void ConvBiasDesc::set_conv(DType data_type, const param::ConvBias& param,
namespace conv_bias {
bool is_cudnn_supported(const BiasForwardSizeArgs& args) {
if (args.src_layout->dtype == args.filter_layout->dtype &&
args.src_layout->dtype == dtype::BFloat16()) {
return false;
}
// CUDNN_STATUS_EXECUTION_FAILED on Tegra K1, so disable CUDNN
// on Tegra K1.
if (args.handle->is_tegra_k1())
......
......@@ -20,6 +20,10 @@ using namespace cuda;
using namespace conv_bias;
bool ConvBiasForwardImpl::AlgoMatmul::is_available(const SizeArgs& args) const {
if (args.src_layout->dtype == args.filter_layout->dtype &&
args.src_layout->dtype == dtype::BFloat16()) {
return false;
}
if (args.z_layout->ndim > 0)
return false;
......
......@@ -9,6 +9,7 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/cuda/conv_bias/opr_impl.h"
#include "megdnn/dtype.h"
#include "src/cuda/conv_bias/helper.h"
#include "src/cuda/conv_bias/algo.h"
#include "src/cuda/handle.h"
......@@ -176,14 +177,26 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic(
conv_args = orig_args;
}
if (reproducible) {
return megdnn::get_reproducible_algo<ConvBiasForwardImpl>(
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes,
"cuda convbias fwd");
if (args.src_layout->dtype.enumv() != DTypeTrait<dtype::BFloat16>::enumv) {
if (reproducible) {
return megdnn::get_reproducible_algo<ConvBiasForwardImpl>(
sm_algo_pack.non_cudnn_algos, args,
workspace_limit_in_bytes, "cuda convbias fwd");
} else {
return megdnn::get_usable_algo<ConvBiasForwardImpl>(
sm_algo_pack.non_cudnn_algos, args,
workspace_limit_in_bytes, "cuda convbias fwd");
}
} else {
return megdnn::get_usable_algo<ConvBiasForwardImpl>(
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes,
"cuda convbias fwd");
if (reproducible) {
return megdnn::get_reproducible_algo<ConvBiasForwardImpl>(
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes,
"cuda convbias fwd");
} else {
return megdnn::get_usable_algo<ConvBiasForwardImpl>(
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes,
"cuda convbias fwd");
}
}
}
......
......@@ -57,6 +57,7 @@ public:
class AlgoInt8NCHW4IMMAImplicitGemm;
class AlgoInt8CHWN4IMMAImplicitGemmReorderFilter;
class AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth;
class AlgoBFloat16;
class AlgoPack;
......
......@@ -33,11 +33,12 @@ ConvolutionBackwardDataImpl::AlgoPack::AlgoPack() {
// add gconv algos by AlgoGroupConvGeneral
auto all_algos_data = all_algos.data();
for (size_t i = 2; i < all_algos.size(); ++ i) {
size_t group_algo_start = 2;
for (size_t i = group_algo_start; i < all_algos.size(); ++ i) {
gconv.push_back({all_algos[i]});
}
for (size_t i = 2; i < all_algos.size(); ++ i) {
algo2gconv[all_algos[i]] = &gconv[i - 2];
for (size_t i = group_algo_start; i < all_algos.size(); ++ i) {
algo2gconv[all_algos[i]] = &gconv[i - group_algo_start];
}
for (auto &&i: gconv) {
all_algos.push_back(&i);
......@@ -45,6 +46,12 @@ ConvolutionBackwardDataImpl::AlgoPack::AlgoPack() {
megdnn_assert(all_algos_data == all_algos.data());
non_cudnn_algos.push_back(all_algos.rbegin()[0]); // group matmul
size_t algo_size = all_algos.size();
for (size_t i=0; i<algo_size; ++i) {
bfloat16_refhold.emplace_back(new AlgoBFloat16(all_algos[i]));
all_algos.push_back(bfloat16_refhold.back().get());
bfloat16_algos.push_back(bfloat16_refhold.back().get());
}
}
ConvolutionBackwardDataImpl::AlgoCUDNN*
......@@ -65,18 +72,19 @@ ConvolutionBackwardDataImpl::AlgoBase::SizeArgs::SizeArgs(
ConvolutionBackwardDataImpl *o,
const TensorLayout &filter, const TensorLayout &diff,
const TensorLayout &grad):
SizeArgs(o, o->check_layout_fwd(grad, filter, diff), diff, grad)
SizeArgs(o, filter, o->check_layout_fwd(grad, filter, diff), diff, grad)
{
}
ConvolutionBackwardDataImpl::AlgoBase::SizeArgs::SizeArgs(
ConvolutionBackwardDataImpl *o,
const CanonizedFilterMeta &filter, const TensorLayout &diff,
ConvolutionBackwardDataImpl *o, const TensorLayout& filter,
const CanonizedFilterMeta &filter_meta, const TensorLayout &diff,
const TensorLayout &grad):
handle{concrete_handle(o->handle())},
filter_meta{filter},
filter_meta{filter_meta},
diff_layout{&diff},
grad_layout{&grad},
filter_layout{&filter},
opr{o}
{
}
......
......@@ -31,22 +31,24 @@ class ConvolutionBackwardDataImpl::AlgoBase: public Algorithm {
struct SizeArgs {
HandleImpl *handle;
CanonizedFilterMeta filter_meta;
const TensorLayout *diff_layout, *grad_layout;
const TensorLayout *diff_layout, *grad_layout, *filter_layout;
ConvolutionBackwardDataImpl *opr;
std::string to_string() const;
void init_desc(convolution::CUDNNBwdDataDescs &desc) const {
desc.set(filter_meta, *diff_layout, *grad_layout, opr->param());
}
SizeArgs(ConvolutionBackwardDataImpl *opr,
const TensorLayout &filter, const TensorLayout &diff,
const TensorLayout &grad);
SizeArgs(ConvolutionBackwardDataImpl *opr,
const CanonizedFilterMeta &filter, const TensorLayout &diff,
const TensorLayout &grad);
SizeArgs(ConvolutionBackwardDataImpl* opr,
const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad);
SizeArgs(ConvolutionBackwardDataImpl* opr,
const TensorLayout& filter,
const CanonizedFilterMeta& filter_meta,
const TensorLayout& diff, const TensorLayout& grad);
convolution::ForwardSizeArgs as_fwd_args() const {
return {handle, grad_layout, filter_meta, diff_layout};
return {handle, grad_layout, filter_layout, filter_meta,
diff_layout};
}
};
struct ExecArgs: public SizeArgs {
......@@ -170,6 +172,25 @@ class ConvolutionBackwardDataImpl::AlgoChanwiseSmall final: public AlgoBase {
}
};
class ConvolutionBackwardDataImpl::AlgoBFloat16 final : public AlgoBase {
public:
AlgoBFloat16(ConvolutionBackwardDataImpl::AlgoBase*);
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
const char* name() const override { return m_name.c_str(); }
bool is_reproducible() const override { return true; }
private:
std::string m_name;
ConvolutionBackwardDataImpl::AlgoBase* m_algorithm = nullptr;
SizeArgs float_args(const SizeArgs& args, ConvolutionBackwardDataImpl* opr,
TensorLayout& fsrc, TensorLayout& ffilter,
TensorLayout& fdst) const;
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
};
//! implement group conv by another algo
class ConvolutionBackwardDataImpl::AlgoGroupConvGeneral final: public AlgoBase {
AlgoBase *m_impl;
......@@ -210,12 +231,14 @@ class ConvolutionBackwardDataImpl::AlgoPack {
AlgoChanwiseSmall chanwise_small;
std::vector<AlgoGroupConvGeneral> gconv;
std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv;
std::vector<std::unique_ptr<AlgoBFloat16>> bfloat16_refhold;
std::vector<AlgoBase*>
//! all algorithms
all_algos,
//! non-cudnn algos, used for heuristic if cudnn is not supported
non_cudnn_algos;
non_cudnn_algos,
bfloat16_algos;
AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdDataAlgo_t algo);
};
......
/**
* \file src/cuda/convolution/backward_data/bfloat16.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "./algo.h"
#include "src/cuda/convolution/chanwise/kern.cuh"
#include "src/cuda/utils.h"
using namespace megdnn;
using namespace cuda;
using namespace convolution;
ConvolutionBackwardDataImpl::AlgoBFloat16::AlgoBFloat16(
ConvolutionBackwardDataImpl::AlgoBase* algorithm)
: m_algorithm(algorithm) {
megdnn_assert_internal(algorithm);
m_name = ssprintf("CONVOLUTION_BACKWARD_DATD_BFLOAT16:%s",
m_algorithm->name());
}
ConvolutionBackwardDataImpl::AlgoBase::SizeArgs
ConvolutionBackwardDataImpl::AlgoBFloat16::float_args(
const SizeArgs& args, ConvolutionBackwardDataImpl* opr,
TensorLayout& ffilter, TensorLayout& fdiff, TensorLayout& fgrad) const {
ffilter = *args.filter_layout;
fdiff = *args.diff_layout;
fgrad = *args.grad_layout;
auto change_dtype = [](TensorLayout& layout) {
if (layout.dtype == dtype::BFloat16()) {
layout.dtype = dtype::Float32();
}
};
change_dtype(ffilter);
change_dtype(fdiff);
change_dtype(fgrad);
opr->param() = args.opr->param();
opr->param().compute_mode = Param::ComputeMode::DEFAULT;
opr->execution_policy() = {m_algorithm};
return SizeArgs(opr, ffilter, fdiff, fgrad);
}
bool ConvolutionBackwardDataImpl::AlgoBFloat16::is_available(
const SizeArgs& args) const {
TensorLayout ffilter, fdiff, fgrad;
auto conv_back_data_opr =
args.handle->create_operator<ConvolutionBackwardData>();
SizeArgs fargs = float_args(
args,
static_cast<ConvolutionBackwardDataImpl*>(conv_back_data_opr.get()),
ffilter, fdiff, fgrad);
return args.diff_layout->dtype == args.filter_layout->dtype &&
args.diff_layout->dtype == dtype::BFloat16() &&
m_algorithm->is_available(fargs);
}
WorkspaceBundle ConvolutionBackwardDataImpl::AlgoBFloat16::get_workspace_bundle(
void* ptr, const SizeArgs& args) const {
TensorLayout ffilter, fdiff, fgrad;
auto conv_back_data_opr =
args.handle->create_operator<ConvolutionBackwardData>();
SizeArgs fargs = float_args(
args,
static_cast<ConvolutionBackwardDataImpl*>(conv_back_data_opr.get()),
ffilter, fdiff, fgrad);
SmallVector<size_t> sizes;
auto get_workspace = [&sizes](const TensorLayout& src,
const TensorLayout& dst) {
if (src.dtype != dst.dtype) {
sizes.push_back(dst.span().dist_byte());
}
};
get_workspace(*args.filter_layout, ffilter);
get_workspace(*args.diff_layout, fdiff);
get_workspace(*args.grad_layout, fgrad);
sizes.push_back(m_algorithm->get_workspace_in_bytes(fargs));
return {ptr, std::move(sizes)};
}
size_t ConvolutionBackwardDataImpl::AlgoBFloat16::get_workspace_in_bytes(
const SizeArgs& args) const {
return get_workspace_bundle(nullptr, args).total_size_in_bytes();
}
void ConvolutionBackwardDataImpl::AlgoBFloat16::exec(
const ExecArgs& args) const {
TensorND ffilter_tensor = *args.filter_tensor;
TensorND fdiff_tensor = *args.diff_tensor;
TensorND fgrad_tensor = *args.grad_tensor;
auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args);
CompTypeCvter<dtype::BFloat16, dtype::Float32> cvter(args.handle, &bundle);
{
cvter.src_to_comp_type(*args.filter_tensor, ffilter_tensor)
.src_to_comp_type(*args.diff_tensor, fdiff_tensor)
.src_to_comp_type(*args.grad_tensor, fgrad_tensor);
}
{
auto conv_back_data_opr =
args.handle->create_operator<ConvolutionBackwardData>();
conv_back_data_opr->param() = args.opr->param();
conv_back_data_opr->param().compute_mode = Param::ComputeMode::DEFAULT;
conv_back_data_opr->execution_policy() = {m_algorithm};
conv_back_data_opr->exec(ffilter_tensor, fdiff_tensor, fgrad_tensor,
cvter.workspace());
}
{ cvter.comp_to_dst_type(fgrad_tensor, *args.grad_tensor); }
}
// vim: syntax=cpp.doxygen
......@@ -19,6 +19,10 @@ using namespace convolution;
bool ConvolutionBackwardDataImpl::AlgoChanwise::is_available(
const SizeArgs& args) const {
if (args.diff_layout->dtype == args.filter_layout->dtype &&
args.diff_layout->dtype == dtype::BFloat16()) {
return false;
}
auto&& fm = args.filter_meta;
return args.filter_meta.format == Param::Format::NCHW &&
args.diff_layout->dtype.category() == DTypeCategory::FLOAT &&
......
......@@ -29,6 +29,10 @@ inline bool is_available_small(const chanwise::Param& param) {
bool ConvolutionBackwardDataImpl::AlgoChanwiseSmall::is_available(
const SizeArgs &args) const {
if (args.diff_layout->dtype == args.filter_layout->dtype &&
args.diff_layout->dtype == dtype::BFloat16()) {
return false;
}
#if CUDA_VERSION < 9000
if (args.diff_layout->dtype.enumv() == DTypeEnum::Float16)
return false;
......
......@@ -38,6 +38,10 @@ ConvolutionBackwardDataImpl::AlgoGroupConvGeneral::AlgoGroupConvGeneral(
bool ConvolutionBackwardDataImpl::AlgoGroupConvGeneral::is_available(
const SizeArgs &args) const {
if (args.diff_layout->dtype == args.filter_layout->dtype &&
args.diff_layout->dtype == dtype::BFloat16()) {
return false;
}
auto sub_args = args;
TensorLayout diff_pg, grad_pg;
modify_size_args(sub_args, diff_pg, grad_pg);
......
......@@ -20,6 +20,10 @@ using namespace cuda;
bool ConvolutionBackwardDataImpl::AlgoMatmul::is_available(
const SizeArgs &args) const {
if (args.diff_layout->dtype == args.filter_layout->dtype &&
args.diff_layout->dtype == dtype::BFloat16()) {
return false;
}
auto &&fm = args.filter_meta;
return args.filter_meta.format == Param::Format::NCHW &&
args.diff_layout->dtype.category() == DTypeCategory::FLOAT &&
......
......@@ -43,6 +43,12 @@ ConvolutionBackwardFilterImpl::AlgoPack::AlgoPack() {
megdnn_assert(all_algos_data == all_algos.data());
non_cudnn_algos.push_back(all_algos.rbegin()[0]); // group matmul
size_t algo_size = all_algos.size();
for (size_t i=0; i<algo_size; ++i) {
bfloat16_refhold.emplace_back(new AlgoBFloat16(all_algos[i]));
all_algos.push_back(bfloat16_refhold.back().get());
bfloat16_algos.push_back(bfloat16_refhold.back().get());
}
}
ConvolutionBackwardFilterImpl::AlgoCUDNN*
......@@ -64,21 +70,20 @@ ConvolutionBackwardFilterImpl::AlgoBase::SizeArgs::SizeArgs(
ConvolutionBackwardFilterImpl *o,
const TensorLayout &src, const TensorLayout &diff,
const TensorLayout &grad):
SizeArgs(o, src, diff, o->check_layout_fwd(src, grad, diff))
SizeArgs(o, src, diff, grad, o->check_layout_fwd(src, grad, diff))
{
}
ConvolutionBackwardFilterImpl::AlgoBase::SizeArgs::SizeArgs(
ConvolutionBackwardFilterImpl *o,
const TensorLayout &src, const TensorLayout &diff,
const CanonizedFilterMeta &grad):
handle{concrete_handle(o->handle())},
src_layout{&src},
diff_layout{&diff},
grad_filter_meta{grad},
opr{o}
{
}
ConvolutionBackwardFilterImpl* o, const TensorLayout& src,
const TensorLayout& diff, const TensorLayout& grad,
const CanonizedFilterMeta& grad_meta)
: handle{concrete_handle(o->handle())},
src_layout{&src},
diff_layout{&diff},
grad_layout{&grad},
grad_filter_meta{grad_meta},
opr{o} {}
ConvolutionBackwardFilterImpl::AlgoBase::ExecArgs::ExecArgs(
ConvolutionBackwardFilterImpl *opr,
......
......@@ -30,7 +30,7 @@ class ConvolutionBackwardFilterImpl::AlgoBase: public Algorithm {
public:
struct SizeArgs {
HandleImpl *handle;
const TensorLayout *src_layout, *diff_layout;
const TensorLayout *src_layout, *diff_layout, *grad_layout;
CanonizedFilterMeta grad_filter_meta;
ConvolutionBackwardFilterImpl *opr;
......@@ -42,12 +42,14 @@ class ConvolutionBackwardFilterImpl::AlgoBase: public Algorithm {
SizeArgs(ConvolutionBackwardFilterImpl *opr,
const TensorLayout &src, const TensorLayout &diff,
const TensorLayout &grad);
SizeArgs(ConvolutionBackwardFilterImpl *opr,
const TensorLayout &src, const TensorLayout &diff,
const CanonizedFilterMeta &grad);
SizeArgs(ConvolutionBackwardFilterImpl* opr,
const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad,
const CanonizedFilterMeta& grad_meta);
convolution::ForwardSizeArgs as_fwd_args() const {
return {handle, src_layout, grad_filter_meta, diff_layout};
return {handle, src_layout, grad_layout, grad_filter_meta,
diff_layout};
}
};
struct ExecArgs: public SizeArgs {
......@@ -157,6 +159,25 @@ class ConvolutionBackwardFilterImpl::AlgoChanwise final: public AlgoBase {
}
};
class ConvolutionBackwardFilterImpl::AlgoBFloat16 final : public AlgoBase {
public:
AlgoBFloat16(ConvolutionBackwardFilterImpl::AlgoBase*);
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
const char* name() const override { return m_name.c_str(); }
bool is_reproducible() const override { return true; }
private:
std::string m_name;
ConvolutionBackwardFilterImpl::AlgoBase* m_algorithm = nullptr;
SizeArgs float_args(const SizeArgs& args,
ConvolutionBackwardFilterImpl* opr, TensorLayout& fsrc,
TensorLayout& ffilter, TensorLayout& fdst) const;
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
};
//! implement group conv by another algo
class ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral final: public AlgoBase {
AlgoBase *m_impl;
......@@ -196,12 +217,14 @@ class ConvolutionBackwardFilterImpl::AlgoPack {
AlgoChanwise chanwise;
std::vector<AlgoGroupConvGeneral> gconv;
std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv;
std::vector<std::unique_ptr<AlgoBFloat16>> bfloat16_refhold;
std::vector<AlgoBase*>
//! all algorithms
all_algos,
//! non-cudnn algos, used for heuristic if cudnn is not supported
non_cudnn_algos;
non_cudnn_algos,
bfloat16_algos;
AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdFilterAlgo_t algo);
};
......
/**
* \file src/cuda/convolution/backward_filter/bfloat16.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "./algo.h"
#include "src/cuda/convolution/chanwise/kern.cuh"
#include "src/cuda/utils.h"
using namespace megdnn;
using namespace cuda;
using namespace convolution;
ConvolutionBackwardFilterImpl::AlgoBFloat16::AlgoBFloat16(
ConvolutionBackwardFilterImpl::AlgoBase* algorithm)
: m_algorithm(algorithm) {
megdnn_assert_internal(algorithm);
m_name = ssprintf("CONVOLUTION_BACKWARD_Filter_BFLOAT16:%s",
m_algorithm->name());
}
ConvolutionBackwardFilterImpl::AlgoBase::SizeArgs
ConvolutionBackwardFilterImpl::AlgoBFloat16::float_args(
const SizeArgs& args, ConvolutionBackwardFilterImpl* opr,
TensorLayout& fsrc, TensorLayout& fdiff, TensorLayout& fgrad) const {
fsrc = *args.src_layout;
fdiff = *args.diff_layout;
fgrad = *args.grad_layout;
auto change_dtype = [](TensorLayout& layout) {
if (layout.dtype == dtype::BFloat16()) {
layout.dtype = dtype::Float32();
}
};
change_dtype(fsrc);
change_dtype(fdiff);
change_dtype(fgrad);
opr->param() = args.opr->param();
opr->param().compute_mode = Param::ComputeMode::DEFAULT;
opr->execution_policy() = {m_algorithm};
return SizeArgs(opr, fsrc, fdiff, fgrad);
}
bool ConvolutionBackwardFilterImpl::AlgoBFloat16::is_available(
const SizeArgs& args) const {
TensorLayout fsrc, fdiff, fgrad;
auto conv_back_filter_opr =
args.handle->create_operator<ConvolutionBackwardFilter>();
SizeArgs fargs = float_args(args,
static_cast<ConvolutionBackwardFilterImpl*>(
conv_back_filter_opr.get()),
fsrc, fdiff, fgrad);
return args.src_layout->dtype == args.diff_layout->dtype &&
args.src_layout->dtype == dtype::BFloat16() &&
m_algorithm->is_available(fargs);
}
WorkspaceBundle
ConvolutionBackwardFilterImpl::AlgoBFloat16::get_workspace_bundle(
void* ptr, const SizeArgs& args) const {
TensorLayout fsrc, fdiff, fgrad;
auto conv_back_filter_opr =
args.handle->create_operator<ConvolutionBackwardFilter>();
SizeArgs fargs = float_args(args,
static_cast<ConvolutionBackwardFilterImpl*>(
conv_back_filter_opr.get()),
fsrc, fdiff, fgrad);
SmallVector<size_t> sizes;
auto get_workspace = [&sizes](const TensorLayout& src,
const TensorLayout& dst) {
if (src.dtype != dst.dtype) {
sizes.push_back(dst.span().dist_byte());
}
};
get_workspace(*args.src_layout, fsrc);
get_workspace(*args.diff_layout, fdiff);
get_workspace(*args.grad_layout, fgrad);
sizes.push_back(m_algorithm->get_workspace_in_bytes(fargs));
return {ptr, std::move(sizes)};
}
size_t ConvolutionBackwardFilterImpl::AlgoBFloat16::get_workspace_in_bytes(
const SizeArgs& args) const {
return get_workspace_bundle(nullptr, args).total_size_in_bytes();
}
void ConvolutionBackwardFilterImpl::AlgoBFloat16::exec(
const ExecArgs& args) const {
TensorND fsrc_tensor = *args.src_tensor;
TensorND fdiff_tensor = *args.diff_tensor;
TensorND fgrad_tensor = *args.grad_tensor;
auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args);
CompTypeCvter<dtype::BFloat16, dtype::Float32> cvter(args.handle, &bundle);
{
cvter.src_to_comp_type(*args.src_tensor, fsrc_tensor)
.src_to_comp_type(*args.diff_tensor, fdiff_tensor)
.src_to_comp_type(*args.grad_tensor, fgrad_tensor);
}
{
auto conv_back_filter_opr =
args.handle->create_operator<ConvolutionBackwardFilter>();
conv_back_filter_opr->param() = args.opr->param();
conv_back_filter_opr->param().compute_mode =
Param::ComputeMode::DEFAULT;
conv_back_filter_opr->execution_policy() = {m_algorithm};
conv_back_filter_opr->exec(fsrc_tensor, fdiff_tensor, fgrad_tensor,
cvter.workspace());
}
{ cvter.comp_to_dst_type(fgrad_tensor, *args.grad_tensor); }
}
// vim: syntax=cpp.doxygen
......@@ -19,6 +19,10 @@ using namespace convolution;
bool ConvolutionBackwardFilterImpl::AlgoChanwise::is_available(
const SizeArgs &args) const {
if (args.src_layout->dtype == args.src_layout->dtype &&
args.diff_layout->dtype == dtype::BFloat16()) {
return false;
}
auto &&fm = args.grad_filter_meta;
return fm.format == Param::Format::NCHW &&
args.diff_layout->dtype.category() == DTypeCategory::FLOAT &&
......
......@@ -38,6 +38,10 @@ ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral::AlgoGroupConvGeneral(
bool ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral::is_available(
const SizeArgs &args) const {
if (args.src_layout->dtype == args.src_layout->dtype &&
args.diff_layout->dtype == dtype::BFloat16()) {
return false;
}
auto sub_args = args;
TensorLayout src_pg, diff_pg;
modify_size_args(sub_args, src_pg, diff_pg);
......
......@@ -19,6 +19,10 @@ using namespace cuda;
bool ConvolutionBackwardFilterImpl::AlgoMatmul::is_available(
const SizeArgs &args) const {
if (args.src_layout->dtype == args.src_layout->dtype &&
args.diff_layout->dtype == dtype::BFloat16()) {
return false;
}
auto &&fm = args.grad_filter_meta;
return fm.format == Param::Format::NCHW &&
args.diff_layout->dtype.category() == DTypeCategory::FLOAT &&
......
......@@ -16,6 +16,10 @@ using namespace cuda;
using namespace convolution;
bool convolution::is_cudnn_supported(const ForwardSizeArgs &args) {
if (args.src_layout->dtype == args.filter_layout->dtype &&
args.src_layout->dtype == dtype::BFloat16()) {
return false;
}
// CUDNN_STATUS_EXECUTION_FAILED on Tegra K1, so disable CUDNN
// on Tegra K1.
......
......@@ -25,6 +25,7 @@ namespace convolution {
struct ForwardSizeArgs {
HandleImpl *handle;
const TensorLayout *src_layout;
const TensorLayout *filter_layout;
CanonizedFilterMeta filter_meta;
const TensorLayout *dst_layout;
};
......
......@@ -102,7 +102,8 @@ void ConvolutionBackwardDataImpl::exec(_megdnn_tensor_in filter,
_megdnn_tensor_out grad,
_megdnn_workspace workspace) {
AlgoBase::ExecArgs args(this, filter, diff, grad, workspace);
auto algo = get_algorithm(this, args.filter_meta, diff.layout, grad.layout);
auto algo = get_algorithm(this, filter.layout, args.filter_meta,
diff.layout, grad.layout);
algo->check_workspace(args, workspace).exec(args);
}
......@@ -120,16 +121,16 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic(
const TensorLayout& grad, size_t workspace_limit_in_bytes,
bool reproducible) {
auto fm = check_layout_fwd(grad, filter, diff);
return get_algorithm_heuristic(fm, diff, grad, workspace_limit_in_bytes,
reproducible);
return get_algorithm_heuristic(filter, fm, diff, grad,
workspace_limit_in_bytes, reproducible);
}
ConvolutionBackwardDataImpl::Algorithm*
ConvolutionBackwardDataImpl::get_algorithm_heuristic(
const CanonizedFilterMeta& filter, const TensorLayout& diff,
ConvolutionBackwardDataImpl::get_algorithm_heuristic(const TensorLayout& filter,
const CanonizedFilterMeta& filter_meta, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes,
bool reproducible) {
AlgoBase::SizeArgs args(this, filter, diff, grad);
AlgoBase::SizeArgs args(this, filter, filter_meta, diff, grad);
if (args.filter_meta.group > 1 &&
sm_algo_pack.chanwise.is_available_reproducible(
......@@ -209,14 +210,27 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic(
args = orig_args;
}
if (reproducible) {
return megdnn::get_reproducible_algo<ConvolutionBackwardDataImpl>(
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes,
"cuda conv bwd_data");
if (args.filter_layout->dtype.enumv() !=
DTypeTrait<dtype::BFloat16>::enumv) {
if (reproducible) {
return megdnn::get_reproducible_algo<ConvolutionBackwardDataImpl>(
sm_algo_pack.non_cudnn_algos, args,
workspace_limit_in_bytes, "cuda conv bwd_data");
} else {
return megdnn::get_usable_algo<ConvolutionBackwardDataImpl>(
sm_algo_pack.non_cudnn_algos, args,
workspace_limit_in_bytes, "cuda conv bwd_data");
}
} else {
return megdnn::get_usable_algo<ConvolutionBackwardDataImpl>(
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes,
"cuda conv bwd_data");
if (reproducible) {
return megdnn::get_reproducible_algo<ConvolutionBackwardDataImpl>(
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes,
"cuda conv bwd_data");
} else {
return megdnn::get_usable_algo<ConvolutionBackwardDataImpl>(
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes,
"cuda conv bwd_data");
}
}
}
......@@ -225,7 +239,7 @@ size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes(
const TensorLayout &diff,
const TensorLayout &grad) {
AlgoBase::SizeArgs args(this, filter, diff, grad);
return get_algorithm(this, args.filter_meta, diff, grad)->
return get_algorithm(this, filter, args.filter_meta, diff, grad)->
get_workspace_in_bytes(args);
}
......@@ -241,7 +255,7 @@ void ConvolutionBackwardFilterImpl::exec(_megdnn_tensor_in src,
_megdnn_workspace workspace) {
AlgoBase::ExecArgs args(this, src, diff, grad, workspace);
auto algo = get_algorithm(this, src.layout, diff.layout,
args.grad_filter_meta);
grad.layout, args.grad_filter_meta);
algo->check_workspace(args, workspace).exec(args);
}
......@@ -259,16 +273,16 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic(
const TensorLayout& grad, size_t workspace_limit_in_bytes,
bool reproducible) {
auto fm = check_layout_fwd(src, grad, diff);
return get_algorithm_heuristic(src, diff, fm, workspace_limit_in_bytes,
reproducible);
return get_algorithm_heuristic(src, diff, grad, fm,
workspace_limit_in_bytes, reproducible);
}
ConvolutionBackwardFilterImpl::Algorithm*
ConvolutionBackwardFilterImpl::get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& diff,
const CanonizedFilterMeta& grad, size_t workspace_limit_in_bytes,
bool reproducible) {
AlgoBase::SizeArgs args(this, src, diff, grad);
const TensorLayout& grad, const CanonizedFilterMeta& grad_meta,
size_t workspace_limit_in_bytes, bool reproducible) {
AlgoBase::SizeArgs args(this, src, diff, grad, grad_meta);
if (args.grad_filter_meta.group > 1 &&
sm_algo_pack.chanwise.is_available_reproducible(
......@@ -349,14 +363,26 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic(
args = orig_args;
}
if (reproducible) {
return megdnn::get_reproducible_algo<ConvolutionBackwardFilterImpl>(
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes,
"cuda conv bwd_filter");
if (args.src_layout->dtype.enumv() != DTypeTrait<dtype::BFloat16>::enumv) {
if (reproducible) {
return megdnn::get_reproducible_algo<ConvolutionBackwardFilterImpl>(
sm_algo_pack.non_cudnn_algos, args,
workspace_limit_in_bytes, "cuda conv bwd_filter");
} else {
return megdnn::get_usable_algo<ConvolutionBackwardFilterImpl>(
sm_algo_pack.non_cudnn_algos, args,
workspace_limit_in_bytes, "cuda conv bwd_filter");
}
} else {
return megdnn::get_usable_algo<ConvolutionBackwardFilterImpl>(
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes,
"cuda conv bwd_filter");
if (reproducible) {
return megdnn::get_reproducible_algo<ConvolutionBackwardFilterImpl>(
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes,
"cuda conv bwd_filter");
} else {
return megdnn::get_usable_algo<ConvolutionBackwardFilterImpl>(
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes,
"cuda conv bwd_filter");
}
}
}
......@@ -365,7 +391,7 @@ size_t ConvolutionBackwardFilterImpl::get_workspace_in_bytes(
const TensorLayout &diff,
const TensorLayout &grad) {
AlgoBase::SizeArgs args(this, src, diff, grad);
return get_algorithm(this, src, diff, args.grad_filter_meta)->
return get_algorithm(this, src, diff, grad, args.grad_filter_meta)->
get_workspace_in_bytes(args);
}
......
......@@ -60,11 +60,11 @@ class ConvolutionBackwardDataImpl: public ConvolutionBackwardData {
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
bool reproducible) override;
Algorithm* get_algorithm_heuristic(const CanonizedFilterMeta& filter,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
bool reproducible);
Algorithm* get_algorithm_heuristic(
const TensorLayout& filter,
const CanonizedFilterMeta& filter_meta,
const TensorLayout& diff, const TensorLayout& grad,
size_t workspace_limit_in_bytes, bool reproducible);
size_t get_workspace_in_bytes(const TensorLayout& filter,
const TensorLayout& diff,
const TensorLayout& grad) override;
......@@ -76,6 +76,7 @@ class ConvolutionBackwardDataImpl: public ConvolutionBackwardData {
class AlgoChanwise;
class AlgoChanwiseSmall;
class AlgoGroupConvGeneral;
class AlgoBFloat16;
class AlgoPack;
......@@ -104,7 +105,8 @@ class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter {
bool reproducible) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& diff,
const CanonizedFilterMeta& grad,
const TensorLayout& gradk,
const CanonizedFilterMeta& grad_meta,
size_t workspace_limit_in_bytes,
bool reproducible);
size_t get_workspace_in_bytes(const TensorLayout& src,
......@@ -117,6 +119,7 @@ class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter {
class AlgoMatmul;
class AlgoChanwise;
class AlgoGroupConvGeneral;
class AlgoBFloat16;
class AlgoPack;
......
......@@ -50,7 +50,7 @@ class Convolution3DForwardImpl::AlgoBase: public Algorithm {
const CanonizedFilterMeta &filter,
const TensorLayout &dst);
};
struct ExecArgs: public SizeArgs {
struct ExecArgs : public SizeArgs {
const TensorND *src_tensor, *filter_tensor, *dst_tensor;
Workspace workspace;
......
/**
* \file dnn/src/cuda/elemwise/kimpl/ABS_GRAD_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
/**
* \file dnn/src/cuda/elemwise/kimpl/ABS_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
/**
* \file dnn/src/cuda/elemwise/kimpl/ACOS_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOS, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
/**
* \file dnn/src/cuda/elemwise/kimpl/ADD_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
/**
* \file dnn/src/cuda/elemwise/kimpl/ASIN_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASIN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
/**
* \file dnn/src/cuda/elemwise/kimpl/ATAN2_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATAN2, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
/**
* \file dnn/src/cuda/elemwise/kimpl/CEIL_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CEIL, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
/**
* \file dnn/src/cuda/elemwise/kimpl/COND_LEQ_MOV_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
/**
* \file dnn/src/cuda/elemwise/kimpl/COS_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COS, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
/**
* \file dnn/src/cuda/elemwise/kimpl/EQ_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
......@@ -141,6 +141,9 @@ INST_FOR_CTYPE
#define ct dt_float16
INST_FOR_CTYPE
#undef ct
#define ct dt_bfloat16
INST_FOR_CTYPE
#undef ct
#define ct dt_int8
INST_FOR_CTYPE
#undef ct
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册