diff --git a/dnn/include/megdnn/dtype.h b/dnn/include/megdnn/dtype.h index df3319f0b6d00945bedf1d5ed138514adb8dd564..bd11994ea98acff9b6c301c8f28d23d86ccc87ad 100644 --- a/dnn/include/megdnn/dtype.h +++ b/dnn/include/megdnn/dtype.h @@ -29,6 +29,7 @@ #define MEGDNN_FLOAT16_SELECT(_x, _y) _y #else #include "megdnn/dtype/half.hpp" +#include "megdnn/dtype/bfloat16.hpp" #define MEGDNN_INC_FLOAT16(_x) _x #define MEGDNN_FLOAT16_SELECT(_x, _y) _x #endif @@ -49,6 +50,7 @@ namespace megdnn { cb(IntB4) \ cb(Byte) \ MEGDNN_INC_FLOAT16(cb(Float16)) \ + MEGDNN_INC_FLOAT16(cb(BFloat16)) \ cb(UintB4) \ /*! @@ -62,6 +64,7 @@ namespace megdnn { cb(Int32) \ cb(Byte) \ MEGDNN_INC_FLOAT16(cb(Float16)) \ + MEGDNN_INC_FLOAT16(cb(BFloat16)) \ /*! * \brief iterate through each fractional byte dtype @@ -101,6 +104,7 @@ namespace megdnn { #define MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) \ cb(::megdnn::dtype::Float32) \ MEGDNN_INC_FLOAT16(cb(::megdnn::dtype::Float16)) \ + MEGDNN_INC_FLOAT16(cb(::megdnn::dtype::BFloat16)) \ /*! * \brief iterate through each dtype object that can be involved in integer @@ -345,6 +349,7 @@ typedef int16_t dt_int16; typedef int8_t dt_int8; typedef uint8_t dt_uint8; MEGDNN_INC_FLOAT16(typedef half_float::half dt_float16;) +MEGDNN_INC_FLOAT16(typedef half_bfloat16::bfloat16 dt_bfloat16;) #define MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE 100000 #if MEGDNN_CC_HOST @@ -367,6 +372,9 @@ MEGDNN_INC_FLOAT16(typedef half_float::half dt_float16;) Float16, #endif UintB4 = 10, +#if !MEGDNN_DISABLE_FLOAT16 + BFloat16 = 11, +#endif #define FST(_name) _name = MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE, #define D(_name) _name, @@ -702,6 +710,9 @@ MEGDNN_DEF_DT(Uint8, dt_uint8, INT, UNSIGNED, 0, UINT8_MAX); MEGDNN_INC_FLOAT16(MEGDNN_DEF_DT(Float16, dt_float16, FLOAT, SIGNED, std::numeric_limits::lowest(), std::numeric_limits::max())); +MEGDNN_INC_FLOAT16(MEGDNN_DEF_DT(BFloat16, dt_bfloat16, FLOAT, SIGNED, + std::numeric_limits::lowest(), + std::numeric_limits::max())); template <> struct DTypeTrait { diff --git a/dnn/include/megdnn/dtype/bfloat16.hpp b/dnn/include/megdnn/dtype/bfloat16.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7e05c3b8563062aaaf5b65401b7f09fdf9057400 --- /dev/null +++ b/dnn/include/megdnn/dtype/bfloat16.hpp @@ -0,0 +1,2965 @@ +/** + * half - IEEE 754-based half-precision floating point library. + * + * Copyright (c) 2012-2013 Christian Rau + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation + * files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, + * modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE + * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, + * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + * + * Version 1.11.0 + * \file + * Main header file for half precision functionality. + * + * -------------------------------------------------------------------------- + * \file include/megdnn/dtype/bfloat16.hpp + * + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * + * This file has been modified by Megvii ("Megvii Modifications"). + * All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved. + * + * -------------------------------------------------------------------------- + */ + +#ifndef BFLOAT16_BFLOAT16_HPP +#define BFLOAT16_BFLOAT16_HPP +#include "megdnn/arch.h" + +#include "megdnn/dtype/half_common_prologue.h" +#include + +#if !(HALF_ENABLE_CPP11_CSTDINT & HALF_ENABLE_CPP11_CMATH & \ + HALF_ENABLE_CPP11_HASH) +#error "Should use --std=c++11 option for compile." +#endif + +/// Default rounding mode. +/// This specifies the rounding mode used for all conversions between +/// [half](\ref half_bfloat16::bfloat16)s and `float`s as well as for the +/// half_cast() if not specifying a rounding mode explicitly. It can be +/// redefined (before including half.hpp) to one of the standard rounding modes +/// using their respective constants or the equivalent values of +/// `float_round_style`: +/// +/// `float_round_style` | value | rounding +/// ---------------------------------|-------|------------------------- +/// `round_indeterminate` | -1 | fastest (default) +/// `round_toward_zero` | 0 | toward zero +/// `round_to_nearest` | 1 | to nearest +/// `round_toward_infinity` | 2 | toward positive infinity +/// `round_toward_neg_infinity` | 3 | toward negative infinity +/// +/// By default this is set to `1` (`round_to_nearest`). It can even be set to +/// `numeric_limits::round_style` to synchronize the rounding mode with +/// that of the underlying single-precision implementation. +#ifndef BFLOAT16_ROUND_STYLE +#define BFLOAT16_ROUND_STYLE 1 // = to nearest +#endif + +/// Tie-breaking behaviour for round to nearest. +/// This specifies if ties in round to nearest should be resolved by rounding to +/// the nearest even value. By default this is defined to `1` resulting of +/// rounding to the nearest even in half-way cases, but can be redefined to +/// `0` (before including bfloat16.hpp). And thus equal to the round() function. +#ifndef BFLOAT16_ROUND_TIES_TO_EVEN +#define BFLOAT16_ROUND_TIES_TO_EVEN 1 // ties round to nearest even. +#endif + +#if !BFLOAT16_ROUND_TIES_TO_EVEN +#error "BFloat16 only support round ties to even now." +#endif + +/// Value signaling overflow. +/// In correspondence with `HUGE_VAL[F|L]` from `` this symbol expands to +/// a positive value signaling the overflow of an operation, in particular it +/// just evaluates to positive infinity. +#define HUGE_VALBH numeric_limits::infinity() + +/// Fast half-precision fma function. +/// This symbol is only defined if the fma() function generally executes as fast +/// as, or faster than, a separate half-precision multiplication followed by an +/// addition. Due to the internal single-precision implementation of all +/// arithmetic operations, this is in fact always the case. +#define FP_FAST_FMAH 1 + +#ifndef FP_ILOGB0 +#define FP_ILOGB0 INT_MIN +#endif +#ifndef FP_ILOGBNAN +#define FP_ILOGBNAN INT_MAX +#endif +#ifndef FP_SUBNORMAL +#define FP_SUBNORMAL 0 +#endif +#ifndef FP_ZERO +#define FP_ZERO 1 +#endif +#ifndef FP_NAN +#define FP_NAN 2 +#endif +#ifndef FP_INFINITE +#define FP_INFINITE 3 +#endif +#ifndef FP_NORMAL +#define FP_NORMAL 4 +#endif + +/// Main namespace for bfloat16 functionality. +/// This namespace contains all the functionality provided by the library. +/// Bfloat16 has the following format: +/// Sign bit: 1 bit +/// Exponent width: 8 bits +/// Significand precision: 8 bits (7 explicitly stored), as opposed to 24 bits +/// in a classical single-precision floating-point format +namespace half_bfloat16 { +class bfloat16; +/// \internal +/// \brief Implementation details. +namespace detail { +#if HALF_ENABLE_CPP11_TYPE_TRAITS +/// Conditional type. +template +struct conditional : std::conditional {}; + +/// Helper for tag dispatching. +template +struct bool_type : std::integral_constant {}; +using std::false_type; +using std::true_type; + +/// Type traits for floating point types. +template +struct is_float : std::is_floating_point {}; +#else +/// Conditional type. +template +struct conditional { + typedef T type; +}; +template +struct conditional { + typedef F type; +}; + +/// Helper for tag dispatching. +template +struct bool_type {}; +typedef bool_type true_type; +typedef bool_type false_type; + +/// Type traits for floating point types. +template +struct is_float : false_type {}; +template +struct is_float : is_float {}; +template +struct is_float : is_float {}; +template +struct is_float : is_float {}; +template <> +struct is_float : true_type {}; +template <> +struct is_float : true_type {}; +template <> +struct is_float : true_type {}; +#endif + +/// Unsigned integer of (at least) 16 bits width. +typedef uint_least16_t uint16; + +/// Unsigned integer of (at least) 32 bits width. +typedef uint_least32_t uint32; + +/// Fastest signed integer capable of holding all values of type uint16. +typedef int_fast32_t int17; + +/// Tag type for binary_t() construction. +struct binary_t {}; + +/// Temporary bfloat16 expression. +/// This class represents a bfloat16 expression which just stores a +/// single-precision value internally. +struct expr { + /// Conversion constructor. + /// \param f single-precision value to convert + MEGDNN_HOST MEGDNN_DEVICE explicit HALF_CONSTEXPR expr(float f) + : value_(f) {} + + /// Conversion to single-precision. + /// \return single precision value representing expression value + MEGDNN_HOST MEGDNN_DEVICE HALF_CONSTEXPR operator float() const { + return value_; + } + +private: + /// Internal expression value stored in single-precision. + float value_; +}; + +/// SFINAE helper for generic bfloat16 functions. +/// This class template has to be specialized for each valid combination of +/// argument types to provide a corresponding `type` member equivalent to \a T. +/// \tparam T type to return +template +struct enable {}; +template +struct enable { + typedef T type; +}; +template +struct enable { + typedef T type; +}; +template +struct enable { + typedef T type; +}; +template +struct enable { + typedef T type; +}; +template +struct enable { + typedef T type; +}; +template +struct enable { + typedef T type; +}; +template +struct enable { + typedef T type; +}; +template +struct enable { + typedef T type; +}; +template +struct enable { + typedef T type; +}; +template +struct enable { + typedef T type; +}; +template +struct enable { + typedef T type; +}; +template +struct enable { + typedef T type; +}; +template +struct enable { + typedef T type; +}; +template +struct enable { + typedef T type; +}; + +/// Return type for specialized generic 2-argument bfloat16 functions. +/// This class template has to be specialized for each valid combination of +/// argument types to provide a corresponding `type` member denoting the +/// appropriate return type. \tparam T first argument type \tparam U first +/// argument type +template +struct result : enable {}; +template <> +struct result { + typedef bfloat16 type; +}; + +/// \name Classification helpers +/// \{ + +/// Check for infinity. +/// \tparam T argument type (builtin floating point type) +/// \param arg value to query +/// \retval true if infinity +/// \retval false else +template +MEGDNN_HOST MEGDNN_DEVICE bool builtin_isinf(T arg) { +#if defined(__CUDA_ARCH__) + return ::isinf(arg); +#elif HALF_ENABLE_CPP11_CMATH + return ::std::isinf(arg); +#elif defined(_MSC_VER) + return !_finite(static_cast(arg)) && + !_isnan(static_cast(arg)); +#else + return arg == std::numeric_limits::infinity() || + arg == -std::numeric_limits::infinity(); +#endif +} + +/// Check for NaN. +/// \tparam T argument type (builtin floating point type) +/// \param arg value to query +/// \retval true if not a number +/// \retval false else +template +MEGDNN_HOST MEGDNN_DEVICE bool builtin_isnan(T arg) { +#if defined(__CUDA_ARCH__) + return ::isnan(arg); +#elif HALF_ENABLE_CPP11_CMATH + return std::isnan(arg); +#elif defined(_MSC_VER) + return _isnan(static_cast(arg)) != 0; +#else + return arg != arg; +#endif +} + +/// Check sign. +/// \tparam T argument type (builtin floating point type) +/// \param arg value to query +/// \retval true if signbit set +/// \retval false else +template +MEGDNN_HOST MEGDNN_DEVICE bool builtin_signbit(T arg) { +#if defined(__CUDA_ARCH__) + return ::signbit(arg); +#elif HALF_ENABLE_CPP11_CMATH + return std::signbit(arg); +#else + return arg < T() || (arg == T() && T(1) / arg < T()); +#endif +} + +/// \} +/// \name Conversion +/// \{ + +/// Convert single-precision to bfloat16. +/// \param value single-precision value +/// \return binary_t() representation of bfloat16-precision value +template +MEGDNN_HOST MEGDNN_DEVICE uint16 float2bfloat16(float value) { +#if HALF_ENABLE_CPP11_STATIC_ASSERT + static_assert(std::numeric_limits::is_iec559, + "float to bfloat16 conversion needs IEEE 754 " + "conformant 'float' type"); + static_assert(sizeof(uint32) == sizeof(float), + "float to bfloat16 conversion needs unsigned integer " + "type of exactly the size of a 'float'"); + static_assert(R == std::round_to_nearest, "Only support rouding-mode " + "round-to-nearst currently."); +#endif + + union { + float fraw; + uint32_t int_raw; + } r = {value}; + if (~r.int_raw & 0x7f800000) { + //! When the exponent bits are not all 1s, then the value is zero, + //! normal, or subnormal. + r.int_raw += 0x7fff + ((r.int_raw >> 16) & 1); + } else if (r.int_raw & 0xffff) { + //! When all of the exponent bits are 1, the value is Inf or NaN. + //! Preserve signaling NaN here. + r.int_raw |= 0x10000; + } + return uint16(r.int_raw >> 16); +} + +/// Convert integer to bfloat16 floating point. +/// \tparam R rounding mode to use, `round_indeterminate` for fastest rounding +/// \tparam T type to convert (builtin integer type) +/// \param value integral value +/// \return binary_t() representation of bfloat16-precision value +template +MEGDNN_HOST MEGDNN_DEVICE uint16 int2bfloat16(T value) { + return float2bfloat16(static_cast(value)); +} + +/// Convert bfloat16 to single-precision. +/// \param value binary_t() representation of bfloat16 value +/// \return single-precision value +MEGDNN_HOST MEGDNN_DEVICE inline float bfloat162float(uint16 value) { +#if HALF_ENABLE_CPP11_STATIC_ASSERT + static_assert(std::numeric_limits::is_iec559, + "bfloat16 to float conversion needs IEEE 754 conformant " + "'float' type"); + static_assert(sizeof(uint32) == sizeof(float), + "bfloat16 to float conversion needs unsigned integer type of " + "exactly the size of a 'float'"); +#endif + union { + uint32_t int_raw; + float fraw; + } r = {uint32_t(value) << 16}; + return r.fraw; +} + +/// Convert bfloat16 floating point to integer. +/// \tparam T type to convert to (buitlin integer type with at least 16 bits +/// precision, excluding any implicit sign bits) \param value binary_t() +/// representation of bfloat16-precision value \return integral value +template +MEGDNN_HOST MEGDNN_DEVICE T bfloat162int(uint16 value) { + return static_cast(bfloat162float(value)); +} + +/// Round bfloat16 number to nearest integer value. +/// \tparam R rounding mode to use, `round_indeterminate` for fastest rounding +/// \tparam E `true` for round to even, `false` for round away from zero +/// \param value binary_t() representation of bfloat16-precision value +/// \return bfloat16 bits for nearest integral value +template +MEGDNN_HOST MEGDNN_DEVICE uint16 round_bfloat16_impl(uint16 value) { + unsigned int e = value & 0x7FFF; + uint16 result = value; + if (e < 0x3F80) { + result &= 0x8000; + if (R == std::round_to_nearest) + result |= 0x3F80U & -(e >= (0x3F00 + E)); + else if (R == std::round_toward_infinity) + result |= 0x3F80U & -(~(value >> 15) & (e != 0)); + else if (R == std::round_toward_neg_infinity) + result |= 0x3F80U & -(value > 0x8000); + } else if (e < 0x4300) { + e = 134 - (e >> 7); + unsigned int mask = (1 << e) - 1; + if (R == std::round_to_nearest) + result += (1 << (e - 1)) - (~(result >> e) & E); + else if (R == std::round_toward_infinity) + result += mask & ((value >> 15) - 1); + else if (R == std::round_toward_neg_infinity) + result += mask & -(value >> 15); + result &= ~mask; + } + return result; +} + +/// Round bfloat16 number to nearest integer value. +/// \tparam R rounding mode to use, `round_indeterminate` for fastest rounding +/// \param value binary_t() representation of bfloat16-precision value +/// \return bfloat16 bits for nearest integral value +template +MEGDNN_HOST MEGDNN_DEVICE uint16 round_bfloat16(uint16 value) { + return round_bfloat16_impl(value); +} + +/// Round bfloat16 number to nearest integer value using +/// round-to-nearest-away-from-zero. \param value binary_t() representation of +/// bfloat16-precision value \return bfloat16-precision bits for nearest +/// integral value +MEGDNN_HOST MEGDNN_DEVICE inline uint16 round_bfloat16_up(uint16 value) { + return round_bfloat16_impl(value); +} +/// \} + +struct functions; +template +struct unary_specialized; +template +struct binary_specialized; +template +struct bfloat16_caster; +} + +/// bfloat16 floating point type. +class bfloat16 { + friend struct detail::functions; + friend struct detail::unary_specialized; + friend struct detail::binary_specialized; + template + friend struct detail::bfloat16_caster; +#if HALF_ENABLE_CPP11_HASH + friend struct std::hash; +#endif + +public: + /// Default constructor. + MEGDNN_HOST MEGDNN_DEVICE bfloat16() {} + + /// Copy constructor. + /// \tparam T type of concrete bfloat16 expression + /// \param rhs bfloat16 expression to copy from + MEGDNN_HOST MEGDNN_DEVICE bfloat16(detail::expr rhs) + : data_(detail::float2bfloat16(rhs)) {} + + MEGDNN_HOST MEGDNN_DEVICE HALF_CONSTEXPR bfloat16(const bfloat16& rhs) + : data_(rhs.data_) {} + + MEGDNN_HOST MEGDNN_DEVICE bfloat16(const volatile bfloat16& rhs) + : data_(rhs.data_) {} + + MEGDNN_HOST MEGDNN_DEVICE bfloat16& operator=(const bfloat16& rhs) { + data_ = rhs.data_; + return *this; + } + + MEGDNN_HOST MEGDNN_DEVICE bfloat16& operator=( + const volatile bfloat16& rhs) { + data_ = rhs.data_; + return *this; + } + + MEGDNN_HOST MEGDNN_DEVICE volatile bfloat16& operator=( + const bfloat16& rhs) volatile { + data_ = rhs.data_; + return *this; + } + + /// Conversion constructor. + /// \param rhs float to convert + // MEGDNN_HOST MEGDNN_DEVICE explicit bfloat16(float rhs) + // : data_(detail::float2bfloat16(rhs)) {} + + MEGDNN_HOST MEGDNN_DEVICE explicit bfloat16(float rhs) { + data_ = detail::float2bfloat16(rhs); + } + + /// Conversion to single-precision. + /// \return single precision value representing expression value + MEGDNN_HOST MEGDNN_DEVICE operator float() const { + return detail::bfloat162float(data_); + } + + /// Assignment operator. + /// \tparam T type of concrete bfloat16 expression + /// \param rhs bfloat16 expression to copy from + /// \return reference to this bfloat16 + MEGDNN_HOST MEGDNN_DEVICE bfloat16& operator=(detail::expr rhs) { + return *this = static_cast(rhs); + } + + /// Arithmetic assignment. + /// \tparam T type of concrete bfloat16 expression + /// \param rhs bfloat16 expression to add + /// \return reference to this bfloat16 + template + MEGDNN_HOST MEGDNN_DEVICE typename detail::enable::type + operator+=(T rhs) { + return *this += static_cast(rhs); + } + + /// Arithmetic assignment. + /// \tparam T type of concrete bfloat16 expression + /// \param rhs bfloat16 expression to subtract + /// \return reference to this bfloat16 + template + MEGDNN_HOST MEGDNN_DEVICE typename detail::enable::type + operator-=(T rhs) { + return *this -= static_cast(rhs); + } + + /// Arithmetic assignment. + /// \tparam T type of concrete bfloat16 expression + /// \param rhs bfloat16 expression to multiply with + /// \return reference to this bfloat16 + template + MEGDNN_HOST MEGDNN_DEVICE typename detail::enable::type + operator*=(T rhs) { + return *this *= static_cast(rhs); + } + + /// Arithmetic assignment. + /// \tparam T type of concrete bfloat16 expression + /// \param rhs bfloat16 expression to divide by + /// \return reference to this bfloat16 + template + MEGDNN_HOST MEGDNN_DEVICE typename detail::enable::type + operator/=(T rhs) { + return *this /= static_cast(rhs); + } + + /// Assignment operator. + /// \param rhs single-precision value to copy from + /// \return reference to this bfloat16 + MEGDNN_HOST MEGDNN_DEVICE bfloat16& operator=(float rhs) { + data_ = detail::float2bfloat16(rhs); + return *this; + } + + /// Arithmetic assignment. + /// \param rhs single-precision value to add + /// \return reference to this bfloat16 + MEGDNN_HOST MEGDNN_DEVICE bfloat16& operator+=(float rhs) { + data_ = detail::float2bfloat16( + detail::bfloat162float(data_) + rhs); + return *this; + } + + /// Arithmetic assignment. + /// \param rhs single-precision value to subtract + /// \return reference to this bfloat16 + MEGDNN_HOST MEGDNN_DEVICE bfloat16& operator-=(float rhs) { + data_ = detail::float2bfloat16( + detail::bfloat162float(data_) - rhs); + return *this; + } + + /// Arithmetic assignment. + /// \param rhs single-precision value to multiply with + /// \return reference to this bfloat16 + MEGDNN_HOST MEGDNN_DEVICE bfloat16& operator*=(float rhs) { + data_ = detail::float2bfloat16( + detail::bfloat162float(data_) * rhs); + return *this; + } + + /// Arithmetic assignment. + /// \param rhs single-precision value to divide by + /// \return reference to this bfloat16 + MEGDNN_HOST MEGDNN_DEVICE bfloat16& operator/=(float rhs) { + data_ = detail::float2bfloat16( + detail::bfloat162float(data_) / rhs); + return *this; + } + + /// Prefix increment. + /// \return incremented bfloat16 value + MEGDNN_HOST MEGDNN_DEVICE bfloat16& operator++() { return *this += 1.0f; } + + /// Prefix decrement. + /// \return decremented bfloat16 value + MEGDNN_HOST MEGDNN_DEVICE bfloat16& operator--() { return *this -= 1.0f; } + + /// Postfix increment. + /// \return non-incremented bfloat16 value + MEGDNN_HOST MEGDNN_DEVICE bfloat16 operator++(int) { + bfloat16 out(*this); + ++*this; + return out; + } + + /// Postfix decrement. + /// \return non-decremented bfloat16 value + MEGDNN_HOST MEGDNN_DEVICE bfloat16 operator--(int) { + bfloat16 out(*this); + --*this; + return out; + } + + /// Constructor. + /// \param bits binary_t() representation to set bfloat16 to + MEGDNN_HOST MEGDNN_DEVICE HALF_CONSTEXPR bfloat16(detail::binary_t, + detail::uint16 bits) + : data_(bits) {} + + /// Rounding mode to use (always `round_to_nearest` with + /// BFLOAT16_ROUND_TIES_TO_EVEN on) + static HALF_CONSTEXPR_CONST std::float_round_style round_style = + (std::float_round_style)(BFLOAT16_ROUND_STYLE); + + // private: + /// Internal binary_t() representation + detail::uint16 data_; +}; + +#if HALF_ENABLE_CPP11_USER_LITERALS +/// Library-defined bfloat16 literals. +/// Import this namespace to enable bfloat16 floating point literals: +/// ~~~~{.cpp} +/// using namespace half_bfloat16::literal; +/// half_bfloat16::bfloat16 = 4.2_h; +/// ~~~~ +namespace literal { +/// Half literal. +/// While this returns an actual bfloat16-precision value, bfloat16 literals can +/// unfortunately not be constant expressions due to rather involved +/// single-to-bfloat16 conversion. \param value literal value \return bfloat16 +/// with given value (if representable) +inline bfloat16 operator"" _h(long double value) { + return bfloat16(static_cast(value)); +} +} // namespace literal +#endif + +namespace detail { +/// Wrapper implementing unspecialized bfloat16 functions. +struct functions { + /// Addition implementation. + /// \param x first operand + /// \param y second operand + /// \return bfloat16 sum stored in single-precision + MEGDNN_HOST MEGDNN_DEVICE static expr plus(float x, float y) { + return expr(x + y); + } + + /// Subtraction implementation. + /// \param x first operand + /// \param y second operand + /// \return bfloat16 difference stored in single-precision + MEGDNN_HOST MEGDNN_DEVICE static expr minus(float x, float y) { + return expr(x - y); + } + + /// Multiplication implementation. + /// \param x first operand + /// \param y second operand + /// \return bfloat16 product stored in single-precision + MEGDNN_HOST MEGDNN_DEVICE static expr multiplies(float x, float y) { + return expr(x * y); + } + + /// Division implementation. + /// \param x first operand + /// \param y second operand + /// \return bfloat16 quotient stored in single-precision + MEGDNN_HOST MEGDNN_DEVICE static expr divides(float x, float y) { + return expr(x / y); + } + + /// Output implementation. + /// \param out stream to write to + /// \param arg value to write + /// \return reference to stream + template + static std::basic_ostream& write( + std::basic_ostream& out, float arg) { + return out << arg; + } + + /// Input implementation. + /// \param in stream to read from + /// \param arg bfloat16 to read into + /// \return reference to stream + template + static std::basic_istream& read( + std::basic_istream& in, bfloat16& arg) { + float f; + if (in >> f) + arg = f; + return in; + } + + /// Modulo implementation. + /// \param x first operand + /// \param y second operand + /// \return bfloat16 division remainder stored in single-precision + MEGDNN_HOST MEGDNN_DEVICE static expr fmod(float x, float y) { +#if defined(__CUDA_ARCH__) + return expr(fmodf(x, y)); +#else + return expr(std::fmod(x, y)); +#endif + } + + /// Remainder implementation. + /// \param x first operand + /// \param y second operand + /// \return bfloat16 division remainder stored in single-precision + MEGDNN_HOST MEGDNN_DEVICE static expr remainder(float x, float y) { +#if defined(__CUDA_ARCH__) + return expr(remainderf(x, y)); +#else + return expr(std::remainder(x, y)); +#endif + } + + /// Positive difference implementation. + /// \param x first operand + /// \param y second operand + /// \return Positive difference stored in single-precision + MEGDNN_HOST MEGDNN_DEVICE static expr fdim(float x, float y) { +#if defined(__CUDA_ARCH__) + return expr(fdimf(x, y)); +#else + return expr(std::fdim(x, y)); +#endif + } + + /// Fused multiply-add implementation. + /// \param x first operand + /// \param y second operand + /// \param z third operand + /// \return \a x * \a y + \a z stored in single-precision + MEGDNN_HOST MEGDNN_DEVICE static expr fma(float x, float y, float z) { +#if defined(__CUDA_ARCH__) + return expr(fmaf(x, y, z)); +#elif defined(FP_FAST_FMAF) + return expr(std::fma(x, y, z)); +#else + return expr(x * y + z); +#endif + } + + /// Get NaN. + /// \return bfloat16 quiet NaN + MEGDNN_HOST MEGDNN_DEVICE static bfloat16 nanh(const char*) { + return bfloat16(binary_t(), 0x7FFF); + } + + /// Exponential implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + MEGDNN_HOST MEGDNN_DEVICE static expr exp(float arg) { +#if defined(__CUDA_ARCH__) + return expr(expf(arg)); +#else + return expr(std::exp(arg)); +#endif + } + + /// Exponential implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + MEGDNN_HOST MEGDNN_DEVICE static expr expm1(float arg) { +#if defined(__CUDA_ARCH__) + return expr(expm1f(arg)); +#else + return expr(std::expm1(arg)); +#endif + } + + /// Binary exponential implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + MEGDNN_HOST MEGDNN_DEVICE static expr exp2(float arg) { +#if defined(__CUDA_ARCH__) + return expr(exp2f(arg)); +#else + return expr(std::exp2(arg)); +#endif + } + + /// Logarithm implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + MEGDNN_HOST MEGDNN_DEVICE static expr log(float arg) { +#if defined(__CUDA_ARCH__) + return expr(logf(arg)); +#else + return expr(std::log(arg)); +#endif + } + + /// Common logarithm implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + MEGDNN_HOST MEGDNN_DEVICE static expr log10(float arg) { +#if defined(__CUDA_ARCH__) + return expr(log10f(arg)); +#else + return expr(std::log10(arg)); +#endif + } + + /// Logarithm implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + MEGDNN_HOST MEGDNN_DEVICE static expr log1p(float arg) { +#if defined(__CUDA_ARCH__) + return expr(log1pf(arg)); +#else + return expr(std::log1p(arg)); +#endif + } + + /// Binary logarithm implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + MEGDNN_HOST MEGDNN_DEVICE static expr log2(float arg) { +#if defined(__CUDA_ARCH__) + return expr(log2f(arg)); +#else + return expr(std::log2(arg)); +#endif + } + + /// Square root implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + MEGDNN_HOST MEGDNN_DEVICE static expr sqrt(float arg) { +#if defined(__CUDA_ARCH__) + return expr(sqrtf(arg)); +#else + return expr(std::sqrt(arg)); +#endif + } + + /// Cubic root implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + MEGDNN_HOST MEGDNN_DEVICE static expr cbrt(float arg) { +#if defined(__CUDA_ARCH__) + return expr(cbrtf(arg)); +#else + return expr(std::cbrt(arg)); +#endif + } + + /// Hypotenuse implementation. + /// \param x first argument + /// \param y second argument + /// \return function value stored in single-preicision + MEGDNN_HOST MEGDNN_DEVICE static expr hypot(float x, float y) { +#if defined(__CUDA_ARCH__) + return expr(hypotf(x, y)); +#else + return expr(std::hypot(x, y)); +#endif + } + + /// Power implementation. + /// \param base value to exponentiate + /// \param exp power to expontiate to + /// \return function value stored in single-preicision + MEGDNN_HOST MEGDNN_DEVICE static expr pow(float base, float exp) { +#if defined(__CUDA_ARCH__) + return expr(powf(base, exp)); +#else + return expr(std::pow(base, exp)); +#endif + } + + /// Sine implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + MEGDNN_HOST MEGDNN_DEVICE static expr sin(float arg) { +#if defined(__CUDA_ARCH__) + return expr(sinf(arg)); +#else + return expr(std::sin(arg)); +#endif + } + + /// Cosine implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + MEGDNN_HOST MEGDNN_DEVICE static expr cos(float arg) { +#if defined(__CUDA_ARCH__) + return expr(cosf(arg)); +#else + return expr(std::cos(arg)); +#endif + } + + /// Tan implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + MEGDNN_HOST MEGDNN_DEVICE static expr tan(float arg) { +#if defined(__CUDA_ARCH__) + return expr(tanf(arg)); +#else + return expr(std::tan(arg)); +#endif + } + + /// Arc sine implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + MEGDNN_HOST MEGDNN_DEVICE static expr asin(float arg) { +#if defined(__CUDA_ARCH__) + return expr(asinf(arg)); +#else + return expr(std::asin(arg)); +#endif + } + + /// Arc cosine implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + MEGDNN_HOST MEGDNN_DEVICE static expr acos(float arg) { +#if defined(__CUDA_ARCH__) + return expr(acosf(arg)); +#else + return expr(std::acos(arg)); +#endif + } + + /// Arc tangent implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + MEGDNN_HOST MEGDNN_DEVICE static expr atan(float arg) { +#if defined(__CUDA_ARCH__) + return expr(atanf(arg)); +#else + return expr(std::atan(arg)); +#endif + } + + /// Arc tangent implementation. + /// \param x first argument + /// \param y second argument + /// \return function value stored in single-preicision + MEGDNN_HOST MEGDNN_DEVICE static expr atan2(float x, float y) { +#if defined(__CUDA_ARCH__) + return expr(atan2f(x, y)); +#else + return expr(std::atan2(x, y)); +#endif + } + + /// Hyperbolic sine implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + MEGDNN_HOST MEGDNN_DEVICE static expr sinh(float arg) { +#if defined(__CUDA_ARCH__) + return expr(sinhf(arg)); +#else + return expr(std::sinh(arg)); +#endif + } + + /// Hyperbolic cosine implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + MEGDNN_HOST MEGDNN_DEVICE static expr cosh(float arg) { +#if defined(__CUDA_ARCH__) + return expr(coshf(arg)); +#else + return expr(std::cosh(arg)); +#endif + } + + /// Hyperbolic tangent implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + MEGDNN_HOST MEGDNN_DEVICE static expr tanh(float arg) { +#if defined(__CUDA_ARCH__) + return expr(tanhf(arg)); +#else + return expr(std::tanh(arg)); +#endif + } + + /// Hyperbolic area sine implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + MEGDNN_HOST MEGDNN_DEVICE static expr asinh(float arg) { +#if defined(__CUDA_ARCH__) + return expr(asinhf(arg)); +#else + return expr(std::asinh(arg)); +#endif + } + + /// Hyperbolic area cosine implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + MEGDNN_HOST MEGDNN_DEVICE static expr acosh(float arg) { +#if defined(__CUDA_ARCH__) + return expr(acoshf(arg)); +#else + return expr(std::acosh(arg)); +#endif + } + + /// Hyperbolic area tangent implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + MEGDNN_HOST MEGDNN_DEVICE static expr atanh(float arg) { +#if defined(__CUDA_ARCH__) + return expr(atanhf(arg)); +#else + return expr(std::atanh(arg)); +#endif + } + + /// Error function implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + MEGDNN_HOST MEGDNN_DEVICE static expr erf(float arg) { +#if defined(__CUDA_ARCH__) + return expr(erff(arg)); +#else + return expr(std::erf(arg)); +#endif + } + + /// Complementary implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + MEGDNN_HOST MEGDNN_DEVICE static expr erfc(float arg) { +#if defined(__CUDA_ARCH__) + return expr(erfcf(arg)); +#else + return expr(std::erfc(arg)); +#endif + } + + /// Gamma logarithm implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + MEGDNN_HOST MEGDNN_DEVICE static expr lgamma(float arg) { +#if defined(__CUDA_ARCH__) + return expr(lgammaf(arg)); +#else + return expr(std::lgamma(arg)); +#endif + } + + /// Gamma implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + MEGDNN_HOST MEGDNN_DEVICE static expr tgamma(float arg) { +#if defined(__CUDA_ARCH__) + return expr(tgammaf(arg)); +#else + return expr(std::tgamma(arg)); +#endif + } + + /// Minimum implementation. + /// \param x first operand + /// \param y second operand + /// \return minimum value + MEGDNN_HOST MEGDNN_DEVICE static expr fmin(float x, float y) { + return expr(::fmin(x, y)); + } + + /// Maximum implementation. + /// \param x first operand + /// \param y second operand + /// \return maximum value + MEGDNN_HOST MEGDNN_DEVICE static expr fmax(float x, float y) { + return expr(::fmax(x, y)); + } + + /// Floor implementation. + /// \param arg value to round + /// \return rounded value + MEGDNN_HOST MEGDNN_DEVICE static bfloat16 floor(bfloat16 arg) { + return bfloat16( + binary_t(), + round_bfloat16(arg.data_)); + } + + /// Ceiling implementation. + /// \param arg value to round + /// \return rounded value + MEGDNN_HOST MEGDNN_DEVICE static bfloat16 ceil(bfloat16 arg) { + return bfloat16(binary_t(), + round_bfloat16(arg.data_)); + } + + /// Truncation implementation. + /// \param arg value to round + /// \return rounded value + MEGDNN_HOST MEGDNN_DEVICE static bfloat16 trunc(bfloat16 arg) { + return bfloat16(binary_t(), + round_bfloat16(arg.data_)); + } + + /// Nearest integer implementation. + /// \param arg value to round + /// \return rounded value + MEGDNN_HOST MEGDNN_DEVICE static bfloat16 round(bfloat16 arg) { + return bfloat16(binary_t(), round_bfloat16_up(arg.data_)); + } + + /// Nearest integer implementation. + /// \param arg value to round + /// \return rounded value + MEGDNN_HOST MEGDNN_DEVICE static bfloat16 rint(bfloat16 arg) { + return bfloat16(binary_t(), + round_bfloat16(arg.data_)); + } + + /// Nearest integer implementation. + /// \param arg value to round + /// \return rounded value + MEGDNN_HOST MEGDNN_DEVICE static long lrint(bfloat16 arg) { + return detail::bfloat162int(arg.data_); + } + +#if HALF_ENABLE_CPP11_LONG_LONG + /// Nearest integer implementation. + /// \param arg value to round + /// \return rounded value + MEGDNN_HOST MEGDNN_DEVICE static long long llrint(bfloat16 arg) { + return detail::bfloat162int( + arg.data_); + } +#endif + + /// Decompression implementation. + /// \param arg number to decompress + /// \param exp address to store exponent at + /// \return normalized significant + MEGDNN_HOST MEGDNN_DEVICE static bfloat16 frexp(float arg, int* exp) { + return bfloat16(binary_t(), float2bfloat16( + std::frexp(arg, exp))); + } + + /// Decompression implementation. + /// \param arg number to decompress + /// \param iptr address to store integer part at + /// \return fractional part + MEGDNN_HOST MEGDNN_DEVICE static bfloat16 modf(float arg, bfloat16* iptr) { + float fptr = 0.f; + bfloat16 ret = bfloat16( + binary_t(), + float2bfloat16(std::modf(arg, &fptr))); + *iptr = fptr; + return ret; + } + + /// Scaling implementation. + /// \param arg number to scale + /// \param exp power of two to scale by + /// \return scaled number + MEGDNN_HOST MEGDNN_DEVICE static bfloat16 scalbln(float arg, long exp) { + return bfloat16(binary_t(), float2bfloat16( + std::scalbln(arg, exp))); + } + + /// Exponent implementation. + /// \param arg number to query + /// \return floating point exponent + MEGDNN_HOST MEGDNN_DEVICE static int ilogb(float arg) { + return std::ilogb(arg); + } + + /// Exponent implementation. + /// \param arg number to query + /// \return floating point exponent + MEGDNN_HOST MEGDNN_DEVICE static bfloat16 logb(bfloat16 arg) { + return bfloat16(binary_t(), + float2bfloat16(std::logb(arg))); + } + + /// Enumeration implementation. + /// \param from number to increase/decrease + /// \param to direction to enumerate into + /// \return next representable number + MEGDNN_HOST MEGDNN_DEVICE static bfloat16 nextafter(bfloat16 from, + bfloat16 to) { + uint16 fabs = from.data_ & 0x7FFF, tabs = to.data_ & 0x7FFF; + if (fabs > 0x7F80) + return from; + if (tabs > 0x7F80 || from.data_ == to.data_ || !(fabs | tabs)) + return to; + if (!fabs) + return bfloat16(binary_t(), (to.data_ & 0x8000) + 1); + bool lt = (signbit(from) ? (static_cast(0x8000) - from.data_) + : static_cast(from.data_)) < + (signbit(to) ? (static_cast(0x8000) - to.data_) + : static_cast(to.data_)); + return bfloat16( + binary_t(), + from.data_ + + (((from.data_ >> 15) ^ static_cast(lt)) << 1) - + 1); + } + + /// Sign implementation + /// \param x first operand + /// \param y second operand + /// \return composed value + MEGDNN_HOST MEGDNN_DEVICE static bfloat16 copysign(bfloat16 x, bfloat16 y) { + return bfloat16(binary_t(), x.data_ ^ ((x.data_ ^ y.data_) & 0x8000)); + } + + /// Classification implementation. + /// \param arg value to classify + /// \retval true if infinite number + /// \retval false else + MEGDNN_HOST MEGDNN_DEVICE static int fpclassify(bfloat16 arg) { + unsigned int abs = arg.data_ & 0x7FFF; + if (abs > 0x7F80) + return FP_NAN; + if (abs == 0x7F80) + return FP_INFINITE; + if (abs > 0x7F) + return FP_NORMAL; + return abs ? FP_SUBNORMAL : FP_ZERO; + } + + /// Classification implementation. + /// \param arg value to classify + /// \retval true if finite number + /// \retval false else + MEGDNN_HOST MEGDNN_DEVICE static bool isfinite(float arg) { + return std::isfinite(arg); + } + + /// Classification implementation. + /// \param arg value to classify + /// \retval true if infinite number + /// \retval false else + MEGDNN_HOST MEGDNN_DEVICE static bool isinf(float arg) { + return std::isinf(arg); + } + + /// Classification implementation. + /// \param arg value to classify + /// \retval true if not a number + /// \retval false else + MEGDNN_HOST MEGDNN_DEVICE static bool isnan(bfloat16 arg) { + return (arg.data_ & 0x7FFF) > 0x7F80; + } + + /// Classification implementation. + /// \param arg value to classify + /// \retval true if normal number + /// \retval false else + + MEGDNN_HOST MEGDNN_DEVICE static bool isnormal(bfloat16 arg) { + return ((arg.data_ & 0x7F80) != 0) & ((arg.data_ & 0x7F80) != 0x7F80); + } + + /// Sign bit implementation. + /// \param arg value to check + /// \retval true if signed + /// \retval false if unsigned + MEGDNN_HOST MEGDNN_DEVICE static bool signbit(bfloat16 arg) { + return (arg.data_ & 0x8000) != 0; + } + + /// Comparison implementation. + /// \param x first operand + /// \param y second operand + /// \retval true if operands equal + /// \retval false else + MEGDNN_HOST MEGDNN_DEVICE static bool isequal(float x, float y) { + return x == y; + } + + /// Comparison implementation. + /// \param x first operand + /// \param y second operand + /// \retval true if operands not equal + /// \retval false else + MEGDNN_HOST MEGDNN_DEVICE static bool isnotequal(float x, float y) { + return x != y; + } + + /// Comparison implementation. + /// \param x first operand + /// \param y second operand + /// \retval true if \a x > \a y + /// \retval false else + MEGDNN_HOST MEGDNN_DEVICE static bool isgreater(float x, float y) { + return x > y; + } + + /// Comparison implementation. + /// \param x first operand + /// \param y second operand + /// \retval true if \a x >= \a y + /// \retval false else + MEGDNN_HOST MEGDNN_DEVICE static bool isgreaterequal(float x, float y) { + return x >= y; + } + + /// Comparison implementation. + /// \param x first operand + /// \param y second operand + /// \retval true if \a x < \a y + /// \retval false else + MEGDNN_HOST MEGDNN_DEVICE static bool isless(float x, float y) { + return x < y; + } + + /// Comparison implementation. + /// \param x first operand + /// \param y second operand + /// \retval true if \a x <= \a y + /// \retval false else + MEGDNN_HOST MEGDNN_DEVICE static bool islessequal(float x, float y) { + return x <= y; + } + + /// Comparison implementation. + /// \param x first operand + /// \param y second operand + /// \retval true neither \a x > \a y nor \a x < \a y + /// \retval false else + MEGDNN_HOST MEGDNN_DEVICE static bool islessgreater(float x, float y) { + return x < y || x > y; + } + + /// Comparison implementation. + /// \param x first operand + /// \param y second operand + /// \retval true if operand unordered + /// \retval false else + MEGDNN_HOST MEGDNN_DEVICE static bool isunordered(bfloat16 x, bfloat16 y) { + return isnan(x) || isnan(y); + } +}; + +/// Wrapper for unary bfloat16 functions needing specialization for +/// individual argument types. \tparam T argument type +template +struct unary_specialized { + /// Negation implementation. + /// \param arg value to negate + /// \return negated value + MEGDNN_HOST MEGDNN_DEVICE static HALF_CONSTEXPR bfloat16 + negate(bfloat16 arg) { + return bfloat16(binary_t(), arg.data_ ^ 0x8000); + } + + /// Absolute value implementation. + /// \param arg function argument + /// \return absolute value + MEGDNN_HOST MEGDNN_DEVICE static bfloat16 fabs(bfloat16 arg) { + return bfloat16(binary_t(), arg.data_ & 0x7FFF); + } +}; +template <> +struct unary_specialized { + MEGDNN_HOST MEGDNN_DEVICE static HALF_CONSTEXPR expr negate(float arg) { + return expr(-arg); + } + MEGDNN_HOST MEGDNN_DEVICE static expr fabs(float arg) { +#if defined(__CUDA_ARCH__) + return expr(fabsf(arg)); +#else + return expr(std::fabs(arg)); +#endif + } +}; + +/// Wrapper for binary_t() bfloat16-precision functions needing +/// specialization for individual argument types. \tparam T first argument +/// type \tparam U first argument type +template +struct binary_specialized { + /// Minimum implementation. + /// \param x first operand + /// \param y second operand + /// \return minimum value + MEGDNN_HOST MEGDNN_DEVICE static expr fmin(float x, float y) { + return detail::functions::fmin(x, y); + } + + /// Maximum implementation. + /// \param x first operand + /// \param y second operand + /// \return maximum value + MEGDNN_HOST MEGDNN_DEVICE static expr fmax(float x, float y) { + return detail::functions::fmax(x, y); + } +}; +template <> +struct binary_specialized { + MEGDNN_HOST MEGDNN_DEVICE static bfloat16 fmin(bfloat16 x, bfloat16 y) { + return bfloat16(binary_t(), + float2bfloat16( + static_cast(functions::fmin(x, y)))); + } + MEGDNN_HOST MEGDNN_DEVICE static bfloat16 fmax(bfloat16 x, bfloat16 y) { + return bfloat16(binary_t(), + float2bfloat16( + static_cast(functions::fmax(x, y)))); + } +}; + +/// Helper class for bfloat16 casts. +/// This class template has to be specialized for all valid cast argument to +/// define an appropriate static `cast` member function and a corresponding +/// `type` member denoting its return type. \tparam T destination type +/// \tparam U source type \tparam R rounding mode to use +template +struct bfloat16_caster {}; +template +struct bfloat16_caster { +#if HALF_ENABLE_CPP11_STATIC_ASSERT && HALF_ENABLE_CPP11_TYPE_TRAITS + static_assert(std::is_arithmetic::value, + "bfloat16_cast from non-arithmetic type unsupported"); +#endif + + typedef bfloat16 type; + MEGDNN_HOST MEGDNN_DEVICE static bfloat16 cast(U arg) { + return cast_impl(arg, is_float()); + }; + +private: + MEGDNN_HOST MEGDNN_DEVICE static bfloat16 cast_impl(U arg, true_type) { + return bfloat16(binary_t(), float2bfloat16(static_cast(arg))); + } + MEGDNN_HOST MEGDNN_DEVICE static bfloat16 cast_impl(U arg, false_type) { + return bfloat16(binary_t(), int2bfloat16(arg)); + } +}; +template +struct bfloat16_caster { +#if HALF_ENABLE_CPP11_STATIC_ASSERT && HALF_ENABLE_CPP11_TYPE_TRAITS + static_assert(std::is_arithmetic::value, + "bfloat16_cast to non-arithmetic type unsupported"); +#endif + + typedef T type; + template + MEGDNN_HOST MEGDNN_DEVICE static T cast(U arg) { + return cast_impl(arg, is_float()); + } + +private: + MEGDNN_HOST MEGDNN_DEVICE static T cast_impl(float arg, true_type) { + return static_cast(arg); + } + MEGDNN_HOST MEGDNN_DEVICE static T cast_impl(bfloat16 arg, false_type) { + return bfloat162int(arg.data_); + } +}; +template +struct bfloat16_caster : public bfloat16_caster {}; +template +struct bfloat16_caster { + typedef bfloat16 type; + MEGDNN_HOST MEGDNN_DEVICE static bfloat16 cast(bfloat16 arg) { return arg; } +}; +template +struct bfloat16_caster + : public bfloat16_caster {}; + +/// \name Comparison operators +/// \{ + +/// Comparison for equality. +/// \param x first operand +/// \param y second operand +/// \retval true if operands equal +/// \retval false else +template +MEGDNN_HOST MEGDNN_DEVICE typename enable::type operator==(T x, + U y) { + return functions::isequal(x, y); +} + +/// Comparison for inequality. +/// \param x first operand +/// \param y second operand +/// \retval true if operands not equal +/// \retval false else +template +MEGDNN_HOST MEGDNN_DEVICE typename enable::type operator!=(T x, + U y) { + return functions::isnotequal(x, y); +} + +/// Comparison for less than. +/// \param x first operand +/// \param y second operand +/// \retval true if \a x less than \a y +/// \retval false else +template +MEGDNN_HOST MEGDNN_DEVICE typename enable::type operator<(T x, + U y) { + return functions::isless(x, y); +} + +/// Comparison for greater than. +/// \param x first operand +/// \param y second operand +/// \retval true if \a x greater than \a y +/// \retval false else +template +MEGDNN_HOST MEGDNN_DEVICE typename enable::type operator>(T x, + U y) { + return functions::isgreater(x, y); +} + +/// Comparison for less equal. +/// \param x first operand +/// \param y second operand +/// \retval true if \a x less equal \a y +/// \retval false else +template +MEGDNN_HOST MEGDNN_DEVICE typename enable::type operator<=(T x, + U y) { + return functions::islessequal(x, y); +} + +/// Comparison for greater equal. +/// \param x first operand +/// \param y second operand +/// \retval true if \a x greater equal \a y +/// \retval false else +template +MEGDNN_HOST MEGDNN_DEVICE typename enable::type operator>=(T x, + U y) { + return functions::isgreaterequal(x, y); +} + +/// \} +/// \name Arithmetic operators +/// \{ + +/// Add bfloat16s. +/// \param x left operand +/// \param y right operand +/// \return sum of bfloat16 expressions +template +MEGDNN_HOST MEGDNN_DEVICE typename enable::type operator+(T x, + U y) { + return functions::plus(x, y); +} + +/// Subtract bfloat16s. +/// \param x left operand +/// \param y right operand +/// \return difference of bfloat16 expressions +template +MEGDNN_HOST MEGDNN_DEVICE typename enable::type operator-(T x, + U y) { + return functions::minus(x, y); +} + +/// Multiply bfloat16s. +/// \param x left operand +/// \param y right operand +/// \return product of bfloat16 expressions +template +MEGDNN_HOST MEGDNN_DEVICE typename enable::type operator*(T x, + U y) { + return functions::multiplies(x, y); +} + +/// Divide bfloat16s. +/// \param x left operand +/// \param y right operand +/// \return quotient of bfloat16 expressions +template +MEGDNN_HOST MEGDNN_DEVICE typename enable::type operator/(T x, + U y) { + return functions::divides(x, y); +} + +/// Identity. +/// \param arg operand +/// \return uncahnged operand +template +MEGDNN_HOST MEGDNN_DEVICE HALF_CONSTEXPR typename enable::type operator+( + T arg) { + return arg; +} + +/// Negation. +/// \param arg operand +/// \return negated operand +template +MEGDNN_HOST MEGDNN_DEVICE HALF_CONSTEXPR typename enable::type operator-( + T arg) { + return unary_specialized::negate(arg); +} + +/// \} +/// \name Input and output +/// \{ + +/// Output operator. +/// \param out output stream to write into +/// \param arg bfloat16 expression to write +/// \return reference to output stream +template +typename enable&, T>::type operator<<( + std::basic_ostream& out, T arg) { + return functions::write(out, arg); +} + +/// Input operator. +/// \param in input stream to read from +/// \param arg bfloat16 to read into +/// \return reference to input stream +template +std::basic_istream& operator>>( + std::basic_istream& in, bfloat16& arg) { + return functions::read(in, arg); +} + +/// \} +/// \name Basic mathematical operations +/// \{ + +/// Absolute value. +/// \param arg operand +/// \return absolute value of \a arg +// template typename enable::type abs(T arg) { +// return unary_specialized::fabs(arg); } +MEGDNN_HOST MEGDNN_DEVICE inline bfloat16 abs(bfloat16 arg) { + return unary_specialized::fabs(arg); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr abs(expr arg) { + return unary_specialized::fabs(arg); +} + +/// Absolute value. +/// \param arg operand +/// \return absolute value of \a arg +// template typename enable::type fabs(T arg) { +// return unary_specialized::fabs(arg); } +MEGDNN_HOST MEGDNN_DEVICE inline bfloat16 fabs(bfloat16 arg) { + return unary_specialized::fabs(arg); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr fabs(expr arg) { + return unary_specialized::fabs(arg); +} + +/// Remainder of division. +/// \param x first operand +/// \param y second operand +/// \return remainder of floating point division. +// template typename enable::type +// fmod(T x, U y) { return functions::fmod(x, y); } +MEGDNN_HOST MEGDNN_DEVICE inline expr fmod(bfloat16 x, bfloat16 y) { + return functions::fmod(x, y); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr fmod(bfloat16 x, expr y) { + return functions::fmod(x, y); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr fmod(expr x, bfloat16 y) { + return functions::fmod(x, y); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr fmod(expr x, expr y) { + return functions::fmod(x, y); +} + +/// Remainder of division. +/// \param x first operand +/// \param y second operand +/// \return remainder of floating point division. +// template typename enable::type +// remainder(T x, U y) { return functions::remainder(x, y); } +MEGDNN_HOST MEGDNN_DEVICE inline expr remainder(bfloat16 x, bfloat16 y) { + return functions::remainder(x, y); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr remainder(bfloat16 x, expr y) { + return functions::remainder(x, y); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr remainder(expr x, bfloat16 y) { + return functions::remainder(x, y); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr remainder(expr x, expr y) { + return functions::remainder(x, y); +} + +/// Fused multiply add. +/// \param x first operand +/// \param y second operand +/// \param z third operand +/// \return ( \a x * \a y ) + \a z rounded as one operation. +// template typename +// enable::type fma(T x, U y, V z) { return functions::fma(x, y, +// z); +//} +MEGDNN_HOST MEGDNN_DEVICE inline expr fma(bfloat16 x, bfloat16 y, bfloat16 z) { + return functions::fma(x, y, z); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr fma(bfloat16 x, bfloat16 y, expr z) { + return functions::fma(x, y, z); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr fma(bfloat16 x, expr y, bfloat16 z) { + return functions::fma(x, y, z); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr fma(bfloat16 x, expr y, expr z) { + return functions::fma(x, y, z); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr fma(expr x, bfloat16 y, bfloat16 z) { + return functions::fma(x, y, z); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr fma(expr x, bfloat16 y, expr z) { + return functions::fma(x, y, z); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr fma(expr x, expr y, bfloat16 z) { + return functions::fma(x, y, z); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr fma(expr x, expr y, expr z) { + return functions::fma(x, y, z); +} + +/// Maximum of bfloat16 expressions. +/// \param x first operand +/// \param y second operand +/// \return maximum of operands +// template typename result::type +// fmax(T x, U y) { return binary_specialized::fmax(x, y); } +MEGDNN_HOST MEGDNN_DEVICE inline bfloat16 fmax(bfloat16 x, bfloat16 y) { + return binary_specialized::fmax(x, y); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr fmax(bfloat16 x, expr y) { + return binary_specialized::fmax(x, y); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr fmax(expr x, bfloat16 y) { + return binary_specialized::fmax(x, y); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr fmax(expr x, expr y) { + return binary_specialized::fmax(x, y); +} + +/// Minimum of bfloat16 expressions. +/// \param x first operand +/// \param y second operand +/// \return minimum of operands +// template typename result::type +// fmin(T x, U y) { return binary_specialized::fmin(x, y); } +MEGDNN_HOST MEGDNN_DEVICE inline bfloat16 fmin(bfloat16 x, bfloat16 y) { + return binary_specialized::fmin(x, y); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr fmin(bfloat16 x, expr y) { + return binary_specialized::fmin(x, y); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr fmin(expr x, bfloat16 y) { + return binary_specialized::fmin(x, y); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr fmin(expr x, expr y) { + return binary_specialized::fmin(x, y); +} + +/// Positive difference. +/// \param x first operand +/// \param y second operand +/// \return \a x - \a y or 0 if difference negative +// template typename enable::type +// fdim(T x, U y) { return functions::fdim(x, y); } +MEGDNN_HOST MEGDNN_DEVICE inline expr fdim(bfloat16 x, bfloat16 y) { + return functions::fdim(x, y); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr fdim(bfloat16 x, expr y) { + return functions::fdim(x, y); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr fdim(expr x, bfloat16 y) { + return functions::fdim(x, y); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr fdim(expr x, expr y) { + return functions::fdim(x, y); +} + +/// Get NaN value. +/// \param arg descriptive string (ignored) +/// \return quiet NaN +MEGDNN_HOST MEGDNN_DEVICE inline bfloat16 nanh(const char* arg) { + return functions::nanh(arg); +} + +/// \} +/// \name Exponential functions +/// \{ + +/// Exponential function. +/// \param arg function argument +/// \return e raised to \a arg +// template typename enable::type exp(T arg) { +// return functions::exp(arg); } +MEGDNN_HOST MEGDNN_DEVICE inline expr exp(bfloat16 arg) { + return functions::exp(arg); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr exp(expr arg) { + return functions::exp(arg); +} + +/// Exponential minus one. +/// \param arg function argument +/// \return e raised to \a arg subtracted by 1 +// template typename enable::type expm1(T arg) +//{ return functions::expm1(arg); } +MEGDNN_HOST MEGDNN_DEVICE inline expr expm1(bfloat16 arg) { + return functions::expm1(arg); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr expm1(expr arg) { + return functions::expm1(arg); +} + +/// Binary exponential. +/// \param arg function argument +/// \return 2 raised to \a arg +// template typename enable::type exp2(T arg) { +// return functions::exp2(arg); } +MEGDNN_HOST MEGDNN_DEVICE inline expr exp2(bfloat16 arg) { + return functions::exp2(arg); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr exp2(expr arg) { + return functions::exp2(arg); +} + +/// Natural logorithm. +/// \param arg function argument +/// \return logarithm of \a arg to base e +// template typename enable::type log(T arg) { +// return functions::log(arg); } +MEGDNN_HOST MEGDNN_DEVICE inline expr log(bfloat16 arg) { + return functions::log(arg); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr log(expr arg) { + return functions::log(arg); +} + +/// Common logorithm. +/// \param arg function argument +/// \return logarithm of \a arg to base 10 +// template typename enable::type log10(T arg) +//{ return functions::log10(arg); } +MEGDNN_HOST MEGDNN_DEVICE inline expr log10(bfloat16 arg) { + return functions::log10(arg); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr log10(expr arg) { + return functions::log10(arg); +} + +/// Natural logorithm. +/// \param arg function argument +/// \return logarithm of \a arg plus 1 to base e +// template typename enable::type log1p(T arg) +//{ return functions::log1p(arg); } +MEGDNN_HOST MEGDNN_DEVICE inline expr log1p(bfloat16 arg) { + return functions::log1p(arg); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr log1p(expr arg) { + return functions::log1p(arg); +} + +/// Binary logorithm. +/// \param arg function argument +/// \return logarithm of \a arg to base 2 +// template typename enable::type log2(T arg) { +// return functions::log2(arg); } +MEGDNN_HOST MEGDNN_DEVICE inline expr log2(bfloat16 arg) { + return functions::log2(arg); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr log2(expr arg) { + return functions::log2(arg); +} + +/// \} +/// \name Power functions +/// \{ + +/// Square root. +/// \param arg function argument +/// \return square root of \a arg +// template typename enable::type sqrt(T arg) { +// return functions::sqrt(arg); } +MEGDNN_HOST MEGDNN_DEVICE inline expr sqrt(bfloat16 arg) { + return functions::sqrt(arg); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr sqrt(expr arg) { + return functions::sqrt(arg); +} + +/// Cubic root. +/// \param arg function argument +/// \return cubic root of \a arg +// template typename enable::type cbrt(T arg) { +// return functions::cbrt(arg); } +MEGDNN_HOST MEGDNN_DEVICE inline expr cbrt(bfloat16 arg) { + return functions::cbrt(arg); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr cbrt(expr arg) { + return functions::cbrt(arg); +} + +/// Hypotenuse function. +/// \param x first argument +/// \param y second argument +/// \return square root of sum of squares without internal over- or +/// underflows +// template typename enable::type +// hypot(T x, U y) { return functions::hypot(x, y); } +MEGDNN_HOST MEGDNN_DEVICE inline expr hypot(bfloat16 x, bfloat16 y) { + return functions::hypot(x, y); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr hypot(bfloat16 x, expr y) { + return functions::hypot(x, y); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr hypot(expr x, bfloat16 y) { + return functions::hypot(x, y); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr hypot(expr x, expr y) { + return functions::hypot(x, y); +} + +/// Power function. +/// \param base first argument +/// \param exp second argument +/// \return \a base raised to \a exp +// template typename enable::type +// pow(T base, U exp) { return functions::pow(base, exp); } +MEGDNN_HOST MEGDNN_DEVICE inline expr pow(bfloat16 base, bfloat16 exp) { + return functions::pow(base, exp); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr pow(bfloat16 base, expr exp) { + return functions::pow(base, exp); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr pow(expr base, bfloat16 exp) { + return functions::pow(base, exp); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr pow(expr base, expr exp) { + return functions::pow(base, exp); +} + +/// \} +/// \name Trigonometric functions +/// \{ + +/// Sine function. +/// \param arg function argument +/// \return sine value of \a arg +// template typename enable::type sin(T arg) { +// return functions::sin(arg); } +MEGDNN_HOST MEGDNN_DEVICE inline expr sin(bfloat16 arg) { + return functions::sin(arg); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr sin(expr arg) { + return functions::sin(arg); +} + +/// Cosine function. +/// \param arg function argument +/// \return cosine value of \a arg +// template typename enable::type cos(T arg) { +// return functions::cos(arg); } +MEGDNN_HOST MEGDNN_DEVICE inline expr cos(bfloat16 arg) { + return functions::cos(arg); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr cos(expr arg) { + return functions::cos(arg); +} + +/// Tangent function. +/// \param arg function argument +/// \return tangent value of \a arg +// template typename enable::type tan(T arg) { +// return functions::tan(arg); } +MEGDNN_HOST MEGDNN_DEVICE inline expr tan(bfloat16 arg) { + return functions::tan(arg); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr tan(expr arg) { + return functions::tan(arg); +} + +/// Arc sine. +/// \param arg function argument +/// \return arc sine value of \a arg +// template typename enable::type asin(T arg) { +// return functions::asin(arg); } +MEGDNN_HOST MEGDNN_DEVICE inline expr asin(bfloat16 arg) { + return functions::asin(arg); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr asin(expr arg) { + return functions::asin(arg); +} + +/// Arc cosine function. +/// \param arg function argument +/// \return arc cosine value of \a arg +// template typename enable::type acos(T arg) { +// return functions::acos(arg); } +MEGDNN_HOST MEGDNN_DEVICE inline expr acos(bfloat16 arg) { + return functions::acos(arg); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr acos(expr arg) { + return functions::acos(arg); +} + +/// Arc tangent function. +/// \param arg function argument +/// \return arc tangent value of \a arg +// template typename enable::type atan(T arg) { +// return functions::atan(arg); } +MEGDNN_HOST MEGDNN_DEVICE inline expr atan(bfloat16 arg) { + return functions::atan(arg); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr atan(expr arg) { + return functions::atan(arg); +} + +/// Arc tangent function. +/// \param x first argument +/// \param y second argument +/// \return arc tangent value +// template typename enable::type +// atan2(T x, U y) { return functions::atan2(x, y); } +MEGDNN_HOST MEGDNN_DEVICE inline expr atan2(bfloat16 x, bfloat16 y) { + return functions::atan2(x, y); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr atan2(bfloat16 x, expr y) { + return functions::atan2(x, y); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr atan2(expr x, bfloat16 y) { + return functions::atan2(x, y); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr atan2(expr x, expr y) { + return functions::atan2(x, y); +} + +/// \} +/// \name Hyperbolic functions +/// \{ + +/// Hyperbolic sine. +/// \param arg function argument +/// \return hyperbolic sine value of \a arg +// template typename enable::type sinh(T arg) { +// return functions::sinh(arg); } +MEGDNN_HOST MEGDNN_DEVICE inline expr sinh(bfloat16 arg) { + return functions::sinh(arg); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr sinh(expr arg) { + return functions::sinh(arg); +} + +/// Hyperbolic cosine. +/// \param arg function argument +/// \return hyperbolic cosine value of \a arg +// template typename enable::type cosh(T arg) { +// return functions::cosh(arg); } +MEGDNN_HOST MEGDNN_DEVICE inline expr cosh(bfloat16 arg) { + return functions::cosh(arg); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr cosh(expr arg) { + return functions::cosh(arg); +} + +/// Hyperbolic tangent. +/// \param arg function argument +/// \return hyperbolic tangent value of \a arg +// template typename enable::type tanh(T arg) { +// return functions::tanh(arg); } +MEGDNN_HOST MEGDNN_DEVICE inline expr tanh(bfloat16 arg) { + return functions::tanh(arg); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr tanh(expr arg) { + return functions::tanh(arg); +} + +/// Hyperbolic area sine. +/// \param arg function argument +/// \return area sine value of \a arg +// template typename enable::type asinh(T arg) +//{ return functions::asinh(arg); } +MEGDNN_HOST MEGDNN_DEVICE inline expr asinh(bfloat16 arg) { + return functions::asinh(arg); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr asinh(expr arg) { + return functions::asinh(arg); +} + +/// Hyperbolic area cosine. +/// \param arg function argument +/// \return area cosine value of \a arg +// template typename enable::type acosh(T arg) +//{ return functions::acosh(arg); } +MEGDNN_HOST MEGDNN_DEVICE inline expr acosh(bfloat16 arg) { + return functions::acosh(arg); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr acosh(expr arg) { + return functions::acosh(arg); +} + +/// Hyperbolic area tangent. +/// \param arg function argument +/// \return area tangent value of \a arg +// template typename enable::type atanh(T arg) +//{ return functions::atanh(arg); } +MEGDNN_HOST MEGDNN_DEVICE inline expr atanh(bfloat16 arg) { + return functions::atanh(arg); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr atanh(expr arg) { + return functions::atanh(arg); +} + +/// \} +/// \name Error and gamma functions +/// \{ + +/// Error function. +/// \param arg function argument +/// \return error function value of \a arg +// template typename enable::type erf(T arg) { +// return functions::erf(arg); } +MEGDNN_HOST MEGDNN_DEVICE inline expr erf(bfloat16 arg) { + return functions::erf(arg); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr erf(expr arg) { + return functions::erf(arg); +} + +/// Complementary error function. +/// \param arg function argument +/// \return 1 minus error function value of \a arg +// template typename enable::type erfc(T arg) { +// return functions::erfc(arg); } +MEGDNN_HOST MEGDNN_DEVICE inline expr erfc(bfloat16 arg) { + return functions::erfc(arg); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr erfc(expr arg) { + return functions::erfc(arg); +} + +/// Natural logarithm of gamma function. +/// \param arg function argument +/// \return natural logarith of gamma function for \a arg +// template typename enable::type lgamma(T arg) +//{ return functions::lgamma(arg); } +MEGDNN_HOST MEGDNN_DEVICE inline expr lgamma(bfloat16 arg) { + return functions::lgamma(arg); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr lgamma(expr arg) { + return functions::lgamma(arg); +} + +/// Gamma function. +/// \param arg function argument +/// \return gamma function value of \a arg +// template typename enable::type tgamma(T arg) +//{ return functions::tgamma(arg); } +MEGDNN_HOST MEGDNN_DEVICE inline expr tgamma(bfloat16 arg) { + return functions::tgamma(arg); +} +MEGDNN_HOST MEGDNN_DEVICE inline expr tgamma(expr arg) { + return functions::tgamma(arg); +} + +/// \} +/// \name Rounding +/// \{ + +/// Nearest integer not less than bfloat16 value. +/// \param arg bfloat16 to round +/// \return nearest integer not less than \a arg +// template typename enable::type ceil(T +// arg) { return functions::ceil(arg); } +MEGDNN_HOST MEGDNN_DEVICE inline bfloat16 ceil(bfloat16 arg) { + return functions::ceil(arg); +} +MEGDNN_HOST MEGDNN_DEVICE inline bfloat16 ceil(expr arg) { + return functions::ceil(arg); +} + +/// Nearest integer not greater than bfloat16 value. +/// \param arg bfloat16 to round +/// \return nearest integer not greater than \a arg +// template typename enable::type floor(T +// arg) { return functions::floor(arg); } +MEGDNN_HOST MEGDNN_DEVICE inline bfloat16 floor(bfloat16 arg) { + return functions::floor(arg); +} +MEGDNN_HOST MEGDNN_DEVICE inline bfloat16 floor(expr arg) { + return functions::floor(arg); +} + +/// Nearest integer not greater in magnitude than bfloat16 value. +/// \param arg bfloat16 to round +/// \return nearest integer not greater in magnitude than \a arg +// template typename enable::type trunc(T +// arg) { return functions::trunc(arg); } +MEGDNN_HOST MEGDNN_DEVICE inline bfloat16 trunc(bfloat16 arg) { + return functions::trunc(arg); +} +MEGDNN_HOST MEGDNN_DEVICE inline bfloat16 trunc(expr arg) { + return functions::trunc(arg); +} + +/// Nearest integer. +/// \param arg bfloat16 to round +/// \return nearest integer, rounded away from zero in bfloat16-way cases +// template typename enable::type round(T +// arg) { return functions::round(arg); } +MEGDNN_HOST MEGDNN_DEVICE inline bfloat16 round(bfloat16 arg) { + return functions::round(arg); +} +MEGDNN_HOST MEGDNN_DEVICE inline bfloat16 round(expr arg) { + return functions::round(arg); +} + +/// Nearest integer using bfloat16's internal rounding mode. +/// \param arg bfloat16 expression to round +/// \return nearest integer using default rounding mode +// template typename enable::type +// nearbyint(T arg) { return functions::nearbyint(arg); } +MEGDNN_HOST MEGDNN_DEVICE inline bfloat16 nearbyint(bfloat16 arg) { + return functions::rint(arg); +} +MEGDNN_HOST MEGDNN_DEVICE inline bfloat16 nearbyint(expr arg) { + return functions::rint(arg); +} + +/// Nearest integer using bfloat16's internal rounding mode. +/// \param arg bfloat16 expression to round +/// \return nearest integer using default rounding mode +// template typename enable::type rint(T +// arg) { return functions::rint(arg); } +MEGDNN_HOST MEGDNN_DEVICE inline bfloat16 rint(bfloat16 arg) { + return functions::rint(arg); +} +MEGDNN_HOST MEGDNN_DEVICE inline bfloat16 rint(expr arg) { + return functions::rint(arg); +} + +/// Nearest integer using bfloat16's internal rounding mode. +/// \param arg bfloat16 expression to round +/// \return nearest integer using default rounding mode +// template typename enable::type lrint(T arg) +//{ return functions::lrint(arg); } +MEGDNN_HOST MEGDNN_DEVICE inline long lrint(bfloat16 arg) { + return functions::lrint(arg); +} +MEGDNN_HOST MEGDNN_DEVICE inline long lrint(expr arg) { + return functions::lrint(arg); +} +#if HALF_ENABLE_CPP11_LONG_LONG +/// Nearest integer using bfloat16's internal rounding mode. +/// \param arg bfloat16 expression to round +/// \return nearest integer using default rounding mode +// template typename enable::type llrint(T +// arg) { return functions::llrint(arg); } +MEGDNN_HOST MEGDNN_DEVICE inline long long llrint(bfloat16 arg) { + return functions::llrint(arg); +} +MEGDNN_HOST MEGDNN_DEVICE inline long long llrint(expr arg) { + return functions::llrint(arg); +} +#endif + +/// \} +/// \name Floating point manipulation +/// \{ + +/// Decompress floating point number. +/// \param arg number to decompress +/// \param exp address to store exponent at +/// \return significant in range [0.5, 1) +// template typename enable::type frexp(T +// arg, int *exp) { return functions::frexp(arg, exp); } +MEGDNN_HOST MEGDNN_DEVICE inline bfloat16 frexp(bfloat16 arg, int* exp) { + return functions::frexp(arg, exp); +} +MEGDNN_HOST MEGDNN_DEVICE inline bfloat16 frexp(expr arg, int* exp) { + return functions::frexp(arg, exp); +} + +/// Multiply by power of two. +/// \param arg number to modify +/// \param exp power of two to multiply with +/// \return \a arg multplied by 2 raised to \a exp +// template typename enable::type ldexp(T +// arg, int exp) { return functions::scalbln(arg, exp); } +MEGDNN_HOST MEGDNN_DEVICE inline bfloat16 ldexp(bfloat16 arg, int exp) { + return functions::scalbln(arg, exp); +} +MEGDNN_HOST MEGDNN_DEVICE inline bfloat16 ldexp(expr arg, int exp) { + return functions::scalbln(arg, exp); +} + +/// Extract integer and fractional parts. +/// \param arg number to decompress +/// \param iptr address to store integer part at +/// \return fractional part +// template typename enable::type modf(T +// arg, bfloat16 *iptr) { return functions::modf(arg, iptr); } +MEGDNN_HOST MEGDNN_DEVICE inline bfloat16 modf(bfloat16 arg, bfloat16* iptr) { + return functions::modf(arg, iptr); +} +MEGDNN_HOST MEGDNN_DEVICE inline bfloat16 modf(expr arg, bfloat16* iptr) { + return functions::modf(arg, iptr); +} + +/// Multiply by power of two. +/// \param arg number to modify +/// \param exp power of two to multiply with +/// \return \a arg multplied by 2 raised to \a exp +// template typename enable::type scalbn(T +// arg, int exp) { return functions::scalbln(arg, exp); } +MEGDNN_HOST MEGDNN_DEVICE inline bfloat16 scalbn(bfloat16 arg, int exp) { + return functions::scalbln(arg, exp); +} +MEGDNN_HOST MEGDNN_DEVICE inline bfloat16 scalbn(expr arg, int exp) { + return functions::scalbln(arg, exp); +} + +/// Multiply by power of two. +/// \param arg number to modify +/// \param exp power of two to multiply with +/// \return \a arg multplied by 2 raised to \a exp +// template typename enable::type scalbln(T +// arg, long exp) { return functions::scalbln(arg, exp); } +MEGDNN_HOST MEGDNN_DEVICE inline bfloat16 scalbln(bfloat16 arg, long exp) { + return functions::scalbln(arg, exp); +} +MEGDNN_HOST MEGDNN_DEVICE inline bfloat16 scalbln(expr arg, long exp) { + return functions::scalbln(arg, exp); +} + +/// Extract exponent. +/// \param arg number to query +/// \return floating point exponent +/// \retval FP_ILOGB0 for zero +/// \retval FP_ILOGBNAN for NaN +/// \retval MAX_INT for infinity +// template typename enable::type ilogb(T arg) { +// return functions::ilogb(arg); } +MEGDNN_HOST MEGDNN_DEVICE inline int ilogb(bfloat16 arg) { + return functions::ilogb(arg); +} +MEGDNN_HOST MEGDNN_DEVICE inline int ilogb(expr arg) { + return functions::ilogb(arg); +} + +/// Extract exponent. +/// \param arg number to query +/// \return floating point exponent +// template typename enable::type logb(T +// arg) { return functions::logb(arg); } +MEGDNN_HOST MEGDNN_DEVICE inline bfloat16 logb(bfloat16 arg) { + return functions::logb(arg); +} +MEGDNN_HOST MEGDNN_DEVICE inline bfloat16 logb(expr arg) { + return functions::logb(arg); +} + +/// Next representable value. +/// \param from value to compute next representable value for +/// \param to direction towards which to compute next value +/// \return next representable value after \a from in direction towards \a +/// to +// template typename +// enable::type nextafter(T from, U to) { return +// functions::nextafter(from, to); } +MEGDNN_HOST MEGDNN_DEVICE inline bfloat16 nextafter(bfloat16 from, + bfloat16 to) { + return functions::nextafter(from, to); +} +MEGDNN_HOST MEGDNN_DEVICE inline bfloat16 nextafter(bfloat16 from, expr to) { + return functions::nextafter(from, to); +} +MEGDNN_HOST MEGDNN_DEVICE inline bfloat16 nextafter(expr from, bfloat16 to) { + return functions::nextafter(from, to); +} +MEGDNN_HOST MEGDNN_DEVICE inline bfloat16 nextafter(expr from, expr to) { + return functions::nextafter(from, to); +} + +/// Take sign. +/// \param x value to change sign for +/// \param y value to take sign from +/// \return value equal to \a x in magnitude and to \a y in sign +// template typename +// enable::type copysign(T x, U y) { return +// functions::copysign(x, y); } +MEGDNN_HOST MEGDNN_DEVICE inline bfloat16 copysign(bfloat16 x, bfloat16 y) { + return functions::copysign(x, y); +} +MEGDNN_HOST MEGDNN_DEVICE inline bfloat16 copysign(bfloat16 x, expr y) { + return functions::copysign(x, y); +} +MEGDNN_HOST MEGDNN_DEVICE inline bfloat16 copysign(expr x, bfloat16 y) { + return functions::copysign(x, y); +} +MEGDNN_HOST MEGDNN_DEVICE inline bfloat16 copysign(expr x, expr y) { + return functions::copysign(x, y); +} + +/// \} +/// \name Floating point classification +/// \{ + +/// Classify floating point value. +/// \param arg number to classify +/// \retval FP_ZERO for positive and negative zero +/// \retval FP_SUBNORMAL for subnormal numbers +/// \retval FP_INFINITY for positive and negative infinity +/// \retval FP_NAN for NaNs +/// \retval FP_NORMAL for all other (normal) values +// template typename enable::type fpclassify(T +// arg) { return functions::fpclassify(arg); } +MEGDNN_HOST MEGDNN_DEVICE inline int fpclassify(bfloat16 arg) { + return functions::fpclassify(arg); +} +MEGDNN_HOST MEGDNN_DEVICE inline int fpclassify(expr arg) { + return functions::fpclassify(arg); +} + +/// Check if finite number. +/// \param arg number to check +/// \retval true if neither infinity nor NaN +/// \retval false else +// template typename enable::type isfinite(T +// arg) { return functions::isfinite(arg); } +MEGDNN_HOST MEGDNN_DEVICE inline bool isfinite(bfloat16 arg) { + return functions::isfinite(arg); +} +MEGDNN_HOST MEGDNN_DEVICE inline bool isfinite(expr arg) { + return functions::isfinite(arg); +} + +/// Check for infinity. +/// \param arg number to check +/// \retval true for positive or negative infinity +/// \retval false else +// template typename enable::type isinf(T arg) +//{ return functions::isinf(arg); } +MEGDNN_HOST MEGDNN_DEVICE inline bool isinf(bfloat16 arg) { + return functions::isinf(arg); +} +MEGDNN_HOST MEGDNN_DEVICE inline bool isinf(expr arg) { + return functions::isinf(arg); +} + +/// Check for NaN. +/// \param arg number to check +/// \retval true for NaNs +/// \retval false else +// template typename enable::type isnan(T arg) +//{ return functions::isnan(arg); } +MEGDNN_HOST MEGDNN_DEVICE inline bool isnan(bfloat16 arg) { + return functions::isnan(arg); +} +MEGDNN_HOST MEGDNN_DEVICE inline bool isnan(expr arg) { + return functions::isnan(arg); +} + +/// Check if normal number. +/// \param arg number to check +/// \retval true if normal number +/// \retval false if either subnormal, zero, infinity or NaN +// template typename enable::type isnormal(T +// arg) { return functions::isnormal(arg); } +MEGDNN_HOST MEGDNN_DEVICE inline bool isnormal(bfloat16 arg) { + return functions::isnormal(arg); +} +MEGDNN_HOST MEGDNN_DEVICE inline bool isnormal(expr arg) { + return functions::isnormal(arg); +} + +/// Check sign. +/// \param arg number to check +/// \retval true for negative number +/// \retval false for positive number +// template typename enable::type signbit(T +// arg) { return functions::signbit(arg); } +MEGDNN_HOST MEGDNN_DEVICE inline bool signbit(bfloat16 arg) { + return functions::signbit(arg); +} +MEGDNN_HOST MEGDNN_DEVICE inline bool signbit(expr arg) { + return functions::signbit(arg); +} + +/// \} +/// \name Comparison +/// \{ + +/// Comparison for greater than. +/// \param x first operand +/// \param y second operand +/// \retval true if \a x greater than \a y +/// \retval false else +// template typename enable::type +// isgreater(T x, U y) { return functions::isgreater(x, y); } +MEGDNN_HOST MEGDNN_DEVICE inline bool isgreater(bfloat16 x, bfloat16 y) { + return functions::isgreater(x, y); +} +MEGDNN_HOST MEGDNN_DEVICE inline bool isgreater(bfloat16 x, expr y) { + return functions::isgreater(x, y); +} +MEGDNN_HOST MEGDNN_DEVICE inline bool isgreater(expr x, bfloat16 y) { + return functions::isgreater(x, y); +} +MEGDNN_HOST MEGDNN_DEVICE inline bool isgreater(expr x, expr y) { + return functions::isgreater(x, y); +} + +/// Comparison for greater equal. +/// \param x first operand +/// \param y second operand +/// \retval true if \a x greater equal \a y +/// \retval false else +// template typename enable::type +// isgreaterequal(T x, U y) { return functions::isgreaterequal(x, y); } +MEGDNN_HOST MEGDNN_DEVICE inline bool isgreaterequal(bfloat16 x, bfloat16 y) { + return functions::isgreaterequal(x, y); +} +MEGDNN_HOST MEGDNN_DEVICE inline bool isgreaterequal(bfloat16 x, expr y) { + return functions::isgreaterequal(x, y); +} +MEGDNN_HOST MEGDNN_DEVICE inline bool isgreaterequal(expr x, bfloat16 y) { + return functions::isgreaterequal(x, y); +} +MEGDNN_HOST MEGDNN_DEVICE inline bool isgreaterequal(expr x, expr y) { + return functions::isgreaterequal(x, y); +} + +/// Comparison for less than. +/// \param x first operand +/// \param y second operand +/// \retval true if \a x less than \a y +/// \retval false else +// template typename enable::type +// isless(T x, U y) { return functions::isless(x, y); } +MEGDNN_HOST MEGDNN_DEVICE inline bool isless(bfloat16 x, bfloat16 y) { + return functions::isless(x, y); +} +MEGDNN_HOST MEGDNN_DEVICE inline bool isless(bfloat16 x, expr y) { + return functions::isless(x, y); +} +MEGDNN_HOST MEGDNN_DEVICE inline bool isless(expr x, bfloat16 y) { + return functions::isless(x, y); +} +MEGDNN_HOST MEGDNN_DEVICE inline bool isless(expr x, expr y) { + return functions::isless(x, y); +} + +/// Comparison for less equal. +/// \param x first operand +/// \param y second operand +/// \retval true if \a x less equal \a y +/// \retval false else +// template typename enable::type +// islessequal(T x, U y) { return functions::islessequal(x, y); } +MEGDNN_HOST MEGDNN_DEVICE inline bool islessequal(bfloat16 x, bfloat16 y) { + return functions::islessequal(x, y); +} +MEGDNN_HOST MEGDNN_DEVICE inline bool islessequal(bfloat16 x, expr y) { + return functions::islessequal(x, y); +} +MEGDNN_HOST MEGDNN_DEVICE inline bool islessequal(expr x, bfloat16 y) { + return functions::islessequal(x, y); +} +MEGDNN_HOST MEGDNN_DEVICE inline bool islessequal(expr x, expr y) { + return functions::islessequal(x, y); +} + +/// Comarison for less or greater. +/// \param x first operand +/// \param y second operand +/// \retval true if either less or greater +/// \retval false else +// template typename enable::type +// islessgreater(T x, U y) { return functions::islessgreater(x, y); } +MEGDNN_HOST MEGDNN_DEVICE inline bool islessgreater(bfloat16 x, bfloat16 y) { + return functions::islessgreater(x, y); +} +MEGDNN_HOST MEGDNN_DEVICE inline bool islessgreater(bfloat16 x, expr y) { + return functions::islessgreater(x, y); +} +MEGDNN_HOST MEGDNN_DEVICE inline bool islessgreater(expr x, bfloat16 y) { + return functions::islessgreater(x, y); +} +MEGDNN_HOST MEGDNN_DEVICE inline bool islessgreater(expr x, expr y) { + return functions::islessgreater(x, y); +} + +/// Check if unordered. +/// \param x first operand +/// \param y second operand +/// \retval true if unordered (one or two NaN operands) +/// \retval false else +// template typename enable::type +// isunordered(T x, U y) { return functions::isunordered(x, y); } +MEGDNN_HOST MEGDNN_DEVICE inline bool isunordered(bfloat16 x, bfloat16 y) { + return functions::isunordered(x, y); +} +MEGDNN_HOST MEGDNN_DEVICE inline bool isunordered(bfloat16 x, expr y) { + return functions::isunordered(x, y); +} +MEGDNN_HOST MEGDNN_DEVICE inline bool isunordered(expr x, bfloat16 y) { + return functions::isunordered(x, y); +} +MEGDNN_HOST MEGDNN_DEVICE inline bool isunordered(expr x, expr y) { + return functions::isunordered(x, y); +} + +/// \name Casting +/// \{ + +/// Cast to or from bfloat16-precision floating point number. +/// This casts between [bfloat16](\ref bfloat16_float::bfloat16) and any +/// built-in arithmetic type. Floating point types are converted via an +/// explicit cast to/from `float` (using the rounding mode of the built-in +/// single precision implementation) and thus any possible warnings due to +/// an otherwise implicit conversion to/from `float` will be suppressed. +/// Integer types are converted directly using the given rounding mode, +/// without any roundtrip over `float` that a `static_cast` would otherwise +/// do. It uses the default rounding mode. +/// +/// Using this cast with neither of the two types being a [bfloat16](\ref +/// bfloat16_float::bfloat16) or with any of the two types not being a +/// built-in arithmetic type (apart from [bfloat16](\ref +/// bfloat16_float::bfloat16), of course) results in a compiler error and +/// casting between [bfloat16](\ref bfloat16_float::bfloat16)s is just a +/// no-op. \tparam T destination type (bfloat16 or built-in arithmetic type) +/// \tparam U source type (bfloat16 or built-in arithmetic type) \param arg +/// value to cast \return \a arg converted to destination type +template +MEGDNN_HOST MEGDNN_DEVICE typename bfloat16_caster::type bfloat16_cast( + U arg) { + return bfloat16_caster::cast(arg); +} + +/// Cast to or from bfloat16-precision floating point number. +/// This casts between [bfloat16](\ref bfloat16_float::bfloat16) and any +/// built-in arithmetic type. Floating point types are converted via an +/// explicit cast to/from `float` (using the rounding mode of the built-in +/// single precision implementation) and thus any possible warnings due to +/// an otherwise implicit conversion to/from `float` will be suppressed. +/// Integer types are converted directly using the given rounding mode, +/// without any roundtrip over `float` that a `static_cast` would otherwise +/// do. +/// +/// Using this cast with neither of the two types being a [bfloat16](\ref +/// bfloat16_float::bfloat16) or with any of the two types not being a +/// built-in arithmetic type (apart from [bfloat16](\ref +/// bfloat16_float::bfloat16), of course) results in a compiler error and +/// casting between [bfloat16](\ref bfloat16_float::bfloat16)s is just a +/// no-op. \tparam T destination type (bfloat16 or built-in arithmetic type) +/// \tparam R rounding mode to use. \tparam U source type (bfloat16 or +/// built-in arithmetic type) \param arg value to cast \return \a arg +/// converted to destination type +template +MEGDNN_HOST MEGDNN_DEVICE typename bfloat16_caster::type bfloat16_cast( + U arg) { + return bfloat16_caster::cast(arg); +} +/// \} +} // namespace detail + +using detail::operator==; +using detail::operator!=; +using detail::operator<; +using detail::operator>; +using detail::operator<=; +using detail::operator>=; +using detail::operator+; +using detail::operator-; +using detail::operator*; +using detail::operator/; +using detail::operator<<; +using detail::operator>>; + +using detail::abs; +using detail::acos; +using detail::acosh; +using detail::asin; +using detail::asinh; +using detail::atan; +using detail::atan2; +using detail::atanh; +using detail::cbrt; +using detail::ceil; +using detail::cos; +using detail::cosh; +using detail::erf; +using detail::erfc; +using detail::exp; +using detail::exp2; +using detail::expm1; +using detail::fabs; +using detail::fdim; +using detail::floor; +using detail::fma; +using detail::fmax; +using detail::fmin; +using detail::fmod; +using detail::hypot; +using detail::lgamma; +using detail::log; +using detail::log10; +using detail::log1p; +using detail::log2; +using detail::lrint; +using detail::nanh; +using detail::nearbyint; +using detail::pow; +using detail::remainder; +using detail::rint; +using detail::round; +using detail::sin; +using detail::sinh; +using detail::sqrt; +using detail::tan; +using detail::tanh; +using detail::tgamma; +using detail::trunc; +#if HALF_ENABLE_CPP11_LONG_LONG +using detail::llrint; +#endif +using detail::copysign; +using detail::fpclassify; +using detail::frexp; +using detail::ilogb; +using detail::isfinite; +using detail::isgreater; +using detail::isgreaterequal; +using detail::isinf; +using detail::isless; +using detail::islessequal; +using detail::islessgreater; +using detail::isnan; +using detail::isnormal; +using detail::isunordered; +using detail::ldexp; +using detail::logb; +using detail::modf; +using detail::nextafter; +using detail::scalbln; +using detail::scalbn; +using detail::signbit; + +using detail::bfloat16_cast; +} // namespace half_bfloat16 + +/// Extensions to the C++ standard library. +namespace std { +/// Numeric limits for bfloat16-precision floats. +/// Because of the underlying single-precision implementation of many +/// operations, it inherits some properties from `numeric_limits`. +template <> +class numeric_limits : public numeric_limits { +public: + /// Supports signed values. + static HALF_CONSTEXPR_CONST bool is_signed = true; + + /// Is not exact. + static HALF_CONSTEXPR_CONST bool is_exact = false; + + /// Doesn't provide modulo arithmetic. + static HALF_CONSTEXPR_CONST bool is_modulo = false; + + /// IEEE conformant. + static HALF_CONSTEXPR_CONST bool is_iec559 = false; + + /// Supports infinity. + static HALF_CONSTEXPR_CONST bool has_infinity = true; + + /// Supports quiet NaNs. + static HALF_CONSTEXPR_CONST bool has_quiet_NaN = true; + + /// Supports subnormal values. + static HALF_CONSTEXPR_CONST float_denorm_style has_denorm = denorm_present; + + /// Rounding mode. + /// Due to the mix of internal single-precision computations (using the + /// rounding mode of the underlying single-precision implementation) with + /// explicit truncation of the single-to-bfloat16 conversions, the actual + /// rounding mode is indeterminate. + static HALF_CONSTEXPR_CONST float_round_style round_style = + (numeric_limits::round_style == + half_bfloat16::bfloat16::round_style) + ? half_bfloat16::bfloat16::round_style + : round_indeterminate; + + /// Significant digits. + static HALF_CONSTEXPR_CONST int digits = 8; + + /// Significant decimal digits. + static HALF_CONSTEXPR_CONST int digits10 = 2; + + /// Required decimal digits to represent all possible values. + static HALF_CONSTEXPR_CONST int max_digits10 = 4; + + /// Number base. + static HALF_CONSTEXPR_CONST int radix = 2; + + /// One more than smallest exponent. + static HALF_CONSTEXPR_CONST int min_exponent = -125; + + /// Smallest normalized representable power of 10. + static HALF_CONSTEXPR_CONST int min_exponent10 = -37; + + /// One more than largest exponent + static HALF_CONSTEXPR_CONST int max_exponent = 128; + + /// Largest finitely representable power of 10. + static HALF_CONSTEXPR_CONST int max_exponent10 = 38; + + /// Smallest positive normal value. + MEGDNN_HOST MEGDNN_DEVICE static HALF_CONSTEXPR half_bfloat16::bfloat16 + min() HALF_NOTHROW { + return half_bfloat16::bfloat16(half_bfloat16::detail::binary_t(), + 0x0080); + } + + /// Smallest finite value. + MEGDNN_HOST MEGDNN_DEVICE static HALF_CONSTEXPR half_bfloat16::bfloat16 + lowest() HALF_NOTHROW { + return half_bfloat16::bfloat16(half_bfloat16::detail::binary_t(), + 0xFF7F); + } + + /// Largest finite value. + MEGDNN_HOST MEGDNN_DEVICE static HALF_CONSTEXPR half_bfloat16::bfloat16 + max() HALF_NOTHROW { + return half_bfloat16::bfloat16(half_bfloat16::detail::binary_t(), + 0x7F7F); + } + + /// Difference between one and next representable value. + MEGDNN_HOST MEGDNN_DEVICE static HALF_CONSTEXPR half_bfloat16::bfloat16 + epsilon() HALF_NOTHROW { + return half_bfloat16::bfloat16(half_bfloat16::detail::binary_t(), + 0x3C00); + } + + /// Maximum rounding error. + MEGDNN_HOST MEGDNN_DEVICE static HALF_CONSTEXPR half_bfloat16::bfloat16 + round_error() HALF_NOTHROW { + return half_bfloat16::bfloat16( + half_bfloat16::detail::binary_t(), + (round_style == round_to_nearest) ? 0x3F00 : 0x3F80); + } + + /// Positive infinity. + MEGDNN_HOST MEGDNN_DEVICE static HALF_CONSTEXPR half_bfloat16::bfloat16 + infinity() HALF_NOTHROW { + return half_bfloat16::bfloat16(half_bfloat16::detail::binary_t(), + 0x7F80); + } + + /// Quiet NaN. + MEGDNN_HOST MEGDNN_DEVICE static HALF_CONSTEXPR half_bfloat16::bfloat16 + quiet_NaN() HALF_NOTHROW { + return half_bfloat16::bfloat16(half_bfloat16::detail::binary_t(), + 0x7FFF); + } + + /// Signalling NaN. + MEGDNN_HOST MEGDNN_DEVICE static HALF_CONSTEXPR half_bfloat16::bfloat16 + signaling_NaN() HALF_NOTHROW { + return half_bfloat16::bfloat16(half_bfloat16::detail::binary_t(), + 0x7FBF); + } + + /// Smallest positive subnormal value. + MEGDNN_HOST MEGDNN_DEVICE static HALF_CONSTEXPR half_bfloat16::bfloat16 + denorm_min() HALF_NOTHROW { + return half_bfloat16::bfloat16(half_bfloat16::detail::binary_t(), + 0x0001); + } +}; + +#ifdef MEGDNN_CC_HOST +#if HALF_ENABLE_CPP11_HASH +/// Hash function for bfloat16-precision floats. +/// This is only defined if C++11 `hash` is supported and enabled. +template <> +struct hash +{ + /// Type of function argument. + typedef half_bfloat16::bfloat16 argument_type; + + /// Function return type. + typedef size_t result_type; + + /// Compute hash function. + /// \param arg bfloat16 to hash + /// \return hash value + MEGDNN_HOST MEGDNN_DEVICE result_type operator()(argument_type arg) const { + return hash()( + static_cast(arg.data_) & -(arg.data_ != 0x8000)); + } +}; +#endif +#endif +} // namespace std + +#include "megdnn/dtype/half_common_epilogue.h" + +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/include/megdnn/dtype/half.hpp b/dnn/include/megdnn/dtype/half.hpp index 1621d7bcd1b233c2e387d865c76f3b829a705467..b602f64f716c5698d0e725559d7b6b84e298c276 100644 --- a/dnn/include/megdnn/dtype/half.hpp +++ b/dnn/include/megdnn/dtype/half.hpp @@ -50,167 +50,7 @@ #include #endif -/// Combined gcc version number. -#define HALF_GNUC_VERSION (__GNUC__*100+__GNUC_MINOR__) - -//check C++11 language features -#if defined(__clang__) //clang - #if __has_feature(cxx_static_assert) && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) - #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 - #endif - #if __has_feature(cxx_constexpr) && !defined(HALF_ENABLE_CPP11_CONSTEXPR) - #define HALF_ENABLE_CPP11_CONSTEXPR 1 - #endif - #if __has_feature(cxx_noexcept) && !defined(HALF_ENABLE_CPP11_NOEXCEPT) - #define HALF_ENABLE_CPP11_NOEXCEPT 1 - #endif - #if __has_feature(cxx_user_literals) && !defined(HALF_ENABLE_CPP11_USER_LITERALS) - #define HALF_ENABLE_CPP11_USER_LITERALS 1 - #endif - #if (defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L) && !defined(HALF_ENABLE_CPP11_LONG_LONG) - #define HALF_ENABLE_CPP11_LONG_LONG 1 - #endif -/*#elif defined(__INTEL_COMPILER) //Intel C++ - #if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) ???????? - #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 - #endif - #if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) ???????? - #define HALF_ENABLE_CPP11_CONSTEXPR 1 - #endif - #if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) ???????? - #define HALF_ENABLE_CPP11_NOEXCEPT 1 - #endif - #if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_LONG_LONG) ???????? - #define HALF_ENABLE_CPP11_LONG_LONG 1 - #endif*/ -#elif defined(__GNUC__) //gcc - #if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L - #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) - #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 - #endif - #if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) - #define HALF_ENABLE_CPP11_CONSTEXPR 1 - #endif - #if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) - #define HALF_ENABLE_CPP11_NOEXCEPT 1 - #endif - #if HALF_GNUC_VERSION >= 407 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) - #define HALF_ENABLE_CPP11_USER_LITERALS 1 - #endif - #if !defined(HALF_ENABLE_CPP11_LONG_LONG) - #define HALF_ENABLE_CPP11_LONG_LONG 1 - #endif - #endif -#elif defined(_MSC_VER) //Visual C++ - #if _MSC_VER >= 1600 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) - #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 - #endif - #if _MSC_VER >= 1310 && !defined(HALF_ENABLE_CPP11_LONG_LONG) - #define HALF_ENABLE_CPP11_LONG_LONG 1 - #endif - #define HALF_POP_WARNINGS 1 - #pragma warning(push) - //! 4521 and 4522 is multiple copy/assigment operator specified - #pragma warning(disable : 4099 4127 4146 4521 4522) //struct vs class, constant in if, negative unsigned -#endif - -//check C++11 library features -#include -#if defined(_LIBCPP_VERSION) //libc++ - #if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 - #ifndef HALF_ENABLE_CPP11_TYPE_TRAITS - #define HALF_ENABLE_CPP11_TYPE_TRAITS 1 - #endif - #ifndef HALF_ENABLE_CPP11_CSTDINT - #define HALF_ENABLE_CPP11_CSTDINT 1 - #endif - #ifndef HALF_ENABLE_CPP11_CMATH - #define HALF_ENABLE_CPP11_CMATH 1 - #endif - #ifndef HALF_ENABLE_CPP11_HASH - #define HALF_ENABLE_CPP11_HASH 1 - #endif - #endif -#elif defined(__GLIBCXX__) //libstdc++ - #if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 - #ifdef __clang__ - #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) - #define HALF_ENABLE_CPP11_TYPE_TRAITS 1 - #endif - #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CSTDINT) - #define HALF_ENABLE_CPP11_CSTDINT 1 - #endif - #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CMATH) - #define HALF_ENABLE_CPP11_CMATH 1 - #endif - #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_HASH) - #define HALF_ENABLE_CPP11_HASH 1 - #endif - #else - #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CSTDINT) - #define HALF_ENABLE_CPP11_CSTDINT 1 - #endif - #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CMATH) - #define HALF_ENABLE_CPP11_CMATH 1 - #endif - #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_HASH) - #define HALF_ENABLE_CPP11_HASH 1 - #endif - #endif - #endif -#elif defined(_CPPLIB_VER) //Dinkumware/Visual C++ - #if _CPPLIB_VER >= 520 - #ifndef HALF_ENABLE_CPP11_TYPE_TRAITS - #define HALF_ENABLE_CPP11_TYPE_TRAITS 1 - #endif - #ifndef HALF_ENABLE_CPP11_CSTDINT - #define HALF_ENABLE_CPP11_CSTDINT 1 - #endif - #ifndef HALF_ENABLE_CPP11_HASH - #define HALF_ENABLE_CPP11_HASH 1 - #endif - #endif - #if _CPPLIB_VER >= 610 - #ifndef HALF_ENABLE_CPP11_CMATH - #define HALF_ENABLE_CPP11_CMATH 1 - #endif - #endif -#endif -#undef HALF_GNUC_VERSION - -//support constexpr -#if HALF_ENABLE_CPP11_CONSTEXPR - #define HALF_CONSTEXPR constexpr - #define HALF_CONSTEXPR_CONST constexpr -#else - #define HALF_CONSTEXPR - #define HALF_CONSTEXPR_CONST const -#endif - -//support noexcept -#if HALF_ENABLE_CPP11_NOEXCEPT - #define HALF_NOEXCEPT noexcept - #define HALF_NOTHROW noexcept -#else - #define HALF_NOEXCEPT - #define HALF_NOTHROW throw() -#endif - -#include -#include -#include -#include -#include -#if HALF_ENABLE_CPP11_TYPE_TRAITS - #include -#endif -#if HALF_ENABLE_CPP11_CSTDINT - #include -#endif -#if HALF_ENABLE_CPP11_HASH - #include -#endif - +#include "megdnn/dtype/half_common_prologue.h" /// Default rounding mode. /// This specifies the rounding mode used for all conversions between [half](\ref half_float::half)s and `float`s as well as @@ -3141,16 +2981,7 @@ namespace std #endif } - -#undef HALF_CONSTEXPR -#undef HALF_CONSTEXPR_CONST -#undef HALF_NOEXCEPT -#undef HALF_NOTHROW -#ifdef HALF_POP_WARNINGS - #pragma warning(pop) - #undef HALF_POP_WARNINGS -#endif - +#include "megdnn/dtype/half_common_epilogue.h" #endif // vim: syntax=cpp.doxygen diff --git a/dnn/include/megdnn/dtype/half_common_epilogue.h b/dnn/include/megdnn/dtype/half_common_epilogue.h new file mode 100644 index 0000000000000000000000000000000000000000..34bae2d325d2e42e8436733330f9ebe0012c99d6 --- /dev/null +++ b/dnn/include/megdnn/dtype/half_common_epilogue.h @@ -0,0 +1,48 @@ +/** + * half - IEEE 754-based half-precision floating point library. + * + * Copyright (c) 2012-2013 Christian Rau + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation + * files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, + * modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE + * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, + * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + * + * Version 1.11.0 + * \file + * Main header file for half precision functionality. + * + * -------------------------------------------------------------------------- + * \file include/megdnn/dtype/half_common_epilogue.h + * + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * + * This file has been modified by Megvii ("Megvii Modifications"). + * All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved. + * + * -------------------------------------------------------------------------- + */ + +#undef HALF_CONSTEXPR +#undef HALF_CONSTEXPR_CONST +#undef HALF_NOEXCEPT +#undef HALF_NOTHROW +#ifdef HALF_POP_WARNINGS + #pragma warning(pop) + #undef HALF_POP_WARNINGS +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/include/megdnn/dtype/half_common_prologue.h b/dnn/include/megdnn/dtype/half_common_prologue.h new file mode 100644 index 0000000000000000000000000000000000000000..6de710ebf8e5ac1c79721312057a29e313615f75 --- /dev/null +++ b/dnn/include/megdnn/dtype/half_common_prologue.h @@ -0,0 +1,202 @@ +/** + * half - IEEE 754-based half-precision floating point library. + * + * Copyright (c) 2012-2013 Christian Rau + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation + * files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, + * modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE + * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, + * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + * + * Version 1.11.0 + * \file + * Main header file for half precision functionality. + * + * -------------------------------------------------------------------------- + * \file dnn/include/megdnn/dtype/half_common_prologue.h + * + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * + * This file has been modified by Megvii ("Megvii Modifications"). + * All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved. + * + * -------------------------------------------------------------------------- + */ + +#include "megdnn/arch.h" + +/// Combined gcc version number. +#define HALF_GNUC_VERSION (__GNUC__*100+__GNUC_MINOR__) + +//check C++11 language features +#if defined(__clang__) //clang + #if __has_feature(cxx_static_assert) && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) + #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 + #endif + #if __has_feature(cxx_constexpr) && !defined(HALF_ENABLE_CPP11_CONSTEXPR) + #define HALF_ENABLE_CPP11_CONSTEXPR 1 + #endif + #if __has_feature(cxx_noexcept) && !defined(HALF_ENABLE_CPP11_NOEXCEPT) + #define HALF_ENABLE_CPP11_NOEXCEPT 1 + #endif + #if __has_feature(cxx_user_literals) && !defined(HALF_ENABLE_CPP11_USER_LITERALS) + #define HALF_ENABLE_CPP11_USER_LITERALS 1 + #endif + #if (defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L) && !defined(HALF_ENABLE_CPP11_LONG_LONG) + #define HALF_ENABLE_CPP11_LONG_LONG 1 + #endif +/*#elif defined(__INTEL_COMPILER) //Intel C++ + #if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) ???????? + #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 + #endif + #if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) ???????? + #define HALF_ENABLE_CPP11_CONSTEXPR 1 + #endif + #if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) ???????? + #define HALF_ENABLE_CPP11_NOEXCEPT 1 + #endif + #if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_LONG_LONG) ???????? + #define HALF_ENABLE_CPP11_LONG_LONG 1 + #endif*/ +#elif defined(__GNUC__) //gcc + #if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L + #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) + #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 + #endif + #if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) + #define HALF_ENABLE_CPP11_CONSTEXPR 1 + #endif + #if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) + #define HALF_ENABLE_CPP11_NOEXCEPT 1 + #endif + #if HALF_GNUC_VERSION >= 407 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) + #define HALF_ENABLE_CPP11_USER_LITERALS 1 + #endif + #if !defined(HALF_ENABLE_CPP11_LONG_LONG) + #define HALF_ENABLE_CPP11_LONG_LONG 1 + #endif + #endif +#elif defined(_MSC_VER) //Visual C++ + #if _MSC_VER >= 1600 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) + #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 + #endif + #if _MSC_VER >= 1310 && !defined(HALF_ENABLE_CPP11_LONG_LONG) + #define HALF_ENABLE_CPP11_LONG_LONG 1 + #endif + #define HALF_POP_WARNINGS 1 + #pragma warning(push) + //! 4521 and 4522 is multiple copy/assigment operator specified + #pragma warning(disable : 4099 4127 4146 4521 4522) //struct vs class, constant in if, negative unsigned +#endif + +//check C++11 library features +#include +#if defined(_LIBCPP_VERSION) //libc++ + #if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 + #ifndef HALF_ENABLE_CPP11_TYPE_TRAITS + #define HALF_ENABLE_CPP11_TYPE_TRAITS 1 + #endif + #ifndef HALF_ENABLE_CPP11_CSTDINT + #define HALF_ENABLE_CPP11_CSTDINT 1 + #endif + #ifndef HALF_ENABLE_CPP11_CMATH + #define HALF_ENABLE_CPP11_CMATH 1 + #endif + #ifndef HALF_ENABLE_CPP11_HASH + #define HALF_ENABLE_CPP11_HASH 1 + #endif + #endif +#elif defined(__GLIBCXX__) //libstdc++ + #if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 + #ifdef __clang__ + #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) + #define HALF_ENABLE_CPP11_TYPE_TRAITS 1 + #endif + #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CSTDINT) + #define HALF_ENABLE_CPP11_CSTDINT 1 + #endif + #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CMATH) + #define HALF_ENABLE_CPP11_CMATH 1 + #endif + #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_HASH) + #define HALF_ENABLE_CPP11_HASH 1 + #endif + #else + #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CSTDINT) + #define HALF_ENABLE_CPP11_CSTDINT 1 + #endif + #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CMATH) + #define HALF_ENABLE_CPP11_CMATH 1 + #endif + #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_HASH) + #define HALF_ENABLE_CPP11_HASH 1 + #endif + #endif + #endif +#elif defined(_CPPLIB_VER) //Dinkumware/Visual C++ + #if _CPPLIB_VER >= 520 + #ifndef HALF_ENABLE_CPP11_TYPE_TRAITS + #define HALF_ENABLE_CPP11_TYPE_TRAITS 1 + #endif + #ifndef HALF_ENABLE_CPP11_CSTDINT + #define HALF_ENABLE_CPP11_CSTDINT 1 + #endif + #ifndef HALF_ENABLE_CPP11_HASH + #define HALF_ENABLE_CPP11_HASH 1 + #endif + #endif + #if _CPPLIB_VER >= 610 + #ifndef HALF_ENABLE_CPP11_CMATH + #define HALF_ENABLE_CPP11_CMATH 1 + #endif + #endif +#endif +#undef HALF_GNUC_VERSION + +//support constexpr +#if HALF_ENABLE_CPP11_CONSTEXPR + #define HALF_CONSTEXPR constexpr + #define HALF_CONSTEXPR_CONST constexpr +#else + #define HALF_CONSTEXPR + #define HALF_CONSTEXPR_CONST const +#endif + +//support noexcept +#if HALF_ENABLE_CPP11_NOEXCEPT + #define HALF_NOEXCEPT noexcept + #define HALF_NOTHROW noexcept +#else + #define HALF_NOEXCEPT + #define HALF_NOTHROW throw() +#endif + +#include +#include +#include +#include +#include +#if HALF_ENABLE_CPP11_TYPE_TRAITS + #include +#endif +#if HALF_ENABLE_CPP11_CSTDINT + #include +#endif +#if HALF_ENABLE_CPP11_HASH + #include +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/scripts/gen_cond_take_kern_impls.py b/dnn/scripts/gen_cond_take_kern_impls.py index e06add1b84e301e86ab5f25405418b6e6884a6e4..8e6fe3e43c87bc6a30b438e48909cecd73bc5bde 100755 --- a/dnn/scripts/gen_cond_take_kern_impls.py +++ b/dnn/scripts/gen_cond_take_kern_impls.py @@ -30,7 +30,7 @@ def main(): w('// generated by gen_cond_take_kern_impls.py') w('#include "../kern.inl"') w('') - if dtype == 'dt_float16': + if dtype == 'dt_float16' or dtype == 'dt_bfloat16': w('#if !MEGDNN_DISABLE_FLOAT16') w('namespace megdnn {') w('namespace cuda {') @@ -48,7 +48,7 @@ def main(): w('} // cond_take') w('} // cuda') w('} // megdnn') - if dtype == 'dt_float16': + if dtype == 'dt_float16' or dtype == 'dt_bfloat16': w('#endif') print('generated {}'.format(fname)) diff --git a/dnn/scripts/gen_elemwise_kern_impls.py b/dnn/scripts/gen_elemwise_kern_impls.py index 309725677f3236a70db65292fb52484eacc1208f..05f4e579f821c9b70b16d166c34c9dda94279b78 100755 --- a/dnn/scripts/gen_elemwise_kern_impls.py +++ b/dnn/scripts/gen_elemwise_kern_impls.py @@ -34,7 +34,7 @@ def main(): w = lambda s: print(s, file=fout) w('// generated by gen_elemwise_kern_impls.py') - if ctype == 'dt_float16': + if ctype == 'dt_float16' or ctype == 'dt_bfloat16': w('#if !MEGDNN_DISABLE_FLOAT16') w('#define KERN_IMPL_MODE(cb) {}'.format(formode)) @@ -42,7 +42,7 @@ def main(): w('#define KERN_IMPL_CTYPE {}'.format(ctype)) w('#include "../kern_impl.inl"') - if ctype == 'dt_float16': + if ctype == 'dt_float16' or ctype == 'dt_bfloat16': w('#endif') print('generated {}'.format(fname)) diff --git a/dnn/scripts/gen_elemwise_special_kern_impls.py b/dnn/scripts/gen_elemwise_special_kern_impls.py index a9c868ae96b6099ac6e12a30732a4bdbc2b4f1a8..2e75e7203e1984ae9a312355955a40bb319041d3 100755 --- a/dnn/scripts/gen_elemwise_special_kern_impls.py +++ b/dnn/scripts/gen_elemwise_special_kern_impls.py @@ -30,14 +30,14 @@ def main(): w = lambda s: print(s, file=fout) w('// generated by gen_elemwise_special_kern_impls.py') - if dtype == 'dt_float16': + if dtype == 'dt_float16' or dtype == 'dt_bfloat16': w('#if !MEGDNN_DISABLE_FLOAT16') w('#include "../special_kerns.inl"') w('INST(::megdnn::dtype::{})'.format(DTYPES[dtype][0])) w('#undef INST') w('}') w('}') - if dtype == 'dt_float16': + if dtype == 'dt_float16' or dtype == 'dt_bfloat16': w('#endif') print('generated {}'.format(fname)) diff --git a/dnn/scripts/gen_elemwise_utils.py b/dnn/scripts/gen_elemwise_utils.py index 3a3b04cb1b3db69fc183df4aeafca92d8117a5f5..f6968f0e0acdecdfa7b445e4d5c6c89620e2cc20 100755 --- a/dnn/scripts/gen_elemwise_utils.py +++ b/dnn/scripts/gen_elemwise_utils.py @@ -6,7 +6,8 @@ DTYPES = {'dt_int32': ('Int32', 'INT'), 'dt_int8': ('Int8', 'INT'), 'dt_int16': ('Int16', 'INT'), 'dt_float32': ('Float32', 'FLOAT'), - 'dt_float16': ('Float16', 'FLOAT') + 'dt_float16': ('Float16', 'FLOAT'), + 'dt_bfloat16': ('BFloat16', 'FLOAT') } MODES = { diff --git a/dnn/src/common/convolution.cpp b/dnn/src/common/convolution.cpp index 0a89010fb25b6393b836ea98bfa4914386b07c19..50166749e61ac4e644bbf778e3bdd1dcac79a2a7 100644 --- a/dnn/src/common/convolution.cpp +++ b/dnn/src/common/convolution.cpp @@ -618,9 +618,10 @@ void ConvolutionBase::check_or_deduce_dtype_fwd(DType src, megdnn_assert(param().compute_mode != Param::ComputeMode::FLOAT32 #if !MEGDNN_DISABLE_FLOAT16 || src.enumv() == DTypeEnum::Float16 + || src.enumv() == DTypeEnum::BFloat16 #endif - , - "ComputeMode::FLOAT32 is only available for Float16 " + , + "ComputeMode::FLOAT32 is only available for Float16/BFloat16 " "input / output."); } @@ -1036,9 +1037,10 @@ void ConvolutionBackwardData::deduce_dtype(DType filter, DType diff, megdnn_assert(param().compute_mode != Param::ComputeMode::FLOAT32 #if !MEGDNN_DISABLE_FLOAT16 || filter.enumv() == DTypeEnum::Float16 + || filter.enumv() == DTypeEnum::BFloat16 #endif - , - "ComputeMode::FLOAT32 is only available for Float16 " + , + "ComputeMode::FLOAT32 is only available for Float16/BFloat16 " "input / output."); } diff --git a/dnn/src/common/elemwise/kern_defs.cuh b/dnn/src/common/elemwise/kern_defs.cuh index 49bd21deeff4f1dd60f6b2a5c618b35801ba264c..02f2eabff66347fd37eb2d1b20da2774218f2dd8 100644 --- a/dnn/src/common/elemwise/kern_defs.cuh +++ b/dnn/src/common/elemwise/kern_defs.cuh @@ -87,7 +87,8 @@ namespace megdnn { //! define kernel for all float types #define DEF_KERN_FLOAT(_mode, _imp) \ DEF_KERN(dt_float32, _mode, _imp); \ - MEGDNN_INC_FLOAT16(DEF_KERN(dt_float16, _mode, _imp);) + MEGDNN_INC_FLOAT16(DEF_KERN(dt_float16, _mode, _imp);) \ + MEGDNN_INC_FLOAT16(DEF_KERN(dt_bfloat16, _mode, _imp);) //! define kernel for all int types #define DEF_KERN_INT(_mode, _imp) \ diff --git a/dnn/src/common/matrix_mul.cpp b/dnn/src/common/matrix_mul.cpp index f96c2b16bfa5dc2a76bfefa49e03d939d00816b5..b12871c4007d36f3240a66ec4ca84723e94941b2 100644 --- a/dnn/src/common/matrix_mul.cpp +++ b/dnn/src/common/matrix_mul.cpp @@ -69,11 +69,11 @@ void MatrixMulForward::deduce_layout(const TensorLayout& A, C = TensorLayout(TensorShape({A0, B1}), C.dtype); } else { auto do_deduce = [&](size_t pack_size) { - megdnn_assert( - A.ndim == 4 && B.ndim == 3, - "matmul requires input dimension to be A(4), B(3); get: %s %s", - A.TensorShape::to_string().c_str(), - B.TensorShape::to_string().c_str()); + megdnn_assert(A.ndim == 4 && B.ndim == 3, + "matmul requires input dimension to be A(4), B(3); " + "get: %s %s", + A.TensorShape::to_string().c_str(), + B.TensorShape::to_string().c_str()); A0 = A.shape[0]; A1 = A.shape[1]; B0 = B.shape[0]; @@ -82,11 +82,11 @@ void MatrixMulForward::deduce_layout(const TensorLayout& A, std::swap(A0, A1); if (m_param.transposeB) std::swap(B0, B1); - megdnn_assert( - A1 == B0, - "shape mismatch in matmal: (transposed) A is (%zu,%zu,4,4), " - "(transposed) B is (%zu,%zu,4)", - A0, A1, B0, B1); + megdnn_assert(A1 == B0, + "shape mismatch in matmal: (transposed) A is " + "(%zu,%zu,4,4), " + "(transposed) B is (%zu,%zu,4)", + A0, A1, B0, B1); C = TensorLayout(TensorShape({A0, B1, pack_size}), C.dtype); }; do_deduce(pack_size(param().format)); @@ -172,8 +172,9 @@ void MatrixMulForward::check_exec(const TensorLayout& A, const TensorLayout& B, } megdnn_assert(param().compute_mode != Param::ComputeMode::FLOAT32 MEGDNN_INC_FLOAT16( - || A.dtype == dtype::Float16()), - "ComputeMode::FLOAT32 is only available for Float16 " + || A.dtype == dtype::Float16() || + A.dtype == dtype::BFloat16()), + "ComputeMode::FLOAT32 is only available for Float16/BFloat16 " "input / output."); auto required_workspace_in_bytes = get_workspace_in_bytes(A, B, C); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); diff --git a/dnn/src/common/rounding_converter.cuh b/dnn/src/common/rounding_converter.cuh index 5a1c6327d3bf60d3f11c08f439775c221152f763..ca17d14c00dcbbbbc506299dbd7445dcb4bb4675 100644 --- a/dnn/src/common/rounding_converter.cuh +++ b/dnn/src/common/rounding_converter.cuh @@ -46,6 +46,14 @@ struct RoundingConverter { } }; +template <> +struct RoundingConverter { + __host__ __device__ __forceinline__ half_bfloat16::bfloat16 operator()( + float x) const { + return static_cast(x); + } +}; + #endif // #ifdef MEGDNN_DISABLE_FLOAT16 template <> diff --git a/dnn/src/common/utils.h b/dnn/src/common/utils.h index 1061ab479cc88ec8eb2e2ae43c3033272f56e7a2..449c9b04eff77e2d2f286a42e395a5f6df9e633c 100644 --- a/dnn/src/common/utils.h +++ b/dnn/src/common/utils.h @@ -16,6 +16,7 @@ #include "megdnn/dtype.h" #include "megdnn/handle.h" #include "megdnn/thin/small_vector.h" +#include "megdnn/oprs/general.h" #include "src/common/hash_ct.h" #include "src/common/utils.cuh" @@ -548,6 +549,59 @@ public: std::string to_string() const; }; +/**! + * \brief helpers for oprs using typecvt between comp_type and dst_type + * \tparam SrcType src type + * \tparam CompType compute type, such as fp32 for conv + * \tparam DstType dst type + */ +template +struct CompTypeCvter { + std::unique_ptr m_cvt_opr; + WorkspaceBundle* m_workspace_bundle; + size_t m_workspace_idx; + CompTypeCvter(Handle* handle, WorkspaceBundle* bundle) + : m_workspace_bundle(bundle), m_workspace_idx(0) { + megdnn_assert( + (DTypeTrait::enumv != DTypeTrait::enumv && + DTypeTrait::enumv != DTypeTrait::enumv), + "SrcType(%s) == CompType(%s) or DstType(%s) == CompType(%s) is " + "not " + "supportted.", + SrcType().name(), CompType().name(), DstType().name(), + CompType().name()); + m_cvt_opr = handle->create_operator(); + } + + //! Convert tensor dtype from SrcType to CompType. + CompTypeCvter& src_to_comp_type(const TensorND& src, TensorND& comp) { + if (src.layout.dtype.enumv() == DTypeTrait::enumv) { + if (!comp.layout.dtype.valid() || + comp.layout.dtype.enumv() != DTypeTrait::enumv) { + comp.layout.dtype = CompType(); + comp.layout.init_contiguous_stride(); + comp.raw_ptr = m_workspace_bundle->get(m_workspace_idx++); + if (src.layout.ndim) { + m_cvt_opr->exec(src, comp); + } + } + } + return *this; + } + + //! Convert tensor dtype from CompType to DstType. + CompTypeCvter& comp_to_dst_type(const TensorND& comp, const TensorND& dst) { + megdnn_assert(comp.layout.dtype.enumv() == DTypeTrait::enumv); + if (dst.layout.dtype.enumv() == DTypeTrait::enumv) { + m_cvt_opr->exec(comp, dst); + } + return *this; + } + + Workspace workspace() { + return m_workspace_bundle->get_workspace(m_workspace_idx); + } +}; } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/common/warp_perspective.cpp b/dnn/src/common/warp_perspective.cpp index 83a4624f76a9e98b3b03322542f0dba9ec8e8aa5..c8d568ad61bd54f1bea5d730b63a06ccc780f025 100644 --- a/dnn/src/common/warp_perspective.cpp +++ b/dnn/src/common/warp_perspective.cpp @@ -55,17 +55,19 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout &src, megdnn_assert(mat.shape[2] == 3_z, "%s", errmsg().c_str()); if (param().format == param::WarpPerspective::Format::NCHW) { - megdnn_assert(src.dtype.enumv() == DTypeEnum::Float32 || - MEGDNN_FLOAT16_SELECT( - src.dtype.enumv() == DTypeEnum::Float16, - false) || - src.dtype.enumv() == DTypeEnum::Int8 || - src.dtype.enumv() == DTypeEnum::Uint8 || - (src.dtype.enumv() == DTypeEnum::QuantizedS8 || - src.dtype.enumv() == DTypeEnum::Quantized8Asymm), - "WarpPerspective NCHW input dtype should be " - "Float32/Int8/Uint8/QInt8/QUint8" MEGDNN_FLOAT16_SELECT( - "/Float16", "") "."); + megdnn_assert( + src.dtype.enumv() == DTypeEnum::Float32 || + MEGDNN_FLOAT16_SELECT( + (src.dtype.enumv() == DTypeEnum::Float16 || + src.dtype.enumv() == DTypeEnum::BFloat16), + false) || + src.dtype.enumv() == DTypeEnum::Int8 || + src.dtype.enumv() == DTypeEnum::Uint8 || + (src.dtype.enumv() == DTypeEnum::QuantizedS8 || + src.dtype.enumv() == DTypeEnum::Quantized8Asymm), + "WarpPerspective NCHW input dtype should be " + "Float32/Int8/Uint8/QInt8/QUint8" MEGDNN_FLOAT16_SELECT( + "/Float16/BFloat16", "") "."); megdnn_assert( (src.dtype.category() == DTypeCategory::FLOAT && (src.dtype == mat.dtype || @@ -107,14 +109,17 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout &src, param::WarpPerspective::BorderMode::ISOLATED); } else { megdnn_assert(param().format == param::WarpPerspective::Format::NHWCD4); - megdnn_assert(src.dtype == dtype::Float32() || - MEGDNN_FLOAT16_SELECT( - src.dtype == dtype::Float16(), false) || - src.dtype.enumv() == DTypeEnum::QuantizedS8 || - src.dtype.enumv() == DTypeEnum::Quantized8Asymm, - "WarpPerspective NHWCD4 input dtype should be " - "Float32" MEGDNN_FLOAT16_SELECT( - "/Float16", "") ",QunatizedS8, Quantized8Asymm."); + megdnn_assert( + src.dtype == dtype::Float32() || + MEGDNN_FLOAT16_SELECT((src.dtype == dtype::Float16() || + src.dtype == dtype::BFloat16()), + false) || + src.dtype.enumv() == DTypeEnum::QuantizedS8 || + src.dtype.enumv() == DTypeEnum::Quantized8Asymm, + "WarpPerspective NHWCD4 input dtype should be " + "Float32" MEGDNN_FLOAT16_SELECT( + "/Float16/BFloat16", + "") ",QunatizedS8, Quantized8Asymm."); megdnn_assert( (src.dtype == mat.dtype || mat.dtype == dtype::Float32()), "The input to WarpPerspective is in NHWCD4 format, in this " @@ -253,30 +258,30 @@ void WarpPerspectiveForward::check_exec_allow_nhwc_mat_idx( } } -void WarpPerspectiveBackwardData::check_exec(const TensorLayout &mat, - const TensorLayout &diff, - const TensorLayout &grad, - size_t workspace_in_bytes) -{ +void WarpPerspectiveBackwardData::check_exec(const TensorLayout& mat, + const TensorLayout& diff, + const TensorLayout& grad, + size_t workspace_in_bytes) { check_layout_fwd(grad, mat, diff); - megdnn_assert(grad.dtype == dtype::Float32(), - "Backward WarpPerspective only supports Float32."); + megdnn_assert(grad.dtype == dtype::Float32() MEGDNN_INC_FLOAT16( + || grad.dtype == dtype::BFloat16()), + "Backward WarpPerspective only supports Float32/BFloat16."); auto required_workspace_in_bytes = get_workspace_in_bytes(mat, diff, grad); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -void WarpPerspectiveBackwardMat::check_exec(const TensorLayout &src, - const TensorLayout &mat, - const TensorLayout &diff, - const TensorLayout &grad, - size_t workspace_in_bytes) -{ +void WarpPerspectiveBackwardMat::check_exec(const TensorLayout& src, + const TensorLayout& mat, + const TensorLayout& diff, + const TensorLayout& grad, + size_t workspace_in_bytes) { check_layout_fwd(src, mat, diff); megdnn_assert_eq_layout(mat, grad); - megdnn_assert(grad.dtype == dtype::Float32(), - "Backward WarpPerspective only supports Float32."); - auto required_workspace_in_bytes = get_workspace_in_bytes(src, - mat, diff, grad); + megdnn_assert(grad.dtype == dtype::Float32() MEGDNN_INC_FLOAT16( + || grad.dtype == dtype::BFloat16()), + "Backward WarpPerspective only supports Float32/BFloat16."); + auto required_workspace_in_bytes = + get_workspace_in_bytes(src, mat, diff, grad); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } diff --git a/dnn/src/cuda/cond_take/kimpl/dt_bfloat16.cu b/dnn/src/cuda/cond_take/kimpl/dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..06161eaafbb17bb49e6250efeb69b68e11462f00 --- /dev/null +++ b/dnn/src/cuda/cond_take/kimpl/dt_bfloat16.cu @@ -0,0 +1,29 @@ +/** + * \file dnn/src/cuda/cond_take/kimpl/dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_cond_take_kern_impls.py +#include "../kern.inl" + +#if !MEGDNN_DISABLE_FLOAT16 +namespace megdnn { +namespace cuda { +namespace cond_take { + +inst_genidx(::megdnn::dtype::BFloat16) +#undef inst_genidx + +inst_copy(::megdnn::dtype::BFloat16) +#undef inst_copy +#undef inst_copy_ + +} // cond_take +} // cuda +} // megdnn +#endif diff --git a/dnn/src/cuda/conv_bias/algo.cpp b/dnn/src/cuda/conv_bias/algo.cpp index fa015d3359a9918b33d43be00e85f61eb5b22e17..235f3cf93b305a3b97f7e91846d46db4b0f94457 100644 --- a/dnn/src/cuda/conv_bias/algo.cpp +++ b/dnn/src/cuda/conv_bias/algo.cpp @@ -62,6 +62,13 @@ ConvBiasForwardImpl::AlgoPack::AlgoPack() { non_cudnn_algos.push_back(all_algos.rbegin()[1]); // group batched_matmul non_cudnn_algos.push_back(all_algos.rbegin()[0]); // group 1x1 + algo_size = all_algos.size(); + for (size_t i = 0; i < algo_size; ++i) { + bfloat16_refhold.emplace_back(new AlgoBFloat16(all_algos[i])); + all_algos.push_back(bfloat16_refhold.back().get()); + bfloat16_algos.push_back(bfloat16_refhold.back().get()); + } + size_t all_algo_size = all_algos.size(); #if CUDA_VERSION >= 10000 fill_imma_algos(); diff --git a/dnn/src/cuda/conv_bias/algo.h b/dnn/src/cuda/conv_bias/algo.h index 7b7a042e09fc557f8535794528dc78247ab31045..bc45781bf697249421a89cb88bafa370365a99e8 100644 --- a/dnn/src/cuda/conv_bias/algo.h +++ b/dnn/src/cuda/conv_bias/algo.h @@ -499,6 +499,28 @@ private: }; #endif +class ConvBiasForwardImpl::AlgoBFloat16 final : public AlgoBase { +public: + AlgoBFloat16(AlgoBase* impl); + + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + void exec(const ExecArgs& args) const override; + + const char* name() const override { return m_name.c_str(); } + + bool is_reproducible() const override { return m_impl->is_reproducible(); } + +private: + SizeArgs float_args(const SizeArgs& args, ConvBiasForwardImpl* opr, + TensorLayout& fsrc, TensorLayout& ffilter, + TensorLayout& fbias, TensorLayout& fz, + TensorLayout& fdst) const; + WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; + AlgoBase* m_impl; + std::string m_name; +}; + class ConvBiasForwardImpl::AlgoPack { AlgoPack(const AlgoPack&) = delete; AlgoPack& operator=(const AlgoPack&) = delete; @@ -508,7 +530,8 @@ public: std::vector all_algos, //! non-cudnn algos, used for heuristic if cudnn is not supported - non_cudnn_algos; + non_cudnn_algos, + bfloat16_algos; std::vector cudnn_conv_bias_activations; std::vector cudnn_convs; AlgoChanwise chanwise; @@ -531,6 +554,7 @@ public: int8_chwn4_imma_unroll_width; #endif std::vector> gconv_refhold; + std::vector> bfloat16_refhold; std::unordered_map algo2gconv; AlgoBase* cudnn_conv_bias_act_from_enum(cudnnConvolutionFwdAlgo_t algo); diff --git a/dnn/src/cuda/conv_bias/bfloat16.cpp b/dnn/src/cuda/conv_bias/bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..aa70ea87a80af35b0f3608ad562b4d900f768d8c --- /dev/null +++ b/dnn/src/cuda/conv_bias/bfloat16.cpp @@ -0,0 +1,120 @@ +/** + * \file dnn/src/cuda/conv_bias/bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/cuda/conv_bias/algo.h" +#include "src/cuda/handle.h" +#include "src/cuda/utils.cuh" +#include "src/cuda/utils.h" + +using namespace megdnn; +using namespace cuda; +using namespace conv_bias; + +ConvBiasForwardImpl::AlgoBFloat16::AlgoBFloat16( + ConvBiasForwardImpl::AlgoBase* algorithm) + : m_impl(algorithm) { + megdnn_assert_internal(algorithm); + m_name = ssprintf("BFLOAT16:%s", m_impl->name()); +} + +ConvBiasForwardImpl::AlgoBase::SizeArgs +ConvBiasForwardImpl::AlgoBFloat16::float_args( + const SizeArgs& args, ConvBiasForwardImpl* opr, TensorLayout& fsrc, + TensorLayout& ffilter, TensorLayout& fbias, TensorLayout& fz, + TensorLayout& fdst) const { + fsrc = *args.src_layout; + ffilter = *args.filter_layout; + fbias = *args.bias_layout; + fz = *args.z_layout; + fdst = *args.dst_layout; + auto change_dtype = [](TensorLayout& layout) { + if (layout.dtype == dtype::BFloat16()) { + layout.dtype = dtype::Float32(); + } + }; + change_dtype(fsrc); + change_dtype(ffilter); + change_dtype(fbias); + change_dtype(fz); + change_dtype(fdst); + opr->param() = args.opr->param(); + opr->param().compute_mode = Param::ComputeMode::DEFAULT; + opr->execution_policy() = {m_impl}; + return SizeArgs(opr, fsrc, ffilter, fbias, fz, fdst); +} + +bool ConvBiasForwardImpl::AlgoBFloat16::is_available( + const SizeArgs& args) const { + TensorLayout fsrc, ffilter, fbias, fz, fdst; + auto convbias_opr = args.handle->create_operator(); + SizeArgs fargs = float_args( + args, static_cast(convbias_opr.get()), fsrc, + ffilter, fbias, fz, fdst); + return args.src_layout->dtype == args.filter_layout->dtype && + args.src_layout->dtype == dtype::BFloat16() && + m_impl->is_available(fargs); +} + +WorkspaceBundle ConvBiasForwardImpl::AlgoBFloat16::get_workspace_bundle( + void* ptr, const SizeArgs& args) const { + TensorLayout fsrc, ffilter, fbias, fz, fdst; + auto convbias_opr = args.handle->create_operator(); + SizeArgs fargs = float_args( + args, static_cast(convbias_opr.get()), fsrc, + ffilter, fbias, fz, fdst); + SmallVector sizes; + auto get_workspace = [&sizes](const TensorLayout& src, + const TensorLayout& dst) { + if (src.dtype != dst.dtype) { + sizes.push_back(dst.span().dist_byte()); + } + }; + get_workspace(*args.src_layout, fsrc); + get_workspace(*args.filter_layout, ffilter); + get_workspace(*args.bias_layout, fbias); + get_workspace(*args.z_layout, fz); + get_workspace(*args.dst_layout, fdst); + sizes.push_back(m_impl->get_workspace_in_bytes(fargs)); + return {ptr, std::move(sizes)}; +} + +size_t ConvBiasForwardImpl::AlgoBFloat16::get_workspace_in_bytes( + const SizeArgs& args) const { + return get_workspace_bundle(nullptr, args).total_size_in_bytes(); +} + +void ConvBiasForwardImpl::AlgoBFloat16::exec(const ExecArgs& args) const { + TensorND fsrc_tensor = *args.src_tensor; + TensorND ffilter_tensor = *args.filter_tensor; + TensorND fbias_tensor = *args.bias_tensor; + TensorND fz_tensor = *args.z_tensor; + TensorND fdst_tensor = *args.dst_tensor; + auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); + CompTypeCvter cvter(args.handle, &bundle); + { + cvter.src_to_comp_type(*args.src_tensor, fsrc_tensor) + .src_to_comp_type(*args.filter_tensor, ffilter_tensor) + .src_to_comp_type(*args.bias_tensor, fbias_tensor) + .src_to_comp_type(*args.z_tensor, fz_tensor) + .src_to_comp_type(*args.dst_tensor, fdst_tensor); + } + { + auto convbias_opr = args.handle->create_operator(); + convbias_opr->param() = args.opr->param(); + convbias_opr->param().compute_mode = Param::ComputeMode::DEFAULT; + convbias_opr->execution_policy() = {m_impl}; + convbias_opr->exec(fsrc_tensor, ffilter_tensor, fbias_tensor, fz_tensor, + fdst_tensor, cvter.workspace()); + } + { cvter.comp_to_dst_type(fdst_tensor, *args.dst_tensor); } +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/conv_bias/chanwise.cpp b/dnn/src/cuda/conv_bias/chanwise.cpp index cdff851e326d5e4a648758edaaa60914b8c0f3e4..c02cae5da6812787c6aefcdb732be80735995871 100644 --- a/dnn/src/cuda/conv_bias/chanwise.cpp +++ b/dnn/src/cuda/conv_bias/chanwise.cpp @@ -20,6 +20,10 @@ using namespace conv_bias; bool ConvBiasForwardImpl::AlgoChanwise::is_available( const SizeArgs& args) const { + if (args.src_layout->dtype == args.filter_layout->dtype && + args.src_layout->dtype == dtype::BFloat16()) { + return false; + } if (args.z_layout->ndim > 0) return false; diff --git a/dnn/src/cuda/conv_bias/chanwise_small.cpp b/dnn/src/cuda/conv_bias/chanwise_small.cpp index 3c8d4dcacca935ab11532f1faa567bb908c93df1..9aed9f0346172d7b0d394c217062c767a8ecdc17 100644 --- a/dnn/src/cuda/conv_bias/chanwise_small.cpp +++ b/dnn/src/cuda/conv_bias/chanwise_small.cpp @@ -30,6 +30,10 @@ inline bool is_available_small(const chanwise::Param& param) { bool ConvBiasForwardImpl::AlgoChanwiseSmall::is_available( const SizeArgs& args) const { + if (args.src_layout->dtype == args.filter_layout->dtype && + args.src_layout->dtype == dtype::BFloat16()) { + return false; + } if (args.z_layout->ndim > 0) return false; #if CUDA_VERSION < 9000 diff --git a/dnn/src/cuda/conv_bias/cudnn_conv_bias_activation.cpp b/dnn/src/cuda/conv_bias/cudnn_conv_bias_activation.cpp index 09efb16070ce052578d6e76f30a0de9faecf9e61..f38357718f4f5ff0347903e58b8f2e5e182ea5af 100644 --- a/dnn/src/cuda/conv_bias/cudnn_conv_bias_activation.cpp +++ b/dnn/src/cuda/conv_bias/cudnn_conv_bias_activation.cpp @@ -23,6 +23,10 @@ using namespace conv_bias; bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available( const SizeArgs& args) const { + if (args.src_layout->dtype == args.filter_layout->dtype && + args.src_layout->dtype == dtype::BFloat16()) { + return false; + } if (args.bias_layout->ndim == 0 || args.bias_layout->eq_shape(*args.dst_layout)) return false; diff --git a/dnn/src/cuda/conv_bias/group_conv.cpp b/dnn/src/cuda/conv_bias/group_conv.cpp index cfcf60a3afce76b190fbf0389add4d0ee5a69399..4063f049fd8c2ae2b2e9c2c7d49983923faa081b 100644 --- a/dnn/src/cuda/conv_bias/group_conv.cpp +++ b/dnn/src/cuda/conv_bias/group_conv.cpp @@ -50,6 +50,10 @@ ConvBiasForwardImpl::AlgoGroupConvGeneral::AlgoGroupConvGeneral(AlgoBase* impl) bool ConvBiasForwardImpl::AlgoGroupConvGeneral::is_available( const SizeArgs& args) const { + if (args.src_layout->dtype == args.filter_layout->dtype && + args.src_layout->dtype == dtype::BFloat16()) { + return false; + } if (args.z_layout->ndim > 0 || args.filter_meta.group <= 1) return false; auto&& param = args.opr->param(); diff --git a/dnn/src/cuda/conv_bias/helper.cpp b/dnn/src/cuda/conv_bias/helper.cpp index e36eb88a85cdda510c9b30d20766ca39b7d5b623..b466628bc20a847268b50d55fdbd102a84962eb5 100644 --- a/dnn/src/cuda/conv_bias/helper.cpp +++ b/dnn/src/cuda/conv_bias/helper.cpp @@ -136,6 +136,11 @@ void ConvBiasDesc::set_conv(DType data_type, const param::ConvBias& param, namespace conv_bias { bool is_cudnn_supported(const BiasForwardSizeArgs& args) { + if (args.src_layout->dtype == args.filter_layout->dtype && + args.src_layout->dtype == dtype::BFloat16()) { + return false; + } + // CUDNN_STATUS_EXECUTION_FAILED on Tegra K1, so disable CUDNN // on Tegra K1. if (args.handle->is_tegra_k1()) diff --git a/dnn/src/cuda/conv_bias/matmul.cpp b/dnn/src/cuda/conv_bias/matmul.cpp index 1f7956a6bff6bc369b1101ec20a53bdbfa20129e..fc99218eb6d3bea9d24797cbf6e4783c49a2e15d 100644 --- a/dnn/src/cuda/conv_bias/matmul.cpp +++ b/dnn/src/cuda/conv_bias/matmul.cpp @@ -20,6 +20,10 @@ using namespace cuda; using namespace conv_bias; bool ConvBiasForwardImpl::AlgoMatmul::is_available(const SizeArgs& args) const { + if (args.src_layout->dtype == args.filter_layout->dtype && + args.src_layout->dtype == dtype::BFloat16()) { + return false; + } if (args.z_layout->ndim > 0) return false; diff --git a/dnn/src/cuda/conv_bias/opr_impl.cpp b/dnn/src/cuda/conv_bias/opr_impl.cpp index 79ae71fc5f972bf4f2c32275a6bacf402d44a258..c8b91ca13c18753e8322bdb028eb88582aba8e1a 100644 --- a/dnn/src/cuda/conv_bias/opr_impl.cpp +++ b/dnn/src/cuda/conv_bias/opr_impl.cpp @@ -9,6 +9,7 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "src/cuda/conv_bias/opr_impl.h" +#include "megdnn/dtype.h" #include "src/cuda/conv_bias/helper.h" #include "src/cuda/conv_bias/algo.h" #include "src/cuda/handle.h" @@ -176,14 +177,26 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( conv_args = orig_args; } - if (reproducible) { - return megdnn::get_reproducible_algo( - sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, - "cuda convbias fwd"); + if (args.src_layout->dtype.enumv() != DTypeTrait::enumv) { + if (reproducible) { + return megdnn::get_reproducible_algo( + sm_algo_pack.non_cudnn_algos, args, + workspace_limit_in_bytes, "cuda convbias fwd"); + } else { + return megdnn::get_usable_algo( + sm_algo_pack.non_cudnn_algos, args, + workspace_limit_in_bytes, "cuda convbias fwd"); + } } else { - return megdnn::get_usable_algo( - sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, - "cuda convbias fwd"); + if (reproducible) { + return megdnn::get_reproducible_algo( + sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, + "cuda convbias fwd"); + } else { + return megdnn::get_usable_algo( + sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, + "cuda convbias fwd"); + } } } diff --git a/dnn/src/cuda/conv_bias/opr_impl.h b/dnn/src/cuda/conv_bias/opr_impl.h index 4efc46b3ef759b0fd840b0d2ede289a60992caa8..a5fcaeda3f7a81680fe1a8dd82729daa3cf5d08b 100644 --- a/dnn/src/cuda/conv_bias/opr_impl.h +++ b/dnn/src/cuda/conv_bias/opr_impl.h @@ -57,6 +57,7 @@ public: class AlgoInt8NCHW4IMMAImplicitGemm; class AlgoInt8CHWN4IMMAImplicitGemmReorderFilter; class AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth; + class AlgoBFloat16; class AlgoPack; diff --git a/dnn/src/cuda/convolution/backward_data/algo.cpp b/dnn/src/cuda/convolution/backward_data/algo.cpp index 5ef94ebb81d1403408150cc262aa7c4647d6e066..b888e947a9974bc35457508772555a4a640d1b17 100644 --- a/dnn/src/cuda/convolution/backward_data/algo.cpp +++ b/dnn/src/cuda/convolution/backward_data/algo.cpp @@ -33,11 +33,12 @@ ConvolutionBackwardDataImpl::AlgoPack::AlgoPack() { // add gconv algos by AlgoGroupConvGeneral auto all_algos_data = all_algos.data(); - for (size_t i = 2; i < all_algos.size(); ++ i) { + size_t group_algo_start = 2; + for (size_t i = group_algo_start; i < all_algos.size(); ++ i) { gconv.push_back({all_algos[i]}); } - for (size_t i = 2; i < all_algos.size(); ++ i) { - algo2gconv[all_algos[i]] = &gconv[i - 2]; + for (size_t i = group_algo_start; i < all_algos.size(); ++ i) { + algo2gconv[all_algos[i]] = &gconv[i - group_algo_start]; } for (auto &&i: gconv) { all_algos.push_back(&i); @@ -45,6 +46,12 @@ ConvolutionBackwardDataImpl::AlgoPack::AlgoPack() { megdnn_assert(all_algos_data == all_algos.data()); non_cudnn_algos.push_back(all_algos.rbegin()[0]); // group matmul + size_t algo_size = all_algos.size(); + for (size_t i=0; icheck_layout_fwd(grad, filter, diff), diff, grad) + SizeArgs(o, filter, o->check_layout_fwd(grad, filter, diff), diff, grad) { } ConvolutionBackwardDataImpl::AlgoBase::SizeArgs::SizeArgs( - ConvolutionBackwardDataImpl *o, - const CanonizedFilterMeta &filter, const TensorLayout &diff, + ConvolutionBackwardDataImpl *o, const TensorLayout& filter, + const CanonizedFilterMeta &filter_meta, const TensorLayout &diff, const TensorLayout &grad): handle{concrete_handle(o->handle())}, - filter_meta{filter}, + filter_meta{filter_meta}, diff_layout{&diff}, grad_layout{&grad}, + filter_layout{&filter}, opr{o} { } diff --git a/dnn/src/cuda/convolution/backward_data/algo.h b/dnn/src/cuda/convolution/backward_data/algo.h index 0a97f17deb288fa59a3272af59e43e63deaa0744..9e505e189c0548af16344ce98a10d0a8b49d7a31 100644 --- a/dnn/src/cuda/convolution/backward_data/algo.h +++ b/dnn/src/cuda/convolution/backward_data/algo.h @@ -31,22 +31,24 @@ class ConvolutionBackwardDataImpl::AlgoBase: public Algorithm { struct SizeArgs { HandleImpl *handle; CanonizedFilterMeta filter_meta; - const TensorLayout *diff_layout, *grad_layout; + const TensorLayout *diff_layout, *grad_layout, *filter_layout; ConvolutionBackwardDataImpl *opr; std::string to_string() const; void init_desc(convolution::CUDNNBwdDataDescs &desc) const { desc.set(filter_meta, *diff_layout, *grad_layout, opr->param()); } - SizeArgs(ConvolutionBackwardDataImpl *opr, - const TensorLayout &filter, const TensorLayout &diff, - const TensorLayout &grad); - SizeArgs(ConvolutionBackwardDataImpl *opr, - const CanonizedFilterMeta &filter, const TensorLayout &diff, - const TensorLayout &grad); + SizeArgs(ConvolutionBackwardDataImpl* opr, + const TensorLayout& filter, const TensorLayout& diff, + const TensorLayout& grad); + SizeArgs(ConvolutionBackwardDataImpl* opr, + const TensorLayout& filter, + const CanonizedFilterMeta& filter_meta, + const TensorLayout& diff, const TensorLayout& grad); convolution::ForwardSizeArgs as_fwd_args() const { - return {handle, grad_layout, filter_meta, diff_layout}; + return {handle, grad_layout, filter_layout, filter_meta, + diff_layout}; } }; struct ExecArgs: public SizeArgs { @@ -170,6 +172,25 @@ class ConvolutionBackwardDataImpl::AlgoChanwiseSmall final: public AlgoBase { } }; +class ConvolutionBackwardDataImpl::AlgoBFloat16 final : public AlgoBase { +public: + AlgoBFloat16(ConvolutionBackwardDataImpl::AlgoBase*); + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + void exec(const ExecArgs& args) const override; + + const char* name() const override { return m_name.c_str(); } + bool is_reproducible() const override { return true; } + +private: + std::string m_name; + ConvolutionBackwardDataImpl::AlgoBase* m_algorithm = nullptr; + SizeArgs float_args(const SizeArgs& args, ConvolutionBackwardDataImpl* opr, + TensorLayout& fsrc, TensorLayout& ffilter, + TensorLayout& fdst) const; + WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; +}; + //! implement group conv by another algo class ConvolutionBackwardDataImpl::AlgoGroupConvGeneral final: public AlgoBase { AlgoBase *m_impl; @@ -210,12 +231,14 @@ class ConvolutionBackwardDataImpl::AlgoPack { AlgoChanwiseSmall chanwise_small; std::vector gconv; std::unordered_map algo2gconv; + std::vector> bfloat16_refhold; std::vector //! all algorithms all_algos, //! non-cudnn algos, used for heuristic if cudnn is not supported - non_cudnn_algos; + non_cudnn_algos, + bfloat16_algos; AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdDataAlgo_t algo); }; diff --git a/dnn/src/cuda/convolution/backward_data/bfloat16.cpp b/dnn/src/cuda/convolution/backward_data/bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2c33ae83d39023c68fbb584a6b5cd821711e6686 --- /dev/null +++ b/dnn/src/cuda/convolution/backward_data/bfloat16.cpp @@ -0,0 +1,115 @@ +/** + * \file src/cuda/convolution/backward_data/bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "./algo.h" +#include "src/cuda/convolution/chanwise/kern.cuh" +#include "src/cuda/utils.h" + +using namespace megdnn; +using namespace cuda; +using namespace convolution; + +ConvolutionBackwardDataImpl::AlgoBFloat16::AlgoBFloat16( + ConvolutionBackwardDataImpl::AlgoBase* algorithm) + : m_algorithm(algorithm) { + megdnn_assert_internal(algorithm); + m_name = ssprintf("CONVOLUTION_BACKWARD_DATD_BFLOAT16:%s", + m_algorithm->name()); +} + +ConvolutionBackwardDataImpl::AlgoBase::SizeArgs +ConvolutionBackwardDataImpl::AlgoBFloat16::float_args( + const SizeArgs& args, ConvolutionBackwardDataImpl* opr, + TensorLayout& ffilter, TensorLayout& fdiff, TensorLayout& fgrad) const { + ffilter = *args.filter_layout; + fdiff = *args.diff_layout; + fgrad = *args.grad_layout; + auto change_dtype = [](TensorLayout& layout) { + if (layout.dtype == dtype::BFloat16()) { + layout.dtype = dtype::Float32(); + } + }; + change_dtype(ffilter); + change_dtype(fdiff); + change_dtype(fgrad); + opr->param() = args.opr->param(); + opr->param().compute_mode = Param::ComputeMode::DEFAULT; + opr->execution_policy() = {m_algorithm}; + return SizeArgs(opr, ffilter, fdiff, fgrad); +} + +bool ConvolutionBackwardDataImpl::AlgoBFloat16::is_available( + const SizeArgs& args) const { + TensorLayout ffilter, fdiff, fgrad; + auto conv_back_data_opr = + args.handle->create_operator(); + SizeArgs fargs = float_args( + args, + static_cast(conv_back_data_opr.get()), + ffilter, fdiff, fgrad); + return args.diff_layout->dtype == args.filter_layout->dtype && + args.diff_layout->dtype == dtype::BFloat16() && + m_algorithm->is_available(fargs); +} + +WorkspaceBundle ConvolutionBackwardDataImpl::AlgoBFloat16::get_workspace_bundle( + void* ptr, const SizeArgs& args) const { + TensorLayout ffilter, fdiff, fgrad; + auto conv_back_data_opr = + args.handle->create_operator(); + SizeArgs fargs = float_args( + args, + static_cast(conv_back_data_opr.get()), + ffilter, fdiff, fgrad); + SmallVector sizes; + auto get_workspace = [&sizes](const TensorLayout& src, + const TensorLayout& dst) { + if (src.dtype != dst.dtype) { + sizes.push_back(dst.span().dist_byte()); + } + }; + get_workspace(*args.filter_layout, ffilter); + get_workspace(*args.diff_layout, fdiff); + get_workspace(*args.grad_layout, fgrad); + sizes.push_back(m_algorithm->get_workspace_in_bytes(fargs)); + return {ptr, std::move(sizes)}; +} + +size_t ConvolutionBackwardDataImpl::AlgoBFloat16::get_workspace_in_bytes( + const SizeArgs& args) const { + return get_workspace_bundle(nullptr, args).total_size_in_bytes(); +} + +void ConvolutionBackwardDataImpl::AlgoBFloat16::exec( + const ExecArgs& args) const { + TensorND ffilter_tensor = *args.filter_tensor; + TensorND fdiff_tensor = *args.diff_tensor; + TensorND fgrad_tensor = *args.grad_tensor; + auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); + CompTypeCvter cvter(args.handle, &bundle); + { + cvter.src_to_comp_type(*args.filter_tensor, ffilter_tensor) + .src_to_comp_type(*args.diff_tensor, fdiff_tensor) + .src_to_comp_type(*args.grad_tensor, fgrad_tensor); + } + { + auto conv_back_data_opr = + args.handle->create_operator(); + conv_back_data_opr->param() = args.opr->param(); + conv_back_data_opr->param().compute_mode = Param::ComputeMode::DEFAULT; + conv_back_data_opr->execution_policy() = {m_algorithm}; + conv_back_data_opr->exec(ffilter_tensor, fdiff_tensor, fgrad_tensor, + cvter.workspace()); + } + { cvter.comp_to_dst_type(fgrad_tensor, *args.grad_tensor); } +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convolution/backward_data/chanwise.cpp b/dnn/src/cuda/convolution/backward_data/chanwise.cpp index d2fc6249906b185ad4890a1a9f1b343b6f4cae1c..57d3242ea971da58a8421409c644ad8e24d83af8 100644 --- a/dnn/src/cuda/convolution/backward_data/chanwise.cpp +++ b/dnn/src/cuda/convolution/backward_data/chanwise.cpp @@ -19,6 +19,10 @@ using namespace convolution; bool ConvolutionBackwardDataImpl::AlgoChanwise::is_available( const SizeArgs& args) const { + if (args.diff_layout->dtype == args.filter_layout->dtype && + args.diff_layout->dtype == dtype::BFloat16()) { + return false; + } auto&& fm = args.filter_meta; return args.filter_meta.format == Param::Format::NCHW && args.diff_layout->dtype.category() == DTypeCategory::FLOAT && diff --git a/dnn/src/cuda/convolution/backward_data/chanwise_small.cpp b/dnn/src/cuda/convolution/backward_data/chanwise_small.cpp index 562644beba0ce1101b7d56b9f7e4be7bae720298..7f00a29eac0742f7429fde64bc675ef02441881f 100644 --- a/dnn/src/cuda/convolution/backward_data/chanwise_small.cpp +++ b/dnn/src/cuda/convolution/backward_data/chanwise_small.cpp @@ -29,6 +29,10 @@ inline bool is_available_small(const chanwise::Param& param) { bool ConvolutionBackwardDataImpl::AlgoChanwiseSmall::is_available( const SizeArgs &args) const { + if (args.diff_layout->dtype == args.filter_layout->dtype && + args.diff_layout->dtype == dtype::BFloat16()) { + return false; + } #if CUDA_VERSION < 9000 if (args.diff_layout->dtype.enumv() == DTypeEnum::Float16) return false; diff --git a/dnn/src/cuda/convolution/backward_data/group_conv.cpp b/dnn/src/cuda/convolution/backward_data/group_conv.cpp index 2e60eb98da5835cdf6914fc579aa206a832a6705..51e2a4b14fd8afcebcc97abc1144492b847895b6 100644 --- a/dnn/src/cuda/convolution/backward_data/group_conv.cpp +++ b/dnn/src/cuda/convolution/backward_data/group_conv.cpp @@ -38,6 +38,10 @@ ConvolutionBackwardDataImpl::AlgoGroupConvGeneral::AlgoGroupConvGeneral( bool ConvolutionBackwardDataImpl::AlgoGroupConvGeneral::is_available( const SizeArgs &args) const { + if (args.diff_layout->dtype == args.filter_layout->dtype && + args.diff_layout->dtype == dtype::BFloat16()) { + return false; + } auto sub_args = args; TensorLayout diff_pg, grad_pg; modify_size_args(sub_args, diff_pg, grad_pg); diff --git a/dnn/src/cuda/convolution/backward_data/matmul.cpp b/dnn/src/cuda/convolution/backward_data/matmul.cpp index 1a873f1ddca2e1333bcc34b4a8fbb202f982e809..1188f3e5bcdd58890a0bba318da4780e84c669d8 100644 --- a/dnn/src/cuda/convolution/backward_data/matmul.cpp +++ b/dnn/src/cuda/convolution/backward_data/matmul.cpp @@ -20,6 +20,10 @@ using namespace cuda; bool ConvolutionBackwardDataImpl::AlgoMatmul::is_available( const SizeArgs &args) const { + if (args.diff_layout->dtype == args.filter_layout->dtype && + args.diff_layout->dtype == dtype::BFloat16()) { + return false; + } auto &&fm = args.filter_meta; return args.filter_meta.format == Param::Format::NCHW && args.diff_layout->dtype.category() == DTypeCategory::FLOAT && diff --git a/dnn/src/cuda/convolution/backward_filter/algo.cpp b/dnn/src/cuda/convolution/backward_filter/algo.cpp index fdffefa8f18e5ea218c3a8cc55b48624d99f4e31..601663af7f557eb4c2dd91f248639936e8d39431 100644 --- a/dnn/src/cuda/convolution/backward_filter/algo.cpp +++ b/dnn/src/cuda/convolution/backward_filter/algo.cpp @@ -43,6 +43,12 @@ ConvolutionBackwardFilterImpl::AlgoPack::AlgoPack() { megdnn_assert(all_algos_data == all_algos.data()); non_cudnn_algos.push_back(all_algos.rbegin()[0]); // group matmul + size_t algo_size = all_algos.size(); + for (size_t i=0; icheck_layout_fwd(src, grad, diff)) + SizeArgs(o, src, diff, grad, o->check_layout_fwd(src, grad, diff)) { } ConvolutionBackwardFilterImpl::AlgoBase::SizeArgs::SizeArgs( - ConvolutionBackwardFilterImpl *o, - const TensorLayout &src, const TensorLayout &diff, - const CanonizedFilterMeta &grad): - handle{concrete_handle(o->handle())}, - src_layout{&src}, - diff_layout{&diff}, - grad_filter_meta{grad}, - opr{o} -{ -} + ConvolutionBackwardFilterImpl* o, const TensorLayout& src, + const TensorLayout& diff, const TensorLayout& grad, + const CanonizedFilterMeta& grad_meta) + : handle{concrete_handle(o->handle())}, + src_layout{&src}, + diff_layout{&diff}, + grad_layout{&grad}, + grad_filter_meta{grad_meta}, + opr{o} {} ConvolutionBackwardFilterImpl::AlgoBase::ExecArgs::ExecArgs( ConvolutionBackwardFilterImpl *opr, diff --git a/dnn/src/cuda/convolution/backward_filter/algo.h b/dnn/src/cuda/convolution/backward_filter/algo.h index c1a25860b020508883292de123c427f19932dc15..ef70258b922046bbff44630cdd42483960160963 100644 --- a/dnn/src/cuda/convolution/backward_filter/algo.h +++ b/dnn/src/cuda/convolution/backward_filter/algo.h @@ -30,7 +30,7 @@ class ConvolutionBackwardFilterImpl::AlgoBase: public Algorithm { public: struct SizeArgs { HandleImpl *handle; - const TensorLayout *src_layout, *diff_layout; + const TensorLayout *src_layout, *diff_layout, *grad_layout; CanonizedFilterMeta grad_filter_meta; ConvolutionBackwardFilterImpl *opr; @@ -42,12 +42,14 @@ class ConvolutionBackwardFilterImpl::AlgoBase: public Algorithm { SizeArgs(ConvolutionBackwardFilterImpl *opr, const TensorLayout &src, const TensorLayout &diff, const TensorLayout &grad); - SizeArgs(ConvolutionBackwardFilterImpl *opr, - const TensorLayout &src, const TensorLayout &diff, - const CanonizedFilterMeta &grad); + SizeArgs(ConvolutionBackwardFilterImpl* opr, + const TensorLayout& src, const TensorLayout& diff, + const TensorLayout& grad, + const CanonizedFilterMeta& grad_meta); convolution::ForwardSizeArgs as_fwd_args() const { - return {handle, src_layout, grad_filter_meta, diff_layout}; + return {handle, src_layout, grad_layout, grad_filter_meta, + diff_layout}; } }; struct ExecArgs: public SizeArgs { @@ -157,6 +159,25 @@ class ConvolutionBackwardFilterImpl::AlgoChanwise final: public AlgoBase { } }; +class ConvolutionBackwardFilterImpl::AlgoBFloat16 final : public AlgoBase { +public: + AlgoBFloat16(ConvolutionBackwardFilterImpl::AlgoBase*); + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + void exec(const ExecArgs& args) const override; + + const char* name() const override { return m_name.c_str(); } + bool is_reproducible() const override { return true; } + +private: + std::string m_name; + ConvolutionBackwardFilterImpl::AlgoBase* m_algorithm = nullptr; + SizeArgs float_args(const SizeArgs& args, + ConvolutionBackwardFilterImpl* opr, TensorLayout& fsrc, + TensorLayout& ffilter, TensorLayout& fdst) const; + WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; +}; + //! implement group conv by another algo class ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral final: public AlgoBase { AlgoBase *m_impl; @@ -196,12 +217,14 @@ class ConvolutionBackwardFilterImpl::AlgoPack { AlgoChanwise chanwise; std::vector gconv; std::unordered_map algo2gconv; + std::vector> bfloat16_refhold; std::vector //! all algorithms all_algos, //! non-cudnn algos, used for heuristic if cudnn is not supported - non_cudnn_algos; + non_cudnn_algos, + bfloat16_algos; AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdFilterAlgo_t algo); }; diff --git a/dnn/src/cuda/convolution/backward_filter/bfloat16.cpp b/dnn/src/cuda/convolution/backward_filter/bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..21c98745819140b44e265687c11605b279334cfb --- /dev/null +++ b/dnn/src/cuda/convolution/backward_filter/bfloat16.cpp @@ -0,0 +1,117 @@ +/** + * \file src/cuda/convolution/backward_filter/bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "./algo.h" +#include "src/cuda/convolution/chanwise/kern.cuh" +#include "src/cuda/utils.h" + +using namespace megdnn; +using namespace cuda; +using namespace convolution; + +ConvolutionBackwardFilterImpl::AlgoBFloat16::AlgoBFloat16( + ConvolutionBackwardFilterImpl::AlgoBase* algorithm) + : m_algorithm(algorithm) { + megdnn_assert_internal(algorithm); + m_name = ssprintf("CONVOLUTION_BACKWARD_Filter_BFLOAT16:%s", + m_algorithm->name()); +} + +ConvolutionBackwardFilterImpl::AlgoBase::SizeArgs +ConvolutionBackwardFilterImpl::AlgoBFloat16::float_args( + const SizeArgs& args, ConvolutionBackwardFilterImpl* opr, + TensorLayout& fsrc, TensorLayout& fdiff, TensorLayout& fgrad) const { + fsrc = *args.src_layout; + fdiff = *args.diff_layout; + fgrad = *args.grad_layout; + auto change_dtype = [](TensorLayout& layout) { + if (layout.dtype == dtype::BFloat16()) { + layout.dtype = dtype::Float32(); + } + }; + change_dtype(fsrc); + change_dtype(fdiff); + change_dtype(fgrad); + opr->param() = args.opr->param(); + opr->param().compute_mode = Param::ComputeMode::DEFAULT; + opr->execution_policy() = {m_algorithm}; + return SizeArgs(opr, fsrc, fdiff, fgrad); +} + +bool ConvolutionBackwardFilterImpl::AlgoBFloat16::is_available( + const SizeArgs& args) const { + TensorLayout fsrc, fdiff, fgrad; + auto conv_back_filter_opr = + args.handle->create_operator(); + SizeArgs fargs = float_args(args, + static_cast( + conv_back_filter_opr.get()), + fsrc, fdiff, fgrad); + return args.src_layout->dtype == args.diff_layout->dtype && + args.src_layout->dtype == dtype::BFloat16() && + m_algorithm->is_available(fargs); +} + +WorkspaceBundle +ConvolutionBackwardFilterImpl::AlgoBFloat16::get_workspace_bundle( + void* ptr, const SizeArgs& args) const { + TensorLayout fsrc, fdiff, fgrad; + auto conv_back_filter_opr = + args.handle->create_operator(); + SizeArgs fargs = float_args(args, + static_cast( + conv_back_filter_opr.get()), + fsrc, fdiff, fgrad); + SmallVector sizes; + auto get_workspace = [&sizes](const TensorLayout& src, + const TensorLayout& dst) { + if (src.dtype != dst.dtype) { + sizes.push_back(dst.span().dist_byte()); + } + }; + get_workspace(*args.src_layout, fsrc); + get_workspace(*args.diff_layout, fdiff); + get_workspace(*args.grad_layout, fgrad); + sizes.push_back(m_algorithm->get_workspace_in_bytes(fargs)); + return {ptr, std::move(sizes)}; +} + +size_t ConvolutionBackwardFilterImpl::AlgoBFloat16::get_workspace_in_bytes( + const SizeArgs& args) const { + return get_workspace_bundle(nullptr, args).total_size_in_bytes(); +} + +void ConvolutionBackwardFilterImpl::AlgoBFloat16::exec( + const ExecArgs& args) const { + TensorND fsrc_tensor = *args.src_tensor; + TensorND fdiff_tensor = *args.diff_tensor; + TensorND fgrad_tensor = *args.grad_tensor; + auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); + CompTypeCvter cvter(args.handle, &bundle); + { + cvter.src_to_comp_type(*args.src_tensor, fsrc_tensor) + .src_to_comp_type(*args.diff_tensor, fdiff_tensor) + .src_to_comp_type(*args.grad_tensor, fgrad_tensor); + } + { + auto conv_back_filter_opr = + args.handle->create_operator(); + conv_back_filter_opr->param() = args.opr->param(); + conv_back_filter_opr->param().compute_mode = + Param::ComputeMode::DEFAULT; + conv_back_filter_opr->execution_policy() = {m_algorithm}; + conv_back_filter_opr->exec(fsrc_tensor, fdiff_tensor, fgrad_tensor, + cvter.workspace()); + } + { cvter.comp_to_dst_type(fgrad_tensor, *args.grad_tensor); } +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convolution/backward_filter/chanwise.cpp b/dnn/src/cuda/convolution/backward_filter/chanwise.cpp index 52f590b159732a03dd318090a7597d29a31f399b..f12c5101fe1afd6704dcee8a9fe03023e1cd9e94 100644 --- a/dnn/src/cuda/convolution/backward_filter/chanwise.cpp +++ b/dnn/src/cuda/convolution/backward_filter/chanwise.cpp @@ -19,6 +19,10 @@ using namespace convolution; bool ConvolutionBackwardFilterImpl::AlgoChanwise::is_available( const SizeArgs &args) const { + if (args.src_layout->dtype == args.src_layout->dtype && + args.diff_layout->dtype == dtype::BFloat16()) { + return false; + } auto &&fm = args.grad_filter_meta; return fm.format == Param::Format::NCHW && args.diff_layout->dtype.category() == DTypeCategory::FLOAT && diff --git a/dnn/src/cuda/convolution/backward_filter/group_conv.cpp b/dnn/src/cuda/convolution/backward_filter/group_conv.cpp index 164145fce42d7367628ba1c2c2fecf76b29f76dd..a2569813d4d3c61178717d2fc00cb51cea2859c5 100644 --- a/dnn/src/cuda/convolution/backward_filter/group_conv.cpp +++ b/dnn/src/cuda/convolution/backward_filter/group_conv.cpp @@ -38,6 +38,10 @@ ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral::AlgoGroupConvGeneral( bool ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral::is_available( const SizeArgs &args) const { + if (args.src_layout->dtype == args.src_layout->dtype && + args.diff_layout->dtype == dtype::BFloat16()) { + return false; + } auto sub_args = args; TensorLayout src_pg, diff_pg; modify_size_args(sub_args, src_pg, diff_pg); diff --git a/dnn/src/cuda/convolution/backward_filter/matmul.cpp b/dnn/src/cuda/convolution/backward_filter/matmul.cpp index 7d454534c3f2da6a1c8513472f3242cb5309eeec..761f0ec1e21d688c9b859fc8f87a694a8901699c 100644 --- a/dnn/src/cuda/convolution/backward_filter/matmul.cpp +++ b/dnn/src/cuda/convolution/backward_filter/matmul.cpp @@ -19,6 +19,10 @@ using namespace cuda; bool ConvolutionBackwardFilterImpl::AlgoMatmul::is_available( const SizeArgs &args) const { + if (args.src_layout->dtype == args.src_layout->dtype && + args.diff_layout->dtype == dtype::BFloat16()) { + return false; + } auto &&fm = args.grad_filter_meta; return fm.format == Param::Format::NCHW && args.diff_layout->dtype.category() == DTypeCategory::FLOAT && diff --git a/dnn/src/cuda/convolution/helper.cpp b/dnn/src/cuda/convolution/helper.cpp index 807df29e79bd3f13b407500e911a7607d1716559..06878e51316edb9b76040f97404d589bd1f4e6cb 100644 --- a/dnn/src/cuda/convolution/helper.cpp +++ b/dnn/src/cuda/convolution/helper.cpp @@ -16,6 +16,10 @@ using namespace cuda; using namespace convolution; bool convolution::is_cudnn_supported(const ForwardSizeArgs &args) { + if (args.src_layout->dtype == args.filter_layout->dtype && + args.src_layout->dtype == dtype::BFloat16()) { + return false; + } // CUDNN_STATUS_EXECUTION_FAILED on Tegra K1, so disable CUDNN // on Tegra K1. diff --git a/dnn/src/cuda/convolution/helper.h b/dnn/src/cuda/convolution/helper.h index e61449dc2ec3480f167f891c36268e78362525b7..92d24c7b11892e1f94831289cb79d01c7f19a8ce 100644 --- a/dnn/src/cuda/convolution/helper.h +++ b/dnn/src/cuda/convolution/helper.h @@ -25,6 +25,7 @@ namespace convolution { struct ForwardSizeArgs { HandleImpl *handle; const TensorLayout *src_layout; + const TensorLayout *filter_layout; CanonizedFilterMeta filter_meta; const TensorLayout *dst_layout; }; diff --git a/dnn/src/cuda/convolution/opr_impl.cpp b/dnn/src/cuda/convolution/opr_impl.cpp index 3558dfaa694026f05b91ff0042c4d0599f281998..7832a0b48027c5a000aeddadb7bfeead82133c21 100644 --- a/dnn/src/cuda/convolution/opr_impl.cpp +++ b/dnn/src/cuda/convolution/opr_impl.cpp @@ -102,7 +102,8 @@ void ConvolutionBackwardDataImpl::exec(_megdnn_tensor_in filter, _megdnn_tensor_out grad, _megdnn_workspace workspace) { AlgoBase::ExecArgs args(this, filter, diff, grad, workspace); - auto algo = get_algorithm(this, args.filter_meta, diff.layout, grad.layout); + auto algo = get_algorithm(this, filter.layout, args.filter_meta, + diff.layout, grad.layout); algo->check_workspace(args, workspace).exec(args); } @@ -120,16 +121,16 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic( const TensorLayout& grad, size_t workspace_limit_in_bytes, bool reproducible) { auto fm = check_layout_fwd(grad, filter, diff); - return get_algorithm_heuristic(fm, diff, grad, workspace_limit_in_bytes, - reproducible); + return get_algorithm_heuristic(filter, fm, diff, grad, + workspace_limit_in_bytes, reproducible); } ConvolutionBackwardDataImpl::Algorithm* -ConvolutionBackwardDataImpl::get_algorithm_heuristic( - const CanonizedFilterMeta& filter, const TensorLayout& diff, +ConvolutionBackwardDataImpl::get_algorithm_heuristic(const TensorLayout& filter, + const CanonizedFilterMeta& filter_meta, const TensorLayout& diff, const TensorLayout& grad, size_t workspace_limit_in_bytes, bool reproducible) { - AlgoBase::SizeArgs args(this, filter, diff, grad); + AlgoBase::SizeArgs args(this, filter, filter_meta, diff, grad); if (args.filter_meta.group > 1 && sm_algo_pack.chanwise.is_available_reproducible( @@ -209,14 +210,27 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic( args = orig_args; } - if (reproducible) { - return megdnn::get_reproducible_algo( - sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, - "cuda conv bwd_data"); + if (args.filter_layout->dtype.enumv() != + DTypeTrait::enumv) { + if (reproducible) { + return megdnn::get_reproducible_algo( + sm_algo_pack.non_cudnn_algos, args, + workspace_limit_in_bytes, "cuda conv bwd_data"); + } else { + return megdnn::get_usable_algo( + sm_algo_pack.non_cudnn_algos, args, + workspace_limit_in_bytes, "cuda conv bwd_data"); + } } else { - return megdnn::get_usable_algo( - sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, - "cuda conv bwd_data"); + if (reproducible) { + return megdnn::get_reproducible_algo( + sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, + "cuda conv bwd_data"); + } else { + return megdnn::get_usable_algo( + sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, + "cuda conv bwd_data"); + } } } @@ -225,7 +239,7 @@ size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes( const TensorLayout &diff, const TensorLayout &grad) { AlgoBase::SizeArgs args(this, filter, diff, grad); - return get_algorithm(this, args.filter_meta, diff, grad)-> + return get_algorithm(this, filter, args.filter_meta, diff, grad)-> get_workspace_in_bytes(args); } @@ -241,7 +255,7 @@ void ConvolutionBackwardFilterImpl::exec(_megdnn_tensor_in src, _megdnn_workspace workspace) { AlgoBase::ExecArgs args(this, src, diff, grad, workspace); auto algo = get_algorithm(this, src.layout, diff.layout, - args.grad_filter_meta); + grad.layout, args.grad_filter_meta); algo->check_workspace(args, workspace).exec(args); } @@ -259,16 +273,16 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic( const TensorLayout& grad, size_t workspace_limit_in_bytes, bool reproducible) { auto fm = check_layout_fwd(src, grad, diff); - return get_algorithm_heuristic(src, diff, fm, workspace_limit_in_bytes, - reproducible); + return get_algorithm_heuristic(src, diff, grad, fm, + workspace_limit_in_bytes, reproducible); } ConvolutionBackwardFilterImpl::Algorithm* ConvolutionBackwardFilterImpl::get_algorithm_heuristic( const TensorLayout& src, const TensorLayout& diff, - const CanonizedFilterMeta& grad, size_t workspace_limit_in_bytes, - bool reproducible) { - AlgoBase::SizeArgs args(this, src, diff, grad); + const TensorLayout& grad, const CanonizedFilterMeta& grad_meta, + size_t workspace_limit_in_bytes, bool reproducible) { + AlgoBase::SizeArgs args(this, src, diff, grad, grad_meta); if (args.grad_filter_meta.group > 1 && sm_algo_pack.chanwise.is_available_reproducible( @@ -349,14 +363,26 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic( args = orig_args; } - if (reproducible) { - return megdnn::get_reproducible_algo( - sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, - "cuda conv bwd_filter"); + if (args.src_layout->dtype.enumv() != DTypeTrait::enumv) { + if (reproducible) { + return megdnn::get_reproducible_algo( + sm_algo_pack.non_cudnn_algos, args, + workspace_limit_in_bytes, "cuda conv bwd_filter"); + } else { + return megdnn::get_usable_algo( + sm_algo_pack.non_cudnn_algos, args, + workspace_limit_in_bytes, "cuda conv bwd_filter"); + } } else { - return megdnn::get_usable_algo( - sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, - "cuda conv bwd_filter"); + if (reproducible) { + return megdnn::get_reproducible_algo( + sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, + "cuda conv bwd_filter"); + } else { + return megdnn::get_usable_algo( + sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, + "cuda conv bwd_filter"); + } } } @@ -365,7 +391,7 @@ size_t ConvolutionBackwardFilterImpl::get_workspace_in_bytes( const TensorLayout &diff, const TensorLayout &grad) { AlgoBase::SizeArgs args(this, src, diff, grad); - return get_algorithm(this, src, diff, args.grad_filter_meta)-> + return get_algorithm(this, src, diff, grad, args.grad_filter_meta)-> get_workspace_in_bytes(args); } diff --git a/dnn/src/cuda/convolution/opr_impl.h b/dnn/src/cuda/convolution/opr_impl.h index 393bd9d590b9e09b9f10934e909367e8199bf041..e8c73cec7c89991d6d60fa626b7953d0cb8dfee2 100644 --- a/dnn/src/cuda/convolution/opr_impl.h +++ b/dnn/src/cuda/convolution/opr_impl.h @@ -60,11 +60,11 @@ class ConvolutionBackwardDataImpl: public ConvolutionBackwardData { const TensorLayout& grad, size_t workspace_limit_in_bytes, bool reproducible) override; - Algorithm* get_algorithm_heuristic(const CanonizedFilterMeta& filter, - const TensorLayout& diff, - const TensorLayout& grad, - size_t workspace_limit_in_bytes, - bool reproducible); + Algorithm* get_algorithm_heuristic( + const TensorLayout& filter, + const CanonizedFilterMeta& filter_meta, + const TensorLayout& diff, const TensorLayout& grad, + size_t workspace_limit_in_bytes, bool reproducible); size_t get_workspace_in_bytes(const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& grad) override; @@ -76,6 +76,7 @@ class ConvolutionBackwardDataImpl: public ConvolutionBackwardData { class AlgoChanwise; class AlgoChanwiseSmall; class AlgoGroupConvGeneral; + class AlgoBFloat16; class AlgoPack; @@ -104,7 +105,8 @@ class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter { bool reproducible) override; Algorithm* get_algorithm_heuristic(const TensorLayout& src, const TensorLayout& diff, - const CanonizedFilterMeta& grad, + const TensorLayout& gradk, + const CanonizedFilterMeta& grad_meta, size_t workspace_limit_in_bytes, bool reproducible); size_t get_workspace_in_bytes(const TensorLayout& src, @@ -117,6 +119,7 @@ class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter { class AlgoMatmul; class AlgoChanwise; class AlgoGroupConvGeneral; + class AlgoBFloat16; class AlgoPack; diff --git a/dnn/src/cuda/convolution3d/forward/algo.h b/dnn/src/cuda/convolution3d/forward/algo.h index 46974d68d8f7b866f78735def5fe750af1b0176b..baf6ad16209fceb43d99a6f3288bfe553e15a856 100644 --- a/dnn/src/cuda/convolution3d/forward/algo.h +++ b/dnn/src/cuda/convolution3d/forward/algo.h @@ -50,7 +50,7 @@ class Convolution3DForwardImpl::AlgoBase: public Algorithm { const CanonizedFilterMeta &filter, const TensorLayout &dst); }; - struct ExecArgs: public SizeArgs { + struct ExecArgs : public SizeArgs { const TensorND *src_tensor, *filter_tensor, *dst_tensor; Workspace workspace; diff --git a/dnn/src/cuda/elemwise/kimpl/ABS_GRAD_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/ABS_GRAD_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..0bd78ac30fbe9089271e949086bcae803e1c9428 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ABS_GRAD_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/ABS_GRAD_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/ABS_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/ABS_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..a34d79473a2e47cbdbc77e0fc9aee38ba47549dc --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ABS_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/ABS_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/ACOS_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/ACOS_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..b7717ef7561229f614bed2e428f4f9dad5c5bf6d --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ACOS_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/ACOS_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOS, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/ADD_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/ADD_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..2b6d9c2989f59e45103b14506159180de2062a2c --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ADD_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/ADD_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/ASIN_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/ASIN_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..613a14c8af7b6cf63bcff86587b8e2ad452985b9 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ASIN_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/ASIN_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASIN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/ATAN2_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/ATAN2_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..bd707b777cdd456e922a3a7d1d4d45f81416d688 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ATAN2_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/ATAN2_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATAN2, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/CEIL_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/CEIL_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..63a9e634419dd9370d8d0aecd12e42539babd204 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/CEIL_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/CEIL_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CEIL, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/COND_LEQ_MOV_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/COND_LEQ_MOV_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..866c5eb2304f9d47133252721b3e6b5aee952c1b --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/COND_LEQ_MOV_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/COND_LEQ_MOV_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/COS_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/COS_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..ccb477e01edbb906c484b768c3749a0d6e29b09a --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/COS_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/COS_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COS, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/EQ_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/EQ_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..e2eb5c93ea4788a81ed6097a780cc82d5ade44b7 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/EQ_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/EQ_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/ERFCINV_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/ERFCINV_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..c236e3d9d3938ab6017bc31947c527cda67e7d32 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ERFCINV_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/ERFCINV_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERFCINV, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/ERFC_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/ERFC_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..2cf4591d55d215012451783773bb6476b4e92872 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ERFC_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/ERFC_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERFC, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/ERFINV_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/ERFINV_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..7e0a8143cc560b3dee6f1a28540a57abf9a119a8 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ERFINV_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/ERFINV_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERFINV, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/ERF_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/ERF_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..9179d28afea6ca4051c2df8baffd4e750a8fba0a --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ERF_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/ERF_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERF, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/EXPM1_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/EXPM1_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..872d9211d945a91d4b5c37119550d9c7cbba7a36 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/EXPM1_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/EXPM1_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EXPM1, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/EXP_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/EXP_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..bfd06555ee1e53df1c99d4058bb63b50326081ee --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/EXP_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/EXP_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EXP, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/FAST_TANH_GRAD_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/FAST_TANH_GRAD_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..9f49a00b32b780614090e811c5a1ef5e474744d7 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/FAST_TANH_GRAD_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/FAST_TANH_GRAD_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/FAST_TANH_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/FAST_TANH_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..c4864a30def1666d21512ca3f936dcc1516522a0 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/FAST_TANH_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/FAST_TANH_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/FLOOR_DIV_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/FLOOR_DIV_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..bfbdc1536755a8debc92a24a419a8ec5c3c8b127 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/FLOOR_DIV_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/FLOOR_DIV_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR_DIV, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/FLOOR_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/FLOOR_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..0068af13be1edb03037d2ec7cf17a1ccbfbe1858 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/FLOOR_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/FLOOR_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_H_SWISH_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_H_SWISH_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..57efd743cbce8375d1a7b483bd849be72ec346b7 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_H_SWISH_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/FUSE_ADD_H_SWISH_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_RELU_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_RELU_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..ac4261fe924ddeb9fee92e42b661f6854b429688 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_RELU_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/FUSE_ADD_RELU_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_SIGMOID_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_SIGMOID_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..9759107a2ba1d88a2ff3b2d6e17b01771ed0316e --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_SIGMOID_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/FUSE_ADD_SIGMOID_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_SIGMOID, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_TANH_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_TANH_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..4b6b9c1b4d587a8feba3b92302c922aa3a21a8b0 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_TANH_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/FUSE_ADD_TANH_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_TANH, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/FUSE_MUL_ADD3_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/FUSE_MUL_ADD3_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..66eeec8c152b36b524c543405a9a69767a83477f --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/FUSE_MUL_ADD3_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/FUSE_MUL_ADD3_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/H_SWISH_GRAD_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/H_SWISH_GRAD_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..d6717af8919d79474690da93c2305c78494d0dfa --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/H_SWISH_GRAD_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/H_SWISH_GRAD_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/H_SWISH_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/H_SWISH_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..639d23b497b9e3d82b11e58d9e2cfb019010520b --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/H_SWISH_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/H_SWISH_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/LEQ_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/LEQ_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..1e0560cb29fb5f1fa2bf26245b104c309543c178 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/LEQ_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/LEQ_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/LOG1P_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/LOG1P_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..4728eb070101fa219cd303e4a7beff906489a324 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/LOG1P_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/LOG1P_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOG1P, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/LOG_SUM_EXP_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/LOG_SUM_EXP_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..08cd429f66d76770ea46dd5ca6b29cddf4350769 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/LOG_SUM_EXP_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/LOG_SUM_EXP_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOG_SUM_EXP, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/LOG_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/LOG_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..235ae74d5d14d9836da5f3a7dd0fb69d6e33cf61 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/LOG_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/LOG_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOG, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/LT_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/LT_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..5093ee94a7135a4d983ce21022e93eea972e5a38 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/LT_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/LT_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/MAX_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/MAX_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..bb14bd818d5424d967388589fd8831a39c93f6e4 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/MAX_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/MAX_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MAX, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/MIN_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/MIN_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..a0205661e311796d785609fa420eba0ab7b50143 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/MIN_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/MIN_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MIN, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/MOD_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/MOD_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..397e050c63ec77fd3cd0127582ba7ca28d5dd27b --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/MOD_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/MOD_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MOD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/MUL_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/MUL_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..f8bd0b10b7e6e60310b34b9e6bcb8953b571c293 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/MUL_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/MUL_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MUL, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/NEGATE_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/NEGATE_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..8f2aa43d51069214d514ebec2d23a457fcb4198c --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/NEGATE_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/NEGATE_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/POW_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/POW_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..65ef8109efbbf249f49cbf62793b7770ea39cb4c --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/POW_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/POW_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(POW, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/RELU_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/RELU_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..432eb921a0ac401615f479818dd7641e9cd30ade --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/RELU_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/RELU_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/ROUND_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/ROUND_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..5af0313bfd62ac4fd7aa50e2863d62bfe97b0cab --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/ROUND_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/ROUND_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ROUND, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/SIGMOID_GRAD_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/SIGMOID_GRAD_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..692c3d64ca5ecf198fde8ecb9e9703cba831e6cd --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SIGMOID_GRAD_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/SIGMOID_GRAD_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/SIGMOID_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/SIGMOID_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..b94f4a5fba996555e2514d08243d2e15140e6299 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SIGMOID_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/SIGMOID_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/SIN_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/SIN_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..a3f23bcffb8c86c7c553df232732e86e94b89396 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SIN_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/SIN_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/SUB_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/SUB_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..a72052b6e914243ed69987b99ba094f3d7e153bc --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SUB_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/SUB_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SUB, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/SWITCH_GT0_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/SWITCH_GT0_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..c5daf3dfbf0341d9581161570db890fd26bb4d32 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SWITCH_GT0_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/SWITCH_GT0_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SWITCH_GT0, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/TANH_GRAD_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/TANH_GRAD_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..a048f507620b8bda26fa9b7c6c8a751e8942c7a4 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/TANH_GRAD_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/TANH_GRAD_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/TANH_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/TANH_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..fd7c0528016433234f13c120e38a9b4df5694380 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/TANH_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/TANH_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/TRUE_DIV_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/TRUE_DIV_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..9618c03744cd5cdaa75117fcf823716907ef43ab --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/TRUE_DIV_dt_bfloat16.cu @@ -0,0 +1,17 @@ +/** + * \file dnn/src/cuda/elemwise/kimpl/TRUE_DIV_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TRUE_DIV, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/special_kimpl/special_dt_bfloat16.cu b/dnn/src/cuda/elemwise/special_kimpl/special_dt_bfloat16.cu new file mode 100644 index 0000000000000000000000000000000000000000..139793a872d6409dde9eea80992867ccaa0ef834 --- /dev/null +++ b/dnn/src/cuda/elemwise/special_kimpl/special_dt_bfloat16.cu @@ -0,0 +1,18 @@ +/** + * \file dnn/src/cuda/elemwise/special_kimpl/special_dt_bfloat16.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_special_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#include "../special_kerns.inl" +INST(::megdnn::dtype::BFloat16) +#undef INST +} +} +#endif diff --git a/dnn/src/cuda/elemwise_helper.cpp b/dnn/src/cuda/elemwise_helper.cpp index 15791f6ac5c506476fe1c91c399eae0057bb26be..e64cc593fcc55a38214738328013a9babb9f6168 100644 --- a/dnn/src/cuda/elemwise_helper.cpp +++ b/dnn/src/cuda/elemwise_helper.cpp @@ -141,6 +141,9 @@ INST_FOR_CTYPE #define ct dt_float16 INST_FOR_CTYPE #undef ct +#define ct dt_bfloat16 +INST_FOR_CTYPE +#undef ct #define ct dt_int8 INST_FOR_CTYPE #undef ct diff --git a/dnn/src/cuda/elemwise_helper.cuh b/dnn/src/cuda/elemwise_helper.cuh index 14bf22a90c9a72563f26f35eb1f32ca5a84a3ee5..d9c5e5d169ba90542800ed52c6ec346adba360b2 100644 --- a/dnn/src/cuda/elemwise_helper.cuh +++ b/dnn/src/cuda/elemwise_helper.cuh @@ -68,6 +68,17 @@ namespace elemwise_intl { return t; } + struct __attribute__((aligned(8))) bhalf4 { + dt_bfloat16 x, y, z, w; + }; + + __device__ __forceinline__ bhalf4 make_bhalf4(dt_bfloat16 x, dt_bfloat16 y, + dt_bfloat16 z, dt_bfloat16 w) { + bhalf4 t; + t.x = x, t.y = y, t.z = z, t.w = w; + return t; + } + #define INST(_ctype, _vect_type) \ template <> \ class VectTypeTrait<_ctype> { \ @@ -87,6 +98,7 @@ namespace elemwise_intl { INST(dt_uint8, uchar4); INST(dt_float32, float4); INST(dt_float16, half4); + INST(dt_bfloat16, bhalf4); INST(dt_int32, int4); INST(dt_int16, short4); #undef as_raw diff --git a/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_incr.cu b/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_incr.cu index 02ba292749e4f755bdcd2e56a83400d734e8598f..74874f23d7559ea8514056011c9f1f49d4f584af 100644 --- a/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_incr.cu +++ b/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_incr.cu @@ -17,6 +17,11 @@ __device__ void atomicAdd(megdnn::dt_float16 *, megdnn::dt_float16) { __trap(); ((int*)0)[0] = 1; } + +__device__ void atomicAdd(megdnn::dt_bfloat16 *, megdnn::dt_bfloat16) { + __trap(); + ((int*)0)[0] = 1; +} #endif __device__ void atomicAdd(megdnn::dt_int8 *, megdnn::dt_int8) { diff --git a/dnn/src/cuda/matrix_mul/algos.cpp b/dnn/src/cuda/matrix_mul/algos.cpp index 9ef84c06c28be4798dfc01200a0b3b6e1e19450d..38598335e01845d3e9901e16d89ec68ed478defd 100644 --- a/dnn/src/cuda/matrix_mul/algos.cpp +++ b/dnn/src/cuda/matrix_mul/algos.cpp @@ -29,6 +29,10 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() { all_algos.push_back(&cublas_lt); #endif all_algos.push_back(&naive); +#if !MEGDNN_DISABLE_FLOAT16 + cublas_bfloat16 = std::make_unique(&cublas); + all_algos.push_back(cublas_bfloat16.get()); +#endif } MatrixMulForwardImpl::AlgoPack MatrixMulForwardImpl::sm_algo_pack; diff --git a/dnn/src/cuda/matrix_mul/algos.h b/dnn/src/cuda/matrix_mul/algos.h index 8ea190e99590399a88a5c3baa275280570e82f06..56e968c065025af51accb20f198c779e733016c0 100644 --- a/dnn/src/cuda/matrix_mul/algos.h +++ b/dnn/src/cuda/matrix_mul/algos.h @@ -15,6 +15,7 @@ #include "src/cuda/matrix_mul/opr_impl.h" #include +#include #if CUDA_VERSION >= 10010 #include #endif @@ -140,6 +141,24 @@ public: bool is_reproducible() const override { return true; } }; +#if !MEGDNN_DISABLE_FLOAT16 +class MatrixMulForwardImpl::AlgoBFloat16 final : public AlgoBase { +public: + AlgoBFloat16(MatrixMulForwardImpl::AlgoBase*); + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + const char* name() const override { return m_name.c_str(); } + void exec(const ExecArgs& args) const override; + bool is_reproducible() const override { return true; } + +private: + MatrixMulForwardImpl::AlgoBase* m_algorithm = nullptr; + std::string m_name; + WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; + SizeArgs float_args(const SizeArgs& args) const; +}; +#endif + class MatrixMulForwardImpl::AlgoPack { AlgoPack(const AlgoPack&) = delete; AlgoPack& operator=(const AlgoPack&) = delete; @@ -154,7 +173,9 @@ public: #if CUDA_VERSION >= 10010 AlgoCuBlasLt cublas_lt; #endif - +#if !MEGDNN_DISABLE_FLOAT16 + std::unique_ptr cublas_bfloat16; +#endif std::vector all_algos; }; diff --git a/dnn/src/cuda/matrix_mul/bfloat16.cpp b/dnn/src/cuda/matrix_mul/bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7d97c21ea0400a0573dc768d9c2c3b26d34e0ae2 --- /dev/null +++ b/dnn/src/cuda/matrix_mul/bfloat16.cpp @@ -0,0 +1,91 @@ +/** + * \file dnn/src/cuda/matrix_mul/bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/cuda/handle.h" +#include "src/cuda/matrix_mul/algos.h" +#include "src/cuda/utils.h" + +using namespace megdnn; +using namespace cuda; + +MatrixMulForwardImpl::AlgoBFloat16::AlgoBFloat16( + MatrixMulForwardImpl::AlgoBase* algorithm) + : m_algorithm(algorithm) { + megdnn_assert_internal(algorithm); + m_name = ssprintf("MATMUL_BFLOAT16:%s", m_algorithm->name()); +} + +MatrixMulForwardImpl::AlgoBase::SizeArgs +MatrixMulForwardImpl::AlgoBFloat16::float_args(const SizeArgs& args) const { + auto new_args = args; + auto change_dtype = [](TensorLayout& layout) { + if (layout.dtype == dtype::BFloat16()) { + layout.dtype = dtype::Float32(); + } + }; + change_dtype(new_args.layout_a); + change_dtype(new_args.layout_b); + change_dtype(new_args.layout_c); + return new_args; +} + +bool MatrixMulForwardImpl::AlgoBFloat16::is_available( + const SizeArgs& args) const { + auto fargs = float_args(args); + return args.layout_a.dtype == dtype::BFloat16() && + m_algorithm->is_available(fargs); +} + +WorkspaceBundle MatrixMulForwardImpl::AlgoBFloat16::get_workspace_bundle( + void* ptr, const SizeArgs& args) const { + auto fargs = float_args(args); + SmallVector sizes; + auto get_workspace = [&sizes](const TensorLayout& src) { + TensorLayout dst = src; + if (dst.dtype == dtype::BFloat16()) { + dst.dtype = dtype::Float32(); + sizes.push_back(dst.span().dist_byte()); + } + }; + get_workspace(args.layout_a); + get_workspace(args.layout_b); + get_workspace(args.layout_c); + sizes.push_back(m_algorithm->get_workspace_in_bytes(fargs)); + return {ptr, std::move(sizes)}; +} + +size_t MatrixMulForwardImpl::AlgoBFloat16::get_workspace_in_bytes( + const SizeArgs& args) const { + return get_workspace_bundle(nullptr, args).total_size_in_bytes(); +} + +void MatrixMulForwardImpl::AlgoBFloat16::exec(const ExecArgs& args) const { + TensorND a = args.tensor_a; + TensorND b = args.tensor_b; + TensorND c = args.tensor_c; + auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); + auto ctypecvt = CompTypeCvter( + args.opr->handle(), &bundle); + ctypecvt.src_to_comp_type(args.tensor_a, a) + .src_to_comp_type(args.tensor_b, b) + .src_to_comp_type(args.tensor_c, c); + { + auto matmul_opr = + args.opr->handle()->create_operator(); + matmul_opr->param() = args.opr->param(); + matmul_opr->param().compute_mode = Param::ComputeMode::DEFAULT; + matmul_opr->execution_policy() = {m_algorithm}; + matmul_opr->exec(a, b, c, ctypecvt.workspace()); + } + ctypecvt.comp_to_dst_type(c, args.tensor_c); +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/matrix_mul/cublas_lt.cpp b/dnn/src/cuda/matrix_mul/cublas_lt.cpp index 42f6bca22f741168c245940e48aaf3efdf724ee7..4764e91ea9309ee85fba46ec0dcd264ba55401f9 100644 --- a/dnn/src/cuda/matrix_mul/cublas_lt.cpp +++ b/dnn/src/cuda/matrix_mul/cublas_lt.cpp @@ -18,10 +18,11 @@ using namespace megdnn; using namespace cuda; bool MatrixMulForwardImpl::AlgoCuBlasLt::is_available( - const SizeArgs &args) const { + const SizeArgs& args) const { if (args.opr->param().format != param::MatrixMul::Format::DEFAULT) return false; - if (args.layout_a.dtype.enumv() == DTypeEnum::Quantized4Asymm) + if (args.layout_a.dtype.enumv() == DTypeEnum::Quantized4Asymm || + args.layout_a.dtype.enumv() == DTypeEnum::BFloat16) return false; CUBLASLTMatmulDesc::SizeArgs ltArgs(args); return CUBLASLTMatmulDesc(ltArgs).is_available(ltArgs, INT_MAX); diff --git a/dnn/src/cuda/matrix_mul/opr_impl.h b/dnn/src/cuda/matrix_mul/opr_impl.h index b7ea9361761af0e5a44eb1f3e1ca14e71a026bdc..cc75bd9db93153df9dea4d9cd908be8b1573e58f 100644 --- a/dnn/src/cuda/matrix_mul/opr_impl.h +++ b/dnn/src/cuda/matrix_mul/opr_impl.h @@ -47,6 +47,9 @@ public: class AlgoCuBlasLt; #endif class AlgoNaive; +#if !MEGDNN_DISABLE_FLOAT16 + class AlgoBFloat16; +#endif class AlgoPack; static const AlgoPack& algo_pack() { diff --git a/dnn/src/cuda/pooling/opr_impl.cpp b/dnn/src/cuda/pooling/opr_impl.cpp index 3d9b351ea0b1e6b0a454288fc2530a7305cd159c..34ced0a81e87d5445654f8a71a626099d8dd0519 100644 --- a/dnn/src/cuda/pooling/opr_impl.cpp +++ b/dnn/src/cuda/pooling/opr_impl.cpp @@ -10,59 +10,83 @@ */ #include "src/cuda/pooling/opr_impl.h" -#include "src/cuda/utils.h" #include "./pooling2d_int8_cdiv4hwn4.cuh" +#include "src/cuda/utils.h" namespace megdnn { namespace cuda { -void PoolingForwardImpl::setup_descs(const TensorLayout &src, - const TensorLayout &dst) -{ +void PoolingForwardImpl::setup_descs(const TensorLayout& src, + const TensorLayout& dst) { src_desc.set(src, param().format); dst_desc.set(dst, param().format); pooling_desc.set(this->param()); } -void PoolingForwardImpl::exec(_megdnn_tensor_in src, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) -{ - check_exec(src.layout, dst.layout, workspace.size); - using Format = param::Pooling::Format; - if (param().format == Format::CHWN4) { - pooling2d::Param kern_param; - size_t c = src.layout[0], hi = src.layout[1], wi = src.layout[2], - n = src.layout[3], ho = dst.layout[1], wo = dst.layout[2]; - c = c * 4; - size_t ph = param().pad_h, pw = param().pad_w; - size_t window_h = param().window_h, window_w = param().window_w; - size_t sh = param().stride_h, sw = param().stride_w; - kern_param.n = n, kern_param.c = c, kern_param.hi = hi, - kern_param.wi = wi, kern_param.ho = ho, kern_param.wo = wo, - kern_param.ph = ph, kern_param.pw = pw, kern_param.window_h = window_h, - kern_param.window_w = window_w, kern_param.sh = sh, kern_param.sw = sw; - auto&& stream = cuda_stream(handle()); - return pooling2d::_do_pooling2d_int8_cdiv4hwn4( - src.compatible_ptr(), dst.compatible_ptr(), - kern_param, stream, static_cast(param().mode)); +WorkspaceBundle PoolingForwardImpl::get_workspace_bundle( + void* ptr, const TensorLayout& src, const TensorLayout& dst) const { + SmallVector sizes; + TensorLayout fsrc = src; + TensorLayout fdst = dst; + auto get_workspace = [&sizes](TensorLayout& layout) { + if (layout.dtype == dtype::BFloat16()) { + layout.dtype = dtype::Float32(); + sizes.push_back(layout.span().dist_byte()); + } + }; + get_workspace(fsrc); + get_workspace(fdst); + return {ptr, std::move(sizes)}; +} + +void PoolingForwardImpl::exec(_megdnn_tensor_in ssrc, _megdnn_tensor_out sdst, + _megdnn_workspace sworkspace) { + check_exec(ssrc.layout, sdst.layout, sworkspace.size); + TensorND src = ssrc; + TensorND dst = sdst; + auto wsb = + get_workspace_bundle(sworkspace.raw_ptr, ssrc.layout, sdst.layout); + auto ctypecvt = CompTypeCvter( + concrete_handle(this->handle()), &wsb); + if (ssrc.layout.dtype.enumv() == DTypeTrait::enumv) { + ctypecvt.src_to_comp_type(ssrc, src).src_to_comp_type(sdst, dst); + } + { + using Format = param::Pooling::Format; + if (param().format == Format::CHWN4) { + pooling2d::Param kern_param; + size_t c = src.layout[0], hi = src.layout[1], wi = src.layout[2], + n = src.layout[3], ho = dst.layout[1], wo = dst.layout[2]; + c = c * 4; + size_t ph = param().pad_h, pw = param().pad_w; + size_t window_h = param().window_h, window_w = param().window_w; + size_t sh = param().stride_h, sw = param().stride_w; + kern_param.n = n, kern_param.c = c, kern_param.hi = hi, + kern_param.wi = wi, kern_param.ho = ho, kern_param.wo = wo, + kern_param.ph = ph, kern_param.pw = pw, + kern_param.window_h = window_h, kern_param.window_w = window_w, + kern_param.sh = sh, kern_param.sw = sw; + auto&& stream = cuda_stream(handle()); + return pooling2d::_do_pooling2d_int8_cdiv4hwn4( + src.compatible_ptr(), dst.compatible_ptr(), + kern_param, stream, static_cast(param().mode)); + } + auto handle = cudnn_handle(this->handle()); + setup_descs(src.layout, dst.layout); + dt_float32 alpha = 1.0f, beta = 0.0f; + cudnn_check(cudnnPoolingForward(handle, pooling_desc.desc, &alpha, + src_desc.desc, src.raw_ptr, &beta, + dst_desc.desc, dst.raw_ptr)); + } + if (ssrc.layout.dtype.enumv() == DTypeTrait::enumv) { + ctypecvt.comp_to_dst_type(dst, sdst); } - auto handle = cudnn_handle(this->handle()); - setup_descs(src.layout, dst.layout); - dt_float32 alpha = 1.0f, beta = 0.0f; - cudnn_check(cudnnPoolingForward(handle, - pooling_desc.desc, - &alpha, - src_desc.desc, src.raw_ptr, - &beta, - dst_desc.desc, dst.raw_ptr)); } -void PoolingBackwardImpl::setup_descs(const TensorLayout &src, - const TensorLayout &dst, - const TensorLayout &diff, - const TensorLayout &grad) -{ +void PoolingBackwardImpl::setup_descs(const TensorLayout& src, + const TensorLayout& dst, + const TensorLayout& diff, + const TensorLayout& grad) { src_desc.set(src); dst_desc.set(dst); diff_desc.set(diff); @@ -70,27 +94,62 @@ void PoolingBackwardImpl::setup_descs(const TensorLayout &src, pooling_desc.set(this->param()); } -void PoolingBackwardImpl::exec(_megdnn_tensor_in src, - _megdnn_tensor_in dst, - _megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) -{ - check_exec(src.layout, dst.layout, diff.layout, grad.layout, workspace.size); +WorkspaceBundle PoolingBackwardImpl::get_workspace_bundle( + void* ptr, const TensorLayout& src, const TensorLayout& dst, + const TensorLayout& diff, const TensorLayout& grad) const { + SmallVector sizes; + TensorLayout fsrc = src; + TensorLayout fdst = dst; + TensorLayout fdiff = diff; + TensorLayout fgrad = grad; + auto get_workspace = [&sizes](TensorLayout& layout) { + if (layout.dtype == dtype::BFloat16()) { + layout.dtype = dtype::Float32(); + sizes.push_back(layout.span().dist_byte()); + } + }; + get_workspace(fsrc); + get_workspace(fdst); + get_workspace(fdiff); + get_workspace(fgrad); + return {ptr, std::move(sizes)}; +} + +void PoolingBackwardImpl::exec(_megdnn_tensor_in ssrc, _megdnn_tensor_in sdst, + _megdnn_tensor_in sdiff, + _megdnn_tensor_out sgrad, + _megdnn_workspace sworkspace) { + check_exec(ssrc.layout, sdst.layout, sdiff.layout, sgrad.layout, + sworkspace.size); auto handle = cudnn_handle(this->handle()); - setup_descs(src.layout, dst.layout, diff.layout, grad.layout); - float alpha = 1.0f, beta = 0.0f; - cudnn_check(cudnnPoolingBackward(handle, - pooling_desc.desc, - &alpha, - dst_desc.desc, dst.raw_ptr, - diff_desc.desc, diff.raw_ptr, - src_desc.desc, src.raw_ptr, - &beta, + TensorND src = ssrc; + TensorND dst = sdst; + TensorND diff = sdiff; + TensorND grad = sgrad; + auto wsb = get_workspace_bundle(sworkspace.raw_ptr, ssrc.layout, + sdst.layout, sdiff.layout, sgrad.layout); + auto ctypecvt = CompTypeCvter( + concrete_handle(this->handle()), &wsb); + if (ssrc.layout.dtype.enumv() == DTypeTrait::enumv) { + ctypecvt.src_to_comp_type(ssrc, src) + .src_to_comp_type(sdst, dst) + .src_to_comp_type(sdiff, diff) + .src_to_comp_type(sgrad, grad); + } + { + setup_descs(src.layout, dst.layout, diff.layout, grad.layout); + float alpha = 1.0f, beta = 0.0f; + cudnn_check(cudnnPoolingBackward( + handle, pooling_desc.desc, &alpha, dst_desc.desc, dst.raw_ptr, + diff_desc.desc, diff.raw_ptr, src_desc.desc, src.raw_ptr, &beta, grad_desc.desc, grad.raw_ptr)); + } + if (ssrc.layout.dtype.enumv() == DTypeTrait::enumv) { + ctypecvt.comp_to_dst_type(grad, sgrad); + } } -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/pooling/opr_impl.h b/dnn/src/cuda/pooling/opr_impl.h index 86599fd7c48fd204ba15af6e85f9a1c4b1c27566..784caf5f95d26e0df93db527ed6203ecd0975502 100644 --- a/dnn/src/cuda/pooling/opr_impl.h +++ b/dnn/src/cuda/pooling/opr_impl.h @@ -12,47 +12,52 @@ #include "megdnn/oprs.h" #include "src/cuda/cudnn_wrapper.h" +#include "src/cuda/utils.h" namespace megdnn { namespace cuda { -class PoolingForwardImpl final: public PoolingForward { - public: - using PoolingForward::PoolingForward; - void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout &, - const TensorLayout &) override { - return 0; - } - private: - TensorDesc src_desc, dst_desc; - PoolingDesc pooling_desc; - void setup_descs(const TensorLayout &src, const TensorLayout &dst); +class PoolingForwardImpl final : public PoolingForward { +public: + using PoolingForward::PoolingForward; + void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout& src, + const TensorLayout& dst) override { + return get_workspace_bundle(nullptr, src, dst).total_size_in_bytes(); + } + +private: + TensorDesc src_desc, dst_desc; + PoolingDesc pooling_desc; + void setup_descs(const TensorLayout& src, const TensorLayout& dst); + WorkspaceBundle get_workspace_bundle(void* ptr, const TensorLayout& src, + const TensorLayout& dst) const; }; -class PoolingBackwardImpl final: public PoolingBackward { - public: - using PoolingBackward::PoolingBackward; - void exec(_megdnn_tensor_in src, - _megdnn_tensor_in dst, - _megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout &, - const TensorLayout &, - const TensorLayout &, - const TensorLayout &) override { - return 0; - } - private: - TensorDesc src_desc, dst_desc, diff_desc, grad_desc; - PoolingDesc pooling_desc; - void setup_descs(const TensorLayout &src, - const TensorLayout &dst, - const TensorLayout &diff, - const TensorLayout &grad); +class PoolingBackwardImpl final : public PoolingBackward { +public: + using PoolingBackward::PoolingBackward; + void exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, + _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout& src, + const TensorLayout& dst, + const TensorLayout& diff, + const TensorLayout& grad) override { + return get_workspace_bundle(nullptr, src, dst, diff, grad) + .total_size_in_bytes(); + } +private: + TensorDesc src_desc, dst_desc, diff_desc, grad_desc; + PoolingDesc pooling_desc; + void setup_descs(const TensorLayout& src, const TensorLayout& dst, + const TensorLayout& diff, const TensorLayout& grad); + WorkspaceBundle get_workspace_bundle(void* ptr, const TensorLayout& src, + const TensorLayout& dst, + const TensorLayout& diff, + const TensorLayout& grad) const; }; } // namespace cuda diff --git a/dnn/src/cuda/roi_align/roi_align.cu b/dnn/src/cuda/roi_align/roi_align.cu index 3f368334be7516095a176b3bd4bb542d16837bbe..de3ac56c10ec85e73f055ded0919ccea0cf00a81 100644 --- a/dnn/src/cuda/roi_align/roi_align.cu +++ b/dnn/src/cuda/roi_align/roi_align.cu @@ -175,6 +175,7 @@ void backward_proxy(const int nthreads, const T* top_diff, const int, const T*, T*, cudaStream_t); INST(dt_float32) INST(dt_float16) +INST(dt_bfloat16) #undef INST } // namespace roi_align diff --git a/dnn/src/cuda/roi_pooling/roi_pooling.cu b/dnn/src/cuda/roi_pooling/roi_pooling.cu index cda81366fcfc466bdbd4a57e7e56dc60e70202bd..6af2bc217e080a3fe4dd8bc0e869659eca9635a7 100644 --- a/dnn/src/cuda/roi_pooling/roi_pooling.cu +++ b/dnn/src/cuda/roi_pooling/roi_pooling.cu @@ -211,6 +211,7 @@ void backward_proxy(const int nthreads, const T* top_diff, T*, const T*, cudaStream_t); INST(dt_float32) INST(dt_float16) +INST(dt_bfloat16) #undef INST } // namespace roi_pooling diff --git a/dnn/src/cuda/type_cvt/kern.cu b/dnn/src/cuda/type_cvt/kern.cu index 0cab603992cac420c5282476012e947709031a16..dd52b8c7752709c0760377c6f01c696037e3a4f0 100644 --- a/dnn/src/cuda/type_cvt/kern.cu +++ b/dnn/src/cuda/type_cvt/kern.cu @@ -242,6 +242,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, cb(dtype_src, dt_uint8) \ cb(dtype_src, dt_float32) \ cb(dtype_src, dt_float16) \ + cb(dtype_src, dt_bfloat16) \ #define MEGDNN_FOREACH_QUANTIZED_DTYPE_WITH_DTYPE_SRC(dtype_src, cb) \ cb(dtype_src, dt_quint8) \ @@ -263,6 +264,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, cb(dt_uint8) \ cb(dt_float32) \ cb(dt_float16) \ + cb(dt_bfloat16) \ #define MEGDNN_FOREACH_QUANTIZED_CTYPE(cb) \ cb(dt_quint8) \ diff --git a/dnn/src/cuda/utils.cuh b/dnn/src/cuda/utils.cuh index c0348c6cece8aad24d05c321d94af372bcdde86b..cccf6495167973f66eb6c611cc03287434f792bf 100644 --- a/dnn/src/cuda/utils.cuh +++ b/dnn/src/cuda/utils.cuh @@ -234,6 +234,29 @@ MEGDNN_DEVICE void atomic_add(dt_float16* address, dt_float16 val) { #endif } +template <> +MEGDNN_DEVICE void atomic_add(dt_bfloat16* address, dt_bfloat16 val) { + unsigned int* address_as_ui = reinterpret_cast( + reinterpret_cast(address) - + (reinterpret_cast(address) & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + + do { + assumed = old; + unsigned short data = reinterpret_cast(address) & 2 + ? (old >> 16) + : (old & 0xffff); + dt_bfloat16 hsum = *reinterpret_cast(&data); + hsum += val; + data = *reinterpret_cast(&hsum); + old = reinterpret_cast(address) & 2 + ? (old & 0xffff) | (data << 16) + : (old & 0xffff0000) | data; + old = ::atomicCAS(address_as_ui, assumed, old); + } while (assumed != old); +} + static inline MEGDNN_DEVICE void dot_prod(int a, int b, int c, int& d) { #if __CUDA_ARCH__ >= 610 // clang-format off diff --git a/dnn/src/cuda/warp_perspective/backward_data.cpp b/dnn/src/cuda/warp_perspective/backward_data.cpp index 9b400514dc1b2f413c0395d5cc6f9b0ebbef5684..272364a1909250e8c65a3802d7c8fe54a5c67ed3 100644 --- a/dnn/src/cuda/warp_perspective/backward_data.cpp +++ b/dnn/src/cuda/warp_perspective/backward_data.cpp @@ -17,71 +17,101 @@ namespace megdnn { namespace cuda { -void WarpPerspectiveBackwardDataImpl::exec(_megdnn_tensor_in mat, - _megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) -{ - check_exec(mat.layout, diff.layout, grad.layout, workspace.size); - auto stream = cuda_stream(this->handle()); - auto N = grad.layout.shape[0], - C = grad.layout.shape[1], - IH = grad.layout.shape[2], - IW = grad.layout.shape[3], - OH = diff.layout.shape[2], - OW = diff.layout.shape[3]; - auto bval = param().border_val; - auto bmode = warp_perspective::get_bmode(param().bmode); +WorkspaceBundle WarpPerspectiveBackwardDataImpl::get_workspace_bundle( + void* ptr, const TensorLayout& mat, const TensorLayout& diff, + const TensorLayout& grad) const { + SmallVector sizes; + TensorLayout fmat = mat; + TensorLayout fdiff = diff; + TensorLayout fgrad = grad; + auto get_workspace = [&sizes](TensorLayout& layout) { + if (layout.dtype == dtype::BFloat16()) { + layout.dtype = dtype::Float32(); + sizes.push_back(layout.span().dist_byte()); + } + }; + get_workspace(fmat); + get_workspace(fdiff); + get_workspace(fgrad); + sizes.push_back(get_float32_workspace_in_bytes(fmat, fdiff, fgrad)); + return {ptr, std::move(sizes)}; +} - size_t batch_x_channel_size = N * C; - size_t max_batch_x_channel = max_batch_x_channel_size(); - if(batch_x_channel_size <= max_batch_x_channel) { - warp_perspective::backward_data_proxy( - mat.ptr(), - diff.ptr(), - grad.ptr(), - reinterpret_cast(workspace.raw_ptr), - N, C, IH, IW, OH, OW, bval, - bmode, stream); - } else { - dt_float32* mat_ptr = mat.ptr(); - dt_float32* diff_ptr = diff.ptr(); - dt_float32* grad_ptr = grad.ptr(); - size_t max_batch_size = max_batch_x_channel / C; - while (N > 0){ - size_t curr_batch_size = N > max_batch_size ? max_batch_size : N; +void WarpPerspectiveBackwardDataImpl::exec(_megdnn_tensor_in smat, + _megdnn_tensor_in sdiff, + _megdnn_tensor_out sgrad, + _megdnn_workspace sworkspace) { + check_exec(smat.layout, sdiff.layout, sgrad.layout, sworkspace.size); + TensorND mat = smat; + TensorND diff = sdiff; + TensorND grad = sgrad; + auto bundle = get_workspace_bundle(sworkspace.raw_ptr, smat.layout, + sdiff.layout, sgrad.layout); + auto ctypecvt = CompTypeCvter( + concrete_handle(this->handle()), &bundle); + if (sgrad.layout.dtype.enumv() == DTypeTrait::enumv) { + ctypecvt.src_to_comp_type(smat, mat) + .src_to_comp_type(sdiff, diff) + .src_to_comp_type(sgrad, grad); + } + { + auto workspace = ctypecvt.workspace(); + auto stream = cuda_stream(this->handle()); + auto N = grad.layout.shape[0], C = grad.layout.shape[1], + IH = grad.layout.shape[2], IW = grad.layout.shape[3], + OH = diff.layout.shape[2], OW = diff.layout.shape[3]; + auto bval = param().border_val; + auto bmode = warp_perspective::get_bmode(param().bmode); + + size_t batch_x_channel_size = N * C; + size_t max_batch_x_channel = max_batch_x_channel_size(); + if (batch_x_channel_size <= max_batch_x_channel) { warp_perspective::backward_data_proxy( - mat_ptr, diff_ptr, grad_ptr, - reinterpret_cast(workspace.raw_ptr), - curr_batch_size, C, IH, IW, OH, OW, bval, - bmode, stream); + mat.ptr(), diff.ptr(), + grad.ptr(), + reinterpret_cast(workspace.raw_ptr), N, C, IH, IW, + OH, OW, bval, bmode, stream); + } else { + dt_float32* mat_ptr = mat.ptr(); + dt_float32* diff_ptr = diff.ptr(); + dt_float32* grad_ptr = grad.ptr(); + size_t max_batch_size = max_batch_x_channel / C; + while (N > 0) { + size_t curr_batch_size = + N > max_batch_size ? max_batch_size : N; + warp_perspective::backward_data_proxy( + mat_ptr, diff_ptr, grad_ptr, + reinterpret_cast(workspace.raw_ptr), + curr_batch_size, C, IH, IW, OH, OW, bval, bmode, + stream); - if( N <= max_batch_size) { - break; - } - else { - N -= max_batch_size; - mat_ptr += curr_batch_size * mat.layout.stride[0]; - diff_ptr += curr_batch_size * diff.layout.stride[0]; - grad_ptr += curr_batch_size * grad.layout.stride[0]; + if (N <= max_batch_size) { + break; + } else { + N -= max_batch_size; + mat_ptr += curr_batch_size * mat.layout.stride[0]; + diff_ptr += curr_batch_size * diff.layout.stride[0]; + grad_ptr += curr_batch_size * grad.layout.stride[0]; + } } } } + if (sgrad.layout.dtype.enumv() == DTypeTrait::enumv) { + ctypecvt.comp_to_dst_type(grad, sgrad); + } } -size_t WarpPerspectiveBackwardDataImpl::get_workspace_in_bytes( - const TensorLayout & /* mat */, - const TensorLayout &diff, - const TensorLayout &grad) -{ - auto N = grad.shape[0], C = grad.shape[1], - IH = grad.shape[2], IW = grad.shape[3]; +size_t WarpPerspectiveBackwardDataImpl::get_float32_workspace_in_bytes( + const TensorLayout& /* mat */, const TensorLayout& diff, + const TensorLayout& grad) const { + auto N = grad.shape[0], C = grad.shape[1], IH = grad.shape[2], + IW = grad.shape[3]; auto OH = diff.shape[2], OW = diff.shape[3]; auto bmode = warp_perspective::get_bmode(param().bmode); size_t max_batch_size = N; size_t max_batch_x_channel = max_batch_x_channel_size(); - if(N * C > max_batch_x_channel) { + if (N * C > max_batch_x_channel) { max_batch_size = max_batch_x_channel / C; } @@ -90,7 +120,7 @@ size_t WarpPerspectiveBackwardDataImpl::get_workspace_in_bytes( return res; } -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/warp_perspective/backward_mat.cpp b/dnn/src/cuda/warp_perspective/backward_mat.cpp index 2db43d2133c40b0e1afca9971cc7eb6c9d2dc007..48964c2cbe29a859ef9f3b406e441ee598fe64d5 100644 --- a/dnn/src/cuda/warp_perspective/backward_mat.cpp +++ b/dnn/src/cuda/warp_perspective/backward_mat.cpp @@ -17,62 +17,96 @@ namespace megdnn { namespace cuda { -void WarpPerspectiveBackwardMatImpl::exec(_megdnn_tensor_in src, - _megdnn_tensor_in mat, - _megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) -{ - check_exec(src.layout, mat.layout, diff.layout, grad.layout, - workspace.size); - auto stream = cuda_stream(this->handle()); - auto N = src.layout.shape[0], - C = src.layout.shape[1], - IH = src.layout.shape[2], - IW = src.layout.shape[3], - OH = diff.layout.shape[2], - OW = diff.layout.shape[3]; - auto bval = param().border_val; - auto bmode = warp_perspective::get_bmode(param().bmode); +WorkspaceBundle WarpPerspectiveBackwardMatImpl::get_workspace_bundle( + void* ptr, const TensorLayout& src, const TensorLayout& mat, + const TensorLayout& diff, const TensorLayout& grad) const { + SmallVector sizes; + TensorLayout fsrc = src; + TensorLayout fmat = mat; + TensorLayout fdiff = diff; + TensorLayout fgrad = grad; + auto get_workspace = [&sizes](TensorLayout& layout) { + if (layout.dtype == dtype::BFloat16()) { + layout.dtype = dtype::Float32(); + sizes.push_back(layout.span().dist_byte()); + } + }; + get_workspace(fsrc); + get_workspace(fmat); + get_workspace(fdiff); + get_workspace(fgrad); + return {ptr, std::move(sizes)}; +} - size_t batch_x_channel_size = N * C; - size_t max_batch_x_channel = max_batch_x_channel_size(); - if(batch_x_channel_size <= max_batch_x_channel) { - warp_perspective::backward_mat_proxy(src.ptr(), - mat.ptr(), - diff.ptr(), - grad.ptr(), - N, C, IH, IW, OH, OW, bval, - bmode, stream); - } else { - dt_float32* src_ptr = src.ptr(); - dt_float32* mat_ptr = mat.ptr(); - dt_float32* diff_ptr = diff.ptr(); - dt_float32* grad_ptr = grad.ptr(); - size_t max_batch_size = max_batch_x_channel / C; - while (N > 0){ - size_t curr_batch_size = N > max_batch_size ? max_batch_size : N; - warp_perspective::backward_mat_proxy(src_ptr, - mat_ptr, diff_ptr, grad_ptr, - curr_batch_size, C, IH, IW, OH, OW, bval, - bmode, stream); +void WarpPerspectiveBackwardMatImpl::exec(_megdnn_tensor_in ssrc, + _megdnn_tensor_in smat, + _megdnn_tensor_in sdiff, + _megdnn_tensor_out sgrad, + _megdnn_workspace sworkspace) { + check_exec(ssrc.layout, smat.layout, sdiff.layout, sgrad.layout, + sworkspace.size); + TensorND src = ssrc; + TensorND mat = smat; + TensorND diff = sdiff; + TensorND grad = sgrad; + auto bundle = get_workspace_bundle(sworkspace.raw_ptr, ssrc.layout, + smat.layout, sdiff.layout, sgrad.layout); + auto ctypecvt = CompTypeCvter( + concrete_handle(this->handle()), &bundle); + if (ssrc.layout.dtype.enumv() == DTypeTrait::enumv) { + ctypecvt.src_to_comp_type(ssrc, src) + .src_to_comp_type(smat, mat) + .src_to_comp_type(sdiff, diff) + .src_to_comp_type(sgrad, grad); + } + { + auto stream = cuda_stream(this->handle()); + auto N = src.layout.shape[0], C = src.layout.shape[1], + IH = src.layout.shape[2], IW = src.layout.shape[3], + OH = diff.layout.shape[2], OW = diff.layout.shape[3]; + auto bval = param().border_val; + auto bmode = warp_perspective::get_bmode(param().bmode); - if( N <= max_batch_size) { - break; - } - else { - N -= max_batch_size; - src_ptr += curr_batch_size * src.layout.stride[0]; - mat_ptr += curr_batch_size * mat.layout.stride[0]; - diff_ptr += curr_batch_size * diff.layout.stride[0]; - grad_ptr += curr_batch_size * grad.layout.stride[0]; + size_t batch_x_channel_size = N * C; + size_t max_batch_x_channel = max_batch_x_channel_size(); + if (batch_x_channel_size <= max_batch_x_channel) { + warp_perspective::backward_mat_proxy( + src.ptr(), mat.ptr(), + diff.ptr(), grad.ptr(), N, C, IH, + IW, OH, OW, bval, bmode, stream); + } else { + dt_float32* src_ptr = src.ptr(); + dt_float32* mat_ptr = mat.ptr(); + dt_float32* diff_ptr = diff.ptr(); + dt_float32* grad_ptr = grad.ptr(); + size_t max_batch_size = max_batch_x_channel / C; + while (N > 0) { + size_t curr_batch_size = + N > max_batch_size ? max_batch_size : N; + warp_perspective::backward_mat_proxy( + src_ptr, mat_ptr, diff_ptr, grad_ptr, curr_batch_size, + C, IH, IW, OH, OW, bval, bmode, stream); + + if (N <= max_batch_size) { + break; + } else { + N -= max_batch_size; + src_ptr += curr_batch_size * src.layout.stride[0]; + mat_ptr += curr_batch_size * mat.layout.stride[0]; + diff_ptr += curr_batch_size * diff.layout.stride[0]; + grad_ptr += curr_batch_size * grad.layout.stride[0]; + } } } } + + if (ssrc.layout.dtype.enumv() == DTypeTrait::enumv) { + ctypecvt.comp_to_dst_type(grad, sgrad); + } } -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/warp_perspective/forward.cpp b/dnn/src/cuda/warp_perspective/forward.cpp index 5c465ec30b044f7c92f4c4ea1bbe1c1f593635dd..52823132ee4023c9719df96be2a31627cddf660e 100644 --- a/dnn/src/cuda/warp_perspective/forward.cpp +++ b/dnn/src/cuda/warp_perspective/forward.cpp @@ -85,104 +85,158 @@ void warp_perspective_cv_exec(_megdnn_tensor_in src, _megdnn_tensor_in mat, } // namespace warp_perspective -void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in src, - _megdnn_tensor_in mat, - _megdnn_tensor_in mat_idx, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) { - check_exec_allow_nhwc_mat_idx(src.layout, mat.layout, mat_idx.layout, - dst.layout, workspace.size); - auto stream = cuda_stream(this->handle()); - bool is_nhwc = param().format == param::WarpPerspective::Format::NHWC; - - if (is_nhwc && param().imode != Param::InterpolationMode::LINEAR) { - // use opencv impl only for nhwc and non-linear interp - megdnn_assert(!mat_idx.raw_ptr, - "mat_idx is not supported in NHWC case with " - "non-linear interpolation"); - warp_perspective::warp_perspective_cv_exec( - src, mat, dst, param().border_val, - warp_perspective::get_bmode(param().bmode), - warp_perspective::get_imode(param().imode), workspace, stream); - - return; +WorkspaceBundle WarpPerspectiveForwardImpl::get_workspace_bundle( + void* ptr, const TensorLayout& src, const TensorLayout& mat, + const TensorLayout& mat_idx, const TensorLayout& dst) const { + MEGDNN_MARK_USED_VAR(mat_idx); + SmallVector sizes; + TensorLayout fsrc = src; + TensorLayout fmat = mat; + TensorLayout fdst = dst; + auto get_workspace = [&sizes](TensorLayout& layout) { + if (layout.dtype == dtype::BFloat16()) { + layout.dtype = dtype::Float32(); + sizes.push_back(layout.span().dist_byte()); + } + }; + get_workspace(fsrc); + get_workspace(fmat); + get_workspace(fdst); + if (param().format == param::WarpPerspective::Format::NHWC) { + //! use double for the workspace dtype as float may cause + //! accuracy problems + sizes.push_back(mat.total_nr_elems() * sizeof(double)); } - megdnn_assert(warp::is_dnn_available(src.layout, mat.layout, dst.layout, - param().imode, param().format)); - size_t C, IH, IW, OH, OW; - if (is_nhwc) { - C = src.layout.shape[3]; - IH = src.layout.shape[1]; - IW = src.layout.shape[2]; - OH = dst.layout.shape[1]; - OW = dst.layout.shape[2]; - } else if (param().format == Param::Format::NCHW4) { - C = src.layout.shape[1] * 4; - IH = src.layout.shape[2]; - IW = src.layout.shape[3]; - OH = dst.layout.shape[2]; - OW = dst.layout.shape[3]; - } else { - megdnn_assert(param().format == param::WarpPerspective::Format::NCHW, - "invalid warp_perspective format"); - C = src.layout.shape[1]; - IH = src.layout.shape[2]; - IW = src.layout.shape[3]; - OH = dst.layout.shape[2]; - OW = dst.layout.shape[3]; + return {ptr, std::move(sizes)}; +} + +void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in ssrc, + _megdnn_tensor_in smat, + _megdnn_tensor_in smat_idx, + _megdnn_tensor_out sdst, + _megdnn_workspace sworkspace) { + check_exec_allow_nhwc_mat_idx(ssrc.layout, smat.layout, smat_idx.layout, + sdst.layout, sworkspace.size); + + TensorND src = ssrc; + TensorND mat = smat; + TensorND mat_idx = smat_idx; + TensorND dst = sdst; + auto bundle = + get_workspace_bundle(sworkspace.raw_ptr, ssrc.layout, smat.layout, + smat_idx.layout, sdst.layout); + auto ctypecvt = CompTypeCvter( + concrete_handle(this->handle()), &bundle); + if (ssrc.layout.dtype.enumv() == DTypeTrait::enumv) { + ctypecvt.src_to_comp_type(ssrc, src) + .src_to_comp_type(smat, mat) + .src_to_comp_type(sdst, dst); } - megdnn_assert(param().imode == Param::InterpolationMode::LINEAR, - "unsupported interpolation mode for NCHW format"); - auto bval = param().border_val; - auto bmode = warp_perspective::get_bmode(param().bmode); - - if (src.layout.dtype == dtype::Float32{}) { - warp_perspective::forward_proxy( - is_nhwc, src.ptr(), mat.ptr(), - mat_idx.raw_ptr ? mat_idx.ptr() : nullptr, - dst.ptr(), src.layout[0], mat.layout[0], C, IH, IW, - OH, OW, bval, bmode, async_error_info(handle()), - m_error_tracker, stream); - } else if (MEGDNN_FLOAT16_SELECT(src.layout.dtype == dtype::Float16(), - false)) { + + { + auto stream = cuda_stream(this->handle()); + bool is_nhwc = param().format == param::WarpPerspective::Format::NHWC; + + if (is_nhwc && param().imode != Param::InterpolationMode::LINEAR) { + // use opencv impl only for nhwc and non-linear interp + megdnn_assert(!mat_idx.raw_ptr, + "mat_idx is not supported in NHWC case with " + "non-linear interpolation"); + warp_perspective::warp_perspective_cv_exec( + src, mat, dst, param().border_val, + warp_perspective::get_bmode(param().bmode), + warp_perspective::get_imode(param().imode), + ctypecvt.workspace(), stream); + + } else { + megdnn_assert(warp::is_dnn_available(src.layout, mat.layout, + dst.layout, param().imode, + param().format)); + size_t C, IH, IW, OH, OW; + if (is_nhwc) { + C = src.layout.shape[3]; + IH = src.layout.shape[1]; + IW = src.layout.shape[2]; + OH = dst.layout.shape[1]; + OW = dst.layout.shape[2]; + } else if (param().format == Param::Format::NCHW4) { + C = src.layout.shape[1] * 4; + IH = src.layout.shape[2]; + IW = src.layout.shape[3]; + OH = dst.layout.shape[2]; + OW = dst.layout.shape[3]; + } else { + megdnn_assert( + param().format == param::WarpPerspective::Format::NCHW, + "invalid warp_perspective format"); + C = src.layout.shape[1]; + IH = src.layout.shape[2]; + IW = src.layout.shape[3]; + OH = dst.layout.shape[2]; + OW = dst.layout.shape[3]; + } + megdnn_assert(param().imode == Param::InterpolationMode::LINEAR, + "unsupported interpolation mode for NCHW format"); + auto bval = param().border_val; + auto bmode = warp_perspective::get_bmode(param().bmode); + + if (src.layout.dtype == dtype::Float32{}) { + warp_perspective::forward_proxy( + is_nhwc, src.ptr(), mat.ptr(), + mat_idx.raw_ptr ? mat_idx.ptr() : nullptr, + dst.ptr(), src.layout[0], mat.layout[0], C, + IH, IW, OH, OW, bval, bmode, async_error_info(handle()), + m_error_tracker, stream); + } else if (MEGDNN_FLOAT16_SELECT( + src.layout.dtype == dtype::Float16(), false)) { #ifndef MEGDNN_DISABLE_FLOAT16 - warp_perspective::forward_proxy( - is_nhwc, src.ptr(), mat.ptr(), - mat_idx.raw_ptr ? mat_idx.ptr() : nullptr, - dst.ptr(), src.layout[0], mat.layout[0], C, IH, IW, - OH, OW, static_cast(bval), bmode, - async_error_info(handle()), m_error_tracker, stream); + warp_perspective::forward_proxy( + is_nhwc, src.ptr(), mat.ptr(), + mat_idx.raw_ptr ? mat_idx.ptr() : nullptr, + dst.ptr(), src.layout[0], mat.layout[0], C, + IH, IW, OH, OW, static_cast(bval), bmode, + async_error_info(handle()), m_error_tracker, stream); #endif - } else if (src.layout.dtype == dtype::Uint8()) { - warp_perspective::forward_proxy( - is_nhwc, src.ptr(), mat.ptr(), - mat_idx.raw_ptr ? mat_idx.ptr() : nullptr, - dst.ptr(), src.layout[0], mat.layout[0], C, IH, IW, - OH, OW, bval, bmode, async_error_info(handle()), - m_error_tracker, stream); - } else if (src.layout.dtype == dtype::Int8()) { - megdnn_assert(!is_nhwc, - "WarpPerspective on CUDA does not support NHWC + Int8"); - warp_perspective::forward_proxy( - false, src.ptr(), mat.ptr(), - mat_idx.raw_ptr ? mat_idx.ptr() : nullptr, - dst.ptr(), src.layout[0], mat.layout[0], C, IH, IW, OH, - OW, bval /* implicit float -> int8 conversion, should be safe */ - , - bmode, async_error_info(handle()), m_error_tracker, stream); - } else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { - megdnn_assert(param().format == Param::Format::NCHW4, - "WarpPerspective on CUDA supports NCHW4 + QuantizedS8 only"); - warp_perspective::forward_proxy_nchw4( - src.compatible_ptr(), mat.ptr(), - mat_idx.raw_ptr ? mat_idx.ptr() : nullptr, - dst.compatible_ptr(), src.layout[0], mat.layout[0], - C, IH, IW, OH, OW, bval, bmode, async_error_info(handle()), - m_error_tracker, stream); - } else { - megdnn_throw( - ssprintf("unsupported dtype: %s", src.layout.dtype.name())); + } else if (src.layout.dtype == dtype::Uint8()) { + warp_perspective::forward_proxy( + is_nhwc, src.ptr(), mat.ptr(), + mat_idx.raw_ptr ? mat_idx.ptr() : nullptr, + dst.ptr(), src.layout[0], mat.layout[0], C, + IH, IW, OH, OW, bval, bmode, async_error_info(handle()), + m_error_tracker, stream); + } else if (src.layout.dtype == dtype::Int8()) { + megdnn_assert( + !is_nhwc, + "WarpPerspective on CUDA does not support NHWC + Int8"); + warp_perspective::forward_proxy( + false, src.ptr(), mat.ptr(), + mat_idx.raw_ptr ? mat_idx.ptr() : nullptr, + dst.ptr(), src.layout[0], mat.layout[0], C, IH, + IW, OH, OW, + bval /* implicit float -> int8 conversion, should be + safe */ + , + bmode, async_error_info(handle()), m_error_tracker, + stream); + } else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { + megdnn_assert(param().format == Param::Format::NCHW4, + "WarpPerspective on CUDA supports NCHW4 + " + "QuantizedS8 only"); + warp_perspective::forward_proxy_nchw4( + src.compatible_ptr(), mat.ptr(), + mat_idx.raw_ptr ? mat_idx.ptr() : nullptr, + dst.compatible_ptr(), src.layout[0], + mat.layout[0], C, IH, IW, OH, OW, bval, bmode, + async_error_info(handle()), m_error_tracker, stream); + } else { + megdnn_throw(ssprintf("unsupported dtype: %s", + src.layout.dtype.name())); + } + } + } + if (ssrc.layout.dtype.enumv() == DTypeTrait::enumv) { + ctypecvt.comp_to_dst_type(dst, sdst); } } diff --git a/dnn/src/cuda/warp_perspective/opr_impl.h b/dnn/src/cuda/warp_perspective/opr_impl.h index bd88f4461e0b91fb7919236967107c41c771a93e..e4b36747ca0ff28fa27d80cfecd730788a18f8b0 100644 --- a/dnn/src/cuda/warp_perspective/opr_impl.h +++ b/dnn/src/cuda/warp_perspective/opr_impl.h @@ -12,66 +12,82 @@ #include "megdnn/oprs.h" #include "src/cuda/cudnn_wrapper.h" +#include "src/cuda/utils.h" namespace megdnn { namespace cuda { -class WarpPerspectiveForwardImpl final: public WarpPerspectiveForward { +class WarpPerspectiveForwardImpl final : public WarpPerspectiveForward { void* m_error_tracker = nullptr; - public: - using WarpPerspectiveForward::WarpPerspectiveForward; - void exec(_megdnn_tensor_in src, - _megdnn_tensor_in mat, - _megdnn_tensor_in mat_idx, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout &, - const TensorLayout &mat, - const TensorLayout &, - const TensorLayout &) override { - if (param().format == param::WarpPerspective::Format::NHWC) { - //! use double for the workspace dtype as float may cause - //! accuracy problems - return mat.total_nr_elems() * sizeof(double); - } - return 0; - } - void set_error_tracker(void* tracker) override { - m_error_tracker = tracker; - } +public: + using WarpPerspectiveForward::WarpPerspectiveForward; + void exec(_megdnn_tensor_in src, _megdnn_tensor_in mat, + _megdnn_tensor_in mat_idx, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout& src, + const TensorLayout& mat, + const TensorLayout& mat_idx, + const TensorLayout& dst) override { + return get_workspace_bundle(nullptr, src, mat, mat_idx, dst) + .total_size_in_bytes(); + } + + void set_error_tracker(void* tracker) override { + m_error_tracker = tracker; + } + +private: + WorkspaceBundle get_workspace_bundle(void* ptr, const TensorLayout& src, + const TensorLayout& mat, + const TensorLayout& mat_idx, + const TensorLayout& dst) const; }; -class WarpPerspectiveBackwardDataImpl final: public WarpPerspectiveBackwardData { - public: - using WarpPerspectiveBackwardData::WarpPerspectiveBackwardData; - void exec(_megdnn_tensor_in mat, - _megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout &mat, - const TensorLayout &diff, - const TensorLayout &grad) override; +class WarpPerspectiveBackwardDataImpl final + : public WarpPerspectiveBackwardData { +public: + using WarpPerspectiveBackwardData::WarpPerspectiveBackwardData; + void exec(_megdnn_tensor_in mat, _megdnn_tensor_in diff, + _megdnn_tensor_out grad, _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout& mat, + const TensorLayout& diff, + const TensorLayout& grad) override { + return get_workspace_bundle(nullptr, mat, diff, grad) + .total_size_in_bytes(); + } + +private: + WorkspaceBundle get_workspace_bundle(void* ptr, const TensorLayout& mat, + const TensorLayout& diff, + const TensorLayout& grad) const; + size_t get_float32_workspace_in_bytes(const TensorLayout& mat, + const TensorLayout& diff, + const TensorLayout& grad) const; }; -class WarpPerspectiveBackwardMatImpl final: public WarpPerspectiveBackwardMat { - public: - using WarpPerspectiveBackwardMat::WarpPerspectiveBackwardMat; - void exec(_megdnn_tensor_in src, - _megdnn_tensor_in mat, - _megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout &, - const TensorLayout &, - const TensorLayout &, - const TensorLayout &) override - { - return 0; - } +class WarpPerspectiveBackwardMatImpl final : public WarpPerspectiveBackwardMat { +public: + using WarpPerspectiveBackwardMat::WarpPerspectiveBackwardMat; + void exec(_megdnn_tensor_in src, _megdnn_tensor_in mat, + _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout& src, + const TensorLayout& mat, + const TensorLayout& diff, + const TensorLayout& grad) override { + return get_workspace_bundle(nullptr, src, mat, diff, grad) + .total_size_in_bytes(); + } + +private: + WorkspaceBundle get_workspace_bundle(void* ptr, const TensorLayout& src, + const TensorLayout& mat, + const TensorLayout& diff, + const TensorLayout& grad) const; }; -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/naive/convolution/convolution.cpp b/dnn/src/naive/convolution/convolution.cpp index e5c457018bc47be67a22ba8dbfe2eb06651b556a..498d21a7e1408952c52ee2ce986661cd673d8039 100644 --- a/dnn/src/naive/convolution/convolution.cpp +++ b/dnn/src/naive/convolution/convolution.cpp @@ -59,6 +59,9 @@ void ConvolutionForwardImpl::exec(_megdnn_tensor_in src, DISPATCH(QuantizedS8, QuantizedS32, dt_int8, dt_int32, dt_int32); MEGDNN_INC_FLOAT16(DISPATCH_CMODE(Float16, Float16, dt_float16, dt_float16, dt_float32, ComputeMode::FLOAT32)); + MEGDNN_INC_FLOAT16(DISPATCH_CMODE(BFloat16, BFloat16, dt_bfloat16, + dt_bfloat16, dt_float32, + ComputeMode::FLOAT32)); DISPATCH(Quantized8Asymm, QuantizedS32, dt_quint8, dt_qint32, dt_qint32); DISPATCH(QuantizedS8, QuantizedS8, dt_int8, dt_int8, dt_int32); #undef DISPATCH @@ -77,7 +80,7 @@ size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes(const TensorLayout& f auto grad_dt = grad.dtype.enumv(); auto diff_dt = diff.dtype.enumv(); #if !MEGDNN_DISABLE_FLOAT16 - if (flt_dt == DTypeEnum::Float16) { + if (flt_dt == DTypeEnum::Float16 || flt_dt == DTypeEnum::BFloat16) { megdnn_assert(flt_dt == grad_dt && flt_dt == diff_dt); workspace_size = grad.span().dist_elem() * dtype::Float32().size(); } @@ -128,6 +131,20 @@ void ConvolutionBackwardDataImpl::exec(_megdnn_tensor_in filter, type_cvt->exec(grad_fp32, grad); return; } + if (filter.layout.dtype == dtype::BFloat16() && + cmode == ComputeMode::FLOAT32) { + TensorND grad_fp32; + grad_fp32.layout = grad.layout; + grad_fp32.layout.dtype = dtype::Float32(); + grad_fp32.raw_ptr = workspace.raw_ptr; + auto&& type_cvt = handle()->create_operator(); + type_cvt->exec(grad, grad_fp32); + MEGDNN_DISPATCH_CPU_KERN_OPR( + (convolution::backward_data( + filter, diff, grad_fp32, filter_meta));); + type_cvt->exec(grad_fp32, grad); + return; + } #endif auto flt_dt = filter.layout.dtype.enumv(); auto grad_dt = grad.layout.dtype.enumv(); @@ -174,7 +191,7 @@ size_t ConvolutionBackwardFilterImpl::get_workspace_in_bytes( auto src_dt = src.dtype.enumv(); auto grad_dt = grad.dtype.enumv(); auto diff_dt = diff.dtype.enumv(); - if (src_dt == DTypeEnum::Float16) { + if (src_dt == DTypeEnum::Float16 || src_dt == DTypeEnum::BFloat16) { megdnn_assert(src_dt == grad_dt && src_dt == diff_dt); workspace_size = grad.span().dist_elem() * dtype::Float32().size(); } @@ -221,6 +238,22 @@ void ConvolutionBackwardFilterImpl::exec(_megdnn_tensor_in src, type_cvt->exec(grad_fp32, grad); return; } + if (src.layout.dtype == dtype::BFloat16() && + cmode == ComputeMode::FLOAT32) { + TensorND grad_fp32; + grad_fp32.layout = grad.layout; + grad_fp32.layout.dtype = dtype::Float32(); + grad_fp32.raw_ptr = workspace.raw_ptr; + auto&& type_cvt = handle()->create_operator(); + type_cvt->exec(grad, grad_fp32); + MEGDNN_DISPATCH_CPU_KERN_OPR( + (convolution::backward_filter(src, diff, grad_fp32, + filter_meta));); + type_cvt->exec(grad_fp32, grad); + return; + } + #endif megdnn_assert_internal(0); diff --git a/dnn/src/naive/elemwise/kimpl/ABS_GRAD_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/ABS_GRAD_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..93919198b511cafae255d23cc0aa3b1380e2f07b --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ABS_GRAD_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/ABS_GRAD_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/ABS_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/ABS_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7e311563cca0d6cb2e9f7c39ffa08277dfb505bd --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ABS_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/ABS_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/ACOS_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/ACOS_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cc60c400413271a5a85cd78b854a39eb4ad7aa07 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ACOS_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/ACOS_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOS, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/ADD_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/ADD_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..91c80e1c6a738eafd114e1402fbb7be7058a712f --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ADD_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/ADD_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/ASIN_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/ASIN_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..94d512dcc2e3abf8c56e4bfdcd6b18c26aed63b9 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ASIN_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/ASIN_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASIN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/ATAN2_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/ATAN2_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cb378ee9d24762b932ef4d4028ce06f80004c711 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ATAN2_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/ATAN2_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATAN2, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/CEIL_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/CEIL_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2276a5637dfaab266ed8a3d8331e1eb29e0a5589 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/CEIL_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/CEIL_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CEIL, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/COND_LEQ_MOV_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/COND_LEQ_MOV_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b76b9ede1fb4b5c744c5fe6f0a95c26055d666c4 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/COND_LEQ_MOV_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/COND_LEQ_MOV_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/COS_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/COS_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..17889e8d283a01610888e24f33cd3280a04f4409 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/COS_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/COS_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COS, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/EQ_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/EQ_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..181a8cee5ed4c38fa02f934f2d971d688c866a05 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/EQ_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/EQ_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/ERFCINV_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/ERFCINV_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..baf9859a5418305fe8ef52bbcfd69bb5da0064de --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ERFCINV_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/ERFCINV_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERFCINV, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/ERFC_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/ERFC_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..249e937dd3175cabdaff0512dbc5e3a094f3ccbc --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ERFC_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/ERFC_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERFC, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/ERFINV_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/ERFINV_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e9f1096aeef12957eda1d30e61ee8ca76a1102a3 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ERFINV_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/ERFINV_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERFINV, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/ERF_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/ERF_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b83de8794d4a6a219ab93841d435cf03a8d61ccb --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ERF_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/ERF_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERF, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/EXPM1_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/EXPM1_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b397721f4b029a2490456a494714237249301855 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/EXPM1_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/EXPM1_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EXPM1, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/EXP_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/EXP_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..97eb022f979de30e152c665086654c117e87c792 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/EXP_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/EXP_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EXP, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/FAST_TANH_GRAD_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/FAST_TANH_GRAD_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..91d4e7e97e32da9ffcb28f659ddeb7bd5d79d25f --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/FAST_TANH_GRAD_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/FAST_TANH_GRAD_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/FAST_TANH_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/FAST_TANH_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cf7f86a68a9036701323a2c8c07528956a4e5803 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/FAST_TANH_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/FAST_TANH_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/FLOOR_DIV_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/FLOOR_DIV_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ec884116773b0d960d9ea4188fa5dcbff2ad0933 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/FLOOR_DIV_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/FLOOR_DIV_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR_DIV, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/FLOOR_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/FLOOR_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..11f74c8679ac95209dcc176288757f9c7f43dc01 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/FLOOR_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/FLOOR_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/FUSE_ADD_H_SWISH_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/FUSE_ADD_H_SWISH_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3d2dae2d63ec69167bf0731cde37f4b4ddc3812f --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/FUSE_ADD_H_SWISH_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/FUSE_ADD_H_SWISH_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/FUSE_ADD_RELU_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/FUSE_ADD_RELU_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9f4645c47dce94e83db82b186bcef65717fab2b7 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/FUSE_ADD_RELU_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/FUSE_ADD_RELU_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/FUSE_ADD_SIGMOID_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/FUSE_ADD_SIGMOID_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0f2e33867891b574c1dbaf3ed045889ff4a711c8 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/FUSE_ADD_SIGMOID_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/FUSE_ADD_SIGMOID_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_SIGMOID, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/FUSE_ADD_TANH_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/FUSE_ADD_TANH_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2a33b0b003f250b68804e61f68af321a120fa70c --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/FUSE_ADD_TANH_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/FUSE_ADD_TANH_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_TANH, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/FUSE_MUL_ADD3_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/FUSE_MUL_ADD3_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..62ce1ed32d12623b54e43feb9a81831774353b52 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/FUSE_MUL_ADD3_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/FUSE_MUL_ADD3_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/H_SWISH_GRAD_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/H_SWISH_GRAD_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..234889c42b5a1eff41c2d3eb71d6b8c9585ff4fc --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/H_SWISH_GRAD_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/H_SWISH_GRAD_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/H_SWISH_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/H_SWISH_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ad33240efad1a59a4ae94ebe4dbde26bb85ec0d9 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/H_SWISH_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/H_SWISH_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/LEQ_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/LEQ_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..adbaee2a76bb4f7bcf5b74de353b592c164ce43a --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/LEQ_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/LEQ_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/LOG1P_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/LOG1P_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0d62d7d069969e8e5335bf0d4e6d8a20c103ee18 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/LOG1P_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/LOG1P_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOG1P, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/LOG_SUM_EXP_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/LOG_SUM_EXP_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5a96d87ebb9a2030375c70469050e710cddbd5dd --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/LOG_SUM_EXP_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/LOG_SUM_EXP_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOG_SUM_EXP, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/LOG_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/LOG_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1365e97f8eab59ae75c96dff26653e7378c07e29 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/LOG_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/LOG_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOG, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/LT_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/LT_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..31f9cfeab08c6a001a7bb7c52a5dcb3df4da8fed --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/LT_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/LT_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/MAX_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/MAX_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..faebe1ded6a40db6b66d80a803d628657eabd51c --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/MAX_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/MAX_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MAX, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/MIN_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/MIN_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..848e72cebaf128f8f8a3f4ee42cc37b6ec3a5572 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/MIN_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/MIN_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MIN, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/MOD_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/MOD_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a90406572fc26281b8a9a8dc582cfaf8bb7f8dea --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/MOD_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/MOD_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MOD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/MUL_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/MUL_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..242a884fb9ef9b04d4b1337399e80fd0e0ca64ed --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/MUL_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/MUL_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MUL, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/NEGATE_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/NEGATE_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..720df0190e032c9e93db27dd75535d31f932a560 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/NEGATE_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/NEGATE_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/POW_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/POW_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ee339df611ec0ef5a745d3c7e34e039c28c91e7c --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/POW_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/POW_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(POW, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/RELU_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/RELU_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..47be86081755afb8bbdb6f3d95eb42cfcb37d4d5 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/RELU_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/RELU_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/ROUND_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/ROUND_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..20d52e01bebceb00ca7e54798e6ce0b6580ed059 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/ROUND_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/ROUND_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ROUND, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/SIGMOID_GRAD_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/SIGMOID_GRAD_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..15570f144b9cb8dadbd93946029a299d52bfc019 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SIGMOID_GRAD_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/SIGMOID_GRAD_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/SIGMOID_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/SIGMOID_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..45bc43db9486b7cfb91685398bb1b3dbce91074b --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SIGMOID_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/SIGMOID_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/SIN_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/SIN_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d72bb36e4b48f1a779e3bc6c5dfbbc83b38a4da7 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SIN_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/SIN_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/SUB_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/SUB_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0e300c07d95f09644d70091021b5322ffbc258ab --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SUB_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/SUB_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SUB, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/SWITCH_GT0_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/SWITCH_GT0_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a050a0c3a186def785acdb72cee842194b0807ad --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SWITCH_GT0_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/SWITCH_GT0_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SWITCH_GT0, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/TANH_GRAD_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/TANH_GRAD_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d363fb41c8c3ec1d46991ed26df20c6457c5d254 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/TANH_GRAD_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/TANH_GRAD_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/TANH_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/TANH_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..10531f047348c0c069f62571f935fb320f97adb1 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/TANH_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/TANH_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/TRUE_DIV_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/TRUE_DIV_dt_bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..91b95d884334b980b977df7123a2d0d9caf5a630 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/TRUE_DIV_dt_bfloat16.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/src/naive/elemwise/kimpl/TRUE_DIV_dt_bfloat16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TRUE_DIV, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/matrix_mul/opr_impl.cpp b/dnn/src/naive/matrix_mul/opr_impl.cpp index eab96e927a458417429407577489302df029cf53..70ba4c493e936224896a44644606f615bbe46564 100644 --- a/dnn/src/naive/matrix_mul/opr_impl.cpp +++ b/dnn/src/naive/matrix_mul/opr_impl.cpp @@ -66,6 +66,13 @@ void dispatch_ta_tb(_megdnn_tensor_in A, _megdnn_tensor_in B, } else if (param.compute_mode == Param::ComputeMode::FLOAT32) { cb(dt_float16, dt_float16, dt_float32); } + } else if (A.layout.dtype == dtype::BFloat16()) { + using Param = MatrixMul::Param; + if (param.compute_mode == Param::ComputeMode::DEFAULT) { + cb(dt_bfloat16, dt_bfloat16, dt_bfloat16); + } else if (param.compute_mode == Param::ComputeMode::FLOAT32) { + cb(dt_bfloat16, dt_bfloat16, dt_float32); + } #endif } else if (A.layout.dtype == dtype::Int8() && C.layout.dtype == dtype::Int16()) { diff --git a/dnn/src/naive/pooling/opr_impl.cpp b/dnn/src/naive/pooling/opr_impl.cpp index 7ddba4dbdecabec5087a2404974879b30e51ff1f..b0a5222d3a9d95bcc9137f6acef780aed6e53aea 100644 --- a/dnn/src/naive/pooling/opr_impl.cpp +++ b/dnn/src/naive/pooling/opr_impl.cpp @@ -494,11 +494,56 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, MIDOUT_END(); } -void PoolingBackwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, - _megdnn_tensor_in diff, _megdnn_tensor_out grad, +WorkspaceBundle PoolingBackwardImpl::get_workspace_bundle( + void* ptr, const TensorLayout& src, const TensorLayout& dst, + const TensorLayout& diff, const TensorLayout& grad) const { + SmallVector sizes; + TensorLayout fsrc = src; + TensorLayout fdst = dst; + TensorLayout fdiff = diff; + TensorLayout fgrad = grad; + auto get_workspace = [&sizes](TensorLayout& layout) { + if (MEGDNN_FLOAT16_SELECT(layout.dtype == dtype::BFloat16(), false)) { + layout.dtype = dtype::Float32(); + sizes.push_back(layout.span().dist_byte()); + } + }; + get_workspace(fsrc); + get_workspace(fdst); + get_workspace(fdiff); + get_workspace(fgrad); + return {ptr, std::move(sizes)}; +} + +size_t PoolingBackwardImpl::get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst, + const TensorLayout& diff, const TensorLayout& grad) { + return get_workspace_bundle(nullptr, src, dst, diff, grad) + .total_size_in_bytes(); +} + +void PoolingBackwardImpl::exec(_megdnn_tensor_in ssrc, _megdnn_tensor_in sdst, + _megdnn_tensor_in sdiff, + _megdnn_tensor_out sgrad, _megdnn_workspace workspace) { - check_exec(src.layout, dst.layout, diff.layout, grad.layout, + check_exec(ssrc.layout, sdst.layout, sdiff.layout, sgrad.layout, workspace.size); + TensorND src = ssrc; + TensorND dst = sdst; + TensorND diff = sdiff; + TensorND grad = sgrad; +#if !MEGDNN_DISABLE_FLOAT16 + auto wsb = get_workspace_bundle(workspace.raw_ptr, ssrc.layout, sdst.layout, + sdiff.layout, sgrad.layout); + auto ctypecvt = CompTypeCvter( + static_cast(handle()), &wsb); + if (ssrc.layout.dtype.enumv() == DTypeTrait::enumv) { + ctypecvt.src_to_comp_type(ssrc, src) + .src_to_comp_type(sdst, dst) + .src_to_comp_type(sdiff, diff) + .src_to_comp_type(sgrad, grad); + } +#endif size_t c_pos, spatial_pos; if (param().format == Param::Format::NCHW) { c_pos = 1; @@ -520,7 +565,7 @@ void PoolingBackwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, MEGDNN_DISPATCH_CPU_KERN(static_cast(handle()), \ Func( \ sptr, dptr, diffptr, gradptr, N, C, IH, \ - IW, OH, OW, PH, PW, SH, SW, FH, FW)); + IW, OH, OW, PH, PW, SH, SW, FH, FW)); \ #define DISPATCH_WITH_FUNC(Func, ctype) \ switch (param().format) { \ @@ -542,27 +587,33 @@ void PoolingBackwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, auto sptr = src.ptr(), dptr = dst.ptr(), \ diffptr = diff.ptr(), gradptr = grad.ptr(); \ DISPATCH_WITH_FUNC(pooling_backward_avg_impl, ctype); \ - return; \ + break; \ } \ case Mode::AVERAGE_COUNT_EXCLUDE_PADDING: { \ auto sptr = src.ptr(), dptr = dst.ptr(), \ diffptr = diff.ptr(), gradptr = grad.ptr(); \ DISPATCH_WITH_FUNC(pooling_backward_avg_expd_impl, ctype); \ - return; \ + break; \ } \ case Mode::MAX: { \ auto sptr = src.ptr(), dptr = dst.ptr(), \ diffptr = diff.ptr(), gradptr = grad.ptr(); \ DISPATCH_WITH_FUNC(pooling_backward_max_impl, ctype); \ - return; \ + break; \ } \ + default: \ + megdnn_assert_internal(0); \ } \ } MEGDNN_FOREACH_COMPUTING_DTYPE(cb) #undef cb #undef DISPATCH_WITH_FUNC_AND_IDX_GETTER #undef DISPATCH_WITH_FUNC - megdnn_assert_internal(0); +#if !MEGDNN_DISABLE_FLOAT16 + if (sgrad.layout.dtype.enumv() == DTypeTrait::enumv) { + ctypecvt.comp_to_dst_type(grad, sgrad); + } +#endif } } // namespace naive diff --git a/dnn/src/naive/pooling/opr_impl.h b/dnn/src/naive/pooling/opr_impl.h index ddbff5f9a4b671d0e1c1b4622e4bf2f1214105a5..19ff371b29207fc55686ba7131d81901e780b61d 100644 --- a/dnn/src/naive/pooling/opr_impl.h +++ b/dnn/src/naive/pooling/opr_impl.h @@ -10,6 +10,7 @@ */ #pragma once #include "megdnn/oprs.h" +#include "src/common/utils.h" namespace megdnn { namespace naive { @@ -25,20 +26,21 @@ class PoolingForwardImpl: public PoolingForward { } }; -class PoolingBackwardImpl: public PoolingBackward { - public: - using PoolingBackward::PoolingBackward; - void exec(_megdnn_tensor_in src, - _megdnn_tensor_in dst, - _megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout &, - const TensorLayout &, - const TensorLayout &, - const TensorLayout &) override { - return 0; - } +class PoolingBackwardImpl : public PoolingBackward { +public: + using PoolingBackward::PoolingBackward; + void exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, + _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, + const TensorLayout&, + const TensorLayout&) override; + +private: + WorkspaceBundle get_workspace_bundle(void* ptr, const TensorLayout&, + const TensorLayout&, + const TensorLayout&, + const TensorLayout&) const; }; } // namespace naive diff --git a/dnn/src/naive/rng/opr_impl.cpp b/dnn/src/naive/rng/opr_impl.cpp index 47476e66fb474a974ddb1bb58c95dae703ab49fd..6ceb285827f9fdffa35686ffda3b4c34d28884b6 100644 --- a/dnn/src/naive/rng/opr_impl.cpp +++ b/dnn/src/naive/rng/opr_impl.cpp @@ -38,6 +38,15 @@ namespace { } #endif +#if !MEGDNN_DISABLE_FLOAT16 + template<> + dt_bfloat16 uniform_int2float(uint64_t x) { + union U { uint16_t i; dt_bfloat16 f; U(): f(0) {} } u; + u.i = (0x7F << 7) | (x >> 57); + return dt_bfloat16(2.f) - u.f; + } +#endif + template void fill_uniform(Xoroshiro128plus *rng, ctype *dst, size_t size) { for (size_t i = 0; i < size; ++ i) { diff --git a/dnn/src/naive/warp_perspective/opr_impl.cpp b/dnn/src/naive/warp_perspective/opr_impl.cpp index 1969022ffda66f59f2c995ad92da042372601b22..f580d2c1fe18290350442b919ffdb5d0acb6b36c 100644 --- a/dnn/src/naive/warp_perspective/opr_impl.cpp +++ b/dnn/src/naive/warp_perspective/opr_impl.cpp @@ -237,34 +237,69 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in src, _megdnn_workspace workspace) { check_exec_allow_nhwc_mat_idx(src.layout, mat.layout, mat_idx.layout, dst.layout, workspace.size); - size_t batch = dst.layout[0]; - if (param().format == Format::NHWCD4) { - size_t oh = dst.layout[1]; + #define cb(dt, ct, mct) \ case DTypeTrait
::enumv: { \ auto kparam = KernParam::from_tensors( \ param().format, param().bmode, param().border_val, src, mat, \ mat_idx, dst, workspace); \ auto run = [kparam, this](size_t index, size_t) { \ - kern_naive_nhwcd4(kparam, index); \ + kern_naive(kparam, index); \ }; \ - MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN_OPR(run, batch* oh); \ + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN_OPR(run, kparam.oh* batch); \ return; \ } - switch (src.layout.dtype.enumv()) { - cb(dtype::Float32, float, float); - MEGDNN_INC_FLOAT16(cb(dtype::Float16, dt_float16, dt_float16)); - cb(dtype::Quantized8Asymm, uint8_t, float); - cb(dtype::QuantizedS8, int8_t, float); - default: - megdnn_throw(ssprintf("Unsupported input DType in " - "WarpPerspective: %s", - src.layout.dtype.name()) - .c_str()); - } -#undef cb +#define KERN_CD4(ct, mct) \ + auto kparam = KernParam::from_tensors( \ + param().format, param().bmode, param().border_val, src, mat, \ + mat_idx, dst, workspace); \ + auto run = [kparam, this](size_t index, size_t) { \ + kern_naive_nhwcd4(kparam, index); \ + }; \ + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN_OPR(run, batch* oh); + +#define KERN(ct, mct) \ + auto kparam = KernParam::from_tensors( \ + param().format, param().bmode, param().border_val, src, mat, \ + mat_idx, dst, workspace); \ + auto run = [kparam, this](size_t index, size_t) { \ + kern_naive(kparam, index); \ + }; \ + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN_OPR(run, kparam.oh* batch); + +#define DISPATCH_ST(dt, ct, mct, kern) \ + if (src.layout.dtype.enumv() == DTypeTrait
::enumv) { \ + kern(ct, mct); \ + return; \ + } + +#define DISPATCH_ST_MT(dt, ct, kern) \ + if (src.layout.dtype.enumv() == DTypeTrait
::enumv) { \ + if (mat.layout.dtype.enumv() == DTypeTrait::enumv) { \ + kern(ct, float); \ + return; \ + } else { \ + kern(ct, ct); \ + return; \ + } \ + } + + if (param().format == Format::NHWCD4) { + size_t oh = dst.layout[1]; + DISPATCH_ST(dtype::Float32, float, float, KERN_CD4); + DISPATCH_ST(dtype::Quantized8Asymm, uint8_t, float, KERN_CD4); + DISPATCH_ST(dtype::QuantizedS8, int8_t, float, KERN_CD4); + + MEGDNN_INC_FLOAT16( + DISPATCH_ST_MT(dtype::Float16, dt_float16, KERN_CD4)); + MEGDNN_INC_FLOAT16( + DISPATCH_ST_MT(dtype::BFloat16, dt_bfloat16, KERN_CD4)); + megdnn_throw(ssprintf("Unsupported input DType in " + "WarpPerspective: %s", + src.layout.dtype.name()) + .c_str()); } if (warp::is_cv_available(src.layout, mat.layout, dst.layout, param().imode, param().format) && @@ -286,32 +321,75 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in src, * input type is float16. */ -#define cb(dt, ct, mct) \ - case DTypeTrait
::enumv: { \ - auto kparam = KernParam::from_tensors( \ - param().format, param().bmode, param().border_val, src, mat, \ - mat_idx, dst, workspace); \ - auto run = [kparam, this](size_t index, size_t) { \ - kern_naive(kparam, index); \ - }; \ - MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN_OPR(run, kparam.oh* batch); \ - return; \ + DISPATCH_ST(dtype::Float32, float, float, KERN); + DISPATCH_ST(dtype::Int8, int8_t, float, KERN); + DISPATCH_ST(dtype::QuantizedS8, int8_t, float, KERN); + DISPATCH_ST(dtype::Uint8, uint8_t, float, KERN); + DISPATCH_ST(dtype::Quantized8Asymm, uint8_t, float, KERN); + + MEGDNN_INC_FLOAT16(DISPATCH_ST_MT(dtype::Float16, dt_float16, KERN)); + MEGDNN_INC_FLOAT16(DISPATCH_ST_MT(dtype::BFloat16, dt_bfloat16, KERN)); + megdnn_throw(ssprintf("Unsupported input DType in " + "WarpPerspective: %s", + src.layout.dtype.name()) + .c_str()); } +#undef DISPATCH_ST_MT +#undef DISPATCH_ST +#undef KERN +#undef KERN_CD4 +} - switch (src.layout.dtype.enumv()) { - cb(dtype::Float32, float, float); - MEGDNN_INC_FLOAT16(cb(dtype::Float16, dt_float16, float)); - cb(dtype::Int8, int8_t, float); - cb(dtype::QuantizedS8, int8_t, float); - cb(dtype::Uint8, uint8_t, float); - cb(dtype::Quantized8Asymm, uint8_t, float); - default: - megdnn_throw(ssprintf("Unsupported input DType in " - "WarpPerspective: %s", - src.layout.dtype.name()) - .c_str()); +template +void WarpPerspectiveBackwardDataImpl::kern_naive(const KernParam& kern_param) { + const int N = kern_param.n, C = kern_param.c, + IH = kern_param.ih, IW = kern_param.iw; + const int OH = kern_param.oh, OW = kern_param.ow; + const ctype* hptr_ = kern_param.hptr; + const mtype* mptr_ = kern_param.mptr; + ctype* sptr_ = kern_param.sptr; + auto hptr = hptr_; + auto mptr = mptr_; + auto sptr = sptr_; + std::memset(sptr, 0, sizeof(ctype) * N * C * IH * IW); + rep(n, N) { + rep(oh, OH) rep(ow, OW) { + float numeratorw = mptr[0] * ow + mptr[1] * oh + mptr[2]; + float numeratorh = mptr[3] * ow + mptr[4] * oh + mptr[5]; + float denominator = mptr[6] * ow + mptr[7] * oh + mptr[8]; + float alphaw = numeratorw / denominator; + float alphah = numeratorh / denominator; + + int iw0 = get_real_coord(std::floor(alphaw) + 0, IW); + int iw1 = get_real_coord(std::floor(alphaw) + 1, IW); + int ih0 = get_real_coord(std::floor(alphah) + 0, IH); + int ih1 = get_real_coord(std::floor(alphah) + 1, IH); + + alphaw -= floor(alphaw); + alphah -= floor(alphah); + rep(c, C) { + float hidden = hptr[c * OH * OW + oh * OW + ow]; + if (iw0 != -1 && ih0 != -1) { + sptr[c * IH * IW + ih0 * IW + iw0] += + (1.0f - alphaw) * (1.0f - alphah) * hidden; + } + if (iw0 != -1 && ih1 != -1) { + sptr[c * IH * IW + ih1 * IW + iw0] += + (1.0f - alphaw) * alphah * hidden; + } + if (iw1 != -1 && ih0 != -1) { + sptr[c * IH * IW + ih0 * IW + iw1] += + alphaw * (1.0f - alphah) * hidden; + } + if (iw1 != -1 && ih1 != -1) { + sptr[c * IH * IW + ih1 * IW + iw1] += + alphaw * alphah * hidden; + } + } } -#undef cb + sptr += C * IH * IW; + hptr += C * OH * OW; + mptr += 3 * 3; } } @@ -322,57 +400,130 @@ void WarpPerspectiveBackwardDataImpl::exec(_megdnn_tensor_in mat, check_exec(mat.layout, diff.layout, grad.layout, workspace.size); megdnn_assert(param().format == param::WarpPerspective::Format::NCHW, "invalid warp_perspective format"); - const int N = grad.layout.shape[0], C = grad.layout.shape[1], - IH = grad.layout.shape[2], IW = grad.layout.shape[3]; - const int OH = diff.layout.shape[2], OW = diff.layout.shape[3]; - const float* hptr_ = diff.ptr(); - const float* mptr_ = mat.ptr(); - float* sptr_ = grad.ptr(); - auto kern = [=]() { - auto hptr = hptr_, mptr = mptr_; - auto sptr = sptr_; - std::memset(sptr, 0, sizeof(float) * N * C * IH * IW); - rep(n, N) { - rep(oh, OH) rep(ow, OW) { - float numeratorw = mptr[0] * ow + mptr[1] * oh + mptr[2]; - float numeratorh = mptr[3] * ow + mptr[4] * oh + mptr[5]; - float denominator = mptr[6] * ow + mptr[7] * oh + mptr[8]; - float alphaw = numeratorw / denominator; - float alphah = numeratorh / denominator; +#define DISPATCH_ST_MT(dt, ct) \ + if (diff.layout.dtype.enumv() == DTypeTrait
::enumv) { \ + if (mat.layout.dtype.enumv() == DTypeTrait::enumv) { \ + auto kparam = KernParam::from_tensors(mat, diff, grad); \ + MEGDNN_DISPATCH_CPU_KERN_OPR(kern_naive(kparam)); \ + return; \ + } else { \ + auto kparam = KernParam::from_tensors(mat, diff, grad); \ + MEGDNN_DISPATCH_CPU_KERN_OPR(kern_naive(kparam)); \ + return; \ + } \ + } + DISPATCH_ST_MT(dtype::Float32, dt_float32); + MEGDNN_INC_FLOAT16(DISPATCH_ST_MT(dtype::BFloat16, dt_bfloat16)); + megdnn_throw(ssprintf("Unsupported input DType in " + "WarpPerspective: %s", + diff.layout.dtype.name()) + .c_str()); +#undef DISPATCH_ST_MT +} - int iw0 = get_real_coord(std::floor(alphaw) + 0, IW); - int iw1 = get_real_coord(std::floor(alphaw) + 1, IW); - int ih0 = get_real_coord(std::floor(alphah) + 0, IH); - int ih1 = get_real_coord(std::floor(alphah) + 1, IH); +template +void WarpPerspectiveBackwardMatImpl::kern_naive( + const KernParam& kern_param) { + const int N = kern_param.n, C = kern_param.c, IH = kern_param.ih, + IW = kern_param.iw; + const int OH = kern_param.oh, OW = kern_param.ow; - alphaw -= floor(alphaw); - alphah -= floor(alphah); - rep(c, C) { - float hidden = hptr[c * OH * OW + oh * OW + ow]; - if (iw0 != -1 && ih0 != -1) { - sptr[c * IH * IW + ih0 * IW + iw0] += - (1.0f - alphaw) * (1.0f - alphah) * hidden; - } - if (iw0 != -1 && ih1 != -1) { - sptr[c * IH * IW + ih1 * IW + iw0] += - (1.0f - alphaw) * alphah * hidden; - } - if (iw1 != -1 && ih0 != -1) { - sptr[c * IH * IW + ih0 * IW + iw1] += - alphaw * (1.0f - alphah) * hidden; - } - if (iw1 != -1 && ih1 != -1) { - sptr[c * IH * IW + ih1 * IW + iw1] += - alphaw * alphah * hidden; - } + auto hptr = kern_param.hptr; + auto sptr = kern_param.sptr; + auto mptr = kern_param.mptr; + auto res = kern_param.res; + auto border_val = kern_param.border_val; + std::memset(res, 0, sizeof(float) * N * 3 * 3); + rep(n, N) { + rep(oh, OH) rep(ow, OW) { + float numeratorw = mptr[0] * ow + mptr[1] * oh + mptr[2]; + float numeratorh = mptr[3] * ow + mptr[4] * oh + mptr[5]; + float denominator = mptr[6] * ow + mptr[7] * oh + mptr[8]; + float denominator2 = denominator * denominator; + float alphaw = numeratorw / denominator; + float alphah = numeratorh / denominator; + + int iw0 = get_real_coord(std::floor(alphaw) + 0, IW); + int iw1 = get_real_coord(std::floor(alphaw) + 1, IW); + int ih0 = get_real_coord(std::floor(alphah) + 0, IH); + int ih1 = get_real_coord(std::floor(alphah) + 1, IH); + + alphaw -= floor(alphaw); + alphah -= floor(alphah); + rep(c, C) { + float b = border_val; + float hidden = hptr[c * OH * OW + oh * OW + ow]; + float dalphaw = 0; + dalphaw -= ((ih0 != -1 && iw0 != -1) + ? sptr[c * IH * IW + ih0 * IW + iw0] + : b) * + (1.0f - alphah); + dalphaw += ((ih0 != -1 && iw1 != -1) + ? sptr[c * IH * IW + ih0 * IW + iw1] + : b) * + (1.0f - alphah); + dalphaw -= ((ih1 != -1 && iw0 != -1) + ? sptr[c * IH * IW + ih1 * IW + iw0] + : b) * + alphah; + dalphaw += ((ih1 != -1 && iw1 != -1) + ? sptr[c * IH * IW + ih1 * IW + iw1] + : b) * + alphah; + float dalphah = 0; + dalphah -= ((ih0 != -1 && iw0 != -1) + ? sptr[c * IH * IW + ih0 * IW + iw0] + : b) * + (1.0f - alphaw); + dalphah -= ((ih0 != -1 && iw1 != -1) + ? sptr[c * IH * IW + ih0 * IW + iw1] + : b) * + alphaw; + dalphah += ((ih1 != -1 && iw0 != -1) + ? sptr[c * IH * IW + ih1 * IW + iw0] + : b) * + (1.0f - alphaw); + dalphah += ((ih1 != -1 && iw1 != -1) + ? sptr[c * IH * IW + ih1 * IW + iw1] + : b) * + alphaw; + // printf("dalphaw=%f dalphah=%f\n", dalphaw, dalphaw); + float dw[9], dh[9]; + // dw[i] = d(iw)/d(mat[i]) + dw[0] = ow / denominator; + dw[1] = oh / denominator; + dw[2] = 1.0f / denominator; + dw[3] = 0.0f; + dw[4] = 0.0f; + dw[5] = 0.0f; + float ddenominatorw = -numeratorw / denominator2; + dw[6] = ow * ddenominatorw; + dw[7] = oh * ddenominatorw; + dw[8] = 1.0f * ddenominatorw; + // dh[i] = d(ih)/d(mat[i]) + dh[0] = 0.0f; + dh[1] = 0.0f; + dh[2] = 0.0f; + dh[3] = ow / denominator; + dh[4] = oh / denominator; + dh[5] = 1.0f / denominator; + float ddenominatorh = -numeratorh / denominator2; + dh[6] = ow * ddenominatorh; + dh[7] = oh * ddenominatorh; + dh[8] = 1.0f * ddenominatorh; + rep(i, 9) { + float delta = + hidden * dalphaw * dw[i] + hidden * dalphah * dh[i]; + if (std::isfinite(delta)) + res[i] += delta; } } - sptr += C * IH * IW; - hptr += C * OH * OW; - mptr += 3 * 3; } - }; - MEGDNN_DISPATCH_CPU_KERN_OPR(kern()); + hptr += C * OH * OW; + sptr += C * IH * IW; + mptr += 3 * 3; + res += 3 * 3; + } } void WarpPerspectiveBackwardMatImpl::exec(_megdnn_tensor_in src, @@ -382,107 +533,27 @@ void WarpPerspectiveBackwardMatImpl::exec(_megdnn_tensor_in src, _megdnn_workspace workspace) { check_exec(src.layout, mat.layout, diff.layout, grad.layout, workspace.size); - auto N = src.layout.shape[0], C = src.layout.shape[1], - IH = src.layout.shape[2], IW = src.layout.shape[3]; - auto OH = diff.layout.shape[2], OW = diff.layout.shape[3]; - auto hptr_ = diff.ptr(), sptr_ = src.ptr(), - mptr_ = mat.ptr(), res_ = grad.ptr(); - auto border_val = param().border_val; - auto kern = [=]() { - auto hptr = hptr_, sptr = sptr_, mptr = mptr_, res = res_; - std::memset(res, 0, sizeof(float) * N * 3 * 3); - rep(n, N) { - rep(oh, OH) rep(ow, OW) { - float numeratorw = mptr[0] * ow + mptr[1] * oh + mptr[2]; - float numeratorh = mptr[3] * ow + mptr[4] * oh + mptr[5]; - float denominator = mptr[6] * ow + mptr[7] * oh + mptr[8]; - float denominator2 = denominator * denominator; - float alphaw = numeratorw / denominator; - float alphah = numeratorh / denominator; - - int iw0 = get_real_coord(std::floor(alphaw) + 0, IW); - int iw1 = get_real_coord(std::floor(alphaw) + 1, IW); - int ih0 = get_real_coord(std::floor(alphah) + 0, IH); - int ih1 = get_real_coord(std::floor(alphah) + 1, IH); - - alphaw -= floor(alphaw); - alphah -= floor(alphah); - rep(c, C) { - float b = border_val; - float hidden = hptr[c * OH * OW + oh * OW + ow]; - float dalphaw = 0; - dalphaw -= ((ih0 != -1 && iw0 != -1) - ? sptr[c * IH * IW + ih0 * IW + iw0] - : b) * - (1.0f - alphah); - dalphaw += ((ih0 != -1 && iw1 != -1) - ? sptr[c * IH * IW + ih0 * IW + iw1] - : b) * - (1.0f - alphah); - dalphaw -= ((ih1 != -1 && iw0 != -1) - ? sptr[c * IH * IW + ih1 * IW + iw0] - : b) * - alphah; - dalphaw += ((ih1 != -1 && iw1 != -1) - ? sptr[c * IH * IW + ih1 * IW + iw1] - : b) * - alphah; - float dalphah = 0; - dalphah -= ((ih0 != -1 && iw0 != -1) - ? sptr[c * IH * IW + ih0 * IW + iw0] - : b) * - (1.0f - alphaw); - dalphah -= ((ih0 != -1 && iw1 != -1) - ? sptr[c * IH * IW + ih0 * IW + iw1] - : b) * - alphaw; - dalphah += ((ih1 != -1 && iw0 != -1) - ? sptr[c * IH * IW + ih1 * IW + iw0] - : b) * - (1.0f - alphaw); - dalphah += ((ih1 != -1 && iw1 != -1) - ? sptr[c * IH * IW + ih1 * IW + iw1] - : b) * - alphaw; - // printf("dalphaw=%f dalphah=%f\n", dalphaw, dalphaw); - float dw[9], dh[9]; - // dw[i] = d(iw)/d(mat[i]) - dw[0] = ow / denominator; - dw[1] = oh / denominator; - dw[2] = 1.0f / denominator; - dw[3] = 0.0f; - dw[4] = 0.0f; - dw[5] = 0.0f; - float ddenominatorw = -numeratorw / denominator2; - dw[6] = ow * ddenominatorw; - dw[7] = oh * ddenominatorw; - dw[8] = 1.0f * ddenominatorw; - // dh[i] = d(ih)/d(mat[i]) - dh[0] = 0.0f; - dh[1] = 0.0f; - dh[2] = 0.0f; - dh[3] = ow / denominator; - dh[4] = oh / denominator; - dh[5] = 1.0f / denominator; - float ddenominatorh = -numeratorh / denominator2; - dh[6] = ow * ddenominatorh; - dh[7] = oh * ddenominatorh; - dh[8] = 1.0f * ddenominatorh; - rep(i, 9) { - float delta = hidden * dalphaw * dw[i] + - hidden * dalphah * dh[i]; - if (std::isfinite(delta)) - res[i] += delta; - } - } - } - hptr += C * OH * OW; - sptr += C * IH * IW; - mptr += 3 * 3; - res += 3 * 3; - } - }; - MEGDNN_DISPATCH_CPU_KERN_OPR(kern()); +#define DISPATCH_ST_MT(dt, ct) \ + if (src.layout.dtype.enumv() == DTypeTrait
::enumv) { \ + if (mat.layout.dtype.enumv() == DTypeTrait::enumv) { \ + auto kparam = KernParam::from_tensors( \ + param().border_val, src, mat, diff, grad); \ + MEGDNN_DISPATCH_CPU_KERN_OPR(kern_naive(kparam)); \ + return; \ + } else { \ + auto kparam = KernParam::from_tensors( \ + param().border_val, src, mat, diff, grad); \ + MEGDNN_DISPATCH_CPU_KERN_OPR(kern_naive(kparam)); \ + return; \ + } \ + } + DISPATCH_ST_MT(dtype::Float32, dt_float32); + MEGDNN_INC_FLOAT16(DISPATCH_ST_MT(dtype::BFloat16, dt_bfloat16)); + megdnn_throw(ssprintf("Unsupported input DType in " + "WarpPerspective: %s", + diff.layout.dtype.name()) + .c_str()); +#undef DISPATCH_ST_MT } // vim: syntax=cpp.doxygen diff --git a/dnn/src/naive/warp_perspective/opr_impl.h b/dnn/src/naive/warp_perspective/opr_impl.h index 9ef0ca514e45750ac7ac3ac2deed51b1bd893906..d2a5dc7f8f2dd1ad528e0833e0700947071f2cc2 100644 --- a/dnn/src/naive/warp_perspective/opr_impl.h +++ b/dnn/src/naive/warp_perspective/opr_impl.h @@ -76,7 +76,8 @@ class WarpPerspectiveForwardImpl: public WarpPerspectiveForward { } if (src.layout.dtype.enumv() == DTypeEnum::Float32 || MEGDNN_FLOAT16_SELECT( - src.layout.dtype.enumv() == DTypeEnum::Float16, + (src.layout.dtype.enumv() == DTypeEnum::Float16 || + src.layout.dtype.enumv() == DTypeEnum::BFloat16), false) || src.layout.dtype.enumv() == DTypeEnum::Int8 || src.layout.dtype.enumv() == DTypeEnum::Uint8 || @@ -123,6 +124,27 @@ class WarpPerspectiveForwardImpl: public WarpPerspectiveForward { }; class WarpPerspectiveBackwardDataImpl : public WarpPerspectiveBackwardData { +protected: + template + struct KernParam { + size_t n, c, ih, iw, oh, ow; + ctype *sptr, *hptr; + mtype* mptr; + + static KernParam from_tensors(_megdnn_tensor_in mat, + _megdnn_tensor_in diff, + _megdnn_tensor_out grad) { + KernParam ret; + ret.n = grad.layout.shape[0], ret.c = grad.layout.shape[1], + ret.ih = grad.layout.shape[2], ret.iw = grad.layout.shape[3]; + ret.oh = diff.layout.shape[2], ret.ow = diff.layout.shape[3]; + ret.hptr = diff.ptr(); + ret.mptr = mat.ptr(); + ret.sptr = grad.ptr(); + return ret; + } + }; + public: using WarpPerspectiveBackwardData::WarpPerspectiveBackwardData; void exec(_megdnn_tensor_in mat, _megdnn_tensor_in diff, @@ -131,9 +153,36 @@ public: const TensorLayout&) override { return 0; } +private: + template + void kern_naive(const KernParam& kern_param); }; class WarpPerspectiveBackwardMatImpl : public WarpPerspectiveBackwardMat { +protected: + template + struct KernParam { + size_t n, c, ih, iw, oh, ow; + ctype *sptr, *hptr; + mtype* mptr, *res; + float border_val; + static KernParam from_tensors(float border_val_, _megdnn_tensor_in src, + _megdnn_tensor_in mat, + _megdnn_tensor_in diff, + _megdnn_tensor_out grad) { + KernParam ret; + ret.border_val = border_val_; + ret.n = src.layout.shape[0], ret.c = src.layout.shape[1], + ret.ih = src.layout.shape[2], ret.iw = src.layout.shape[3]; + ret.oh = diff.layout.shape[2], ret.ow = diff.layout.shape[3]; + ret.hptr = diff.ptr(); + ret.mptr = mat.ptr(); + ret.sptr = src.ptr(); + ret.res = grad.ptr(); + return ret; + } + }; + public: using WarpPerspectiveBackwardMat::WarpPerspectiveBackwardMat; void exec(_megdnn_tensor_in src, _megdnn_tensor_in mat, @@ -144,6 +193,10 @@ public: const TensorLayout&) override { return 0; } + +private: + template + void kern_naive(const KernParam& kern_param); }; #define UNPACK_WARP_PERSPECTIVE_FWD_KERN_PARAM(p) \ diff --git a/dnn/test/common/dtype.cpp b/dnn/test/common/dtype.cpp index 6aa504c34a9c10462020d04ef0f614cc5043845b..92f1a8dcda6485b4beb6373674590e913749c50b 100644 --- a/dnn/test/common/dtype.cpp +++ b/dnn/test/common/dtype.cpp @@ -10,9 +10,12 @@ */ #include "megdnn/dtype.h" +#include "megdnn/dtype/bfloat16.hpp" #include #include "test/common/fix_gtest_on_platforms_without_exception.inl" +#include + TEST(TestDType, SizeCheck) { ASSERT_EQ(static_cast(1), ::megdnn::dtype::Int8().size()); ASSERT_EQ(static_cast(1), ::megdnn::dtype::IntB2().size(1)); @@ -97,4 +100,128 @@ TEST(TestDType, TestQuantizedS4) { EXPECT_ANY_THROW(DType::from_enum(DTypeEnum::QuantizedS4)); } +TEST(TestDType, BFLOAT16) { + using namespace megdnn; + using namespace half_bfloat16; + //! Basic bfloat16 dtype tests using RNE round. + bfloat16 m1(2.3515625f), m2(2.351563f), m3(229), m4(-311); + ASSERT_FLOAT_EQ(static_cast(m1), 2.34375f); + ASSERT_FLOAT_EQ(static_cast(m2), 2.359375f); + ASSERT_FLOAT_EQ(static_cast(m3), 229.f); + ASSERT_FLOAT_EQ(static_cast(m4), -312.f); + m3 = -2.3515625f; + m4 = -2.351563f; + ASSERT_FLOAT_EQ(static_cast(m1), static_cast(-m3)); + ASSERT_FLOAT_EQ(static_cast(m2), static_cast(-m4)); + m3 = 2.34375f; + m3 += m2; + m4 = m1; + m4 *= m2; + ASSERT_FLOAT_EQ(static_cast(m3), 4.6875f); + ASSERT_FLOAT_EQ(static_cast(m4), 5.53125f); + m3 -= 2.359375f; + m4 /= 2.359375f; + ASSERT_FLOAT_EQ(static_cast(m3), 2.328125f); + ASSERT_FLOAT_EQ(static_cast(m4), 2.34375f); + m3++; + ++m3; + m4++; + ++m4; + ASSERT_FLOAT_EQ(static_cast(m3), 4.3125f); + ASSERT_FLOAT_EQ(static_cast(m4), 4.34375f); + m3--; + --m3; + m4--; + --m4; + ASSERT_FLOAT_EQ(static_cast(m3), 2.3125f); + ASSERT_FLOAT_EQ(static_cast(m4), 2.34375f); + + //! Comparison operators + ASSERT_TRUE(m1 == m4 && m1 >= m4 && m1 <= m4); + ASSERT_TRUE(m3 != m4 && m4 > m3); + ASSERT_FALSE(m2 < m4); + + //! Arithmetic operators + ASSERT_FLOAT_EQ(m1 + m2, 4.703125f); + ASSERT_FLOAT_EQ(m4 - 3.43281f, -1.08906f); + ASSERT_FLOAT_EQ(-2.34f * m3, -5.41125f); + ASSERT_FLOAT_EQ(9.92625f / m1, 4.2352f); + + //! Basic mathematical operations + bfloat16 b1(-0.5f), b2(0.5f), b3(7.21875); + ASSERT_FLOAT_EQ(abs(b1), abs(b2)); + ASSERT_FLOAT_EQ(acos(b1), 2.094395f); + ASSERT_FLOAT_EQ(acosh(b3), 2.66499658f); + ASSERT_FLOAT_EQ(asin(b1), -0.523599f); + ASSERT_FLOAT_EQ(asinh(b1), -0.48121183f); + ASSERT_FLOAT_EQ(atan(b1), -0.4636476f); + ASSERT_FLOAT_EQ(atan2(b1, b3), -0.06915362f); + ASSERT_FLOAT_EQ(cbrt(b1), -0.79370053f); + ASSERT_FLOAT_EQ(static_cast(ceil(b1)), 0.0f); + ASSERT_FLOAT_EQ(cos(b1), 0.87758255f); + ASSERT_FLOAT_EQ(cosh(b1), 1.12762594f); + ASSERT_FLOAT_EQ(erf(b1), -0.52049988f); + ASSERT_FLOAT_EQ(erfc(b1), 1.52049988f); + ASSERT_FLOAT_EQ(exp(b2), 1.64872122f); + ASSERT_FLOAT_EQ(exp2(b2), 1.41421356f); + ASSERT_FLOAT_EQ(expm1(b2), 0.64872127f); + ASSERT_FLOAT_EQ(fdim(b2, b1), 1.0f); + ASSERT_FLOAT_EQ(floor(b1), -1.0f); + ASSERT_FLOAT_EQ(fma(b1, b2, b1), -0.75f); + ASSERT_FLOAT_EQ(fmax(b1, b2), 0.5f); + ASSERT_FLOAT_EQ(fmin(b1, b2), -0.5f); + ASSERT_FLOAT_EQ(fmod(b3, b2), 0.21875f); + ASSERT_FLOAT_EQ(hypot(b2, b3), 7.23604530f); + ASSERT_FLOAT_EQ(lgamma(b1), 1.26551212f); + ASSERT_FLOAT_EQ(log(b3), 1.97668183f); + ASSERT_FLOAT_EQ(log10(b3), 0.85846198f); + ASSERT_FLOAT_EQ(log1p(b3), 2.10641813f); + ASSERT_FLOAT_EQ(log2(b3), 2.85174904f); + ASSERT_FLOAT_EQ(lrint(b3), 7.f); + ASSERT_EQ(lround(b1), -1); + ASSERT_EQ(lround(b2), 1); + ASSERT_TRUE(isnan(nanh(""))); + ASSERT_FLOAT_EQ(nearbyint(b3), 7.f); + ASSERT_FLOAT_EQ(pow(b3, 2.53f), 148.56237793f); + ASSERT_FLOAT_EQ(remainder(b3, b2), 0.21875f); + ASSERT_FLOAT_EQ(sin(b1), -0.47942555f); + ASSERT_FLOAT_EQ(sinh(b1), -0.52109528f); + ASSERT_FLOAT_EQ(sqrt(b3), 2.68677306f); + ASSERT_FLOAT_EQ(tan(b3), 1.35656071f); + ASSERT_FLOAT_EQ(tanh(b3), 0.99999893f); + ASSERT_FLOAT_EQ(tgamma(b3), 1088.50023434f); + ASSERT_FLOAT_EQ(trunc(b1), 0.0f); + ASSERT_FLOAT_EQ(trunc(b3), 7.0f); + ASSERT_FLOAT_EQ(static_cast(copysign(b1, b2)), 0.5f); + int i = 0; + ASSERT_FLOAT_EQ(static_cast(frexp(b3, &i)), 0.90234375f); + ASSERT_EQ(i, 3); + ASSERT_EQ(ilogb(b3), 2); + ASSERT_FLOAT_EQ(static_cast(ldexp(b3, 4)), 115.50f); + ASSERT_FLOAT_EQ(static_cast(logb(b3)), 2.f); + bfloat16 bf(0.f); + ASSERT_FLOAT_EQ(static_cast(modf(b3, &bf)), 0.21875f); + ASSERT_FLOAT_EQ(static_cast(bf), 7.f); + ASSERT_FLOAT_EQ(static_cast(nextafter(b2, b3)), 0.50390625f); + ASSERT_FLOAT_EQ(static_cast(nextafter(b2, b1)), 0.49804688f); + ASSERT_TRUE(signbit(b1)); + ASSERT_FALSE(signbit(b2)); + + //! Special(Denormal) number. + //! flaot -> bfloat16 + float finf = std::numeric_limits::infinity(), + fnan = std::numeric_limits::quiet_NaN(); + bfloat16 bfinf(finf), bfnan(fnan); + ASSERT_TRUE(isinf(bfinf)); + ASSERT_FALSE(isfinite(bfinf)); + ASSERT_TRUE(isnan(bfnan)); + ASSERT_FALSE(isnormal(bfnan)); + + //! bfloat16 -> float + bfinf = std::numeric_limits::infinity(); + finf = bfinf; + ASSERT_TRUE(std::isinf(finf)); + ASSERT_FALSE(std::isfinite(finf)); +} + // vim: syntax=cpp.doxygen diff --git a/dnn/test/common/extra_impl_helper.cpp b/dnn/test/common/extra_impl_helper.cpp new file mode 100644 index 0000000000000000000000000000000000000000..aa3d5486b6ffbc8c65840533052a451d63e93e7d --- /dev/null +++ b/dnn/test/common/extra_impl_helper.cpp @@ -0,0 +1,49 @@ +/** + * \file test/common/extra_impl_helper.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "test/common/extra_impl_helper.h" + +namespace megdnn { +namespace test { + +template <> +std::function extra_impl_helper( + Handle* h, const AddUpdate::Param& p) { + auto impl = [](const TensorNDArray& tensors, Handle* h, + const AddUpdate::Param& p) { + auto fp32_opr = h->create_operator(); + auto type_cvt = h->create_operator(); + fp32_opr->param() = p; + + TensorNDArray fp32_tensors; + for (size_t i = 0; i < tensors.size(); ++i) { + auto layout = tensors[i].layout; + layout.dtype = dtype::Float32(); + fp32_tensors.emplace_back(malloc(layout.span().dist_byte()), + layout); + type_cvt->exec(tensors[i], fp32_tensors[i]); + } + + fp32_opr->exec(fp32_tensors[0], fp32_tensors[1]); + + type_cvt->exec(fp32_tensors[0], tensors[0]); + + for (size_t i = 0; i < tensors.size(); ++i) { + free(fp32_tensors[i].raw_ptr); + } + }; + return std::bind(impl, std::placeholders::_1, h, std::cref(p)); +} + +} // namespace test +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/test/common/extra_impl_helper.h b/dnn/test/common/extra_impl_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..631483c5d13ba9616e973d5ac055425b026cc036 --- /dev/null +++ b/dnn/test/common/extra_impl_helper.h @@ -0,0 +1,63 @@ +/** + * \file test/common/extra_impl_helper.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "test/common/opr_proxy.h" +#include "megdnn/oprs/general.h" +#include "megdnn/basic_types.h" +#include "megdnn/handle.h" + +namespace megdnn { +namespace test { + +template > +std::function extra_impl_helper( + Handle* h, const typename Opr::Param& p) { + auto impl = [](const TensorNDArray& tensors, Handle* h, + const typename Opr::Param& p) { + static_assert(NR_OUTPUTS <= OprTrait::arity, + "OutNumber should less than or equal to arity."); + Proxy proxy; + auto fp32_opr = h->create_operator(); + auto type_cvt = h->create_operator(); + fp32_opr->param() = p; + + TensorNDArray fp32_tensors; + for (size_t i = 0; i < tensors.size(); ++i) { + auto layout = tensors[i].layout; + layout.dtype = dtype::Float32(); + fp32_tensors.emplace_back(malloc(layout.span().dist_byte()), + layout); + type_cvt->exec(tensors[i], fp32_tensors[i]); + } + + proxy.exec(fp32_opr.get(), fp32_tensors); + + for (size_t i = fp32_tensors.size() - NR_OUTPUTS; + i < fp32_tensors.size(); ++i) { + type_cvt->exec(fp32_tensors[i], tensors[i]); + } + + for (size_t i = 0; i < tensors.size(); ++i) { + free(fp32_tensors[i].raw_ptr); + } + }; + return std::bind(impl, std::placeholders::_1, h, std::cref(p)); +} + +template <> +std::function extra_impl_helper( + Handle* h, const AddUpdate::Param& p); + +} // namespace test +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/test/common/rng.h b/dnn/test/common/rng.h index 5bea1b12abbc0dda590a5d145875c27214706b88..7af67117573ae87f6e3d304ccfa19f6ff17690c2 100644 --- a/dnn/test/common/rng.h +++ b/dnn/test/common/rng.h @@ -12,6 +12,7 @@ #include "megdnn/dtype.h" #include "test/common/utils.h" +#include "test/common/random_state.h" #include #include @@ -41,6 +42,54 @@ private: std::vector m_sequence; }; +class BFloat16PeriodicalRNG : public RNG { +public: + BFloat16PeriodicalRNG() { + size_t bits = sizeof(dt_bfloat16) * 8; + size_t mantissa_bits = std::numeric_limits::digits - 1; + size_t exponent_bits = bits - mantissa_bits - 1; + for (size_t exp = 1u << (exponent_bits - 2); + exp < (1u << exponent_bits) - (1u << (exponent_bits - 2)); ++exp) { + for (size_t x = 0; x < 1u << mantissa_bits; ++x) { + size_t pos_num = (exp << mantissa_bits) + x; + size_t neg_num = + (1u << (bits - 1)) + (exp << mantissa_bits) + x; + union U { + U() {} + uint16_t i; + dt_bfloat16 f; + } i2f; + i2f.i = static_cast(pos_num); + m_sequence.push_back(i2f.f); + i2f.i = static_cast(neg_num); + m_sequence.push_back(i2f.f); + } + } + std::shuffle(m_sequence.begin(), m_sequence.end(), + RandomState::generator()); + } + + void gen(const TensorND& tensor) override { + megdnn_assert(tensor.layout.dtype.enumv() == DTypeTrait::enumv); + size_t nr_elems = tensor.layout.span().dist_elem(); + auto offset = tensor.layout.span().low_elem; + for (size_t i = 0; i < nr_elems; ++i) { + tensor.ptr()[offset + i] = get_single_val(); + } + } + + dt_bfloat16 get_single_val() { + if (m_offset >= m_sequence.size()) { + m_offset = 0; + } + return m_sequence[m_offset++]; + } + +private: + size_t m_offset = 0; + std::vector m_sequence; +}; + class IIDRNG : public RNG { public: void gen(const TensorND& tensor) override; diff --git a/dnn/test/cuda/add_update.cpp b/dnn/test/cuda/add_update.cpp index f92703a32dcad7bd299131aa364c1678a80e880c..328a2cce1c638d8745b638421300c98d730a0d20 100644 --- a/dnn/test/cuda/add_update.cpp +++ b/dnn/test/cuda/add_update.cpp @@ -25,10 +25,15 @@ TEST_F(CUDA, ADD_UPDATE) { checker.set_dtype(0, dtype::Float16()) .set_dtype(1, dtype::Float16()) .execs({{2, 3, 4}, {2, 3, 4}}); + checker.set_dtype(0, dtype::BFloat16()) + .set_dtype(1, dtype::BFloat16()) + .execs({{2, 3, 4}, {2, 3, 4}}); checker.execl({{{2, 3, 4}, dtype::Float32()}, {{2, 3, 4}, {16, 4, 1}, dtype::Float32()}}); checker.execl({{{2, 3, 4}, dtype::Float16()}, {{2, 3, 4}, {16, 4, 1}, dtype::Float16()}}); + checker.execl({{{2, 3, 4}, dtype::BFloat16()}, + {{2, 3, 4}, {16, 4, 1}, dtype::BFloat16()}}); checker.execl({{{2, 3, 4}, {16, 4, 1}, dtype::Float32()}, {{2, 3, 4}, dtype::Float32()}}); @@ -46,7 +51,7 @@ TEST_F(CUDA, ADD_UPDATE) { checker.set_dtype(0, dtype::Uint8()) .set_dtype(1, dtype::Uint8()) .execs({{2, 3, 2}, {2, 3, 2}}); - // test scalar + // test scalar checker.set_dtype(0, dtype::Int8()) .set_dtype(1, dtype::Int8()) .execs({{1}, {1}}); diff --git a/dnn/test/cuda/convolution.cpp b/dnn/test/cuda/convolution.cpp index 0115e7e090a686240927c76fbdf31859d0b136e4..ff3e42e089cd3bf5e6806fa50335cc69a0b8ed09 100644 --- a/dnn/test/cuda/convolution.cpp +++ b/dnn/test/cuda/convolution.cpp @@ -108,6 +108,13 @@ TEST_F(CUDA, CONVOLUTION_FORWARD) .set_epsilon(1e-1) .set_param(arg.param) .execs({arg.src, arg.filter, {}}); + checker.set_dtype(0, dtype::BFloat16()) + .set_dtype(1, dtype::BFloat16()) + .set_dtype(2, dtype::BFloat16()) + .set_epsilon(1e-1) + .set_param(arg.param) + .execs({arg.src, arg.filter, {}}); + } } @@ -216,6 +223,13 @@ TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA) .set_epsilon(1e-1) .set_param(arg.param) .exec(TensorLayoutArray{filter, dst, src}); + src.dtype = dst.dtype = filter.dtype = dtype::BFloat16(); + checker. + set_rng(0, &rng). + set_rng(1, &rng). + set_epsilon(1e-1). + set_param(arg.param). + exec(TensorLayoutArray{filter, dst, src}); } } } @@ -308,6 +322,12 @@ TEST_F(CUDA, CONVOLUTION_BACKWARD_FILTER) .set_epsilon(1e-1) .set_param(arg.param) .exec(TensorLayoutArray{src, dst, filter}); + src.dtype = dst.dtype = filter.dtype = dtype::BFloat16(); + checker.set_rng(0, &rng) + .set_rng(1, &rng) + .set_epsilon(1e-1) + .set_param(arg.param) + .exec(TensorLayoutArray{src, dst, filter}); } } diff --git a/dnn/test/cuda/elemwise.cpp b/dnn/test/cuda/elemwise.cpp index 13ef029a048b889bcb26e3fcd56ca2e2134bf430..1ae79d3b0a669cafa0b01bd6078079e5dc67a9cf 100644 --- a/dnn/test/cuda/elemwise.cpp +++ b/dnn/test/cuda/elemwise.cpp @@ -174,6 +174,105 @@ TEST_F(CUDA, ELEMWISE_IBYTE) { #undef RUN_TERNARY_IBYTE } +// from common/elemwise.cpp +TEST_F(CUDA, ELEMWISE_BFLOAT16) { + using Mode = ElemwiseForward::Param::Mode; + Checker checker(handle_cuda()); + + // unary +#define UNARY_TEST_CASE(_optr) \ + checker.set_param(Mode::_optr).execs({{1, 127}, {}}); \ + checker.set_param(Mode::_optr).execs({{1, 7}, {}}); + +#define BUILD_UNARY_TEST_CASE_FLOAT \ + UNARY_TEST_CASE(ABS) \ + UNARY_TEST_CASE(LOG) \ + UNARY_TEST_CASE(COS) \ + UNARY_TEST_CASE(SIN) \ + UNARY_TEST_CASE(FLOOR) \ + UNARY_TEST_CASE(CEIL) \ + UNARY_TEST_CASE(SIGMOID) \ + UNARY_TEST_CASE(EXP) \ + UNARY_TEST_CASE(TANH) \ + UNARY_TEST_CASE(FAST_TANH) \ + UNARY_TEST_CASE(RELU) \ + UNARY_TEST_CASE(ROUND) + + checker.set_dtype(0, dtype::BFloat16()); + checker.set_dtype(1, dtype::BFloat16()); + UniformFloatRNG rng0(1e-2, 6e1); + checker.set_rng(0, &rng0); + checker.set_epsilon(1e-2); + BUILD_UNARY_TEST_CASE_FLOAT + +#undef UNARY_TEST_CASE +#undef BUILD_UNARY_TEST_CASE_FLOAT + + // binary +#define BINARY_COMPLATE_TEST_CASE(_optr) \ + checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {}}); \ + checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 4, 1, 1}, {}}); \ + checker.set_param(Mode::_optr).execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {}}); \ + checker.set_param(Mode::_optr).execs({{3, 4, 7}, {1, 4, 1}, {}}); \ + checker.set_param(Mode::_optr).execs({{1, 4, 1}, {3, 4, 7}, {}}); \ + checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \ + checker.set_param(Mode::_optr).execs({{1, 1, 1, 1}, {3, 4, 5, 7}, {}}); \ + checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {}}); \ + checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 2, 1}, {}}); \ + checker.set_param(Mode::_optr).execs({{1, 2, 1}, {1, 2, 2}, {}}); \ + checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 1, 1}, {}}); \ + checker.set_param(Mode::_optr).execs({{1, 1, 1}, {1, 2, 2}, {}}); \ + checker.set_param(Mode::_optr).execs({{3, 4, 1}, {3, 4, 1}, {}}); + +#define BUILD_BINARY_COMPLATE_TEST_CASE \ + BINARY_COMPLATE_TEST_CASE(ADD) \ + BINARY_COMPLATE_TEST_CASE(MUL) \ + BINARY_COMPLATE_TEST_CASE(MAX) \ + BINARY_COMPLATE_TEST_CASE(MIN) \ + BINARY_COMPLATE_TEST_CASE(SUB) + + UniformFloatRNG rng1(1e-5, 7e1); + checker.set_rng(0, &rng1); + checker.set_epsilon(1e-2); + checker.set_dtype(0, dtype::BFloat16()); + checker.set_dtype(1, dtype::BFloat16()); + BUILD_BINARY_COMPLATE_TEST_CASE + +#undef BINARY_COMPLATE_TEST_CASE +#undef BUILD_BINARY_COMPLATE_TEST_CASE + + // ternary +#define TERNARY_COMPLATE_TEST_CASE(_optr) \ + checker.set_param(Mode::_optr) \ + .execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}}); \ + checker.set_param(Mode::_optr) \ + .execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}}); \ + checker.set_param(Mode::_optr) \ + .execs({{1, 4, 1}, {3, 4, 7}, {1, 4, 1}, {}}); \ + checker.set_param(Mode::_optr) \ + .execs({{3, 4, 5, 7}, {3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \ + checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {1, 7}, {}}); \ + checker.set_param(Mode::_optr) \ + .execs({{1, 2, 1}, {1, 2, 2}, {1, 2, 1}, {}}); \ + checker.set_param(Mode::_optr) \ + .execs({{1, 2, 2}, {1, 2, 2}, {1, 1, 1}, {}}); \ + checker.set_param(Mode::_optr).execs({{3, 4, 1}, {3, 4, 1}, {3, 4, 1}, {}}); + +#define BUILD_TERNARY_COMPLATE_TEST_CASE \ + TERNARY_COMPLATE_TEST_CASE(FUSE_MUL_ADD3) + + UniformFloatRNG rng2(1e-5, 7e1); + checker.set_rng(0, &rng2); + checker.set_epsilon(1e-2); + checker.set_dtype(0, dtype::BFloat16()); + checker.set_dtype(1, dtype::BFloat16()); + checker.set_dtype(2, dtype::BFloat16()); + BUILD_TERNARY_COMPLATE_TEST_CASE + +#undef TERNARY_COMPLATE_TEST_CASE +#undef BUILD_TERNARY_COMPLATE_TEST_CASE +} + //! the memory of this test case is too large, sometimes will fail on tx1 TEST_F(CUDA, ELEMWISE_BENCHMARK_DENSE) { constexpr size_t A = 256 * 1024 * 64, diff --git a/dnn/test/cuda/matrix_mul.cpp b/dnn/test/cuda/matrix_mul.cpp index f00e466a31f2fc49efeda20718f4aec9212d5e1e..0233a17ed2854c49e00cf501f2e051bc7fb0538c 100644 --- a/dnn/test/cuda/matrix_mul.cpp +++ b/dnn/test/cuda/matrix_mul.cpp @@ -184,8 +184,7 @@ TEST_F(CUDA, MATRIX_MUL_INT8x8x32_NAIVE) { } } -TEST_F(CUDA, MATRIX_MUL) -{ +TEST_F(CUDA, MATRIX_MUL) { if (cuda::current_device_prop().major < 6) { printf("Skip CUDA.MATRIX_MUL test as current device doesn't support\n"); return; @@ -198,6 +197,7 @@ TEST_F(CUDA, MATRIX_MUL) std::vector dtype_array; dtype_array.push_back(dtype::Float32()); dtype_array.push_back(dtype::Float16()); + dtype_array.push_back(dtype::BFloat16()); if (is_int_available) dtype_array.push_back(dtype::Int32()); @@ -216,12 +216,18 @@ TEST_F(CUDA, MATRIX_MUL) B = TensorShape{n, k}; else B = TensorShape{k, n}; - checker.set_param(param). - set_dtype(0, stype). - set_dtype(1, stype). - set_dtype(2, dtype). - set_epsilon(dtype == dtype::Float16() ? 5e-2 : 5e-3). - execs({A, B, {}}); + if (dtype == dtype::BFloat16()) { + param.compute_mode = param::MatrixMul::ComputeMode::FLOAT32; + } + checker.set_param(param) + .set_dtype(0, stype) + .set_dtype(1, stype) + .set_dtype(2, dtype) + .set_epsilon(dtype == dtype::Float16() || + dtype == dtype::BFloat16() + ? 5e-2 + : 5e-3) + .execs({A, B, {}}); } } diff --git a/dnn/test/cuda/pooling.cpp b/dnn/test/cuda/pooling.cpp index 1da117afe4ea83f504e7c3c816e6de35f23865c1..69a7ccc53adda216abfaa2d566a3b4cbbfa019ce 100644 --- a/dnn/test/cuda/pooling.cpp +++ b/dnn/test/cuda/pooling.cpp @@ -28,7 +28,7 @@ TEST_F(CUDA, POOLING_FORWARD) { auto args = pooling::get_args(); using Format = param::Pooling::Format; - std::vector dtypes{dtype::Float16(), dtype::Float32()}; + std::vector dtypes{dtype::Float16(), dtype::BFloat16(), dtype::Float32()}; if (check_compute_capability(6, 0)) { // int pooling is supported only for Pascal or higher dtypes.push_back(dtype::Int8()); @@ -47,6 +47,8 @@ TEST_F(CUDA, POOLING_FORWARD) // different versions of cuDNN differs in rounding behavior; // setting eps to 1 to allow for rounding errors. checker.set_epsilon(1 + 1e-3); + } else if (dtype == dtype::BFloat16()) { + checker.set_epsilon(2e-2); } else { checker.set_epsilon(1e-2); } @@ -75,7 +77,10 @@ TEST_F(CUDA, POOLING_FORWARD) // different versions of cuDNN differs in rounding behavior; // setting eps to 1 to allow for rounding errors. checker.set_epsilon(1 + 1e-3); - } else { + } else if (dtype == dtype::BFloat16()) { + checker.set_epsilon(2e-2); + } + else { checker.set_epsilon(1e-2); } checker.set_param(param) @@ -153,6 +158,12 @@ TEST_F(CUDA, POOLING_BACKWARD) .set_epsilon(1e-2) .exec(TensorShapeArray{ ilayout, olayout, olayout, ilayout}); + BFloat16PeriodicalRNG bf16_rng; + set_dtype(dtype::BFloat16()); + checker.set_param(arg.param) + .set_rng(0, &bf16_rng) + .set_epsilon(1e-2) + .exec(TensorShapeArray{ilayout, olayout, olayout, ilayout}); } /* add test for new Mode temporarily */ @@ -223,6 +234,12 @@ TEST_F(CUDA, POOLING_BACKWARD) .set_epsilon(1e-2) .exec(TensorShapeArray{ ilayout, olayout, olayout, ilayout}); + BFloat16PeriodicalRNG bf16_rng; + set_dtype(dtype::BFloat16()); + checker.set_param(arg.param) + .set_rng(0, &bf16_rng) + .set_epsilon(1e-2) + .exec(TensorShapeArray{ilayout, olayout, olayout, ilayout}); } } diff --git a/dnn/test/cuda/roi_align.cpp b/dnn/test/cuda/roi_align.cpp index 8c8a6cbf052e96fc864bbe870bdb506cca6b5891..84013d4f019a1eed7d8e46fe2810699a5413ce6b 100644 --- a/dnn/test/cuda/roi_align.cpp +++ b/dnn/test/cuda/roi_align.cpp @@ -85,6 +85,8 @@ TEST_F(CUDA, ROI_ALIGN_BACKWARD) { }; run(dtype::Float32()); run(dtype::Float16()); + checker.set_epsilon(5e-2); + run(dtype::BFloat16()); } } // namespace test diff --git a/dnn/test/cuda/type_cvt.cpp b/dnn/test/cuda/type_cvt.cpp index 8b319c715f90c17ab5916ee16454f306d44aaafa..af481993c1afd3898362d18791f43c73d2a9d462 100644 --- a/dnn/test/cuda/type_cvt.cpp +++ b/dnn/test/cuda/type_cvt.cpp @@ -106,6 +106,21 @@ TEST_F(CUDA, QUANTIZED_TYPECVT) { run(dtype::Quantized8Asymm(1e-3f, (uint8_t)18), dtype::QuantizedS32(7e-4f)); } +TEST_F(CUDA, TYPE_CVT_BFLOAT16) { + Checker checker(handle_cuda()); + UniformFloatRNG rng(-20, 20); + checker.set_rng(0, &rng); + std::vector dtypes = {dtype::Float32(), dtype::Float16(), + dtype::Int32(), dtype::Int16(), + dtype::Int8()}; + for (auto sdtype : dtypes) { + TensorLayout src({10, 10}, sdtype), dst({10, 10}, dtype::BFloat16()); + checker.exec(TensorLayoutArray{src, dst}); + TensorLayout src2({10, 10}, dtype::BFloat16()), dst2({10, 10}, sdtype); + checker.exec(TensorLayoutArray{src2, dst2}); + } +} + #if MEGDNN_WITH_BENCHMARK TEST_F(CUDA, BENCHMARK_TYPE_CVT) { UniformIntRNG rng{-128, 127}; diff --git a/dnn/test/cuda/warp_perspective.cpp b/dnn/test/cuda/warp_perspective.cpp index c5909909e5f618028a80a4896fd6fce4515d9955..6997676f0b9359d56f25076c3793b15ace682ec7 100644 --- a/dnn/test/cuda/warp_perspective.cpp +++ b/dnn/test/cuda/warp_perspective.cpp @@ -401,6 +401,91 @@ TEST_F(CUDA, WARP_PERSPECTIVE_BACKWARD_MAT) } } +TEST_F(CUDA, WARP_PERSPECTIVE_FORWARD_BFLOAT16) +{ + using Param = WarpPerspective::Param; + Checker checker(handle_cuda()); + WarpPerspectiveMatRNG rng; + checker.set_rng(1, &rng); + checker.set_dtype(0, dtype::BFloat16()) + .set_dtype(1, dtype::Float32()) + .set_dtype(2, dtype::BFloat16()); + for (auto bmode: {WarpPerspective::BorderMode::WRAP, + WarpPerspective::BorderMode::REFLECT, + WarpPerspective::BorderMode::REPLICATE, + WarpPerspective::BorderMode::CONSTANT}) + { + WarpPerspective::Param param; + param.border_val = 0.3f; + param.bmode = bmode; + param.imode = Param::InterpolationMode::LINEAR; + + param.format = Param::Format::NHWC; + checker.set_param(param); + checker.set_epsilon(2.1).set_max_avg_error(4e-2); + checker.execs({{2, 10, 11, 3}, {2, 3, 3}, {2, 11, 12, 3}}); + + param.format = Param::Format::NCHW; + checker.set_param(param); + checker.execs({{2, 3, 10, 11}, {2, 3, 3}, {2, 3, 11, 12}}); + checker.execs({{20, 3000, 10, 11}, {20, 3, 3}, {20, 3000, 11, 12}}); + } +} + +TEST_F(CUDA, WARP_PERSPECTIVE_BACKWARD_DATA_BFLOAT16) +{ + Checker checker(handle_cuda()); + WarpPerspectiveMatRNG rng; + checker.set_rng(0, &rng) + .set_epsilon(1e-1) + .set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::BFloat16()) + .set_dtype(2, dtype::BFloat16()); + for (int i = 0; i < 1; ++i) { + for (auto bmode: {WarpPerspective::BorderMode::WRAP, + WarpPerspective::BorderMode::REFLECT, + WarpPerspective::BorderMode::REPLICATE, + WarpPerspective::BorderMode::CONSTANT}) + { + WarpPerspective::Param param; + param.border_val = 0.3f; + param.bmode = bmode; + param.imode = param::WarpPerspective::InterpolationMode::LINEAR; + checker.set_param(param); + checker.execs({{2, 3, 3}, {2, 3, 11, 12}, {2, 3, 10, 11}}); + } + } +} + +TEST_F(CUDA, WARP_PERSPECTIVE_BACKWARD_MAT_BFLOAT16) +{ + Checker checker(handle_cuda()); + WarpPerspectiveMatRNG rng; + checker.set_rng(1, &rng) + .set_epsilon(1e-2) + .set_dtype(0, dtype::BFloat16()) + .set_dtype(1, dtype::Float32()) + .set_dtype(2, dtype::BFloat16()) + .set_dtype(3, dtype::Float32()); + for (int i = 0; i < 1; ++i) { + for (auto bmode: {WarpPerspective::BorderMode::WRAP, + WarpPerspective::BorderMode::REFLECT, + WarpPerspective::BorderMode::REPLICATE, + WarpPerspective::BorderMode::CONSTANT}) + { + WarpPerspective::Param param; + param.border_val = 0.3f; + param.imode = param::WarpPerspective::InterpolationMode::LINEAR; + param.bmode = bmode; + checker.set_param(param); + checker.execs({ + {1000, 3, 11, 12}, {1000, 3, 3}, + {1000, 3, 10, 11}, {1000, 3, 3} + }); + } + } +} + TEST_F(CUDA, WARP_PERSPECTIVE_MAT_IDX) { warp_perspective::run_mat_idx_test(handle_cuda()); } diff --git a/dnn/test/naive/add_update.cpp b/dnn/test/naive/add_update.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c0f523626202ff831084cb5d8fea5eec9306776c --- /dev/null +++ b/dnn/test/naive/add_update.cpp @@ -0,0 +1,36 @@ +/** + * \file test/naive/add_update.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "test/naive/fixture.h" + +#include "megdnn/oprs.h" +#include "test/common/checker.h" +#include "test/common/extra_impl_helper.h" + +namespace megdnn { +namespace test { + +TEST_F(NAIVE, ADD_UPDATE_BFLOAT16) { + Checker checker(handle(), false); + param::AddUpdate p{2, -1, 3}; + auto extra_impl = extra_impl_helper(handle(), p); + checker.set_param(p) + .set_dtype(0, dtype::BFloat16()) + .set_dtype(1, dtype::BFloat16()) + .set_extra_opr_impl(extra_impl) + .execs({{2, 2, 3}, {2, 2, 3}}); +} + +} // namespace test +} // namespace megdnn + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/test/naive/convolution.cpp b/dnn/test/naive/convolution.cpp index 797328a964cd01f4afd340042999944519370f36..c6b9078135c1832d7af5e2b272c5d51a38a159fd 100644 --- a/dnn/test/naive/convolution.cpp +++ b/dnn/test/naive/convolution.cpp @@ -15,6 +15,7 @@ #include "test/common/checker.h" #include "test/common/random_state.h" #include "test/common/convolution.h" +#include "test/common/extra_impl_helper.h" using namespace megdnn; using namespace test; @@ -244,4 +245,125 @@ TEST_F(NAIVE, CONVOLUTION_WITH_NCHW4) { .execs({{20, 27, 30, 30, 4}, {3, 4, 9, 1, 1, 4}, {}}); } +TEST_F(NAIVE, CONVOLUTION_BFLOAT16) { + Checker checker(handle(), false); + using Param = Convolution::Param; + Param param; + param.sparse = param::Convolution::Sparse::DENSE; + Param impl_param = param; + + auto run = [&](size_t n, size_t ic, size_t ih, size_t iw, size_t oc, + size_t fh, size_t fw) { + float scale = 1.0f / sqrt(ic * fh * fw); + UniformFloatRNG rng(scale, 2 * scale); + param.pad_h = param.pad_w = 1; + param.stride_h = param.stride_w = 1; + impl_param.pad_h = impl_param.pad_w = 1; + impl_param.stride_h = impl_param.stride_w = 1; + auto extra_impl = + extra_impl_helper(handle(), impl_param); + for (auto cmode : + std::vector{Param::ComputeMode::DEFAULT, + Param::ComputeMode::FLOAT32}) { + param.compute_mode = cmode; + checker.set_param(param) + .set_dtype(0, dtype::BFloat16()) + .set_dtype(1, dtype::BFloat16()) + // Use inferred output dtype. + .set_dtype(2, {}) + .set_rng(0, &rng) + .set_rng(1, &rng) + .set_extra_opr_impl(extra_impl) + .set_epsilon(1e-1) + .execs({{n, ic, ih, iw}, {oc, ic, fh, fw}, {}}); + } + }; + + run(1, 1, 20, 20, 5, 3, 3); + run(1, 2, 8, 7, 11, 3, 1); +} + +TEST_F(NAIVE, CONVOLUTION_BACKWARD_DATA_BFLOAT16) { + Checker checker(handle(), false); + using Param = ConvolutionBackwardData::Param; + + Param param, impl_param; + param.sparse = Param::Sparse::DENSE; + auto run = [&](size_t n, size_t ic, size_t oh, size_t ow, size_t oc, + size_t fh, size_t fw, size_t stride, size_t padding, + const Param::ComputeMode& cmode = + Param::ComputeMode::DEFAULT) { + param.pad_h = param.pad_w = padding; + param.stride_h = param.stride_w = stride; + param.dilate_h = param.dilate_w = 1; + param.compute_mode = cmode; + + TensorLayout diff = + TensorLayout{{n, oc, oh, ow}, dtype::BFloat16()}; + TensorLayout grad; + TensorLayout filter; + filter = {{oc, ic, fh, fw}, dtype::BFloat16()}; + // TensorLayout grad; + { + auto opr = handle()->create_operator(); + opr->param() = param; + opr->deduce_layout(filter, diff, grad); + } + impl_param = param; + impl_param.compute_mode = Param::ComputeMode::DEFAULT; + auto extra_impl = extra_impl_helper( + handle(), impl_param); + checker.set_param(param) + .set_extra_opr_impl(extra_impl) + .set_epsilon(1e-1) + .set_dtype(0, dtype::BFloat16()) + .set_dtype(1, dtype::BFloat16()); + checker.exec(TensorLayoutArray{filter, diff, grad}); + }; + + run(4, 3, 10, 13, 5, 1, 1, 1, 0); + run(2, 1, 24, 43, 11, 3, 3, 2, 1, Param::ComputeMode::FLOAT32); +} + +TEST_F(NAIVE, CONVOLUTION_BACKWARD_FILTER_BFLOAT16) { + using namespace convolution; + Checker checker(handle(), false); + using Param = ConvolutionBackwardFilter::Param; + Param param; + Param impl_param = param; + + auto run = [&](size_t n, size_t ic, size_t ih, size_t iw, size_t oc, + size_t fh, size_t fw, + const Param::ComputeMode& cmode = + Param::ComputeMode::DEFAULT) { + auto src = TensorLayout({n, ic, ih, iw}, dtype::BFloat16()); + auto filter = TensorLayout({oc, ic, fh, fw}, dtype::BFloat16()); + TensorLayout dst; + { + auto opr = handle()->create_operator(); + opr->param() = param; + opr->deduce_layout(src, filter, dst); + } + float scale = 1.0f / sqrt(dst[2] * dst[3]); + UniformFloatRNG rng(scale, 2 * scale); + src.dtype = dst.dtype = filter.dtype = dtype::BFloat16(); + param.compute_mode = cmode; + impl_param = param; + impl_param.compute_mode = Param::ComputeMode::DEFAULT; + auto extra_impl = extra_impl_helper( + handle(), impl_param); + checker.set_rng(0, &rng) + .set_rng(1, &rng) + .set_dtype(0, dtype::BFloat16()) + .set_dtype(1, dtype::BFloat16()) + .set_epsilon(1e-1) + .set_extra_opr_impl(extra_impl) + .set_param(param) + .exec(TensorLayoutArray{src, dst, filter}); + }; + + run(1, 2, 8, 7, 11, 3, 1); + run(1, 1, 20, 20, 5, 3, 3, Param::ComputeMode::FLOAT32); +} + // vim: syntax=cpp.doxygen diff --git a/dnn/test/naive/matrix_mul.cpp b/dnn/test/naive/matrix_mul.cpp index c04c377f635910821a8e41cb667dec1b110f02e8..2d04d01374131126b49c66dfbd971ea33f0324ba 100644 --- a/dnn/test/naive/matrix_mul.cpp +++ b/dnn/test/naive/matrix_mul.cpp @@ -12,8 +12,9 @@ #include "megdnn/oprs/linalg.h" #include "test/common/checker.h" -#include "test/common/random_state.h" #include "test/common/matrix_mul.h" +#include "test/common/random_state.h" +#include "test/common/extra_impl_helper.h" using namespace megdnn; using namespace test; @@ -29,8 +30,8 @@ void run_matmul_mk_format(Handle* handle, param::MatrixMul::Format format, auto extra_impl = [](const TensorNDArray& tensors, param::MatrixMul param, Handle* handle, size_t pack_size) { megdnn_assert((param.format == param::MatrixMul::Format::MK4 || - param.format == param::MatrixMul::Format::MK8) && - tensors.size() == 3); + param.format == param::MatrixMul::Format::MK8) && + tensors.size() == 3); param::MatrixMul new_param = param; new_param.format = param::MatrixMul::Format::DEFAULT; size_t M = tensors[2].layout[0] * pack_size; @@ -133,7 +134,7 @@ void run_matmul_mk_format(Handle* handle, param::MatrixMul::Format format, } // namespace TEST_F(NAIVE, MATRIX_MUL_QUANTIZED4x4x32) { - Checker checker(handle(), /* check_dispatch */false); + Checker checker(handle(), /* check_dispatch */ false); auto GenTensorValueQuint4 = [](const TensorShape& shape, dtype::Quantized4Asymm dtype, const std::vector& values) { @@ -186,49 +187,45 @@ TEST_F(NAIVE, MATRIX_MUL_QUANTIZED4x4x32) { } TEST_F(NAIVE, MATRIX_MUL_QUANTIZED8x8x32) { - Checker checker(handle(), /* check_dispatch */false); + Checker checker(handle(), /* check_dispatch */ false); MatrixMul::Param param; param.transposeA = false; param.transposeB = false; checker.set_param(param).exect( - Testcase{ - TensorValue({4, 7}, dtype::Quantized8Asymm(0.1f, (uint8_t)128), - {6, 97, 210, 47, 213, 246, 92, - 121, 132, 133, 37, 31, 87, 71, - 0, 5, 198, 11, 97, 141, 222, - 166, 76, 212, 190, 108, 245, 143}), - TensorValue({7, 5}, dtype::Quantized8Asymm(0.2f, (uint8_t)233), - { 89, 207, 79, 135, 43, - 29, 235, 171, 40, 78, - 119, 145, 254, 162, 184, - 139, 248, 214, 201, 183, - 127, 75, 48, 200, 96, - 109, 63, 60, 100, 120, - 111, 182, 150, 227, 92}), - {}}, - Testcase{ - {}, - {}, - TensorValue({4, 5}, dtype::QuantizedS32(0.1f * 0.2f), - { 2908, -36975, -9180, -3574, 8114, - 30496, 23588, 32433, 11467, 30974, - 36748, -6939, 26715, 33787, 35329, - -24486, -25049, -19828, -16627, -18972})}); + Testcase{TensorValue( + {4, 7}, dtype::Quantized8Asymm(0.1f, (uint8_t)128), + {6, 97, 210, 47, 213, 246, 92, 121, 132, 133, + 37, 31, 87, 71, 0, 5, 198, 11, 97, 141, + 222, 166, 76, 212, 190, 108, 245, 143}), + TensorValue({7, 5}, + dtype::Quantized8Asymm(0.2f, (uint8_t)233), + {89, 207, 79, 135, 43, 29, 235, 171, 40, + 78, 119, 145, 254, 162, 184, 139, 248, 214, + 201, 183, 127, 75, 48, 200, 96, 109, 63, + 60, 100, 120, 111, 182, 150, 227, 92}), + {}}, + Testcase{{}, + {}, + TensorValue({4, 5}, dtype::QuantizedS32(0.1f * 0.2f), + {2908, -36975, -9180, -3574, 8114, + 30496, 23588, 32433, 11467, 30974, + 36748, -6939, 26715, 33787, 35329, + -24486, -25049, -19828, -16627, -18972})}); param.transposeA = true; checker.set_param(param).exect( - Testcase{ - TensorValue({2, 1}, dtype::Quantized8Asymm(0.7f, (uint8_t)128), - {129, 129}), - TensorValue({2, 1}, dtype::Quantized8Asymm(0.4f, (uint8_t)128), - {129, 129}), - {} - }, - Testcase{ - {}, - {}, - TensorValue({1, 1}, dtype::QuantizedS32(0.7f * 0.4f), {2})}); + Testcase{TensorValue({2, 1}, + dtype::Quantized8Asymm(0.7f, (uint8_t)128), + {129, 129}), + TensorValue({2, 1}, + dtype::Quantized8Asymm(0.4f, (uint8_t)128), + {129, 129}), + {}}, + Testcase{{}, + {}, + TensorValue({1, 1}, dtype::QuantizedS32(0.7f * 0.4f), + {2})}); } TEST_F(NAIVE, MATRIX_MUL_MK4) { @@ -241,4 +238,26 @@ TEST_F(NAIVE, MATRIX_MUL_MK8) { dtype::Int16(), dtype::Int16(), dtype::Int32()); } +TEST_F(NAIVE, MATRIX_MUL_BFLOAT16) { + Checker checker(handle(), /* check_dispatch */ false); + MatrixMul::Param param, fp32_param; + fp32_param = param; + param.compute_mode = param::MatrixMul::ComputeMode::FLOAT32; + checker.set_param(param); + checker.set_dtype(0, dtype::BFloat16()); + checker.set_dtype(1, dtype::BFloat16()); + checker.set_dtype(2, dtype::BFloat16()); + auto extra_impl = extra_impl_helper(handle(), fp32_param); + + checker.set_extra_opr_impl(extra_impl); + checker.set_epsilon(1.5e-2); + UniformFloatRNG frng{1e-2, 5.f}; + checker.set_rng(0, &frng); + checker.set_rng(1, &frng); + checker.execs({{8, 8}, {8, 8}, {}}); + param.compute_mode = param::MatrixMul::ComputeMode::DEFAULT; + checker.set_param(param); + checker.execs({{8, 8}, {8, 8}, {}}); +} + // vim: syntax=cpp.doxygen diff --git a/dnn/test/naive/type_cvt.cpp b/dnn/test/naive/type_cvt.cpp index 7b68e1242fd53d2d7bf465f81d6c593bf4a4fe56..dfcfee8a59b8e96a2a08645a8896289d667b8fab 100644 --- a/dnn/test/naive/type_cvt.cpp +++ b/dnn/test/naive/type_cvt.cpp @@ -117,4 +117,54 @@ TEST_F(NAIVE, TYPECVT_QINT4) { }); } +TEST_F(NAIVE, TYPECVT_BFLOAT16) { + Checker checker(handle(), false); + + checker.exect( + Testcase{TensorValue({1, 1, 2, 4}, dtype::Float32(), + { + 0.19921875, // 0x3E4C0000 + 0.19970703125, // 0x3E4C8000 + 0.1997108459472656, // 0x3E4C8100 + 0.1997032165527344, // 0x3E4C7F00 + 0.2001953125, // 0x3E4D0000 + 0.20068359375, // 0x3E4D8000 + 0.2006874084472656, // 0x3E4D8100 + 0.2006797790527344 // 0x3E4D7F00 + }), + {}}, + Testcase{{}, + TensorValue({1, 1, 2, 4}, dtype::BFloat16(), + { + 0.19921875, // 0x3E4C + 0.19921875, // 0x3E4C + 0.2001953125, // 0x3E4D + 0.19921875, // 0x3E4C + 0.2001953125, // 0x3E4D + 0.201171875, // 0x3E4E + 0.201171875, // 0x3E4E + 0.2001953125 // 0x3E4D + })} + + ); + checker.exect(Testcase{TensorValue({1, 1, 2, 2}, dtype::Float32(), + { + -123456.f, // C7F12000 + -123648.f, // C7F18000 + -123136.f, // C7F08000 + -124160.f // C7F28000 + }), + {}}, + Testcase{{}, + TensorValue({1, 1, 2, 2}, dtype::BFloat16(), + { + -123392.f, // C7F1 + -123904.f, // C7F2 + -122880.f, // C7F0 + -123904.f // C7F2 + })} + + ); +} + // vim: syntax=cpp.doxygen diff --git a/dnn/test/naive/warp_perspective.cpp b/dnn/test/naive/warp_perspective.cpp index 852df87892765cf5f01890eb7a5ab344d09b93ad..0d8f6ae8b22bb76a12e17ea4583e431ee7ed8837 100644 --- a/dnn/test/naive/warp_perspective.cpp +++ b/dnn/test/naive/warp_perspective.cpp @@ -15,6 +15,7 @@ #include "test/common/warp_perspective.h" #include "megdnn/tensor_format.h" #include "test/common/benchmarker.h" +#include "test/common/extra_impl_helper.h" using namespace megdnn; using namespace test; @@ -150,7 +151,6 @@ TEST_F(NAIVE, WARP_PERSPECTIVE_NCHW4) { } } - TEST_F(NAIVE, WARP_PERSPECTIVE) { Checker checker(handle(), false); WarpPerspective::Param param; @@ -455,5 +455,63 @@ TEST_F(NAIVE_BENCHMARK_MULTI_THREADS, BENCHMARK_WARP_PERSPECTIVE) { } #endif +TEST_F(NAIVE, WARP_PERSPECTIVE_BFLOAT16) { + Checker checker(handle(), false); + WarpPerspective::Param p; + p.bmode = WarpPerspective::Param::BorderMode::BORDER_REFLECT; + p.imode = WarpPerspective::Param::InterpolationMode::LINEAR; + p.format = WarpPerspective::Param::Format::NCHW; + + auto extra_impl = extra_impl_helper(handle(), p); + checker.set_param(p) + .set_epsilon(1e-1) + .set_dtype(0, dtype::BFloat16()) + .set_dtype(1, dtype::Float32()) + .set_dtype(2, dtype::BFloat16()) + .set_extra_opr_impl(extra_impl) + .execs({{1, 1, 3, 3}, {1, 3, 3}, {1, 1, 2, 2}}) + .execs({{1000, 2, 10, 11}, {1000, 3, 3}, {1000, 2, 12, 13}}); +} + +TEST_F(NAIVE, WARP_PERSPECTIVE_BACKWARD_DATA_BFLOAT16) { + Checker checker(handle(), false); + WarpPerspectiveBackwardData::Param p; + p.bmode = WarpPerspectiveBackwardData::Param::BorderMode::BORDER_REFLECT; + p.imode = WarpPerspectiveBackwardData::Param::InterpolationMode::LINEAR; + p.format = WarpPerspectiveBackwardData::Param::Format::NCHW; + + auto extra_impl = + extra_impl_helper(handle(), p); + checker.set_param(p) + .set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::BFloat16()) + .set_dtype(2, dtype::BFloat16()) + .set_extra_opr_impl(extra_impl) + .set_epsilon(1e-1) + .execs({{1, 3, 3}, {1, 1, 2, 2}, {1, 1, 3, 3}}); +} + +TEST_F(NAIVE, WARP_PERSPECTIVE_BACKWARD_MAT_BFLOAT16) { + Checker checker(handle(), false); + WarpPerspectiveBackwardMat::Param p; + p.bmode = WarpPerspectiveBackwardMat::Param::BorderMode::BORDER_REFLECT; + p.imode = WarpPerspectiveBackwardMat::Param::InterpolationMode::LINEAR; + p.format = WarpPerspectiveBackwardMat::Param::Format::NCHW; + p.border_val = 0.3f; + + auto extra_impl = + extra_impl_helper(handle(), p); + checker.set_param(p) + .set_dtype(0, dtype::BFloat16()) + .set_dtype(1, dtype::Float32()) + .set_dtype(2, dtype::BFloat16()) + .set_dtype(3, dtype::Float32()) + .set_extra_opr_impl(extra_impl) + .set_epsilon(1e-1) + .execs({{1000, 3, 11, 12}, + {1000, 3, 3}, + {1000, 3, 10, 11}, + {1000, 3, 3}}); +} // vim: syntax=cpp.doxygen diff --git a/python_module/CMakeLists.txt b/python_module/CMakeLists.txt index 4d6059324f142f93baedd986ee07b8bfef36779c..26f573f64f2913e17df73d6d3c086f107692b57a 100644 --- a/python_module/CMakeLists.txt +++ b/python_module/CMakeLists.txt @@ -55,7 +55,7 @@ add_custom_command( add_custom_target(mgb_opr_py DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal/opr.py) -set(SRCS src/cpp/craniotome.cpp src/cpp/function_replace.cpp src/cpp/intbx.cpp src/cpp/megbrain_config.cpp src/cpp/megbrain_pubapi.cpp src/cpp/megbrain_serialize.cpp src/cpp/megbrain_wrap.cpp src/cpp/mm_handler.cpp src/cpp/opr_defs.cpp src/cpp/opr_helper.cpp src/cpp/plugin.cpp src/cpp/python_helper.cpp) +set(SRCS src/cpp/craniotome.cpp src/cpp/function_replace.cpp src/cpp/intbx.cpp src/cpp/bfloat16.cpp src/cpp/megbrain_config.cpp src/cpp/megbrain_pubapi.cpp src/cpp/megbrain_serialize.cpp src/cpp/megbrain_wrap.cpp src/cpp/mm_handler.cpp src/cpp/opr_defs.cpp src/cpp/opr_helper.cpp src/cpp/plugin.cpp src/cpp/python_helper.cpp) if(MGE_WITH_DISTRIBUTED) list(APPEND SRCS src/cpp/zmq_rpc.cpp) diff --git a/python_module/megengine/_internal/__init__.py b/python_module/megengine/_internal/__init__.py index 5106ffd968c6c23c31b35bc367a52fd002512be2..95a689652892b4b7f7557b3a968fe2177c242685 100644 --- a/python_module/megengine/_internal/__init__.py +++ b/python_module/megengine/_internal/__init__.py @@ -610,6 +610,10 @@ def get_opr_fp_graph_exec(comp_graph, output_vars): def to_mgb_supported_dtype(dtype_): """get the dtype supported by megbrain nearest to given dtype""" - if dtype.is_lowbit(dtype_) or dtype.is_quantize(dtype_): + if ( + dtype.is_lowbit(dtype_) + or dtype.is_quantize(dtype_) + or dtype.is_bfloat16(dtype_) + ): return dtype_ return _detail._to_mgb_supported_dtype(dtype_) diff --git a/python_module/megengine/_internal/dtype.py b/python_module/megengine/_internal/dtype.py index dc5a4220f5b3c78609370730730764d032952d50..7d0eb61a9a111b19fa0c383ae4bb27fc10940c9e 100644 --- a/python_module/megengine/_internal/dtype.py +++ b/python_module/megengine/_internal/dtype.py @@ -11,7 +11,7 @@ from typing import Union import numpy as np -from .mgb import intb1, intb2, intb4 +from .mgb import bfloat16, intb1, intb2, intb4 _QuantDtypeMetadata = collections.namedtuple( "QuantDtypeMetadata", ["name", "np_dtype_str", "is_unsigned", "qmin", "qmax",] @@ -40,6 +40,10 @@ def is_lowbit(dtype): return (dtype is intb1) or (dtype is intb2) or (dtype is intb4) +def is_bfloat16(dtype): + return dtype is bfloat16 + + def get_scale(dtype): assert is_quantize(dtype) return dtype.metadata["mgb_dtype"]["scale"] diff --git a/python_module/src/cpp/bfloat16.cpp b/python_module/src/cpp/bfloat16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..712e5219c038f00381b77a476ea78d8b24fef7b6 --- /dev/null +++ b/python_module/src/cpp/bfloat16.cpp @@ -0,0 +1,296 @@ +/** + * \file python_module/src/cpp/bfloat16.cpp + * + * This file is part of MegBrain, a deep learning framework developed by Megvii. + * + * \brief numpy dtypes for bfloat16 + * + * \copyright Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + */ + +#include "megbrain/common.h" +#include "megbrain/dtype.h" + +#include +#include + +#define NO_IMPORT_ARRAY 1 +#include "./numpy_incl.h" + +#pragma GCC diagnostic ignored "-Wmissing-field-initializers" + +namespace { + +struct BFloat16Type { + static int npy_typenum; + mgb::dt_bfloat16 value; + + struct PyObj; + struct NpyType; + + template + struct NpyCast; +}; + +int BFloat16Type::npy_typenum; + +/* ==================== BFloat16Type::NpyCast ==================== */ + +template +struct BFloat16Type::NpyCast { + static void apply(void* from_, void* to_, npy_intp n, void* /*fromarr*/, + void* /*toarr*/) { + auto from = static_cast(from_); + auto to = static_cast(to_); + for (npy_intp i = 0; i < n; ++i) { + float cur = static_cast(from[i]); + to[i].value = cur; + } + } +}; + +template +struct BFloat16Type::NpyCast { + static void apply(void* from_, void* to_, npy_intp n, void* /*fromarr*/, + void* /*toarr*/) { + auto from = static_cast(from_); + auto to = static_cast(to_); + for (npy_intp i = 0; i < n; ++i) { + to[i] = from[i].value; + } + } +}; + +/* ==================== BFloat16Type::PyObj ==================== */ +struct BFloat16Type::PyObj { + PyObject_HEAD BFloat16Type obj; + + static PyTypeObject py_type; + + static PyObject* from_bfloat16(BFloat16Type val) { + auto p = reinterpret_cast(py_type.tp_alloc(&py_type, 0)); + p->obj.value = val.value; + return reinterpret_cast(p); + } + + static PyObject* py_new(PyTypeObject* type, PyObject* args, PyObject* kwds); + static PyObject* py_repr(PyObject* obj); + static PyObject* py_richcompare(PyObject* a, PyObject* b, int op); +}; +PyTypeObject BFloat16Type::PyObj::py_type; + +PyObject* BFloat16Type::PyObj::py_new(PyTypeObject* type, PyObject* args, + PyObject* kwds) { + PyObj* self; + Py_ssize_t size; + + self = (PyObj*)type->tp_alloc(type, 0); + + size = PyTuple_GET_SIZE(args); + if (size > 1) { + PyErr_SetString(PyExc_TypeError, "BFloat16Type Only has 1 parameter"); + return NULL; + } + PyObject* x = PyTuple_GET_ITEM(args, 0); + if (PyObject_IsInstance(x, (PyObject*)&py_type)) { + Py_INCREF(x); + return x; + } + + if (!PyFloat_Check(x)) { + PyErr_SetString(PyExc_TypeError, + "BFloat16Type must be initialized wit float"); + return NULL; + } + + const float s = PyFloat_AsDouble(x); + + self->obj.value = s; + + return (PyObject*)self; +} + +PyObject* BFloat16Type::PyObj::py_repr(PyObject* obj) { + float fval = static_cast(((PyObj*)obj)->obj.value); + return PyUnicode_FromString(mgb::ssprintf("%f", fval).c_str()); +} + +PyObject* BFloat16Type::PyObj::py_richcompare(PyObject* a, PyObject* b, + int op) { + mgb_assert(PyObject_IsInstance(a, (PyObject*)&py_type)); + auto bval = PyFloat_AsDouble(b); + if (bval == -1 && PyErr_Occurred()) { + return NULL; + } + double aval = ((PyObj*)a)->obj.value; +#define OP(py, op) \ + case py: { \ + if (aval op bval) { \ + Py_RETURN_TRUE; \ + } else { \ + Py_RETURN_FALSE; \ + } \ + } + switch (op) { + OP(Py_LT, <) + OP(Py_LE, <=) + OP(Py_EQ, ==) + OP(Py_NE, !=) + OP(Py_GT, >) + OP(Py_GE, >=) + }; +#undef OP + return Py_NotImplemented; +} + +/* ==================== BFloat16Type::NpyType ==================== */ +struct BFloat16Type::NpyType { + static PyArray_ArrFuncs funcs; + static PyArray_Descr descr; + + static bool init(); + + static void copyswap(void* dst, void* src, int swap, void* /*arr*/) { + if (src) { + mgb_assert(!swap); + memcpy(dst, src, sizeof(BFloat16Type)); + } + } + static PyObject* getitem(void* data, void* ap) { + return BFloat16Type::PyObj::from_bfloat16( + *static_cast(data)); + } + static int setitem(PyObject* op, void* ov, void* ap); +}; + +PyArray_ArrFuncs BFloat16Type::NpyType::funcs; +PyArray_Descr BFloat16Type::NpyType::descr; + +int BFloat16Type::NpyType::setitem(PyObject* op, void* ov, void* ap) { + if (PyLong_Check(op)) { + int a = PyLong_AsLong(op); + static_cast(ov)->value = a; + } else if (PyFloat_Check(op)) { + float a = PyFloat_AsDouble(op); + static_cast(ov)->value = a; + } else if (PyObject_IsInstance( + op, (PyObject*)(&(BFloat16Type::PyObj::py_type)))) { + static_cast(ov)->value = ((PyObj*)op)->obj.value; + } else { + PyErr_SetString(PyExc_ValueError, + "input type must be int/float/bfloat16"); + return -1; + } + return 0; +} + +bool BFloat16Type::NpyType::init() { + descr = {PyObject_HEAD_INIT(0) & BFloat16Type::PyObj::py_type, + 'V', // kind + 'f', // type + '=', // byteorder + NPY_NEEDS_PYAPI | NPY_USE_GETITEM | NPY_USE_SETITEM, + 1, // type num + sizeof(BFloat16Type), + alignof(BFloat16Type), + NULL, + NULL, + NULL, + &funcs}; + Py_TYPE(&descr) = &PyArrayDescr_Type; + PyArray_InitArrFuncs(&funcs); + funcs.copyswap = copyswap; + funcs.getitem = getitem; + funcs.setitem = setitem; + npy_typenum = PyArray_RegisterDataType(&descr); + +#define REGISTER_CAST(From, To, From_descr, To_typenum, safe) \ + { \ + PyArray_Descr* from_descr = (From_descr); \ + if (PyArray_RegisterCastFunc(from_descr, (To_typenum), \ + NpyCast::apply) < 0) { \ + return false; \ + } \ + if (safe && PyArray_RegisterCanCast(from_descr, (To_typenum), \ + NPY_NOSCALAR) < 0) { \ + return false; \ + } \ + } +#define REGISTER_INT_CASTS(bits) \ + REGISTER_CAST(npy_int##bits, BFloat16Type, \ + PyArray_DescrFromType(NPY_INT##bits), \ + BFloat16Type::npy_typenum, 1) \ + REGISTER_CAST(BFloat16Type, npy_int##bits, &descr, NPY_INT##bits, 0) \ + REGISTER_CAST(npy_uint##bits, BFloat16Type, \ + PyArray_DescrFromType(NPY_UINT##bits), \ + BFloat16Type::npy_typenum, 1) \ + REGISTER_CAST(BFloat16Type, npy_uint##bits, &descr, NPY_UINT##bits, 0) + + REGISTER_INT_CASTS(8) + REGISTER_INT_CASTS(16) + REGISTER_INT_CASTS(32) + REGISTER_INT_CASTS(64) + REGISTER_CAST(BFloat16Type, float, &descr, NPY_FLOAT, 0) + REGISTER_CAST(float, BFloat16Type, PyArray_DescrFromType(NPY_FLOAT), + BFloat16Type::npy_typenum, 0) + REGISTER_CAST(BFloat16Type, double, &descr, NPY_DOUBLE, 1) + REGISTER_CAST(double, BFloat16Type, PyArray_DescrFromType(NPY_DOUBLE), + BFloat16Type::npy_typenum, 0) + return true; +} + +} // anonymous namespace + +bool init_pytype_bfloat16() { + auto& py_type = BFloat16Type::PyObj::py_type; + py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; + py_type.tp_name = "megbrain._mgb.pybfloat16"; + py_type.tp_basicsize = sizeof(BFloat16Type::PyObj); + py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; + py_type.tp_doc = "bfloat16 type"; + py_type.tp_new = BFloat16Type::PyObj::py_new; + py_type.tp_str = BFloat16Type::PyObj::py_repr; + py_type.tp_repr = BFloat16Type::PyObj::py_repr; + py_type.tp_richcompare = BFloat16Type::PyObj::py_richcompare; + py_type.tp_base = &PyGenericArrType_Type; + return PyType_Ready(&py_type) >= 0; +} + +void register_pytype_bfloat16(PyObject* d, PyObject* m) { + Py_INCREF(&BFloat16Type::PyObj::py_type); + PyDict_SetItemString(d, "bfloat16_pytype", + (PyObject*)&BFloat16Type::PyObj::py_type); + PyModule_AddObject(m, "bfloat16_pytype", + (PyObject*)&BFloat16Type::PyObj::py_type); +} + +//! called from swig init +void _init_bfloat16_types(PyObject* m) { + if (m == NULL) + return; + PyObject* d = PyModule_GetDict(m); + PyArray_Descr* dtype; + if (!init_pytype_bfloat16()) + return; + if (!BFloat16Type::NpyType::init()) + return; + dtype = PyArray_DescrFromType(BFloat16Type::npy_typenum); + if (!dtype) + return; + { + PyObject* pytype = (PyObject*)(&BFloat16Type::PyObj::py_type); + Py_INCREF(pytype); + PyDict_SetItemString(d, "pybfloat16", pytype); + } + Py_INCREF(dtype); + PyDict_SetItemString(d, "bfloat16", (PyObject*)dtype); + register_pytype_bfloat16(d, m); + return; +} + +int mgb::npy_num_bfloat16() { + return BFloat16Type::npy_typenum; +} + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/python_module/src/cpp/megbrain_wrap.cpp b/python_module/src/cpp/megbrain_wrap.cpp index efcea280c225d42d0e2f9f69f5d0d577c66635c5..61795deccbaa67d8f7b6ea4834cbb302bccfb050 100644 --- a/python_module/src/cpp/megbrain_wrap.cpp +++ b/python_module/src/cpp/megbrain_wrap.cpp @@ -905,6 +905,9 @@ SymbolVar fill_retain_dtype(SymbolVar var, PyObject *value) { case DTypeEnum::Float16: return var.fill_retain_dtype( static_cast(*tensor.ptr())); + case DTypeEnum::BFloat16: + return var.fill_retain_dtype( + static_cast(*tensor.ptr())); // TODO: What does this mean? case DTypeEnum::Quantized8Asymm: case DTypeEnum::QuantizedS32: diff --git a/python_module/src/cpp/numpy_incl.h b/python_module/src/cpp/numpy_incl.h index b46b5dc3c71fb6c111790f9a699744c37325dffe..02a2b6936da36abdefde8023a007936dd0079327 100644 --- a/python_module/src/cpp/numpy_incl.h +++ b/python_module/src/cpp/numpy_incl.h @@ -19,10 +19,11 @@ cb(2) \ cb(4) \ -#define FOREACH_MGB_DTYPE_PAIR(cb) \ - cb(IntB1, npy_num_intb1()) \ - cb(IntB2, npy_num_intb2()) \ - cb(IntB4, npy_num_intb4()) \ +#define FOREACH_MGB_DTYPE_PAIR(cb) \ + cb(IntB1, npy_num_intb1()) \ + cb(IntB2, npy_num_intb2()) \ + cb(IntB4, npy_num_intb4()) \ + cb(BFloat16, npy_num_bfloat16()) namespace mgb { //! numpy type num for intb2 type @@ -30,6 +31,7 @@ namespace mgb { int npy_num_intb##n(); FOREACH_MGB_LOW_BIT(DEFINE_NPY_INTBX) #undef DEFINE_NPY_INTBX + int npy_num_bfloat16(); } // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/python_module/src/swig/mgb.i b/python_module/src/swig/mgb.i index 1a838f5a2a465f3995f90976a7534dd9747721b5..df261cf383a639db451eb34650a8c253a2f54413 100644 --- a/python_module/src/swig/mgb.i +++ b/python_module/src/swig/mgb.i @@ -14,11 +14,13 @@ #define SWIG_FILE_WITH_INIT 1 void mgb_init_numpy(); // implemented in python_helper.cpp void _init_intbx_types(PyObject *m); // implemented in intbx.cpp +void _init_bfloat16_types(PyObject *m); // implemented in bfloat16.cpp %} %init %{ mgb_init_numpy(); _init_intbx_types(m); + _init_bfloat16_types(m); %} %include "std_vector.i" @@ -36,6 +38,7 @@ import os intb1 = _mgb.intb1 intb2 = _mgb.intb2 intb4 = _mgb.intb4 +bfloat16 = _mgb.bfloat16 %} %{ diff --git a/src/core/include/megbrain/dtype.h b/src/core/include/megbrain/dtype.h index 96927298f58028c376a22c7145776526c8f840c1..57d2955c723cf78831efcfe67518bbdb3cca5441 100644 --- a/src/core/include/megbrain/dtype.h +++ b/src/core/include/megbrain/dtype.h @@ -18,6 +18,7 @@ namespace mgb { using ::megdnn::dt_byte; MEGDNN_INC_FLOAT16(using ::megdnn::dt_float16;) +MEGDNN_INC_FLOAT16(using ::megdnn::dt_bfloat16;) using ::megdnn::dt_float32; using ::megdnn::dt_int8; using ::megdnn::dt_uint8; diff --git a/src/opr/impl/loop/forward.cpp b/src/opr/impl/loop/forward.cpp index 019765890e79595afad7e8b1e2e8781d09aad7db..d4e878d1bde193961a6907aeeb7ae4645a7b7063 100644 --- a/src/opr/impl/loop/forward.cpp +++ b/src/opr/impl/loop/forward.cpp @@ -382,6 +382,18 @@ cg::OperatorNodeBase::NodeProp* Loop::do_make_node_prop() const { ++ inext; return std::max(inext, 0); } + case DTypeEnum::BFloat16: + { + float iv; + if (val.dtype().enumv() == DTypeEnum::BFloat16) + iv = val.ptr()[0]; + else + iv = val.ptr()[0]; + auto inext = std::ceil(iv); + if (iv == inext && contain_eq) + ++ inext; + return std::max(inext, 0); + } #endif case DTypeEnum::Byte: break; diff --git a/src/opr/impl/loop/impl.cpp b/src/opr/impl/loop/impl.cpp index e5794c37d3d2b2b4a2df8f0cf77137c9d50937a1..45b36a19a2a3f708e1ad58c30276da7d2e8f9b90 100644 --- a/src/opr/impl/loop/impl.cpp +++ b/src/opr/impl/loop/impl.cpp @@ -229,6 +229,8 @@ MGB_DEFINE_OPR_CLASS(LoopImpl::DescImplBase::LoopCondManager::GetCondOpr, #if !MEGDNN_DISABLE_FLOAT16 case DTypeEnum::Float16: return std::abs(cond.ptr()[0]) > 1e-5; + case DTypeEnum::BFloat16: + return std::abs(cond.ptr()[0]) > 1e-5; #endif #define cb(_dt) case DTypeTrait<_dt>::enumv: \ diff --git a/src/serialization/impl/dtype.fbs b/src/serialization/impl/dtype.fbs index 1e5fbd9ced0312abbc904d3b6b497a5bb665280e..debc2d98f6347a9038c4e87c8f705afc78c721b1 100644 --- a/src/serialization/impl/dtype.fbs +++ b/src/serialization/impl/dtype.fbs @@ -20,6 +20,7 @@ enum DTypeEnum : byte { Quantized4Asymm, QuantizedS4, QuantizedS16, + BFloat16, } table LinearQuantizationParam {