diff --git a/dnn/include/megdnn/arch.h b/dnn/include/megdnn/arch.h index dea0dfd891c0d0976fad50f8e9acf86720deb956..7e2a4341fff8a893a4dc137f78c929141c6bb60e 100644 --- a/dnn/include/megdnn/arch.h +++ b/dnn/include/megdnn/arch.h @@ -140,6 +140,16 @@ #define MEGDNN_DEVICE #endif +#if MEGDNN_CC_CUDA + #define MEGDNN_FORCE_INLINE __forceinline__ +#else +#if __GNUC__ || __has_attribute(always_inline) + #define MEGDNN_FORCE_INLINE inline __attribute__((always_inline)) +#else + #define MEGDNN_FORCE_INLINE inline +#endif +#endif + #if defined(_MSC_VER) || defined(WIN32) #define ATTR_ALIGNED(v) __declspec(align(v)) #else diff --git a/dnn/src/common/resize.cuh b/dnn/src/common/resize.cuh index 8445dfb85e7ddab93ace3cab06f5cf19a3fb7958..5eb516dd2bcc04c5536fd03f0e93764c7a8dc179 100644 --- a/dnn/src/common/resize.cuh +++ b/dnn/src/common/resize.cuh @@ -13,18 +13,10 @@ #include "megdnn/arch.h" -#if MEGDNN_CC_HOST && !defined(__host__) -#if __GNUC__ || __has_attribute(always_inline) -#define __forceinline__ inline __attribute__((always_inline)) -#else -#define __forceinline__ inline -#endif -#endif - namespace megdnn { namespace resize { -MEGDNN_HOST MEGDNN_DEVICE __forceinline__ void interpolate_cubic( +MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE void interpolate_cubic( float x, float* coeffs) { const float A = -0.75f; diff --git a/dnn/src/common/rounding_converter.cuh b/dnn/src/common/rounding_converter.cuh index 8d03f533b83b4754bdeed71f41c9dfb750931a6e..336f830974769b872d3f1b8c98964dc11773c1bd 100644 --- a/dnn/src/common/rounding_converter.cuh +++ b/dnn/src/common/rounding_converter.cuh @@ -12,17 +12,6 @@ #pragma once #include "megdnn/dtype.h" -#if MEGDNN_CC_HOST && !defined(__host__) -#define MEGDNN_HOST_DEVICE_SELF_DEFINE -#define __host__ -#define __device__ -#if __GNUC__ || __has_attribute(always_inline) -#define __forceinline__ inline __attribute__((always_inline)) -#else -#define __forceinline__ inline -#endif -#endif - namespace megdnn { namespace rounding { @@ -31,7 +20,8 @@ struct RoundingConverter; template <> struct RoundingConverter { - __host__ __device__ __forceinline__ float operator()(float x) const { + MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE float operator()( + float x) const { return x; } }; @@ -40,7 +30,7 @@ struct RoundingConverter { template <> struct RoundingConverter { - __host__ __device__ __forceinline__ half_float::half operator()( + MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE half_float::half operator()( float x) const { return static_cast(x); } @@ -48,8 +38,8 @@ struct RoundingConverter { template <> struct RoundingConverter { - __host__ __device__ __forceinline__ half_bfloat16::bfloat16 operator()( - float x) const { + MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE half_bfloat16::bfloat16 + operator()(float x) const { return static_cast(x); } }; @@ -58,7 +48,8 @@ struct RoundingConverter { template <> struct RoundingConverter { - __host__ __device__ __forceinline__ int8_t operator()(float x) const { + MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE int8_t + operator()(float x) const { #if MEGDNN_CC_HOST using std::round; #endif @@ -68,11 +59,12 @@ struct RoundingConverter { template <> struct RoundingConverter { - __host__ __device__ __forceinline__ uint8_t operator()(float x) const { + MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE uint8_t + operator()(float x) const { #if MEGDNN_CC_HOST - using std::round; using std::max; using std::min; + using std::round; #endif x = min(255.0f, max(0.0f, x)); //! FIXME!!! check other places return static_cast(round(x)); @@ -81,7 +73,8 @@ struct RoundingConverter { template <> struct RoundingConverter { - __host__ __device__ __forceinline__ dt_qint4 operator()(float x) const { + MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE dt_qint4 + operator()(float x) const { #if MEGDNN_CC_HOST using std::round; #endif @@ -91,7 +84,8 @@ struct RoundingConverter { template <> struct RoundingConverter { - __host__ __device__ __forceinline__ dt_quint4 operator()(float x) const { + MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE dt_quint4 + operator()(float x) const { #if MEGDNN_CC_HOST using std::round; #endif