提交 369c2ccc 编写于 作者: M Megvii Engine Team

style(all): reformat c++ code

GitOrigin-RevId: 3ffd1b211f140e8ca05661a3c71801083ae4a951
上级 bfb30dcb

要显示的变更太多。

To preserve performance only 1000 of 1000+ files are displayed.
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
...@@ -23,9 +23,9 @@ ...@@ -23,9 +23,9 @@
#pragma GCC diagnostic ignored "-Wunused-parameter" #pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wdeprecated-declarations" #pragma GCC diagnostic ignored "-Wdeprecated-declarations"
#pragma GCC diagnostic ignored "-Wsign-compare" #pragma GCC diagnostic ignored "-Wsign-compare"
#include <hip/hip_runtime_api.h>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h> #include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#include <hip/hip_runtime_api.h>
#pragma GCC diagnostic pop #pragma GCC diagnostic pop
#if !defined(__HIP_PLATFORM_HCC__) #if !defined(__HIP_PLATFORM_HCC__)
......
...@@ -11,10 +11,10 @@ ...@@ -11,10 +11,10 @@
#pragma once #pragma once
#include "megdnn/thin/function.h"
#include "megcore_cdefs.h"
#include <cstddef> #include <cstddef>
#include <memory> #include <memory>
#include "megcore_cdefs.h"
#include "megdnn/thin/function.h"
#include "megdnn/internal/visibility_prologue.h" #include "megdnn/internal/visibility_prologue.h"
...@@ -26,36 +26,35 @@ namespace megcore { ...@@ -26,36 +26,35 @@ namespace megcore {
* the caller thread immediately. * the caller thread immediately.
*/ */
class CPUDispatcher { class CPUDispatcher {
public: public:
using Task = megdnn::thin_function<void()>; using Task = megdnn::thin_function<void()>;
using MultiThreadingTask = megdnn::thin_function<void(size_t, size_t)>; using MultiThreadingTask = megdnn::thin_function<void(size_t, size_t)>;
virtual ~CPUDispatcher() noexcept; virtual ~CPUDispatcher() noexcept;
/*! /*!
* \brief dispatch a task on the computing thread * \brief dispatch a task on the computing thread
* \param task the task that would be moved away * \param task the task that would be moved away
*/ */
virtual void dispatch(Task&& task) = 0; virtual void dispatch(Task&& task) = 0;
/*! /*!
* \brief dispatch a multithreading task on the computing thread * \brief dispatch a multithreading task on the computing thread
* \param task the task would be moved away * \param task the task would be moved away
* \param parallelism the parallelism of the task. * \param parallelism the parallelism of the task.
*/ */
virtual void dispatch(MultiThreadingTask&& task, virtual void dispatch(MultiThreadingTask&& task, size_t parallelism) = 0;
size_t parallelism) = 0;
/*!
/*! * \brief synchronize the calling thread with the computing thread
* \brief synchronize the calling thread with the computing thread */
*/ virtual void sync() = 0;
virtual void sync() = 0;
/*!
/*! * \brief the computing thread number.
* \brief the computing thread number. */
*/ virtual size_t nr_threads() = 0;
virtual size_t nr_threads() = 0;
}; };
} // namespace megcore } // namespace megcore
using MegcoreCPUDispatcher = megcore::CPUDispatcher; using MegcoreCPUDispatcher = megcore::CPUDispatcher;
...@@ -63,75 +62,62 @@ using MegcoreCPUDispatcher = megcore::CPUDispatcher; ...@@ -63,75 +62,62 @@ using MegcoreCPUDispatcher = megcore::CPUDispatcher;
* \brief Layer 1: device handle * \brief Layer 1: device handle
*/ */
struct megcoreDeviceContext; struct megcoreDeviceContext;
typedef struct megcoreDeviceContext *megcoreDeviceHandle_t; typedef struct megcoreDeviceContext* megcoreDeviceHandle_t;
megcoreStatus_t megcoreCreateDeviceHandle( megcoreStatus_t megcoreCreateDeviceHandle(
megcoreDeviceHandle_t *handle, megcoreDeviceHandle_t* handle, megcorePlatform_t platform, int deviceID = -1,
megcorePlatform_t platform,
int deviceID = -1,
unsigned int flags = 0); unsigned int flags = 0);
megcoreStatus_t megcoreDestroyDeviceHandle( megcoreStatus_t megcoreDestroyDeviceHandle(megcoreDeviceHandle_t handle);
megcoreDeviceHandle_t handle);
megcoreStatus_t megcoreGetPlatform(
megcoreStatus_t megcoreGetPlatform(megcoreDeviceHandle_t handle, megcoreDeviceHandle_t handle, megcorePlatform_t* platform);
megcorePlatform_t *platform); megcoreStatus_t megcoreGetDeviceID(megcoreDeviceHandle_t handle, int* deviceID);
megcoreStatus_t megcoreGetDeviceID(megcoreDeviceHandle_t handle, megcoreStatus_t megcoreGetMemAlignment(
int *deviceID); megcoreDeviceHandle_t handle, size_t* memAlignmentInBytes);
megcoreStatus_t megcoreGetMemAlignment(megcoreDeviceHandle_t handle,
size_t *memAlignmentInBytes);
megcoreStatus_t megcoreGetDeviceFlags( megcoreStatus_t megcoreGetDeviceFlags(
megcoreDeviceHandle_t handle, megcoreDeviceHandle_t handle, unsigned int* flags);
unsigned int *flags);
megcoreStatus_t megcoreActivate(megcoreDeviceHandle_t handle); megcoreStatus_t megcoreActivate(megcoreDeviceHandle_t handle);
megcoreStatus_t megcoreDeactivate(megcoreDeviceHandle_t handle); megcoreStatus_t megcoreDeactivate(megcoreDeviceHandle_t handle);
megcoreStatus_t megcoreMalloc(megcoreDeviceHandle_t handle, megcoreStatus_t megcoreMalloc(
void **devPtr, size_t sizeInBytes); megcoreDeviceHandle_t handle, void** devPtr, size_t sizeInBytes);
megcoreStatus_t megcoreFree(megcoreDeviceHandle_t handle, megcoreStatus_t megcoreFree(megcoreDeviceHandle_t handle, void* devPtr);
void *devPtr);
/** /**
* \brief Layer 2: computing handle * \brief Layer 2: computing handle
*/ */
struct megcoreComputingContext; struct megcoreComputingContext;
typedef struct megcoreComputingContext *megcoreComputingHandle_t; typedef struct megcoreComputingContext* megcoreComputingHandle_t;
megcoreStatus_t megcoreCreateComputingHandle( megcoreStatus_t megcoreCreateComputingHandle(
megcoreComputingHandle_t *compHandle, megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle,
megcoreDeviceHandle_t devHandle,
unsigned int flags = 0); unsigned int flags = 0);
megcoreStatus_t megcoreCreateComputingHandleWithCPUDispatcher( megcoreStatus_t megcoreCreateComputingHandleWithCPUDispatcher(
megcoreComputingHandle_t *compHandle, megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle,
megcoreDeviceHandle_t devHandle,
const std::shared_ptr<MegcoreCPUDispatcher>& dispatcher, const std::shared_ptr<MegcoreCPUDispatcher>& dispatcher,
unsigned int flags = 0); unsigned int flags = 0);
megcoreStatus_t megcoreDestroyComputingHandle( megcoreStatus_t megcoreDestroyComputingHandle(megcoreComputingHandle_t handle);
megcoreComputingHandle_t handle);
megcoreStatus_t megcoreGetDeviceHandle( megcoreStatus_t megcoreGetDeviceHandle(
megcoreComputingHandle_t compHandle, megcoreComputingHandle_t compHandle, megcoreDeviceHandle_t* devHandle);
megcoreDeviceHandle_t *devHandle);
megcoreStatus_t megcoreGetComputingFlags( megcoreStatus_t megcoreGetComputingFlags(
megcoreComputingHandle_t handle, megcoreComputingHandle_t handle, unsigned int* flags);
unsigned int *flags);
MegcoreCPUDispatcher* megcoreGetCPUDispatcher(megcoreComputingHandle_t handle); MegcoreCPUDispatcher* megcoreGetCPUDispatcher(megcoreComputingHandle_t handle);
megcoreStatus_t megcoreMemcpy( megcoreStatus_t megcoreMemcpy(
megcoreComputingHandle_t handle, megcoreComputingHandle_t handle, void* dst, const void* src, size_t sizeInBytes,
void *dst, const void *src, size_t sizeInBytes,
megcoreMemcpyKind_t kind); megcoreMemcpyKind_t kind);
megcoreStatus_t megcoreMemset( megcoreStatus_t megcoreMemset(
megcoreComputingHandle_t handle, megcoreComputingHandle_t handle, void* dst, int value, size_t sizeInBytes);
void *dst, int value, size_t sizeInBytes);
megcoreStatus_t megcoreSynchronize(megcoreComputingHandle_t handle); megcoreStatus_t megcoreSynchronize(megcoreComputingHandle_t handle);
/** /**
* \brief Miscellaneous * \brief Miscellaneous
*/ */
const char *megcoreGetErrorName(megcoreStatus_t status); const char* megcoreGetErrorName(megcoreStatus_t status);
#include "megdnn/internal/visibility_epilogue.h" #include "megdnn/internal/visibility_epilogue.h"
......
...@@ -33,8 +33,7 @@ megcoreStatus_t createComputingHandleWithAtlasContext( ...@@ -33,8 +33,7 @@ megcoreStatus_t createComputingHandleWithAtlasContext(
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle,
unsigned int flags, const AtlasContext& ctx); unsigned int flags, const AtlasContext& ctx);
megcoreStatus_t getAtlasContext(megcoreComputingHandle_t handle, megcoreStatus_t getAtlasContext(megcoreComputingHandle_t handle, AtlasContext* ctx);
AtlasContext* ctx);
namespace atlas { namespace atlas {
//! convert acl error code to error string //! convert acl error code to error string
...@@ -47,12 +46,12 @@ inline megcoreStatus_t megcoreCreateComputingHandleWithACLStream( ...@@ -47,12 +46,12 @@ inline megcoreStatus_t megcoreCreateComputingHandleWithACLStream(
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle,
unsigned int flags, aclrtStream stream) { unsigned int flags, aclrtStream stream) {
megcore::AtlasContext ctx{stream}; megcore::AtlasContext ctx{stream};
return megcore::createComputingHandleWithAtlasContext(compHandle, devHandle, return megcore::createComputingHandleWithAtlasContext(
flags, ctx); compHandle, devHandle, flags, ctx);
} }
inline megcoreStatus_t megcoreGetACLStream(megcoreComputingHandle_t handle, inline megcoreStatus_t megcoreGetACLStream(
aclrtStream* stream) { megcoreComputingHandle_t handle, aclrtStream* stream) {
megcore::AtlasContext ctx; megcore::AtlasContext ctx;
auto ret = megcore::getAtlasContext(handle, &ctx); auto ret = megcore::getAtlasContext(handle, &ctx);
*stream = ctx.stream; *stream = ctx.stream;
......
...@@ -34,8 +34,8 @@ megcoreStatus_t createComputingHandleWithCambriconContext( ...@@ -34,8 +34,8 @@ megcoreStatus_t createComputingHandleWithCambriconContext(
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle,
unsigned int flags, const CambriconContext& ctx); unsigned int flags, const CambriconContext& ctx);
megcoreStatus_t getCambriconContext(megcoreComputingHandle_t handle, megcoreStatus_t getCambriconContext(
CambriconContext* ctx); megcoreComputingHandle_t handle, CambriconContext* ctx);
} // namespace megcore } // namespace megcore
...@@ -58,4 +58,3 @@ static inline megcoreStatus_t megcoreGetCNRTQueue( ...@@ -58,4 +58,3 @@ static inline megcoreStatus_t megcoreGetCNRTQueue(
#include "megdnn/internal/visibility_epilogue.h" #include "megdnn/internal/visibility_epilogue.h"
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -40,7 +40,6 @@ typedef enum { ...@@ -40,7 +40,6 @@ typedef enum {
megcoreErrorInternalError = 5, megcoreErrorInternalError = 5,
} megcoreStatus_t; } megcoreStatus_t;
/** /**
* \brief Memcpy kind * \brief Memcpy kind
*/ */
...@@ -70,6 +69,6 @@ struct AsyncErrorInfo { ...@@ -70,6 +69,6 @@ struct AsyncErrorInfo {
char msg[228]; char msg[228];
int msg_args[4]; int msg_args[4];
}; };
} // namespace megcore } // namespace megcore
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -33,8 +33,7 @@ megcoreStatus_t createComputingHandleWithCUDAContext( ...@@ -33,8 +33,7 @@ megcoreStatus_t createComputingHandleWithCUDAContext(
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle,
unsigned int flags, const CudaContext& ctx); unsigned int flags, const CudaContext& ctx);
megcoreStatus_t getCUDAContext(megcoreComputingHandle_t handle, megcoreStatus_t getCUDAContext(megcoreComputingHandle_t handle, CudaContext* ctx);
CudaContext* ctx);
} // namespace megcore } // namespace megcore
...@@ -43,8 +42,8 @@ static inline megcoreStatus_t megcoreCreateComputingHandleWithCUDAStream( ...@@ -43,8 +42,8 @@ static inline megcoreStatus_t megcoreCreateComputingHandleWithCUDAStream(
unsigned int flags, cudaStream_t stream) { unsigned int flags, cudaStream_t stream) {
megcore::CudaContext ctx; megcore::CudaContext ctx;
ctx.stream = stream; ctx.stream = stream;
return megcore::createComputingHandleWithCUDAContext(compHandle, devHandle, return megcore::createComputingHandleWithCUDAContext(
flags, ctx); compHandle, devHandle, flags, ctx);
} }
static inline megcoreStatus_t megcoreGetCUDAStream( static inline megcoreStatus_t megcoreGetCUDAStream(
......
...@@ -23,7 +23,9 @@ struct ROCMContext { ...@@ -23,7 +23,9 @@ struct ROCMContext {
hipStream_t stream = nullptr; hipStream_t stream = nullptr;
static std::atomic_bool sm_miopen_algo_search; static std::atomic_bool sm_miopen_algo_search;
static inline bool enable_miopen_algo_search() { return sm_miopen_algo_search.load(); } static inline bool enable_miopen_algo_search() {
return sm_miopen_algo_search.load();
}
static inline void enable_miopen_algo_search(bool enable_algo_search) { static inline void enable_miopen_algo_search(bool enable_algo_search) {
sm_miopen_algo_search.store(enable_algo_search); sm_miopen_algo_search.store(enable_algo_search);
} }
...@@ -40,8 +42,7 @@ megcoreStatus_t createComputingHandleWithROCMContext( ...@@ -40,8 +42,7 @@ megcoreStatus_t createComputingHandleWithROCMContext(
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle,
unsigned int flags, const ROCMContext& ctx); unsigned int flags, const ROCMContext& ctx);
megcoreStatus_t getROCMContext(megcoreComputingHandle_t handle, megcoreStatus_t getROCMContext(megcoreComputingHandle_t handle, ROCMContext* ctx);
ROCMContext* ctx);
// Set MIOpen algo search enabled or disabled // Set MIOpen algo search enabled or disabled
megcoreStatus_t enableMIOpenAlgoSearch(bool enable_algo_search = true); megcoreStatus_t enableMIOpenAlgoSearch(bool enable_algo_search = true);
...@@ -55,8 +56,8 @@ static inline megcoreStatus_t megcoreCreateComputingHandleWithROCMStream( ...@@ -55,8 +56,8 @@ static inline megcoreStatus_t megcoreCreateComputingHandleWithROCMStream(
unsigned int flags, hipStream_t stream) { unsigned int flags, hipStream_t stream) {
megcore::ROCMContext ctx; megcore::ROCMContext ctx;
ctx.stream = stream; ctx.stream = stream;
return megcore::createComputingHandleWithROCMContext(compHandle, devHandle, return megcore::createComputingHandleWithROCMContext(
flags, ctx); compHandle, devHandle, flags, ctx);
} }
static inline megcoreStatus_t megcoreGetROCMStream( static inline megcoreStatus_t megcoreGetROCMStream(
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
*/ */
#pragma once #pragma once
#include "megdnn/version.h"
#include "megdnn/oprs.h" #include "megdnn/oprs.h"
#include "megdnn/version.h"
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -14,20 +14,20 @@ ...@@ -14,20 +14,20 @@
#include "megdnn/config/config.h" #include "megdnn/config/config.h"
#if defined(__GNUC__) || defined(__clang__) #if defined(__GNUC__) || defined(__clang__)
#if !defined (__clang__) #if !defined(__clang__)
// gcc specific // gcc specific
#define GCC_VERSION (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__) #define GCC_VERSION (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__)
#if GCC_VERSION < 40800 #if GCC_VERSION < 40800
#error "GCC version should be at least 4.8.0." #error "GCC version should be at least 4.8.0."
#endif // GCC_VERSION < 40800 #endif // GCC_VERSION < 40800
#endif // !defined(__clang__) #endif // !defined(__clang__)
#ifndef megdnn_trap #ifndef megdnn_trap
#define megdnn_trap() __builtin_trap() #define megdnn_trap() __builtin_trap()
#endif #endif
#define megdnn_likely(v) __builtin_expect(bool(v), 1) #define megdnn_likely(v) __builtin_expect(bool(v), 1)
#define megdnn_unlikely(v) __builtin_expect(bool(v), 0) #define megdnn_unlikely(v) __builtin_expect(bool(v), 0)
#if !defined(__clang__) && MEGDNN_ARMV7 && !defined(NDEBUG) #if !defined(__clang__) && MEGDNN_ARMV7 && !defined(NDEBUG)
//! Thumb2 limit code length //! Thumb2 limit code length
...@@ -36,123 +36,122 @@ ...@@ -36,123 +36,122 @@
#define MEGDNN_ALWAYS_INLINE inline __attribute__((__always_inline__)) #define MEGDNN_ALWAYS_INLINE inline __attribute__((__always_inline__))
#endif #endif
#define MEGDNN_DEPRECATED __attribute__((deprecated)) #define MEGDNN_DEPRECATED __attribute__((deprecated))
#define MEGDNN_PACKED __attribute__((packed)) #define MEGDNN_PACKED __attribute__((packed))
#define MEGDNN_CONSTEXPR constexpr #define MEGDNN_CONSTEXPR constexpr
#define MEGDNN_NOEXCEPT noexcept #define MEGDNN_NOEXCEPT noexcept
#define MEGDNN_STATIC_ASSERT static_assert #define MEGDNN_STATIC_ASSERT static_assert
#define MEGDNN_FINAL final #define MEGDNN_FINAL final
#define MEGDNN_NORETURN __attribute__((noreturn)) #define MEGDNN_NORETURN __attribute__((noreturn))
#define MEGDNN_WARN_UNUSED_RESULT __attribute__((warn_unused_result)) #define MEGDNN_WARN_UNUSED_RESULT __attribute__((warn_unused_result))
#define MEGDNN_ATTRIBUTE_TARGET(simd) __attribute__((target(simd))) #define MEGDNN_ATTRIBUTE_TARGET(simd) __attribute__((target(simd)))
#if defined(__clang_major__) && (__clang_major__ >= 7) #if defined(__clang_major__) && (__clang_major__ >= 7)
#define MEGDNN_LAMBDA_ATTRIBUTE_TARGET(simd) __attribute__((target(simd))) #define MEGDNN_LAMBDA_ATTRIBUTE_TARGET(simd) __attribute__((target(simd)))
#else #else
#define MEGDNN_LAMBDA_ATTRIBUTE_TARGET(simd) [[gnu::target(simd)]] #define MEGDNN_LAMBDA_ATTRIBUTE_TARGET(simd) [[gnu::target(simd)]]
#endif #endif
#define MEGDNN_NOINLINE __attribute__((noinline)) #define MEGDNN_NOINLINE __attribute__((noinline))
#define megdnn_isatty(x) isatty(x) #define megdnn_isatty(x) isatty(x)
#elif defined(__INTEL_COMPILER) || defined(_MSC_VER) #elif defined(__INTEL_COMPILER) || defined(_MSC_VER)
#ifndef megdnn_trap #ifndef megdnn_trap
#define megdnn_trap() __debugbreak() #define megdnn_trap() __debugbreak()
#endif #endif
#define megdnn_likely(v) (bool(v)) #define megdnn_likely(v) (bool(v))
#define megdnn_unlikely(v) (bool(v)) #define megdnn_unlikely(v) (bool(v))
#define MEGDNN_DEPRECATED #define MEGDNN_DEPRECATED
#define MEGDNN_PACKED #define MEGDNN_PACKED
#define MEGDNN_CONSTEXPR constexpr #define MEGDNN_CONSTEXPR constexpr
#define MEGDNN_NOEXCEPT noexcept #define MEGDNN_NOEXCEPT noexcept
#define MEGDNN_STATIC_ASSERT static_assert #define MEGDNN_STATIC_ASSERT static_assert
#define MEGDNN_FINAL final #define MEGDNN_FINAL final
#if defined(_MSC_VER) #if defined(_MSC_VER)
#define MEGDNN_NORETURN __declspec(noreturn) #define MEGDNN_NORETURN __declspec(noreturn)
#define MEGDNN_NOINLINE __declspec(noinline) #define MEGDNN_NOINLINE __declspec(noinline)
#else #else
#define MEGDNN_NORETURN #define MEGDNN_NORETURN
#define MEGDNN_FORCE_NOINLINE #define MEGDNN_FORCE_NOINLINE
#endif // _MSC_VER #endif // _MSC_VER
#define MEGDNN_WARN_UNUSED_RESULT #define MEGDNN_WARN_UNUSED_RESULT
#define megdnn_isatty(x) _isatty(x) #define megdnn_isatty(x) _isatty(x)
#else #else
#error "unknown compiler" #error "unknown compiler"
#endif // __GNUC__ #endif // __GNUC__
// __cpp_exceptions and __cpp_rtti is referred from // __cpp_exceptions and __cpp_rtti is referred from
// https://isocpp.org/std/standing-documentssd-6-sg10-feature-test-recommendations // https://isocpp.org/std/standing-documentssd-6-sg10-feature-test-recommendations
// gcc < 5 does not define __cpp_exceptions but __EXCEPTIONS, // gcc < 5 does not define __cpp_exceptions but __EXCEPTIONS,
// similar for __GXX_RTTI // similar for __GXX_RTTI
// _CPPUNWIND and _CPPRTTI is used by MSVC, see // _CPPUNWIND and _CPPRTTI is used by MSVC, see
// https://docs.microsoft.com/en-us/cpp/preprocessor/predefined-macrosview=vs-2019 // https://docs.microsoft.com/en-us/cpp/preprocessor/predefined-macrosview=vs-2019
#ifndef MEGDNN_ENABLE_EXCEPTIONS #ifndef MEGDNN_ENABLE_EXCEPTIONS
#if __cpp_exceptions || __EXCEPTIONS || \ #if __cpp_exceptions || __EXCEPTIONS || (defined(_MSC_VER) && defined(_CPPUNWIND))
(defined(_MSC_VER) && defined(_CPPUNWIND)) #define MEGDNN_ENABLE_EXCEPTIONS 1
#define MEGDNN_ENABLE_EXCEPTIONS 1 #else
#else #define MEGDNN_ENABLE_EXCEPTIONS 0
#define MEGDNN_ENABLE_EXCEPTIONS 0 #endif
#endif
#endif #endif
#ifndef MEGDNN_ENABLE_RTTI #ifndef MEGDNN_ENABLE_RTTI
#if __cpp_rtti || __GXX_RTTI || (defined(_MSC_VER) && defined(_CPPRTTI)) #if __cpp_rtti || __GXX_RTTI || (defined(_MSC_VER) && defined(_CPPRTTI))
#define MEGDNN_ENABLE_RTTI 1 #define MEGDNN_ENABLE_RTTI 1
#else #else
#define MEGDNN_ENABLE_RTTI 0 #define MEGDNN_ENABLE_RTTI 0
#endif #endif
#endif #endif
#ifdef __CUDACC__ #ifdef __CUDACC__
#define MEGDNN_CC_CUDA 1 #define MEGDNN_CC_CUDA 1
#undef MEGDNN_CONSTEXPR #undef MEGDNN_CONSTEXPR
#define MEGDNN_CONSTEXPR const #define MEGDNN_CONSTEXPR const
#if defined(__CUDACC_VER_MAJOR__) #if defined(__CUDACC_VER_MAJOR__)
#if __CUDACC_VER_MAJOR__ >= 9 #if __CUDACC_VER_MAJOR__ >= 9
#undef MEGDNN_STATIC_ASSERT #undef MEGDNN_STATIC_ASSERT
#define MEGDNN_STATIC_ASSERT(cond, msg) static_assert(cond, msg); #define MEGDNN_STATIC_ASSERT(cond, msg) static_assert(cond, msg);
#else #else
#undef MEGDNN_STATIC_ASSERT #undef MEGDNN_STATIC_ASSERT
#define MEGDNN_STATIC_ASSERT(cond, msg) #define MEGDNN_STATIC_ASSERT(cond, msg)
#endif #endif
#endif #endif
#define nullptr NULL #define nullptr NULL
#undef MEGDNN_FINAL #undef MEGDNN_FINAL
#define MEGDNN_FINAL #define MEGDNN_FINAL
#elif defined(__HIPCC__) #elif defined(__HIPCC__)
#define MEGDNN_CC_CUDA 1 #define MEGDNN_CC_CUDA 1
#else #else
#define MEGDNN_CC_HOST 1 #define MEGDNN_CC_HOST 1
#endif // __CUDACC__ #endif // __CUDACC__
// MEGDNN_HOST and MEGDNN_DEVICE // MEGDNN_HOST and MEGDNN_DEVICE
#if MEGDNN_CC_CUDA #if MEGDNN_CC_CUDA
#define MEGDNN_HOST __host__ #define MEGDNN_HOST __host__
#define MEGDNN_DEVICE __device__ #define MEGDNN_DEVICE __device__
#else #else
#define MEGDNN_HOST #define MEGDNN_HOST
#define MEGDNN_DEVICE #define MEGDNN_DEVICE
#endif #endif
#if MEGDNN_CC_CUDA #if MEGDNN_CC_CUDA
#define MEGDNN_FORCE_INLINE __forceinline__ #define MEGDNN_FORCE_INLINE __forceinline__
#else #else
#if __GNUC__ || __has_attribute(always_inline) #if __GNUC__ || __has_attribute(always_inline)
#define MEGDNN_FORCE_INLINE inline __attribute__((always_inline)) #define MEGDNN_FORCE_INLINE inline __attribute__((always_inline))
#else #else
#define MEGDNN_FORCE_INLINE inline #define MEGDNN_FORCE_INLINE inline
#endif #endif
#endif #endif
#if defined(_MSC_VER) || defined(WIN32) #if defined(_MSC_VER) || defined(WIN32)
#define ATTR_ALIGNED(v) __declspec(align(v)) #define ATTR_ALIGNED(v) __declspec(align(v))
#else #else
#define ATTR_ALIGNED(v) __attribute__((aligned(v))) #define ATTR_ALIGNED(v) __attribute__((aligned(v)))
#endif #endif
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -16,10 +16,10 @@ ...@@ -16,10 +16,10 @@
#include "megdnn/internal/defs.h" #include "megdnn/internal/defs.h"
#if MEGDNN_CC_HOST #if MEGDNN_CC_HOST
#include <cstdarg>
#include <string> #include <string>
#include <type_traits> #include <type_traits>
#include <vector> #include <vector>
#include <cstdarg>
#include "megdnn/thin/small_vector.h" #include "megdnn/thin/small_vector.h"
#endif // MEGDNN_CC_HOST #endif // MEGDNN_CC_HOST
...@@ -35,8 +35,7 @@ class ErrorHandler { ...@@ -35,8 +35,7 @@ class ErrorHandler {
protected: protected:
MEGDNN_NORETURN virtual void do_on_megdnn_error(const std::string& msg) = 0; MEGDNN_NORETURN virtual void do_on_megdnn_error(const std::string& msg) = 0;
MEGDNN_NORETURN virtual void do_on_tensor_reshape_error( MEGDNN_NORETURN virtual void do_on_tensor_reshape_error(const std::string& msg) {
const std::string& msg) {
on_megdnn_error(msg); on_megdnn_error(msg);
} }
...@@ -70,8 +69,9 @@ public: ...@@ -70,8 +69,9 @@ public:
#if MEGDNN_CC_HOST #if MEGDNN_CC_HOST
enum class LogLevel { DEBUG, INFO, WARN, ERROR }; enum class LogLevel { DEBUG, INFO, WARN, ERROR };
typedef void (*LogHandler)(LogLevel level, const char* file, const char* func, typedef void (*LogHandler)(
int line, const char* fmt, va_list ap); LogLevel level, const char* file, const char* func, int line, const char* fmt,
va_list ap);
/*! /*!
* \brief set the callback to receive all log messages * \brief set the callback to receive all log messages
...@@ -144,8 +144,7 @@ struct TensorLayout : public TensorShape { ...@@ -144,8 +144,7 @@ struct TensorLayout : public TensorShape {
ptrdiff_t low_elem, low_byte; ptrdiff_t low_elem, low_byte;
size_t high_elem, high_byte; size_t high_elem, high_byte;
Span(ptrdiff_t low_elem, ptrdiff_t low_byte, size_t high_elem, Span(ptrdiff_t low_elem, ptrdiff_t low_byte, size_t high_elem, size_t high_byte)
size_t high_byte)
: low_elem(low_elem), : low_elem(low_elem),
low_byte(low_byte), low_byte(low_byte),
high_elem(high_elem), high_elem(high_elem),
...@@ -235,11 +234,13 @@ struct TensorLayout : public TensorShape { ...@@ -235,11 +234,13 @@ struct TensorLayout : public TensorShape {
TensorLayout(const TensorShape& shape, DType dtype, Format format); TensorLayout(const TensorShape& shape, DType dtype, Format format);
//! creating layout with user-specified shape and stride. //! creating layout with user-specified shape and stride.
TensorLayout(const TensorShape& shape, const std::vector<ptrdiff_t>& stride, TensorLayout(
DType dtype); const TensorShape& shape, const std::vector<ptrdiff_t>& stride,
DType dtype);
TensorLayout(const TensorShape& shape, const std::vector<ptrdiff_t>& stride, TensorLayout(
DType dtype, Format format); const TensorShape& shape, const std::vector<ptrdiff_t>& stride, DType dtype,
Format format);
/* =================== inplace modifiers =================== */ /* =================== inplace modifiers =================== */
...@@ -310,8 +311,7 @@ struct TensorLayout : public TensorShape { ...@@ -310,8 +311,7 @@ struct TensorLayout : public TensorShape {
* *
* \throw TensorReshapeError if no stride exists for target shape. * \throw TensorReshapeError if no stride exists for target shape.
*/ */
TensorLayout reshape(const TensorShape& shape) const TensorLayout reshape(const TensorShape& shape) const MEGDNN_WARN_UNUSED_RESULT;
MEGDNN_WARN_UNUSED_RESULT;
/*! /*!
* \brief try to reshape to another view; return whether these two shapes * \brief try to reshape to another view; return whether these two shapes
...@@ -319,15 +319,14 @@ struct TensorLayout : public TensorShape { ...@@ -319,15 +319,14 @@ struct TensorLayout : public TensorShape {
* \return true iff there exists target stride so this layout can be * \return true iff there exists target stride so this layout can be
* converted to target shape and the elements can match. * converted to target shape and the elements can match.
*/ */
bool try_reshape(TensorLayout& output, bool try_reshape(TensorLayout& output, const TensorShape& shape) const
const TensorShape& shape) const MEGDNN_WARN_UNUSED_RESULT; MEGDNN_WARN_UNUSED_RESULT;
/*! /*!
* \brief Broadcast on dims with shape == 1 to match target *shape*. * \brief Broadcast on dims with shape == 1 to match target *shape*.
* \throw TensorReshapeError if could not be satisfied * \throw TensorReshapeError if could not be satisfied
*/ */
TensorLayout broadcast(const TensorShape& shape) const TensorLayout broadcast(const TensorShape& shape) const MEGDNN_WARN_UNUSED_RESULT;
MEGDNN_WARN_UNUSED_RESULT;
/*! /*!
* \brief Collapse consecutive axes with contiguous layout together * \brief Collapse consecutive axes with contiguous layout together
...@@ -441,8 +440,7 @@ struct Workspace { ...@@ -441,8 +440,7 @@ struct Workspace {
Workspace() : raw_ptr(NULL), size(0) {} Workspace() : raw_ptr(NULL), size(0) {}
Workspace(dt_byte* raw_ptr_, size_t size_) Workspace(dt_byte* raw_ptr_, size_t size_) : raw_ptr(raw_ptr_), size(size_) {}
: raw_ptr(raw_ptr_), size(size_) {}
template <typename T> template <typename T>
T* ptr(size_t offset_in_bytes = 0) const { T* ptr(size_t offset_in_bytes = 0) const {
...@@ -467,9 +465,8 @@ public: ...@@ -467,9 +465,8 @@ public:
* \param shape requested output shape * \param shape requested output shape
* \param user_data extra user data passed in DynOutMallocPolicyCall * \param user_data extra user data passed in DynOutMallocPolicyCall
*/ */
virtual TensorND alloc_output(size_t id, DType dtype, virtual TensorND alloc_output(
const TensorShape& shape, size_t id, DType dtype, const TensorShape& shape, void* user_data) = 0;
void* user_data) = 0;
/*! /*!
* \brief allocate workspace memory * \brief allocate workspace memory
...@@ -508,19 +505,15 @@ struct DynOutMallocPolicyCall { ...@@ -508,19 +505,15 @@ struct DynOutMallocPolicyCall {
*/ */
template <typename T = void, typename elem = T> template <typename T = void, typename elem = T>
T* alloc_workspace(size_t nr_elem) { T* alloc_workspace(size_t nr_elem) {
using real_elem = using real_elem = typename std::conditional<
typename std::conditional<std::is_same<elem, void>::value, std::is_same<elem, void>::value, uint8_t, elem>::type;
uint8_t, elem>::type; return static_cast<T*>(
return static_cast<T*>(policy->alloc_workspace( policy->alloc_workspace(nr_elem * sizeof(real_elem), user_data));
nr_elem * sizeof(real_elem), user_data));
} }
void free_workspace(void* ptr) { void free_workspace(void* ptr) { return policy->free_workspace(ptr, user_data); }
return policy->free_workspace(ptr, user_data);
}
}; };
template <typename T> template <typename T>
class EnumClassBit { class EnumClassBit {
std::underlying_type_t<T> m_val; std::underlying_type_t<T> m_val;
...@@ -528,8 +521,7 @@ class EnumClassBit { ...@@ -528,8 +521,7 @@ class EnumClassBit {
constexpr EnumClassBit(std::underlying_type_t<T> v) : m_val(v) {} constexpr EnumClassBit(std::underlying_type_t<T> v) : m_val(v) {}
public: public:
constexpr EnumClassBit(T v) constexpr EnumClassBit(T v) : m_val(static_cast<std::underlying_type_t<T>>(v)) {}
: m_val(static_cast<std::underlying_type_t<T>>(v)) {}
constexpr operator T() const { return static_cast<T>(m_val); } constexpr operator T() const { return static_cast<T>(m_val); }
...@@ -542,7 +534,7 @@ public: ...@@ -542,7 +534,7 @@ public:
DEF_OPR(&) DEF_OPR(&)
DEF_OPR(|) DEF_OPR(|)
DEF_OPR (^) DEF_OPR(^)
constexpr EnumClassBit operator~() const { return ~m_val; } constexpr EnumClassBit operator~() const { return ~m_val; }
...@@ -553,14 +545,13 @@ public: ...@@ -553,14 +545,13 @@ public:
} // namespace megdnn } // namespace megdnn
#define _MEGDNN_DECBO_SINGLE_OPR(cls, op) \ #define _MEGDNN_DECBO_SINGLE_OPR(cls, op) \
inline constexpr ::megdnn::EnumClassBit<cls> operator op(cls x, cls y) { \ inline constexpr ::megdnn::EnumClassBit<cls> operator op(cls x, cls y) { \
return ::megdnn::EnumClassBit<cls>(x) \ return ::megdnn::EnumClassBit<cls>(x) op ::megdnn::EnumClassBit<cls>(y); \
op ::megdnn::EnumClassBit<cls>(y); \ } \
} \ inline constexpr ::megdnn::EnumClassBit<cls> operator op( \
inline constexpr ::megdnn::EnumClassBit<cls> operator op( \ ::megdnn::EnumClassBit<cls> x, cls y) { \
::megdnn::EnumClassBit<cls> x, cls y) { \ return x op ::megdnn::EnumClassBit<cls>(y); \
return x op ::megdnn::EnumClassBit<cls>(y); \
} }
#define _MEGDNN_DECBO_SINGLE_OPR_ASSIGN(cls, op) \ #define _MEGDNN_DECBO_SINGLE_OPR_ASSIGN(cls, op) \
......
...@@ -14,14 +14,14 @@ ...@@ -14,14 +14,14 @@
#include "megbrain_build_config.h" #include "megbrain_build_config.h"
#if MGB_ENABLE_GETENV #if MGB_ENABLE_GETENV
#define MGB_GETENV ::std::getenv #define MGB_GETENV ::std::getenv
#else #else
#define MGB_GETENV(_name) static_cast<char*>(nullptr) #define MGB_GETENV(_name) static_cast<char*>(nullptr)
#endif #endif
#ifdef WIN32 #ifdef WIN32
#define unsetenv(_name) _putenv_s(_name, ""); #define unsetenv(_name) _putenv_s(_name, "");
#define setenv(name,value,overwrite) _putenv_s(name,value) #define setenv(name, value, overwrite) _putenv_s(name, value)
#endif #endif
namespace megdnn { namespace megdnn {
...@@ -32,8 +32,7 @@ namespace megdnn { ...@@ -32,8 +32,7 @@ namespace megdnn {
*/ */
template <class Opr, typename... Args> template <class Opr, typename... Args>
bool has_available_algo(Opr* opr, Args&&... args) { bool has_available_algo(Opr* opr, Args&&... args) {
const typename Opr::AlgoBase::SizeArgs size_args( const typename Opr::AlgoBase::SizeArgs size_args(opr, std::forward<Args>(args)...);
opr, std::forward<Args>(args)...);
for (auto i : Opr::algo_pack().all_algos) { for (auto i : Opr::algo_pack().all_algos) {
if (i->is_available(size_args)) { if (i->is_available(size_args)) {
return true; return true;
...@@ -42,6 +41,6 @@ bool has_available_algo(Opr* opr, Args&&... args) { ...@@ -42,6 +41,6 @@ bool has_available_algo(Opr* opr, Args&&... args) {
return false; return false;
} }
} } // namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -17,11 +17,11 @@ ...@@ -17,11 +17,11 @@
#include "megdnn/internal/visibility_prologue.h" #include "megdnn/internal/visibility_prologue.h"
namespace megdnn { namespace megdnn {
std::unique_ptr<Handle> make_cuda_handle_with_stream(cudaStream_t stream, std::unique_ptr<Handle> make_cuda_handle_with_stream(
int device_id = -1); cudaStream_t stream, int device_id = -1);
cudaStream_t get_cuda_stream(Handle *handle); cudaStream_t get_cuda_stream(Handle* handle);
} // namespace megdnn } // namespace megdnn
#include "megdnn/internal/visibility_epilogue.h" #include "megdnn/internal/visibility_epilogue.h"
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
此差异已折叠。
...@@ -3,17 +3,22 @@ ...@@ -3,17 +3,22 @@
* *
* Copyright (c) 2012-2013 Christian Rau <rauy@users.sourceforge.net> * Copyright (c) 2012-2013 Christian Rau <rauy@users.sourceforge.net>
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation * Permission is hereby granted, free of charge, to any person obtaining a copy of this
* files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, * software and associated documentation files (the "Software"), to deal in the Software
* modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the * without restriction, including without limitation the rights to use, copy, modify,
* Software is furnished to do so, subject to the following conditions: * merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
* * permit persons to whom the Software is furnished to do so, subject to the following
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. * conditions:
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE * The above copyright notice and this permission notice shall be included in all copies
* WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR * or substantial portions of the Software.
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, *
* ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
* PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF
* CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE
* OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
* *
* Version 1.11.0 * Version 1.11.0
* \file * \file
...@@ -41,8 +46,8 @@ ...@@ -41,8 +46,8 @@
#undef HALF_NOEXCEPT #undef HALF_NOEXCEPT
#undef HALF_NOTHROW #undef HALF_NOTHROW
#ifdef HALF_POP_WARNINGS #ifdef HALF_POP_WARNINGS
#pragma warning(pop) #pragma warning(pop)
#undef HALF_POP_WARNINGS #undef HALF_POP_WARNINGS
#endif #endif
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -3,17 +3,22 @@ ...@@ -3,17 +3,22 @@
* *
* Copyright (c) 2012-2013 Christian Rau <rauy@users.sourceforge.net> * Copyright (c) 2012-2013 Christian Rau <rauy@users.sourceforge.net>
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation * Permission is hereby granted, free of charge, to any person obtaining a copy of this
* files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, * software and associated documentation files (the "Software"), to deal in the Software
* modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the * without restriction, including without limitation the rights to use, copy, modify,
* Software is furnished to do so, subject to the following conditions: * merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
* permit persons to whom the Software is furnished to do so, subject to the following
* conditions:
* *
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. * The above copyright notice and this permission notice shall be included in all copies
* or substantial portions of the Software.
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
* WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR * INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, * PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
* ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF
* CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE
* OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
* *
* Version 1.11.0 * Version 1.11.0
* \file * \file
...@@ -39,166 +44,164 @@ ...@@ -39,166 +44,164 @@
#include "megdnn/arch.h" #include "megdnn/arch.h"
/// Combined gcc version number. /// Combined gcc version number.
#define HALF_GNUC_VERSION (__GNUC__*100+__GNUC_MINOR__) #define HALF_GNUC_VERSION (__GNUC__ * 100 + __GNUC_MINOR__)
//check C++11 language features // check C++11 language features
#if defined(__clang__) //clang #if defined(__clang__) // clang
#if __has_feature(cxx_static_assert) && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) #if __has_feature(cxx_static_assert) && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT)
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 #define HALF_ENABLE_CPP11_STATIC_ASSERT 1
#endif #endif
#if __has_feature(cxx_constexpr) && !defined(HALF_ENABLE_CPP11_CONSTEXPR) #if __has_feature(cxx_constexpr) && !defined(HALF_ENABLE_CPP11_CONSTEXPR)
#define HALF_ENABLE_CPP11_CONSTEXPR 1 #define HALF_ENABLE_CPP11_CONSTEXPR 1
#endif #endif
#if __has_feature(cxx_noexcept) && !defined(HALF_ENABLE_CPP11_NOEXCEPT) #if __has_feature(cxx_noexcept) && !defined(HALF_ENABLE_CPP11_NOEXCEPT)
#define HALF_ENABLE_CPP11_NOEXCEPT 1 #define HALF_ENABLE_CPP11_NOEXCEPT 1
#endif #endif
#if __has_feature(cxx_user_literals) && !defined(HALF_ENABLE_CPP11_USER_LITERALS) #if __has_feature(cxx_user_literals) && !defined(HALF_ENABLE_CPP11_USER_LITERALS)
#define HALF_ENABLE_CPP11_USER_LITERALS 1 #define HALF_ENABLE_CPP11_USER_LITERALS 1
#endif #endif
#if (defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L) && !defined(HALF_ENABLE_CPP11_LONG_LONG) #if (defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L) && \
#define HALF_ENABLE_CPP11_LONG_LONG 1 !defined(HALF_ENABLE_CPP11_LONG_LONG)
#endif #define HALF_ENABLE_CPP11_LONG_LONG 1
/*#elif defined(__INTEL_COMPILER) //Intel C++ #endif
#if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) ???????? /*#elif defined(__INTEL_COMPILER)
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 //Intel C++ #if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT)
#endif ???????? #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 #endif #if __INTEL_COMPILER >=
#if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) ???????? 1300 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) ???????? #define
#define HALF_ENABLE_CPP11_CONSTEXPR 1 HALF_ENABLE_CPP11_CONSTEXPR 1 #endif #if __INTEL_COMPILER >= 1300 &&
#endif !defined(HALF_ENABLE_CPP11_NOEXCEPT) ???????? #define
#if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) ???????? HALF_ENABLE_CPP11_NOEXCEPT 1 #endif #if __INTEL_COMPILER >= 1100 &&
#define HALF_ENABLE_CPP11_NOEXCEPT 1 !defined(HALF_ENABLE_CPP11_LONG_LONG) ???????? #define
#endif HALF_ENABLE_CPP11_LONG_LONG 1 #endif*/
#if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_LONG_LONG) ???????? #elif defined(__GNUC__) // gcc
#define HALF_ENABLE_CPP11_LONG_LONG 1 #if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L
#endif*/ #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT)
#elif defined(__GNUC__) //gcc #define HALF_ENABLE_CPP11_STATIC_ASSERT 1
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L #endif
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) #if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_CONSTEXPR)
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 #define HALF_ENABLE_CPP11_CONSTEXPR 1
#endif #endif
#if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) #if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_NOEXCEPT)
#define HALF_ENABLE_CPP11_CONSTEXPR 1 #define HALF_ENABLE_CPP11_NOEXCEPT 1
#endif #endif
#if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) #if HALF_GNUC_VERSION >= 407 && !defined(HALF_ENABLE_CPP11_USER_LITERALS)
#define HALF_ENABLE_CPP11_NOEXCEPT 1 #define HALF_ENABLE_CPP11_USER_LITERALS 1
#endif #endif
#if HALF_GNUC_VERSION >= 407 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) #if !defined(HALF_ENABLE_CPP11_LONG_LONG)
#define HALF_ENABLE_CPP11_USER_LITERALS 1 #define HALF_ENABLE_CPP11_LONG_LONG 1
#endif #endif
#if !defined(HALF_ENABLE_CPP11_LONG_LONG) #endif
#define HALF_ENABLE_CPP11_LONG_LONG 1 #elif defined(_MSC_VER) // Visual C++
#endif #if _MSC_VER >= 1600 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT)
#endif #define HALF_ENABLE_CPP11_STATIC_ASSERT 1
#elif defined(_MSC_VER) //Visual C++ #endif
#if _MSC_VER >= 1600 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) #if _MSC_VER >= 1310 && !defined(HALF_ENABLE_CPP11_LONG_LONG)
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 #define HALF_ENABLE_CPP11_LONG_LONG 1
#endif #endif
#if _MSC_VER >= 1310 && !defined(HALF_ENABLE_CPP11_LONG_LONG) #define HALF_POP_WARNINGS 1
#define HALF_ENABLE_CPP11_LONG_LONG 1 #pragma warning(push)
#endif //! 4521 and 4522 is multiple copy/assigment operator specified
#define HALF_POP_WARNINGS 1 #pragma warning(disable : 4099 4127 4146 4521 4522) // struct vs class, constant in if,
#pragma warning(push) // negative unsigned
//! 4521 and 4522 is multiple copy/assigment operator specified
#pragma warning(disable : 4099 4127 4146 4521 4522) //struct vs class, constant in if, negative unsigned
#endif #endif
//check C++11 library features // check C++11 library features
#include <utility> #include <utility>
#if defined(_LIBCPP_VERSION) //libc++ #if defined(_LIBCPP_VERSION) // libc++
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 #if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103
#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS #ifndef HALF_ENABLE_CPP11_TYPE_TRAITS
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 #define HALF_ENABLE_CPP11_TYPE_TRAITS 1
#endif #endif
#ifndef HALF_ENABLE_CPP11_CSTDINT #ifndef HALF_ENABLE_CPP11_CSTDINT
#define HALF_ENABLE_CPP11_CSTDINT 1 #define HALF_ENABLE_CPP11_CSTDINT 1
#endif #endif
#ifndef HALF_ENABLE_CPP11_CMATH #ifndef HALF_ENABLE_CPP11_CMATH
#define HALF_ENABLE_CPP11_CMATH 1 #define HALF_ENABLE_CPP11_CMATH 1
#endif #endif
#ifndef HALF_ENABLE_CPP11_HASH #ifndef HALF_ENABLE_CPP11_HASH
#define HALF_ENABLE_CPP11_HASH 1 #define HALF_ENABLE_CPP11_HASH 1
#endif #endif
#endif #endif
#elif defined(__GLIBCXX__) //libstdc++ #elif defined(__GLIBCXX__) // libstdc++
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 #if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103
#ifdef __clang__ #ifdef __clang__
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS)
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 #define HALF_ENABLE_CPP11_TYPE_TRAITS 1
#endif #endif
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CSTDINT) #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CSTDINT)
#define HALF_ENABLE_CPP11_CSTDINT 1 #define HALF_ENABLE_CPP11_CSTDINT 1
#endif #endif
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CMATH) #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CMATH)
#define HALF_ENABLE_CPP11_CMATH 1 #define HALF_ENABLE_CPP11_CMATH 1
#endif #endif
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_HASH) #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_HASH)
#define HALF_ENABLE_CPP11_HASH 1 #define HALF_ENABLE_CPP11_HASH 1
#endif #endif
#else #else
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CSTDINT) #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CSTDINT)
#define HALF_ENABLE_CPP11_CSTDINT 1 #define HALF_ENABLE_CPP11_CSTDINT 1
#endif #endif
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CMATH) #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CMATH)
#define HALF_ENABLE_CPP11_CMATH 1 #define HALF_ENABLE_CPP11_CMATH 1
#endif #endif
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_HASH) #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_HASH)
#define HALF_ENABLE_CPP11_HASH 1 #define HALF_ENABLE_CPP11_HASH 1
#endif #endif
#endif #endif
#endif #endif
#elif defined(_CPPLIB_VER) //Dinkumware/Visual C++ #elif defined(_CPPLIB_VER) // Dinkumware/Visual C++
#if _CPPLIB_VER >= 520 #if _CPPLIB_VER >= 520
#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS #ifndef HALF_ENABLE_CPP11_TYPE_TRAITS
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 #define HALF_ENABLE_CPP11_TYPE_TRAITS 1
#endif #endif
#ifndef HALF_ENABLE_CPP11_CSTDINT #ifndef HALF_ENABLE_CPP11_CSTDINT
#define HALF_ENABLE_CPP11_CSTDINT 1 #define HALF_ENABLE_CPP11_CSTDINT 1
#endif #endif
#ifndef HALF_ENABLE_CPP11_HASH #ifndef HALF_ENABLE_CPP11_HASH
#define HALF_ENABLE_CPP11_HASH 1 #define HALF_ENABLE_CPP11_HASH 1
#endif #endif
#endif #endif
#if _CPPLIB_VER >= 610 #if _CPPLIB_VER >= 610
#ifndef HALF_ENABLE_CPP11_CMATH #ifndef HALF_ENABLE_CPP11_CMATH
#define HALF_ENABLE_CPP11_CMATH 1 #define HALF_ENABLE_CPP11_CMATH 1
#endif #endif
#endif #endif
#endif #endif
#undef HALF_GNUC_VERSION #undef HALF_GNUC_VERSION
//support constexpr // support constexpr
#if HALF_ENABLE_CPP11_CONSTEXPR #if HALF_ENABLE_CPP11_CONSTEXPR
#define HALF_CONSTEXPR constexpr #define HALF_CONSTEXPR constexpr
#define HALF_CONSTEXPR_CONST constexpr #define HALF_CONSTEXPR_CONST constexpr
#else #else
#define HALF_CONSTEXPR #define HALF_CONSTEXPR
#define HALF_CONSTEXPR_CONST const #define HALF_CONSTEXPR_CONST const
#endif #endif
//support noexcept // support noexcept
#if HALF_ENABLE_CPP11_NOEXCEPT #if HALF_ENABLE_CPP11_NOEXCEPT
#define HALF_NOEXCEPT noexcept #define HALF_NOEXCEPT noexcept
#define HALF_NOTHROW noexcept #define HALF_NOTHROW noexcept
#else #else
#define HALF_NOEXCEPT #define HALF_NOEXCEPT
#define HALF_NOTHROW throw() #define HALF_NOTHROW throw()
#endif #endif
#include <algorithm> #include <algorithm>
#include <limits>
#include <climits> #include <climits>
#include <cmath> #include <cmath>
#include <cstring> #include <cstring>
#include <ostream>
#include <istream> #include <istream>
#include <limits>
#include <ostream>
#if HALF_ENABLE_CPP11_TYPE_TRAITS #if HALF_ENABLE_CPP11_TYPE_TRAITS
#include <type_traits> #include <type_traits>
#endif #endif
#if HALF_ENABLE_CPP11_CSTDINT #if HALF_ENABLE_CPP11_CSTDINT
#include <cstdint> #include <cstdint>
#endif #endif
#if HALF_ENABLE_CPP11_HASH #if HALF_ENABLE_CPP11_HASH
#include <functional> #include <functional>
#endif #endif
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
#pragma once #pragma once
#include "megcore.h" #include "megcore.h"
#include "megdnn/config/config.h"
#include "megdnn/basic_types.h" #include "megdnn/basic_types.h"
#include "megdnn/config/config.h"
#include <functional> #include <functional>
#include <memory> #include <memory>
...@@ -24,150 +24,147 @@ namespace megdnn { ...@@ -24,150 +24,147 @@ namespace megdnn {
class OperatorBase; class OperatorBase;
class Handle { class Handle {
public: public:
enum class HandleType { enum class HandleType {
NAIVE = 0, NAIVE = 0,
FALLBACK = 1, FALLBACK = 1,
X86 = 2, X86 = 2,
ARM_COMMON = 3, ARM_COMMON = 3,
ARMV7 = 4, ARMV7 = 4,
AARCH64 = 5, AARCH64 = 5,
CUDA = 6, CUDA = 6,
ROCM = 11, ROCM = 11,
ATLAS = 13, ATLAS = 13,
CAMBRICON = 12, CAMBRICON = 12,
}; };
//! Device vendor //! Device vendor
enum class HandleVendorType : uint32_t { enum class HandleVendorType : uint32_t {
NOT_SPEC = 0, NOT_SPEC = 0,
MALI = 1, MALI = 1,
ADRENO = 2, ADRENO = 2,
CUDA = 3, CUDA = 3,
INTEL = 4, INTEL = 4,
POWERVR = 5, POWERVR = 5,
AMD = 6, AMD = 6,
}; };
protected: protected:
Handle(megcoreComputingHandle_t computing_handle, HandleType type); Handle(megcoreComputingHandle_t computing_handle, HandleType type);
public: public:
/** /**
* \brief Create a MegDNN handle from a MegCore Computing handle. * \brief Create a MegDNN handle from a MegCore Computing handle.
* *
* \param[in] computing_handle MegCore computing handle. Please note * \param[in] computing_handle MegCore computing handle. Please note
* that computing_handle would not be released when this Handle is * that computing_handle would not be released when this Handle is
* destructed * destructed
* \param[in] debug_level * \param[in] debug_level
* Applicable for CPU computing handle. * Applicable for CPU computing handle.
* 0 means taking the fastest possible code path; it may contains * 0 means taking the fastest possible code path; it may contains
* platform-specific instructions such as SSE for x86_64 or NEON for * platform-specific instructions such as SSE for x86_64 or NEON for
* armv7v7. * armv7v7.
* 1 means taking the fastest possible code path without * 1 means taking the fastest possible code path without
* platform-specific instructions in C++ code. Note that the compiled * platform-specific instructions in C++ code. Note that the compiled
* binary file still contains platform-specific codes. * binary file still contains platform-specific codes.
* 2 means taking the naive code path. Performance is severely * 2 means taking the naive code path. Performance is severely
* hampered, but it is less error-prone since the internal * hampered, but it is less error-prone since the internal
* implementation is rather straightforward. * implementation is rather straightforward.
* *
* **Debug level 1 and 2 should not be used in productions.** * **Debug level 1 and 2 should not be used in productions.**
*/ */
static std::unique_ptr<Handle> make( static std::unique_ptr<Handle> make(
megcoreComputingHandle_t computing_handle, megcoreComputingHandle_t computing_handle, int debug_level = 0);
int debug_level = 0);
#if MEGDNN_WITH_CUDA #if MEGDNN_WITH_CUDA
static std::unique_ptr<Handle> make_cuda_handle( static std::unique_ptr<Handle> make_cuda_handle(
megcoreComputingHandle_t computing_handle); megcoreComputingHandle_t computing_handle);
template <typename opr> template <typename opr>
std::unique_ptr<opr> create_cuda_operator(); std::unique_ptr<opr> create_cuda_operator();
#endif #endif
#if MEGDNN_WITH_ROCM #if MEGDNN_WITH_ROCM
static std::unique_ptr<Handle> make_rocm_handle( static std::unique_ptr<Handle> make_rocm_handle(
megcoreComputingHandle_t computing_handle); megcoreComputingHandle_t computing_handle);
template <typename opr> template <typename opr>
std::unique_ptr<opr> create_rocm_operator(); std::unique_ptr<opr> create_rocm_operator();
#endif #endif
virtual ~Handle(); virtual ~Handle();
/*! /*!
* \brief Get the underlying megcore computing handle. * \brief Get the underlying megcore computing handle.
*/ */
megcoreComputingHandle_t megcore_computing_handle() const { megcoreComputingHandle_t megcore_computing_handle() const {
return m_computing_handle; return m_computing_handle;
} }
/*! /*!
* \brief set a callback function to be invoked when this handle is * \brief set a callback function to be invoked when this handle is
* destructed, so associated resources can be released (e.g. * destructed, so associated resources can be released (e.g.
* computing handle) * computing handle)
* *
* This function can be called at most once. * This function can be called at most once.
*/ */
void set_destructor(const thin_function<void()> &d); void set_destructor(const thin_function<void()>& d);
/*! /*!
* \brief set a callback to be invoked when an operator is destructed * \brief set a callback to be invoked when an operator is destructed
* \param[in,out] cb the callback function; it would be set to the * \param[in,out] cb the callback function; it would be set to the
* previous callback function * previous callback function
*/ */
void set_opr_destruct_callback(thin_function<void(OperatorBase*)> &cb) { void set_opr_destruct_callback(thin_function<void(OperatorBase*)>& cb) {
cb.swap(m_on_opr_destructed); cb.swap(m_on_opr_destructed);
} }
void on_opr_destructed(OperatorBase* opr); void on_opr_destructed(OperatorBase* opr);
/** /**
* \brief Create operator of Opr type. * \brief Create operator of Opr type.
*/ */
template <typename Opr> template <typename Opr>
std::unique_ptr<Opr> create_operator(); std::unique_ptr<Opr> create_operator();
/* /*
* ============================================================= * =============================================================
* Users should call functions below to query memory requirement. * Users should call functions below to query memory requirement.
* ============================================================= * =============================================================
*/ */
/** /**
* \brief The internal data pointer of TensorND should be aligned to * \brief The internal data pointer of TensorND should be aligned to
* alignment_requirement() in bytes. * alignment_requirement() in bytes.
*/ */
virtual size_t alignment_requirement() const; virtual size_t alignment_requirement() const;
//! get alignment in bytes for rows of image 2D tensor format //! get alignment in bytes for rows of image 2D tensor format
virtual size_t image2d_pitch_alignment() const; virtual size_t image2d_pitch_alignment() const;
//! get vendor type //! get vendor type
virtual HandleVendorType vendor_type() const; virtual HandleVendorType vendor_type() const;
HandleType type() const { HandleType type() const { return m_handle_type; }
return m_handle_type;
} /**
* \brief Check is the layout satisfy cross device copy constraint.
/** * 1. The handle of the src and the dst is the same kind
* \brief Check is the layout satisfy cross device copy constraint. * 2. The dst is continguous.
* 1. The handle of the src and the dst is the same kind */
* 2. The dst is continguous. virtual bool check_cross_dev_copy_constraint(const TensorLayout& src);
*/
virtual bool check_cross_dev_copy_constraint(const TensorLayout &src); private:
static constexpr uint32_t ALIVE_MAGIC = 0x8595e9d2u;
private: volatile uint32_t m_alive_magic = ALIVE_MAGIC;
static constexpr uint32_t ALIVE_MAGIC = 0x8595e9d2u; megcoreComputingHandle_t m_computing_handle;
volatile uint32_t m_alive_magic = ALIVE_MAGIC; const HandleType m_handle_type;
megcoreComputingHandle_t m_computing_handle; thin_function<void()> m_destructor;
const HandleType m_handle_type; thin_function<void(OperatorBase*)> m_on_opr_destructed;
thin_function<void()> m_destructor;
thin_function<void(OperatorBase*)> m_on_opr_destructed; Handle() = delete;
Handle(const Handle& rhs) = delete;
Handle() = delete; Handle& operator=(const Handle& rhs) = delete;
Handle(const Handle &rhs) = delete;
Handle &operator=(const Handle &rhs) = delete;
}; };
} // namespace megdnn } // namespace megdnn
#include "megdnn/internal/visibility_epilogue.h" #include "megdnn/internal/visibility_epilogue.h"
......
...@@ -49,8 +49,9 @@ public: ...@@ -49,8 +49,9 @@ public:
mutable std::string m_input; mutable std::string m_input;
public: public:
Key(Handle* opr_handle, Algorithm::OprType opr_type, const TensorLayout* inp_layouts_ptr, Key(Handle* opr_handle, Algorithm::OprType opr_type,
size_t inp_layouts_size, const void* param_ptr = nullptr, size_t param_size = 0) const TensorLayout* inp_layouts_ptr, size_t inp_layouts_size,
const void* param_ptr = nullptr, size_t param_size = 0)
: m_handle{opr_handle}, : m_handle{opr_handle},
m_opr_type{static_cast<uint32_t>(opr_type)}, m_opr_type{static_cast<uint32_t>(opr_type)},
m_inp_layouts_ptr{inp_layouts_ptr}, m_inp_layouts_ptr{inp_layouts_ptr},
......
...@@ -16,20 +16,19 @@ ...@@ -16,20 +16,19 @@
* \brief iterate through small (usually used) ndim values * \brief iterate through small (usually used) ndim values
*/ */
#define MEGDNN_FOREACH_TENSOR_NDIM_SMALL(cb, ...) \ #define MEGDNN_FOREACH_TENSOR_NDIM_SMALL(cb, ...) \
cb(1 ,##__VA_ARGS__) cb(2 ,##__VA_ARGS__) cb(3 ,##__VA_ARGS__) cb(1, ##__VA_ARGS__) cb(2, ##__VA_ARGS__) cb(3, ##__VA_ARGS__)
/*! /*!
* \brief iterate through large (rarely used) ndim values * \brief iterate through large (rarely used) ndim values
*/ */
#define MEGDNN_FOREACH_TENSOR_NDIM_LARGE(cb, ...) \ #define MEGDNN_FOREACH_TENSOR_NDIM_LARGE(cb, ...) \
cb(4 ,##__VA_ARGS__) cb(5 ,##__VA_ARGS__) cb(6 ,##__VA_ARGS__) \ cb(4, ##__VA_ARGS__) cb(5, ##__VA_ARGS__) cb(6, ##__VA_ARGS__) cb(7, ##__VA_ARGS__)
cb(7, ##__VA_ARGS__)
/*! /*!
* \brief iterate through all ndim values * \brief iterate through all ndim values
*/ */
#define MEGDNN_FOREACH_TENSOR_NDIM(cb, ...) \ #define MEGDNN_FOREACH_TENSOR_NDIM(cb, ...) \
MEGDNN_FOREACH_TENSOR_NDIM_SMALL(cb ,##__VA_ARGS__) \ MEGDNN_FOREACH_TENSOR_NDIM_SMALL(cb, ##__VA_ARGS__) \
MEGDNN_FOREACH_TENSOR_NDIM_LARGE(cb ,##__VA_ARGS__) MEGDNN_FOREACH_TENSOR_NDIM_LARGE(cb, ##__VA_ARGS__)
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -11,14 +11,14 @@ ...@@ -11,14 +11,14 @@
// intentional no header guard here // intentional no header guard here
#include "megdnn/handle.h" #include "megdnn/handle.h"
#include "megdnn/oprs/base.h"
#include "megdnn/opr_param_defs.h" #include "megdnn/opr_param_defs.h"
#include "megdnn/opr_result_defs.h" #include "megdnn/opr_result_defs.h"
#include "megdnn/oprs/base.h"
#include "./visibility_prologue.h" #include "./visibility_prologue.h"
#include <limits>
#include <array> #include <array>
#include <limits>
#ifndef _megdnn_in #ifndef _megdnn_in
#define _megdnn_in #define _megdnn_in
...@@ -29,36 +29,37 @@ ...@@ -29,36 +29,37 @@
#endif #endif
#ifndef _megdnn_tensor_in #ifndef _megdnn_tensor_in
#define _megdnn_tensor_in const TensorND & #define _megdnn_tensor_in const TensorND&
#endif #endif
#ifndef _megdnn_tensor_out #ifndef _megdnn_tensor_out
#define _megdnn_tensor_out const TensorND & #define _megdnn_tensor_out const TensorND&
#endif #endif
#ifndef _megdnn_tensor_inout #ifndef _megdnn_tensor_inout
#define _megdnn_tensor_inout const TensorND & #define _megdnn_tensor_inout const TensorND&
#endif #endif
#ifndef _megdnn_workspace #ifndef _megdnn_workspace
#define _megdnn_workspace const Workspace & #define _megdnn_workspace const Workspace&
#endif #endif
#define DEF_OPR_IMPL_CTOR(_opr_name, _base_name) \ #define DEF_OPR_IMPL_CTOR(_opr_name, _base_name) \
public: \ public: \
_opr_name(Handle *handle): _base_name(handle) {} \ _opr_name(Handle* handle) : _base_name(handle) {}
#define DEF_OPR_IMPL(_opr_name, _base_name, _nr_inputs, _nr_outputs) \ #define DEF_OPR_IMPL(_opr_name, _base_name, _nr_inputs, _nr_outputs) \
DEF_OPR_IMPL_CTOR(_opr_name, _base_name) \ DEF_OPR_IMPL_CTOR(_opr_name, _base_name) \
static MEGDNN_CONSTEXPR int NR_INPUTS = _nr_inputs; \ static MEGDNN_CONSTEXPR int NR_INPUTS = _nr_inputs; \
static MEGDNN_CONSTEXPR int NR_OUTPUTS = _nr_outputs; \ static MEGDNN_CONSTEXPR int NR_OUTPUTS = _nr_outputs;
#define DEF_OPR_PARAM(_pname) \ #define DEF_OPR_PARAM(_pname) \
public: \ public: \
using Param = param::_pname; \ using Param = param::_pname; \
Param& param() { return m_param; } \ Param& param() { return m_param; } \
const Param& param() const { return m_param; } \ const Param& param() const { return m_param; } \
protected: \ \
Param m_param protected: \
Param m_param
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -20,4 +20,3 @@ ...@@ -20,4 +20,3 @@
#endif #endif
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -16,25 +16,21 @@ ...@@ -16,25 +16,21 @@
namespace megdnn { namespace megdnn {
namespace opr_result { namespace opr_result {
struct Checksum { struct Checksum {
uint32_t checksum; uint32_t checksum;
union { union {
int32_t iv; int32_t iv;
float fv; float fv;
} last_val; } last_val;
bool operator == (const Checksum &rhs) const { bool operator==(const Checksum& rhs) const {
return checksum == rhs.checksum && return checksum == rhs.checksum && last_val.iv == rhs.last_val.iv;
last_val.iv == rhs.last_val.iv; }
}
bool operator!=(const Checksum& rhs) const { return !operator==(rhs); }
bool operator != (const Checksum &rhs) const { };
return !operator==(rhs);
} } // namespace opr_result
}; } // namespace megdnn
} // namespace opr_result
} // namespace megdnn
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -12,11 +12,11 @@ ...@@ -12,11 +12,11 @@
#include "megdnn/oprs/cv.h" #include "megdnn/oprs/cv.h"
#include "megdnn/oprs/general.h" #include "megdnn/oprs/general.h"
#include "megdnn/oprs/imgproc.h"
#include "megdnn/oprs/linalg.h"
#include "megdnn/oprs/nn.h" #include "megdnn/oprs/nn.h"
#include "megdnn/oprs/nn_int.h" #include "megdnn/oprs/nn_int.h"
#include "megdnn/oprs/imgproc.h"
#include "megdnn/oprs/utils.h" #include "megdnn/oprs/utils.h"
#include "megdnn/oprs/linalg.h"
template <typename Opr> template <typename Opr>
struct OprArityTrait; struct OprArityTrait;
...@@ -53,6 +53,4 @@ INST_ARITY(megdnn::PoolingBackward, 3, 1); ...@@ -53,6 +53,4 @@ INST_ARITY(megdnn::PoolingBackward, 3, 1);
#undef INST_ARITY #undef INST_ARITY
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
此差异已折叠。
...@@ -31,15 +31,17 @@ class FlipForward : public FlipBase { ...@@ -31,15 +31,17 @@ class FlipForward : public FlipBase {
DEF_OPR_IMPL(FlipForward, FlipBase, 1, 1); DEF_OPR_IMPL(FlipForward, FlipBase, 1, 1);
public: public:
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, virtual void exec(
_megdnn_workspace workspace) = 0; _megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& src, TensorLayout& dst); void deduce_layout(const TensorLayout& src, TensorLayout& dst);
virtual size_t get_workspace_in_bytes(const TensorLayout& src, virtual size_t get_workspace_in_bytes(
const TensorLayout& dst) = 0; const TensorLayout& src, const TensorLayout& dst) = 0;
protected: protected:
void check_exec(const TensorLayout& src, const TensorLayout& dst, void check_exec(
size_t workspace_in_bytes); const TensorLayout& src, const TensorLayout& dst,
size_t workspace_in_bytes);
}; };
using Flip = FlipForward; using Flip = FlipForward;
...@@ -56,15 +58,17 @@ class RotateForward : public RotateBase { ...@@ -56,15 +58,17 @@ class RotateForward : public RotateBase {
DEF_OPR_IMPL(RotateForward, RotateBase, 1, 1); DEF_OPR_IMPL(RotateForward, RotateBase, 1, 1);
public: public:
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, virtual void exec(
_megdnn_workspace workspace) = 0; _megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& src, TensorLayout& dst); void deduce_layout(const TensorLayout& src, TensorLayout& dst);
virtual size_t get_workspace_in_bytes(const TensorLayout& src, virtual size_t get_workspace_in_bytes(
const TensorLayout& dst) = 0; const TensorLayout& src, const TensorLayout& dst) = 0;
protected: protected:
void check_exec(const TensorLayout& src, const TensorLayout& dst, void check_exec(
size_t workspace_in_bytes); const TensorLayout& src, const TensorLayout& dst,
size_t workspace_in_bytes);
}; };
using Rotate = RotateForward; using Rotate = RotateForward;
...@@ -81,15 +85,17 @@ class ROICopyForward : public ROICopyBase { ...@@ -81,15 +85,17 @@ class ROICopyForward : public ROICopyBase {
DEF_OPR_IMPL(ROICopyForward, ROICopyBase, 1, 1); DEF_OPR_IMPL(ROICopyForward, ROICopyBase, 1, 1);
public: public:
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, virtual void exec(
_megdnn_workspace workspace) = 0; _megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& src, TensorLayout& dst); void deduce_layout(const TensorLayout& src, TensorLayout& dst);
virtual size_t get_workspace_in_bytes(const TensorLayout& src, virtual size_t get_workspace_in_bytes(
const TensorLayout& dst) = 0; const TensorLayout& src, const TensorLayout& dst) = 0;
protected: protected:
void check_exec(const TensorLayout& src, const TensorLayout& dst, void check_exec(
size_t workspace_in_bytes); const TensorLayout& src, const TensorLayout& dst,
size_t workspace_in_bytes);
}; };
using ROICopy = ROICopyForward; using ROICopy = ROICopyForward;
...@@ -106,15 +112,17 @@ class CvtColorForward : public CvtColorBase { ...@@ -106,15 +112,17 @@ class CvtColorForward : public CvtColorBase {
DEF_OPR_IMPL(CvtColorForward, CvtColorBase, 1, 1); DEF_OPR_IMPL(CvtColorForward, CvtColorBase, 1, 1);
public: public:
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, virtual void exec(
_megdnn_workspace workspace) = 0; _megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& src, TensorLayout& dst); void deduce_layout(const TensorLayout& src, TensorLayout& dst);
virtual size_t get_workspace_in_bytes(const TensorLayout& src, virtual size_t get_workspace_in_bytes(
const TensorLayout& dst) = 0; const TensorLayout& src, const TensorLayout& dst) = 0;
protected: protected:
void check_exec(const TensorLayout& src, const TensorLayout& dst, void check_exec(
size_t workspace_in_bytes); const TensorLayout& src, const TensorLayout& dst,
size_t workspace_in_bytes);
}; };
using CvtColor = CvtColorForward; using CvtColor = CvtColorForward;
...@@ -130,8 +138,9 @@ public: ...@@ -130,8 +138,9 @@ public:
using BorderMode = Param::BorderMode; using BorderMode = Param::BorderMode;
protected: protected:
void check_layout_fwd(const TensorLayout& src, const TensorLayout& trans, void check_layout_fwd(
const TensorLayout& dst); const TensorLayout& src, const TensorLayout& trans,
const TensorLayout& dst);
std::string param_msg() const; std::string param_msg() const;
int get_real_coord(int p, int len); int get_real_coord(int p, int len);
}; };
...@@ -148,15 +157,17 @@ public: ...@@ -148,15 +157,17 @@ public:
* \warning src, trans, border_value, dst should be contiguous * \warning src, trans, border_value, dst should be contiguous
* The size of trans is N * 2 * 3 * The size of trans is N * 2 * 3
*/ */
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in trans, virtual void exec(
_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; _megdnn_tensor_in src, _megdnn_tensor_in trans, _megdnn_tensor_out dst,
virtual size_t get_workspace_in_bytes(const TensorLayout& src, _megdnn_workspace workspace) = 0;
const TensorLayout& trans, virtual size_t get_workspace_in_bytes(
const TensorLayout& dst) = 0; const TensorLayout& src, const TensorLayout& trans,
const TensorLayout& dst) = 0;
protected: protected:
void check_exec(const TensorLayout& src, const TensorLayout& trans, void check_exec(
const TensorLayout& dst, size_t workspace_in_bytes); const TensorLayout& src, const TensorLayout& trans, const TensorLayout& dst,
size_t workspace_in_bytes);
}; };
using WarpAffine = WarpAffineForward; using WarpAffine = WarpAffineForward;
...@@ -173,15 +184,17 @@ class GaussianBlurForward : public GaussianBlurBase { ...@@ -173,15 +184,17 @@ class GaussianBlurForward : public GaussianBlurBase {
DEF_OPR_IMPL(GaussianBlurForward, GaussianBlurBase, 1, 1); DEF_OPR_IMPL(GaussianBlurForward, GaussianBlurBase, 1, 1);
public: public:
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, virtual void exec(
_megdnn_workspace workspace) = 0; _megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& src, TensorLayout& dst); void deduce_layout(const TensorLayout& src, TensorLayout& dst);
virtual size_t get_workspace_in_bytes(const TensorLayout& src, virtual size_t get_workspace_in_bytes(
const TensorLayout& dst) = 0; const TensorLayout& src, const TensorLayout& dst) = 0;
protected: protected:
void check_exec(const TensorLayout& src, const TensorLayout& dst, void check_exec(
size_t workspace_in_bytes); const TensorLayout& src, const TensorLayout& dst,
size_t workspace_in_bytes);
}; };
using GaussianBlur = GaussianBlurForward; using GaussianBlur = GaussianBlurForward;
...@@ -212,15 +225,17 @@ class ResizeForward : public ResizeBase { ...@@ -212,15 +225,17 @@ class ResizeForward : public ResizeBase {
DEF_OPR_IMPL(ResizeForward, ResizeBase, 1, 1); DEF_OPR_IMPL(ResizeForward, ResizeBase, 1, 1);
public: public:
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, virtual void exec(
_megdnn_workspace workspace) = 0; _megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
virtual size_t get_workspace_in_bytes(const TensorLayout& src, virtual size_t get_workspace_in_bytes(
const TensorLayout& dst) = 0; const TensorLayout& src, const TensorLayout& dst) = 0;
protected: protected:
void check_exec(const TensorLayout& src, const TensorLayout& dst, void check_exec(
size_t workspace_in_bytes); const TensorLayout& src, const TensorLayout& dst,
size_t workspace_in_bytes);
}; };
using Resize = ResizeForward; using Resize = ResizeForward;
...@@ -228,15 +243,17 @@ class ResizeBackward : public ResizeBase { ...@@ -228,15 +243,17 @@ class ResizeBackward : public ResizeBase {
DEF_OPR_IMPL(ResizeBackward, ResizeBase, 1, 1); DEF_OPR_IMPL(ResizeBackward, ResizeBase, 1, 1);
public: public:
virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad, virtual void exec(
_megdnn_workspace workspace) = 0; _megdnn_tensor_in diff, _megdnn_tensor_out grad,
_megdnn_workspace workspace) = 0;
virtual size_t get_workspace_in_bytes(const TensorLayout& diff, virtual size_t get_workspace_in_bytes(
const TensorLayout& mat) = 0; const TensorLayout& diff, const TensorLayout& mat) = 0;
protected: protected:
void check_exec(const TensorLayout& diff, const TensorLayout& mat, void check_exec(
size_t workspace_in_bytes); const TensorLayout& diff, const TensorLayout& mat,
size_t workspace_in_bytes);
}; };
/** /**
...@@ -251,29 +268,32 @@ public: ...@@ -251,29 +268,32 @@ public:
using BorderMode = Param::BorderMode; using BorderMode = Param::BorderMode;
protected: protected:
void check_layout_fwd(const TensorLayout& src, const TensorLayout& map_xy, void check_layout_fwd(
const TensorLayout& dst); const TensorLayout& src, const TensorLayout& map_xy,
void deduce_layout_fwd(const TensorLayout& src, const TensorLayout& map_xy, const TensorLayout& dst);
TensorLayout& dst); void deduce_layout_fwd(
const TensorLayout& src, const TensorLayout& map_xy, TensorLayout& dst);
}; };
class RemapForward : public RemapBase { class RemapForward : public RemapBase {
DEF_OPR_IMPL(RemapForward, RemapBase, 2, 1); DEF_OPR_IMPL(RemapForward, RemapBase, 2, 1);
public: public:
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy, virtual void exec(
_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; _megdnn_tensor_in src, _megdnn_tensor_in map_xy, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& src, const TensorLayout& map_xy, void deduce_layout(
TensorLayout& dst); const TensorLayout& src, const TensorLayout& map_xy, TensorLayout& dst);
virtual size_t get_workspace_in_bytes(const TensorLayout& src, virtual size_t get_workspace_in_bytes(
const TensorLayout& map_xy, const TensorLayout& src, const TensorLayout& map_xy,
const TensorLayout& dst) = 0; const TensorLayout& dst) = 0;
protected: protected:
void check_exec(const TensorLayout& src, const TensorLayout& map_xy, void check_exec(
const TensorLayout& dst, size_t workspace_in_bytes); const TensorLayout& src, const TensorLayout& map_xy,
const TensorLayout& dst, size_t workspace_in_bytes);
}; };
using Remap = RemapForward; using Remap = RemapForward;
...@@ -281,35 +301,37 @@ class RemapBackwardData : public RemapBase { ...@@ -281,35 +301,37 @@ class RemapBackwardData : public RemapBase {
DEF_OPR_IMPL(RemapBackwardData, RemapBase, 2, 1); DEF_OPR_IMPL(RemapBackwardData, RemapBase, 2, 1);
public: public:
virtual void exec(_megdnn_tensor_in map_xy, _megdnn_tensor_in diff, virtual void exec(
_megdnn_tensor_out grad, _megdnn_workspace workspace) = 0; _megdnn_tensor_in map_xy, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
_megdnn_workspace workspace) = 0;
virtual size_t get_workspace_in_bytes(const TensorLayout& map_xy, virtual size_t get_workspace_in_bytes(
const TensorLayout& diff, const TensorLayout& map_xy, const TensorLayout& diff,
const TensorLayout& grad) = 0; const TensorLayout& grad) = 0;
protected: protected:
void check_exec(const TensorLayout& map_xy, const TensorLayout& diff, void check_exec(
const TensorLayout& grad, size_t workspace_in_bytes); const TensorLayout& map_xy, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_in_bytes);
}; };
class RemapBackwardMat : public RemapBase { class RemapBackwardMat : public RemapBase {
DEF_OPR_IMPL(RemapBackwardMat, RemapBase, 3, 1); DEF_OPR_IMPL(RemapBackwardMat, RemapBase, 3, 1);
public: public:
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy, virtual void exec(
_megdnn_tensor_in diff, _megdnn_tensor_out grad, _megdnn_tensor_in src, _megdnn_tensor_in map_xy, _megdnn_tensor_in diff,
_megdnn_workspace workspace) = 0; _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
virtual size_t get_workspace_in_bytes(const TensorLayout& src, virtual size_t get_workspace_in_bytes(
const TensorLayout& map_xy, const TensorLayout& src, const TensorLayout& map_xy,
const TensorLayout& diff, const TensorLayout& diff, const TensorLayout& grad) = 0;
const TensorLayout& grad) = 0;
protected: protected:
void check_exec(const TensorLayout& src, const TensorLayout& map_xy, void check_exec(
const TensorLayout& diff, const TensorLayout& grad, const TensorLayout& src, const TensorLayout& map_xy,
size_t workspace_in_bytes); const TensorLayout& diff, const TensorLayout& grad,
size_t workspace_in_bytes);
}; };
class SeparableFilterBase : public OperatorBase { class SeparableFilterBase : public OperatorBase {
...@@ -317,32 +339,34 @@ class SeparableFilterBase : public OperatorBase { ...@@ -317,32 +339,34 @@ class SeparableFilterBase : public OperatorBase {
DEF_OPR_PARAM(SeparableFilter); DEF_OPR_PARAM(SeparableFilter);
protected: protected:
void deduce_layout_fwd(const TensorLayout& src, void deduce_layout_fwd(
const TensorLayout& filter_x, const TensorLayout& src, const TensorLayout& filter_x,
const TensorLayout& filter_y, TensorLayout& dst); const TensorLayout& filter_y, TensorLayout& dst);
void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter_x, void check_layout_fwd(
const TensorLayout& filter_y, const TensorLayout& src, const TensorLayout& filter_x,
const TensorLayout& dst); const TensorLayout& filter_y, const TensorLayout& dst);
}; };
class SeparableFilterForward : public SeparableFilterBase { class SeparableFilterForward : public SeparableFilterBase {
DEF_OPR_IMPL(SeparableFilterForward, SeparableFilterBase, 3, 1); DEF_OPR_IMPL(SeparableFilterForward, SeparableFilterBase, 3, 1);
public: public:
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter_x, virtual void exec(
_megdnn_tensor_in filter_y, _megdnn_tensor_out dst, _megdnn_tensor_in src, _megdnn_tensor_in filter_x,
_megdnn_workspace workspace) = 0; _megdnn_tensor_in filter_y, _megdnn_tensor_out dst,
void deduce_layout(const TensorLayout& src, const TensorLayout& filter_x, _megdnn_workspace workspace) = 0;
const TensorLayout& filter_y, TensorLayout& dst); void deduce_layout(
virtual size_t get_workspace_in_bytes(const TensorLayout& src, const TensorLayout& src, const TensorLayout& filter_x,
const TensorLayout& filter_x, const TensorLayout& filter_y, TensorLayout& dst);
const TensorLayout& filter_y, virtual size_t get_workspace_in_bytes(
const TensorLayout& dst) = 0; const TensorLayout& src, const TensorLayout& filter_x,
const TensorLayout& filter_y, const TensorLayout& dst) = 0;
protected: protected:
void check_exec(const TensorLayout& src, const TensorLayout& filter_x, void check_exec(
const TensorLayout& filter_y, const TensorLayout& dst, const TensorLayout& src, const TensorLayout& filter_x,
size_t workspace_in_bytes); const TensorLayout& filter_y, const TensorLayout& dst,
size_t workspace_in_bytes);
}; };
using SeparableFilter = SeparableFilterForward; using SeparableFilter = SeparableFilterForward;
......
此差异已折叠。
...@@ -13,173 +13,162 @@ ...@@ -13,173 +13,162 @@
namespace megdnn { namespace megdnn {
class WarpPerspectiveBase: public OperatorBase { class WarpPerspectiveBase : public OperatorBase {
DEF_OPR_IMPL_CTOR(WarpPerspectiveBase, OperatorBase); DEF_OPR_IMPL_CTOR(WarpPerspectiveBase, OperatorBase);
DEF_OPR_PARAM(WarpPerspective); DEF_OPR_PARAM(WarpPerspective);
public:
using InterpolationMode = Param::InterpolationMode; public:
using BorderMode = Param::BorderMode; using InterpolationMode = Param::InterpolationMode;
using BorderMode = Param::BorderMode;
protected:
void check_layout_fwd(const TensorLayout &src, const TensorLayout &mat, protected:
const TensorLayout &dst) { void check_layout_fwd(
check_layout_fwd(src, mat, {}, dst); const TensorLayout& src, const TensorLayout& mat, const TensorLayout& dst) {
} check_layout_fwd(src, mat, {}, dst);
}
void check_layout_fwd(const TensorLayout &src, const TensorLayout &mat,
const TensorLayout &mat_idx, const TensorLayout &dst); void check_layout_fwd(
std::string param_msg() const; const TensorLayout& src, const TensorLayout& mat,
int get_real_coord(int p, int len); const TensorLayout& mat_idx, const TensorLayout& dst);
std::string param_msg() const;
int get_real_coord(int p, int len);
}; };
class WarpPerspectiveForward: public WarpPerspectiveBase { class WarpPerspectiveForward : public WarpPerspectiveBase {
DEF_OPR_IMPL(WarpPerspectiveForward, WarpPerspectiveBase, 0, 1); DEF_OPR_IMPL(WarpPerspectiveForward, WarpPerspectiveBase, 0, 1);
public:
/** public:
* \param[in] src (n, channel, in_height, in_width) /**
* \param[in] mat (n, 3, 3) * \param[in] src (n, channel, in_height, in_width)
* \param[out] dst (n, channel, out_height, out_width) * \param[in] mat (n, 3, 3)
* * \param[out] dst (n, channel, out_height, out_width)
* \see http://docs.opencv.org/2.4/modules/imgproc/doc/geometric_transformations.html?highlight=warpaffine *
* * \see
* denominator = mat[2][0]*w+mat[2][1]*h+mat[2][2] * http://docs.opencv.org/2.4/modules/imgproc/doc/geometric_transformations.html?highlight=warpaffine
* dst(h, w) = src((mat[1][0]*w+mat[1][1]*h+mat[1][2])/denominator, *
* (mat[0][0]*w+mat[0][1]*h+mat[0][2])/denominator) * denominator = mat[2][0]*w+mat[2][1]*h+mat[2][2]
* * dst(h, w) = src((mat[1][0]*w+mat[1][1]*h+mat[1][2])/denominator,
* src and dst can have different shapes, as long as their n and c agree. * (mat[0][0]*w+mat[0][1]*h+mat[0][2])/denominator)
* src, mat and dst should be contiguous. *
*/ * src and dst can have different shapes, as long as their n and c agree.
void exec(_megdnn_tensor_in src, * src, mat and dst should be contiguous.
_megdnn_tensor_in mat, */
_megdnn_tensor_out dst, void exec(
_megdnn_workspace workspace) { _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_out dst,
exec(src, mat, {}, dst, workspace); _megdnn_workspace workspace) {
} exec(src, mat, {}, dst, workspace);
}
/**
* \p src should have batch size m, and \p mat and \p mat_idx should /**
* both have batch size n. Each item in \p mat_idx must be in the range * \p src should have batch size m, and \p mat and \p mat_idx should
* of [0, m-1]. * both have batch size n. Each item in \p mat_idx must be in the range
* * of [0, m-1].
* \param mat_idx the indices of input image that each matrix in \p mat *
* should act on. It can also be empty and in such case \p mat * \param mat_idx the indices of input image that each matrix in \p mat
* should have the same batch size as \p src. * should act on. It can also be empty and in such case \p mat
*/ * should have the same batch size as \p src.
virtual void exec(_megdnn_tensor_in src, */
_megdnn_tensor_in mat, virtual void exec(
_megdnn_tensor_in mat_idx, _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx,
_megdnn_tensor_out dst, _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
_megdnn_workspace workspace) = 0;
size_t get_workspace_in_bytes(
size_t get_workspace_in_bytes(const TensorLayout &src, const TensorLayout& src, const TensorLayout& mat, const TensorLayout& dst) {
const TensorLayout &mat, return get_workspace_in_bytes(src, mat, {}, dst);
const TensorLayout &dst) { }
return get_workspace_in_bytes(src, mat, {}, dst);
} virtual size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& mat,
virtual size_t get_workspace_in_bytes(const TensorLayout &src, const TensorLayout& mat_idx, const TensorLayout& dst) = 0;
const TensorLayout &mat,
const TensorLayout &mat_idx, protected:
const TensorLayout &dst) = 0; void check_exec(
protected: const TensorLayout& src, const TensorLayout& mat,
void check_exec(const TensorLayout &src, const TensorLayout& mat_idx, const TensorLayout& dst,
const TensorLayout &mat, size_t workspace_in_bytes);
const TensorLayout &mat_idx,
const TensorLayout &dst, void check_exec_allow_nhwc_mat_idx(
size_t workspace_in_bytes); const TensorLayout& src, const TensorLayout& mat,
const TensorLayout& mat_idx, const TensorLayout& dst,
void check_exec_allow_nhwc_mat_idx(const TensorLayout &src, size_t workspace_in_bytes);
const TensorLayout &mat,
const TensorLayout &mat_idx,
const TensorLayout &dst,
size_t workspace_in_bytes);
}; };
using WarpPerspective = WarpPerspectiveForward; using WarpPerspective = WarpPerspectiveForward;
class WarpPerspectiveBackwardData: public WarpPerspectiveBase { class WarpPerspectiveBackwardData : public WarpPerspectiveBase {
DEF_OPR_IMPL(WarpPerspectiveBackwardData, WarpPerspectiveBase, 2, 1); DEF_OPR_IMPL(WarpPerspectiveBackwardData, WarpPerspectiveBase, 2, 1);
public:
/** public:
* \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec /**
* \param[in] diff the backpropagated gradient wrt. dst * \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec
* \param[out] grad the backpropagated gradient wrt. src * \param[in] diff the backpropagated gradient wrt. dst
* \param[out] workspace temporary workspace to perform backward * \param[out] grad the backpropagated gradient wrt. src
*/ * \param[out] workspace temporary workspace to perform backward
void exec(_megdnn_tensor_in mat, */
_megdnn_tensor_in diff, void exec(
_megdnn_tensor_out grad, _megdnn_tensor_in mat, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
_megdnn_workspace workspace) { _megdnn_workspace workspace) {
exec(mat, {}, diff, grad, workspace); exec(mat, {}, diff, grad, workspace);
} }
virtual void exec(_megdnn_tensor_in mat, virtual void exec(
_megdnn_tensor_in mat_idx, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, _megdnn_tensor_in diff,
_megdnn_tensor_in diff, _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
_megdnn_tensor_out grad,
_megdnn_workspace workspace) = 0; size_t get_workspace_in_bytes(
const TensorLayout& mat, const TensorLayout& diff,
size_t get_workspace_in_bytes(const TensorLayout &mat, const TensorLayout& grad) {
const TensorLayout &diff, return get_workspace_in_bytes(mat, {}, diff, grad);
const TensorLayout &grad) { }
return get_workspace_in_bytes(mat, {}, diff, grad);
} virtual size_t get_workspace_in_bytes(
const TensorLayout& mat, const TensorLayout& mat_idx,
virtual size_t get_workspace_in_bytes(const TensorLayout &mat, const TensorLayout& diff, const TensorLayout& grad) = 0;
const TensorLayout &mat_idx,
const TensorLayout &diff, protected:
const TensorLayout &grad) = 0; void check_exec(
protected: const TensorLayout& mat, const TensorLayout& mat_idx,
void check_exec(const TensorLayout &mat, const TensorLayout& diff, const TensorLayout& grad,
const TensorLayout &mat_idx, size_t workspace_in_bytes);
const TensorLayout &diff,
const TensorLayout &grad,
size_t workspace_in_bytes);
}; };
class WarpPerspectiveBackwardMat: public WarpPerspectiveBase { class WarpPerspectiveBackwardMat : public WarpPerspectiveBase {
DEF_OPR_IMPL(WarpPerspectiveBackwardMat, WarpPerspectiveBase, 3, 1); DEF_OPR_IMPL(WarpPerspectiveBackwardMat, WarpPerspectiveBase, 3, 1);
public:
/** public:
* \param[in] src the `src' parameter in WarpPerspectiveForward::exec /**
* \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec * \param[in] src the `src' parameter in WarpPerspectiveForward::exec
* \param[in] diff the backpropagated gradient wrt. dst * \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec
* \param[out] grad the backpropagated gradient wrt. mat * \param[in] diff the backpropagated gradient wrt. dst
* \param[out] workspace temporary workspace to perform backward * \param[out] grad the backpropagated gradient wrt. mat
*/ * \param[out] workspace temporary workspace to perform backward
void exec(_megdnn_tensor_in src, */
_megdnn_tensor_in mat, void exec(
_megdnn_tensor_in diff, _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in diff,
_megdnn_tensor_out grad, _megdnn_tensor_out grad, _megdnn_workspace workspace) {
_megdnn_workspace workspace) { exec(src, mat, {}, diff, grad, workspace);
exec(src, mat, {}, diff, grad, workspace); }
}
virtual void exec(
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx,
_megdnn_tensor_in mat, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
_megdnn_tensor_in mat_idx, _megdnn_workspace workspace) = 0;
_megdnn_tensor_in diff,
_megdnn_tensor_out grad, size_t get_workspace_in_bytes(
_megdnn_workspace workspace) = 0; const TensorLayout& src, const TensorLayout& mat, const TensorLayout& diff,
const TensorLayout& grad) {
size_t get_workspace_in_bytes(const TensorLayout &src, return get_workspace_in_bytes(src, mat, {}, diff, grad);
const TensorLayout &mat, }
const TensorLayout &diff,
const TensorLayout &grad) { virtual size_t get_workspace_in_bytes(
return get_workspace_in_bytes(src, mat, {}, diff, grad); const TensorLayout& src, const TensorLayout& mat,
} const TensorLayout& mat_idx, const TensorLayout& diff,
const TensorLayout& grad) = 0;
virtual size_t get_workspace_in_bytes(const TensorLayout &src,
const TensorLayout &mat, protected:
const TensorLayout &mat_idx, void check_exec(
const TensorLayout &diff, const TensorLayout& src, const TensorLayout& mat,
const TensorLayout &grad) = 0; const TensorLayout& mat_idx, const TensorLayout& diff,
protected: const TensorLayout& grad, size_t workspace_in_bytes);
void check_exec(const TensorLayout &src,
const TensorLayout &mat,
const TensorLayout &mat_idx,
const TensorLayout &diff,
const TensorLayout &grad,
size_t workspace_in_bytes);
}; };
class DctChannelSelectForward : public OperatorBase { class DctChannelSelectForward : public OperatorBase {
...@@ -194,37 +183,32 @@ public: ...@@ -194,37 +183,32 @@ public:
* \param[dst] DctChannelSelectForward output, default fp32 nchw tensor * \param[dst] DctChannelSelectForward output, default fp32 nchw tensor
* \param[out] workspace temporary workspace to perform forward * \param[out] workspace temporary workspace to perform forward
*/ */
virtual void exec(_megdnn_tensor_in src, virtual void exec(
_megdnn_tensor_in mask_offset, _megdnn_tensor_in src, _megdnn_tensor_in mask_offset,
_megdnn_tensor_in mask_val, _megdnn_tensor_in mask_val, _megdnn_tensor_out dst,
_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
_megdnn_workspace workspace) = 0;
void deduce_layout(
void deduce_layout(const TensorLayout& src, const TensorLayout& src, const TensorLayout& mask_offset,
const TensorLayout& mask_offset, const TensorLayout& mask_val, TensorLayout& dst);
const TensorLayout& mask_val,
TensorLayout& dst); virtual size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& mask_offset,
virtual size_t get_workspace_in_bytes(const TensorLayout& src, const TensorLayout& mask_val, const TensorLayout& dst) = 0;
const TensorLayout& mask_offset,
const TensorLayout& mask_val,
const TensorLayout& dst) = 0;
protected: protected:
void check_layout_fwd(const TensorLayout& src, void check_layout_fwd(
const TensorLayout& mask_offset, const TensorLayout& src, const TensorLayout& mask_offset,
const TensorLayout& mask_val, const TensorLayout& mask_val, const TensorLayout& dst);
const TensorLayout& dst);
void deduce_layout_fwd(
void deduce_layout_fwd(const TensorLayout& src, const TensorLayout& src, const TensorLayout& mask_offset,
const TensorLayout& mask_offset, const TensorLayout& mask_val, TensorLayout& dst);
const TensorLayout& mask_val,
TensorLayout& dst);
std::string param_msg() const; std::string param_msg() const;
}; };
} // namespace megdnn } // namespace megdnn
#include "megdnn/internal/opr_header_epilogue.h" #include "megdnn/internal/opr_header_epilogue.h"
......
...@@ -33,22 +33,22 @@ public: ...@@ -33,22 +33,22 @@ public:
* op(A) = A if transposeA is false, otherwise op(A) = A^t. * op(A) = A if transposeA is false, otherwise op(A) = A^t.
* op(B) = B if transposeB is false, otherwise op(B) = B^t. * op(B) = B if transposeB is false, otherwise op(B) = B^t.
*/ */
virtual void exec(_megdnn_tensor_in A, _megdnn_tensor_in B, virtual void exec(
_megdnn_tensor_out C, _megdnn_workspace workspace) = 0; _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
void deduce_dtype(DType A, DType B, DType &C); _megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& A, const TensorLayout& B, void deduce_dtype(DType A, DType B, DType& C);
TensorLayout& C); void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C);
virtual size_t get_workspace_in_bytes(const TensorLayout& A, virtual size_t get_workspace_in_bytes(
const TensorLayout& B, const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0;
const TensorLayout& C) = 0;
static Algorithm::OprType get_opr_type() { static Algorithm::OprType get_opr_type() {
return Algorithm::OprType::BATCHED_MATRIX_MUL_FORWARD; return Algorithm::OprType::BATCHED_MATRIX_MUL_FORWARD;
} }
protected: protected:
void check_exec(const TensorLayout& A, const TensorLayout& B, void check_exec(
const TensorLayout& C, size_t workspace_in_bytes); const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
size_t workspace_in_bytes);
}; };
using BatchedMatrixMul = BatchedMatrixMulForward; using BatchedMatrixMul = BatchedMatrixMulForward;
...@@ -70,24 +70,24 @@ public: ...@@ -70,24 +70,24 @@ public:
* op(A) = A if transposeA is false, otherwise op(A) = A^t. * op(A) = A if transposeA is false, otherwise op(A) = A^t.
* op(B) = B if transposeB is false, otherwise op(B) = B^t. * op(B) = B if transposeB is false, otherwise op(B) = B^t.
*/ */
virtual void exec(_megdnn_tensor_in A, _megdnn_tensor_in B, virtual void exec(
_megdnn_tensor_out C, _megdnn_workspace workspace) = 0; _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
_megdnn_workspace workspace) = 0;
void deduce_dtype(DType A, DType B, DType& C); void deduce_dtype(DType A, DType B, DType& C);
void deduce_layout(const TensorLayout& A, const TensorLayout& B, void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C);
TensorLayout& C); virtual size_t get_workspace_in_bytes(
virtual size_t get_workspace_in_bytes(const TensorLayout& A, const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0;
const TensorLayout& B,
const TensorLayout& C) = 0;
static size_t pack_size (const Param::Format format); static size_t pack_size(const Param::Format format);
static Algorithm::OprType get_opr_type() { static Algorithm::OprType get_opr_type() {
return Algorithm::OprType::MATRIX_MUL_FORWARD; return Algorithm::OprType::MATRIX_MUL_FORWARD;
} }
protected: protected:
void check_exec(const TensorLayout& A, const TensorLayout& B, void check_exec(
const TensorLayout& C, size_t workspace_in_bytes); const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
size_t workspace_in_bytes);
}; };
using MatrixMul = MatrixMulForward; using MatrixMul = MatrixMulForward;
...@@ -104,11 +104,11 @@ class MatrixInverse : public OperatorBase { ...@@ -104,11 +104,11 @@ class MatrixInverse : public OperatorBase {
DEF_OPR_PARAM(Empty); DEF_OPR_PARAM(Empty);
public: public:
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, virtual void exec(
_megdnn_workspace workspace) = 0; _megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& src, TensorLayout& dst); void deduce_layout(const TensorLayout& src, TensorLayout& dst);
size_t get_workspace_in_bytes(const TensorLayout& src, size_t get_workspace_in_bytes(const TensorLayout& src, const TensorLayout& dst);
const TensorLayout& dst);
protected: protected:
/*! /*!
...@@ -116,8 +116,7 @@ protected: ...@@ -116,8 +116,7 @@ protected:
* *
* Note that \p batch and \p n can be null * Note that \p batch and \p n can be null
*/ */
static void canonize_params(const TensorLayout& layout, size_t* batch, static void canonize_params(const TensorLayout& layout, size_t* batch, size_t* n);
size_t* n);
/*! /*!
* \brief canonize and validate input params for exec() impls * \brief canonize and validate input params for exec() impls
...@@ -125,11 +124,12 @@ protected: ...@@ -125,11 +124,12 @@ protected:
* Since get_workspace_in_bytes() would be called, \p batch and \p n can not * Since get_workspace_in_bytes() would be called, \p batch and \p n can not
* be null * be null
*/ */
void check_exec(const TensorLayout& src, const TensorLayout& dst, void check_exec(
_megdnn_workspace workspace, size_t* batch, size_t* n); const TensorLayout& src, const TensorLayout& dst,
_megdnn_workspace workspace, size_t* batch, size_t* n);
virtual size_t get_workspace_in_bytes(size_t batch, size_t n, virtual size_t get_workspace_in_bytes(
size_t dtype_size) = 0; size_t batch, size_t n, size_t dtype_size) = 0;
}; };
//! inter-product of two vectors //! inter-product of two vectors
...@@ -147,17 +147,17 @@ public: ...@@ -147,17 +147,17 @@ public:
* A, B, C must be contiguous. A and B must have the same 1-dimensional * A, B, C must be contiguous. A and B must have the same 1-dimensional
* shape and non-negative strides. C must be scalar. * shape and non-negative strides. C must be scalar.
*/ */
virtual void exec(_megdnn_tensor_in A, _megdnn_tensor_in B, virtual void exec(
_megdnn_tensor_out C, _megdnn_workspace workspace) = 0; _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
void deduce_layout(const TensorLayout& A, const TensorLayout& B, _megdnn_workspace workspace) = 0;
TensorLayout& C); void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C);
virtual size_t get_workspace_in_bytes(const TensorLayout& A, virtual size_t get_workspace_in_bytes(
const TensorLayout& B, const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0;
const TensorLayout& C) = 0;
protected: protected:
void check_exec(const TensorLayout& A, const TensorLayout& B, void check_exec(
const TensorLayout& C, size_t workspace_in_bytes); const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
size_t workspace_in_bytes);
}; };
using Dot = DotForward; using Dot = DotForward;
...@@ -193,23 +193,24 @@ public: ...@@ -193,23 +193,24 @@ public:
* if compute_uv is false (default to true). * if compute_uv is false (default to true).
* *
*/ */
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out u, virtual void exec(
_megdnn_tensor_out s, _megdnn_tensor_out vt, _megdnn_tensor_in src, _megdnn_tensor_out u, _megdnn_tensor_out s,
_megdnn_workspace workspace) = 0; _megdnn_tensor_out vt, _megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& src, TensorLayout& u, void deduce_layout(
TensorLayout& s, TensorLayout& vt); const TensorLayout& src, TensorLayout& u, TensorLayout& s,
size_t get_workspace_in_bytes(const TensorLayout& src, TensorLayout& vt);
const TensorLayout& u, const TensorLayout& s, size_t get_workspace_in_bytes(
const TensorLayout& vt); const TensorLayout& src, const TensorLayout& u, const TensorLayout& s,
const TensorLayout& vt);
protected: protected:
static void canonize_params(const TensorLayout& layout, size_t* batch, static void canonize_params(
size_t* m, size_t* n); const TensorLayout& layout, size_t* batch, size_t* m, size_t* n);
virtual size_t get_workspace_in_bytes(size_t block_cnt, size_t m, size_t n, virtual size_t get_workspace_in_bytes(
size_t dtype_size) = 0; size_t block_cnt, size_t m, size_t n, size_t dtype_size) = 0;
void check_exec(const TensorLayout& src, const TensorLayout& u, void check_exec(
const TensorLayout& s, const TensorLayout& vt, const TensorLayout& src, const TensorLayout& u, const TensorLayout& s,
size_t workspace_in_bytes); const TensorLayout& vt, size_t workspace_in_bytes);
}; };
using SVD = SVDForward; using SVD = SVDForward;
......
此差异已折叠。
...@@ -36,7 +36,7 @@ public: ...@@ -36,7 +36,7 @@ public:
struct ModeTrait { struct ModeTrait {
uint32_t arity = 0; //!< number of inputs needed uint32_t arity = 0; //!< number of inputs needed
CheckDtypeFunc check_inp[MAX_ARITY]; CheckDtypeFunc check_inp[MAX_ARITY];
SetOrCheckDtypeFunc check_out; //!< dtype of output var SetOrCheckDtypeFunc check_out; //!< dtype of output var
bool need_specify_out_dtype = bool need_specify_out_dtype =
false; //!< the dtype should be setup externally, otherwise false; //!< the dtype should be setup externally, otherwise
//!< would be inferred by check_out(dtype, false) //!< would be inferred by check_out(dtype, false)
...@@ -46,13 +46,10 @@ public: ...@@ -46,13 +46,10 @@ public:
static const ModeTrait& from_mode(Mode mode); static const ModeTrait& from_mode(Mode mode);
}; };
virtual void exec(_megdnn_in const TensorNDArray& src, virtual void exec(_megdnn_in const TensorNDArray& src, _megdnn_tensor_out dst) = 0;
_megdnn_tensor_out dst) = 0;
//! get trait of current mode //! get trait of current mode
const ModeTrait& mode_trait() const { const ModeTrait& mode_trait() const { return ModeTrait::from_mode(m_param.mode); }
return ModeTrait::from_mode(m_param.mode);
}
//! deduce output layout //! deduce output layout
void deduce_layout(const TensorLayoutArray& src, TensorLayout& dst); void deduce_layout(const TensorLayoutArray& src, TensorLayout& dst);
...@@ -60,8 +57,8 @@ public: ...@@ -60,8 +57,8 @@ public:
protected: protected:
//! throw exception if incorrect layout; broadcast input shape to //! throw exception if incorrect layout; broadcast input shape to
//! output shape //! output shape
void check_layout_and_broadcast(const TensorLayoutPtrArray& src, void check_layout_and_broadcast(
const TensorLayout& dst); const TensorLayoutPtrArray& src, const TensorLayout& dst);
}; };
} // namespace megdnn } // namespace megdnn
......
...@@ -15,84 +15,97 @@ ...@@ -15,84 +15,97 @@
namespace megdnn { namespace megdnn {
//! base class for random number generators //! base class for random number generators
class RNGBase: public OperatorBase { class RNGBase : public OperatorBase {
DEF_OPR_IMPL_CTOR(RNGBase, OperatorBase); DEF_OPR_IMPL_CTOR(RNGBase, OperatorBase);
public:
virtual void exec(_megdnn_tensor_out dst, public:
_megdnn_workspace workspace) = 0; virtual void exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
virtual size_t get_workspace_in_bytes(const TensorLayout &dst) = 0; virtual size_t get_workspace_in_bytes(const TensorLayout& dst) = 0;
protected:
virtual void check_exec(const TensorLayout &dst, size_t workspace_in_bytes) = 0; protected:
virtual void check_exec(const TensorLayout& dst, size_t workspace_in_bytes) = 0;
}; };
//! sample from poisson distribution //! sample from poisson distribution
class PoissonRNG: public OperatorBase { class PoissonRNG : public OperatorBase {
DEF_OPR_IMPL(PoissonRNG, OperatorBase, 1, 1); DEF_OPR_IMPL(PoissonRNG, OperatorBase, 1, 1);
DEF_OPR_PARAM(PoissonRNG); DEF_OPR_PARAM(PoissonRNG);
public:
virtual void exec(_megdnn_tensor_in lam, public:
_megdnn_tensor_out dst, virtual void exec(
_megdnn_workspace workspace) = 0; _megdnn_tensor_in lam, _megdnn_tensor_out dst,
virtual size_t get_workspace_in_bytes(const TensorLayout &lam, _megdnn_workspace workspace) = 0;
const TensorLayout &dst) = 0; virtual size_t get_workspace_in_bytes(
protected: const TensorLayout& lam, const TensorLayout& dst) = 0;
void check_exec(const TensorLayout &lam, const TensorLayout &dst,
size_t workspace_in_bytes); protected:
void check_exec(
const TensorLayout& lam, const TensorLayout& dst,
size_t workspace_in_bytes);
}; };
//! sample from beta distribution //! sample from beta distribution
class BetaRNG: public OperatorBase { class BetaRNG : public OperatorBase {
DEF_OPR_IMPL(BetaRNG, OperatorBase, 2, 1); DEF_OPR_IMPL(BetaRNG, OperatorBase, 2, 1);
DEF_OPR_PARAM(BetaRNG); DEF_OPR_PARAM(BetaRNG);
public:
virtual void exec(_megdnn_tensor_in alpha, public:
_megdnn_tensor_in beta, virtual void exec(
_megdnn_tensor_out dst, _megdnn_tensor_in alpha, _megdnn_tensor_in beta, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0; _megdnn_workspace workspace) = 0;
virtual size_t get_workspace_in_bytes(const TensorLayout &alpha, virtual size_t get_workspace_in_bytes(
const TensorLayout &beta, const TensorLayout &dst) = 0; const TensorLayout& alpha, const TensorLayout& beta,
protected: const TensorLayout& dst) = 0;
void check_exec(const TensorLayout &alpha, const TensorLayout &beta,
const TensorLayout &dst, size_t workspace_in_bytes); protected:
void check_exec(
const TensorLayout& alpha, const TensorLayout& beta,
const TensorLayout& dst, size_t workspace_in_bytes);
}; };
//! sample from gamma distribution //! sample from gamma distribution
class GammaRNG: public OperatorBase { class GammaRNG : public OperatorBase {
DEF_OPR_IMPL(GammaRNG, OperatorBase, 2, 1); DEF_OPR_IMPL(GammaRNG, OperatorBase, 2, 1);
DEF_OPR_PARAM(GammaRNG); DEF_OPR_PARAM(GammaRNG);
public:
virtual void exec(_megdnn_tensor_in shape, public:
_megdnn_tensor_in scale, virtual void exec(
_megdnn_tensor_out dst, _megdnn_tensor_in shape, _megdnn_tensor_in scale, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0; _megdnn_workspace workspace) = 0;
virtual size_t get_workspace_in_bytes(const TensorLayout &shape, virtual size_t get_workspace_in_bytes(
const TensorLayout &scale, const TensorLayout &dst) = 0; const TensorLayout& shape, const TensorLayout& scale,
protected: const TensorLayout& dst) = 0;
void check_exec(const TensorLayout &shape, const TensorLayout &scale,
const TensorLayout &dst, size_t workspace_in_bytes); protected:
void check_exec(
const TensorLayout& shape, const TensorLayout& scale,
const TensorLayout& dst, size_t workspace_in_bytes);
}; };
//! sample from uniform distribution on the interval (0, 1] //! sample from uniform distribution on the interval (0, 1]
class UniformRNG: public RNGBase { class UniformRNG : public RNGBase {
DEF_OPR_IMPL(UniformRNG, RNGBase, 0, 1); DEF_OPR_IMPL(UniformRNG, RNGBase, 0, 1);
DEF_OPR_PARAM(UniformRNG); DEF_OPR_PARAM(UniformRNG);
protected:
void check_exec(const TensorLayout &dst, size_t workspace_in_bytes); protected:
void check_exec(const TensorLayout& dst, size_t workspace_in_bytes);
}; };
//! sample from gaussian distribution //! sample from gaussian distribution
class GaussianRNG: public RNGBase { class GaussianRNG : public RNGBase {
DEF_OPR_IMPL(GaussianRNG, RNGBase, 0, 1); DEF_OPR_IMPL(GaussianRNG, RNGBase, 0, 1);
DEF_OPR_PARAM(GaussianRNG); DEF_OPR_PARAM(GaussianRNG);
protected:
void check_exec(const TensorLayout &dst, size_t workspace_in_bytes); protected:
void check_exec(const TensorLayout& dst, size_t workspace_in_bytes);
}; };
class PermutationRNG: public RNGBase { class PermutationRNG : public RNGBase {
DEF_OPR_IMPL(PermutationRNG, RNGBase, 0, 1); DEF_OPR_IMPL(PermutationRNG, RNGBase, 0, 1);
DEF_OPR_PARAM(PermutationRNG); DEF_OPR_PARAM(PermutationRNG);
protected:
void check_exec(const TensorLayout &dst, size_t workspace_in_bytes); protected:
void check_exec(const TensorLayout& dst, size_t workspace_in_bytes);
}; };
class ShuffleRNGForward : public OperatorBase { class ShuffleRNGForward : public OperatorBase {
...@@ -100,18 +113,19 @@ class ShuffleRNGForward : public OperatorBase { ...@@ -100,18 +113,19 @@ class ShuffleRNGForward : public OperatorBase {
DEF_OPR_PARAM(ShuffleRNG); DEF_OPR_PARAM(ShuffleRNG);
public: public:
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, virtual void exec(
_megdnn_tensor_out indices, _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_tensor_out indices,
_megdnn_workspace workspace) = 0; _megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& src, TensorLayout& dst, void deduce_layout(
TensorLayout& indices); const TensorLayout& src, TensorLayout& dst, TensorLayout& indices);
virtual size_t get_workspace_in_bytes(const TensorLayout& src, virtual size_t get_workspace_in_bytes(
const TensorLayout& dst, const TensorLayout& src, const TensorLayout& dst,
const TensorLayout& indices) = 0; const TensorLayout& indices) = 0;
protected: protected:
void check_exec(const TensorLayout& src, const TensorLayout& dst, void check_exec(
const TensorLayout& indices, size_t workspace_in_bytes); const TensorLayout& src, const TensorLayout& dst,
const TensorLayout& indices, size_t workspace_in_bytes);
}; };
using ShuffleRNG = ShuffleRNGForward; using ShuffleRNG = ShuffleRNGForward;
...@@ -120,27 +134,29 @@ class ShuffleRNGBackward : public OperatorBase { ...@@ -120,27 +134,29 @@ class ShuffleRNGBackward : public OperatorBase {
DEF_OPR_PARAM(ShuffleRNG); DEF_OPR_PARAM(ShuffleRNG);
public: public:
virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in indices, virtual void exec(
_megdnn_tensor_out grad, _megdnn_workspace workspace) = 0; _megdnn_tensor_in diff, _megdnn_tensor_in indices, _megdnn_tensor_out grad,
virtual size_t get_workspace_in_bytes(const TensorLayout& diff, _megdnn_workspace workspace) = 0;
const TensorLayout& indices, virtual size_t get_workspace_in_bytes(
const TensorLayout& grad) = 0; const TensorLayout& diff, const TensorLayout& indices,
const TensorLayout& grad) = 0;
protected: protected:
void check_exec(const TensorLayout& diff, const TensorLayout& indices, void check_exec(
const TensorLayout& grad, size_t workspace_in_bytes); const TensorLayout& diff, const TensorLayout& indices,
const TensorLayout& grad, size_t workspace_in_bytes);
}; };
/*! /*!
* \brief sleep for specific time on the computing device; useful for testing * \brief sleep for specific time on the computing device; useful for testing
* async problems * async problems
*/ */
class SleepForward: public OperatorBase { class SleepForward : public OperatorBase {
DEF_OPR_IMPL(SleepForward, OperatorBase, 0, 0); DEF_OPR_IMPL(SleepForward, OperatorBase, 0, 0);
DEF_OPR_PARAM(Sleep); DEF_OPR_PARAM(Sleep);
public: public:
virtual void exec() = 0; virtual void exec() = 0;
}; };
using Sleep = SleepForward; using Sleep = SleepForward;
...@@ -149,20 +165,19 @@ using Sleep = SleepForward; ...@@ -149,20 +165,19 @@ using Sleep = SleepForward;
* *
* data must be a one-dimensional contiguous tensor with dtype byte * data must be a one-dimensional contiguous tensor with dtype byte
*/ */
class ChecksumForward: public OperatorBase { class ChecksumForward : public OperatorBase {
DEF_OPR_PARAM(Empty); DEF_OPR_PARAM(Empty);
DEF_OPR_IMPL(ChecksumForward, OperatorBase, 0, 1); DEF_OPR_IMPL(ChecksumForward, OperatorBase, 0, 1);
public: public:
using Result = opr_result::Checksum; using Result = opr_result::Checksum;
virtual size_t get_workspace_in_bytes(const TensorLayout &data) = 0; virtual size_t get_workspace_in_bytes(const TensorLayout& data) = 0;
virtual Result exec(_megdnn_tensor_in data, virtual Result exec(_megdnn_tensor_in data, _megdnn_workspace workspace) = 0;
_megdnn_workspace workspace) = 0;
protected: protected:
void check_exec(const TensorLayout &layout, size_t workspace_in_bytes); void check_exec(const TensorLayout& layout, size_t workspace_in_bytes);
}; };
using Checksum = ChecksumForward; using Checksum = ChecksumForward;
...@@ -175,21 +190,22 @@ class MaxTensorDiff : public OperatorBase { ...@@ -175,21 +190,22 @@ class MaxTensorDiff : public OperatorBase {
DEF_OPR_PARAM(Empty); DEF_OPR_PARAM(Empty);
DEF_OPR_IMPL(MaxTensorDiff, OperatorBase, 0, 2); DEF_OPR_IMPL(MaxTensorDiff, OperatorBase, 0, 2);
public: public:
virtual size_t get_workspace_in_bytes(const TensorLayout& layout1, virtual size_t get_workspace_in_bytes(
const TensorLayout& layout2) = 0; const TensorLayout& layout1, const TensorLayout& layout2) = 0;
virtual float exec(_megdnn_tensor_in src1, _megdnn_tensor_in src2, virtual float exec(
_megdnn_workspace workspace) = 0; _megdnn_tensor_in src1, _megdnn_tensor_in src2,
_megdnn_workspace workspace) = 0;
protected: protected:
void check_exec(const TensorLayout& layout1, void check_exec(
const TensorLayout& layout2, size_t workspace_in_bytes); const TensorLayout& layout1, const TensorLayout& layout2,
size_t workspace_in_bytes);
}; };
bool check_bias_share_in_channel(
bool check_bias_share_in_channel(const TensorLayout& bias, const TensorLayout& bias, const param::ConvBias::Format format);
const param::ConvBias::Format format);
} // namespace megdnn } // namespace megdnn
......
...@@ -18,9 +18,9 @@ ...@@ -18,9 +18,9 @@
namespace megdnn { namespace megdnn {
enum class TensorFormat::Type { enum class TensorFormat::Type {
DEFAULT = 0, //!< see DefaultTensorFormat DEFAULT = 0, //!< see DefaultTensorFormat
IMAGE2D_PACK4 = 1, //!< see Image2DPack4TensorFormat IMAGE2D_PACK4 = 1, //!< see Image2DPack4TensorFormat
LOWBITS_ALIGNED_TO_BYTE = 2, //!< LOWBITS_ALIGNED_TO_BYTE = 2, //!<
}; };
class TensorFormat::ImplBase { class TensorFormat::ImplBase {
...@@ -33,8 +33,7 @@ public: ...@@ -33,8 +33,7 @@ public:
virtual bool is_contiguous_spec(const TensorLayout& layout) const = 0; virtual bool is_contiguous_spec(const TensorLayout& layout) const = 0;
virtual TensorLayout collapse_contiguous_spec( virtual TensorLayout collapse_contiguous_spec(const TensorLayout& layout) const = 0;
const TensorLayout& layout) const = 0;
virtual TensorLayout::Span span_spec(const TensorLayout& layout) const = 0; virtual TensorLayout::Span span_spec(const TensorLayout& layout) const = 0;
...@@ -79,8 +78,7 @@ public: ...@@ -79,8 +78,7 @@ public:
*/ */
bool is_contiguous_spec(const TensorLayout& layout) const override; bool is_contiguous_spec(const TensorLayout& layout) const override;
TensorLayout collapse_contiguous_spec( TensorLayout collapse_contiguous_spec(const TensorLayout& layout) const override;
const TensorLayout& layout) const override;
TensorLayout::Span span_spec(const TensorLayout& layout) const override; TensorLayout::Span span_spec(const TensorLayout& layout) const override;
...@@ -88,8 +86,7 @@ public: ...@@ -88,8 +86,7 @@ public:
void serialize_append(std::string& result) const override; void serialize_append(std::string& result) const override;
static TensorFormat make(); static TensorFormat make();
static TensorFormat deserialize(const Handle* handle, const void* buf, static TensorFormat deserialize(const Handle* handle, const void* buf, size_t size);
size_t size);
}; };
namespace detail { namespace detail {
...@@ -112,8 +109,8 @@ class Image2DTensorFormatBase : public TensorFormat::ImplBase { ...@@ -112,8 +109,8 @@ class Image2DTensorFormatBase : public TensorFormat::ImplBase {
size_t m_align_axis, m_align_size_in_elements_log2; size_t m_align_axis, m_align_size_in_elements_log2;
protected: protected:
Image2DTensorFormatBase(Type type, size_t align_axis, Image2DTensorFormatBase(
size_t align_size_in_elements); Type type, size_t align_axis, size_t align_size_in_elements);
virtual ~Image2DTensorFormatBase() = default; virtual ~Image2DTensorFormatBase() = default;
public: public:
...@@ -129,9 +126,7 @@ public: ...@@ -129,9 +126,7 @@ public:
size_t align_axis() const { return m_align_axis; } size_t align_axis() const { return m_align_axis; }
size_t align_size_in_elements_log2() const { size_t align_size_in_elements_log2() const { return m_align_size_in_elements_log2; }
return m_align_size_in_elements_log2;
}
std::string to_string() const override; std::string to_string() const override;
...@@ -145,6 +140,7 @@ public: ...@@ -145,6 +140,7 @@ public:
size_t image_height(const TensorLayout& layout) const; size_t image_height(const TensorLayout& layout) const;
void serialize_append(std::string& result) const override; void serialize_append(std::string& result) const override;
protected: protected:
struct SerializePack { struct SerializePack {
uint8_t align_axis; uint8_t align_axis;
...@@ -160,15 +156,14 @@ class Image2DPackedTensorFormatBase : public Image2DTensorFormatBase { ...@@ -160,15 +156,14 @@ class Image2DPackedTensorFormatBase : public Image2DTensorFormatBase {
* align COUNT, but mdl needs align size in byte, which equal to * align COUNT, but mdl needs align size in byte, which equal to
* (image_width algin count) * sizeof(data_type) * pixel_size * (image_width algin count) * sizeof(data_type) * pixel_size
*/ */
size_t image_pitch_alignment_in_bytes(size_t align_size_in_elements, size_t image_pitch_alignment_in_bytes(
const TensorLayout& layout) const; size_t align_size_in_elements, const TensorLayout& layout) const;
protected: protected:
Image2DPackedTensorFormatBase(Type type, size_t align_axis, Image2DPackedTensorFormatBase(
size_t align_size_in_elements, Type type, size_t align_axis, size_t align_size_in_elements,
Handle::HandleVendorType vendor_type) Handle::HandleVendorType vendor_type)
: detail::Image2DTensorFormatBase(type, align_axis, : detail::Image2DTensorFormatBase(type, align_axis, align_size_in_elements),
align_size_in_elements),
m_vendor_type(vendor_type) {} m_vendor_type(vendor_type) {}
virtual ~Image2DPackedTensorFormatBase() = default; virtual ~Image2DPackedTensorFormatBase() = default;
...@@ -197,13 +192,12 @@ public: ...@@ -197,13 +192,12 @@ public:
bool is_contiguous_spec(const TensorLayout& layout) const override; bool is_contiguous_spec(const TensorLayout& layout) const override;
TensorLayout collapse_contiguous_spec( TensorLayout collapse_contiguous_spec(const TensorLayout& layout) const override;
const TensorLayout& layout) const override;
}; };
using Image2DPack4TensorFormatBase = Image2DPackedTensorFormatBase<4>; using Image2DPack4TensorFormatBase = Image2DPackedTensorFormatBase<4>;
/*! /*!
* \brief used for tensors storing lowbit data * \brief used for tensors storing lowbit data
* *
* \param m_size_nbits size in bits of elements in the tensor * \param m_size_nbits size in bits of elements in the tensor
* \param m_align_size_in_bits aligned size in bits * \param m_align_size_in_bits aligned size in bits
...@@ -213,14 +207,14 @@ class LowbitsAlignedTensorFormatBase : public TensorFormat::ImplBase { ...@@ -213,14 +207,14 @@ class LowbitsAlignedTensorFormatBase : public TensorFormat::ImplBase {
size_t m_size_nbits, m_align_size_in_bits, m_align_size_in_elements; size_t m_size_nbits, m_align_size_in_bits, m_align_size_in_elements;
protected: //? protected: //?
LowbitsAlignedTensorFormatBase(Type type, size_t size_nbits, LowbitsAlignedTensorFormatBase(
size_t align_size_in_bits); Type type, size_t size_nbits, size_t align_size_in_bits);
virtual ~LowbitsAlignedTensorFormatBase() = default; virtual ~LowbitsAlignedTensorFormatBase() = default;
public: public:
size_t align_size_in_bits() const { return m_align_size_in_bits; } size_t align_size_in_bits() const { return m_align_size_in_bits; }
size_t size_nbits() const { return m_size_nbits; } size_t size_nbits() const { return m_size_nbits; }
std::string to_string() const override; std::string to_string() const override;
...@@ -238,8 +232,8 @@ public: ...@@ -238,8 +232,8 @@ public:
bool is_contiguous_spec(const TensorLayout& layout) const override; bool is_contiguous_spec(const TensorLayout& layout) const override;
TensorLayout collapse_contiguous_spec( TensorLayout collapse_contiguous_spec(const TensorLayout& layout) const override;
const TensorLayout& layout) const override;
protected: protected:
struct SerializePack { struct SerializePack {
uint8_t size_nbits; uint8_t size_nbits;
...@@ -254,16 +248,14 @@ protected: ...@@ -254,16 +248,14 @@ protected:
* *
* This is used for OpenCL. * This is used for OpenCL.
*/ */
class Image2DPack4TensorFormat final class Image2DPack4TensorFormat final : public detail::Image2DPack4TensorFormatBase {
: public detail::Image2DPack4TensorFormatBase {
public: public:
static constexpr Type TYPE = Type::IMAGE2D_PACK4; static constexpr Type TYPE = Type::IMAGE2D_PACK4;
//! for internal usage or test purposes //! for internal usage or test purposes
static TensorFormat make_raw(size_t align_axis, static TensorFormat make_raw(
size_t align_size_in_elements, size_t align_axis, size_t align_size_in_elements,
Handle::HandleVendorType vendor_type = Handle::HandleVendorType vendor_type = Handle::HandleVendorType::NOT_SPEC);
Handle::HandleVendorType::NOT_SPEC);
static TensorFormat make(size_t align_axis, const Handle* handle); static TensorFormat make(size_t align_axis, const Handle* handle);
...@@ -273,13 +265,11 @@ public: ...@@ -273,13 +265,11 @@ public:
* Note that the alignment may be different if deserialized on another * Note that the alignment may be different if deserialized on another
* handle * handle
*/ */
static TensorFormat deserialize(const Handle* handle, const void* buf, static TensorFormat deserialize(const Handle* handle, const void* buf, size_t size);
size_t size);
static bool is_valid_image(const TensorLayout& layout) { static bool is_valid_image(const TensorLayout& layout) {
if (layout.format.type() == TYPE) { if (layout.format.type() == TYPE) {
layout.format.as_impl<Image2DPack4TensorFormat>().assert_valid( layout.format.as_impl<Image2DPack4TensorFormat>().assert_valid(layout);
layout);
return true; return true;
} }
return false; return false;
...@@ -288,8 +278,9 @@ public: ...@@ -288,8 +278,9 @@ public:
TensorFormat change_axis(size_t axis) const override; TensorFormat change_axis(size_t axis) const override;
private: private:
Image2DPack4TensorFormat(size_t align_axis, size_t align_size_in_elements, Image2DPack4TensorFormat(
Handle::HandleVendorType vendor_type) size_t align_axis, size_t align_size_in_elements,
Handle::HandleVendorType vendor_type)
: detail::Image2DPack4TensorFormatBase( : detail::Image2DPack4TensorFormatBase(
TYPE, align_axis, align_size_in_elements, vendor_type) {} TYPE, align_axis, align_size_in_elements, vendor_type) {}
}; };
...@@ -306,13 +297,12 @@ public: ...@@ -306,13 +297,12 @@ public:
static TensorFormat make(size_t size_nbits); static TensorFormat make(size_t size_nbits);
static TensorFormat deserialize(const Handle* handle, const void* buf, static TensorFormat deserialize(const Handle* handle, const void* buf, size_t size);
size_t size);
static bool is_valid_layout(const TensorLayout& layout) { static bool is_valid_layout(const TensorLayout& layout) {
if (layout.format.type() == TYPE) { if (layout.format.type() == TYPE) {
layout.format.as_impl<LowbitsAlignedToBytesTensorFormat>() layout.format.as_impl<LowbitsAlignedToBytesTensorFormat>().assert_valid(
.assert_valid(layout); layout);
return true; return true;
} }
return false; return false;
...@@ -320,8 +310,7 @@ public: ...@@ -320,8 +310,7 @@ public:
private: private:
LowbitsAlignedToBytesTensorFormat(size_t size_nbits) LowbitsAlignedToBytesTensorFormat(size_t size_nbits)
: detail::LowbitsAlignedTensorFormatBase(TYPE, size_nbits, : detail::LowbitsAlignedTensorFormatBase(TYPE, size_nbits, BYTE_IN_BITS) {}
BYTE_IN_BITS) {}
}; };
} // namespace megdnn } // namespace megdnn
......
...@@ -167,13 +167,11 @@ public: ...@@ -167,13 +167,11 @@ public:
TensorIter(const TensorND& tensor) : m_tensor(tensor) {} TensorIter(const TensorND& tensor) : m_tensor(tensor) {}
Iter begin() const { Iter begin() const { return Iter::make(const_cast<TensorND&>(m_tensor), 0); }
return Iter::make(const_cast<TensorND&>(m_tensor), 0);
}
Iter end() const { Iter end() const {
return Iter::make(const_cast<TensorND&>(m_tensor), return Iter::make(
m_tensor.layout.total_nr_elems()); const_cast<TensorND&>(m_tensor), m_tensor.layout.total_nr_elems());
} }
}; };
/*! /*!
......
...@@ -11,19 +11,19 @@ ...@@ -11,19 +11,19 @@
#pragma once #pragma once
#include <type_traits> #include <cstdlib>
#include <functional> #include <functional>
#include <utility>
#include <memory> #include <memory>
#include <cstdlib> #include <type_traits>
#include <utility>
#include "megdnn/internal/visibility_prologue.h" #include "megdnn/internal/visibility_prologue.h"
namespace megdnn { namespace megdnn {
template<typename Signature> template <typename Signature>
using thin_function = ::std::function<Signature>; using thin_function = ::std::function<Signature>;
} // namespace megdnn } // namespace megdnn
#include "megdnn/internal/visibility_epilogue.h" #include "megdnn/internal/visibility_epilogue.h"
......
...@@ -58,18 +58,16 @@ protected: ...@@ -58,18 +58,16 @@ protected:
m_end_ptr(first_elm), m_end_ptr(first_elm),
m_capacity_ptr(static_cast<char*>(first_elm) + size) {} m_capacity_ptr(static_cast<char*>(first_elm) + size) {}
void grow_pod(void* first_elm_ptr, size_t min_sz_in_bytes, void grow_pod(void* first_elm_ptr, size_t min_sz_in_bytes, size_t type_size);
size_t type_size);
public: public:
size_t size_in_bytes() const { size_t size_in_bytes() const {
return size_t(static_cast<char*>(m_end_ptr) - return size_t(static_cast<char*>(m_end_ptr) - static_cast<char*>(m_begin_ptr));
static_cast<char*>(m_begin_ptr));
} }
size_t capacity_in_bytes() const { size_t capacity_in_bytes() const {
return size_t(static_cast<char*>(m_capacity_ptr) - return size_t(
static_cast<char*>(m_begin_ptr)); static_cast<char*>(m_capacity_ptr) - static_cast<char*>(m_begin_ptr));
} }
bool empty() const { return m_begin_ptr == m_end_ptr; } bool empty() const { return m_begin_ptr == m_end_ptr; }
...@@ -85,20 +83,15 @@ private: ...@@ -85,20 +83,15 @@ private:
U m_first_elm; U m_first_elm;
protected: protected:
SmallVectorTemplateCommon(size_t size) SmallVectorTemplateCommon(size_t size) : SmallVectorBase(&m_first_elm, size) {}
: SmallVectorBase(&m_first_elm, size) {}
void grow_pod(size_t min_sz_in_bytes, size_t type_size) { void grow_pod(size_t min_sz_in_bytes, size_t type_size) {
SmallVectorBase::grow_pod(&m_first_elm, min_sz_in_bytes, type_size); SmallVectorBase::grow_pod(&m_first_elm, min_sz_in_bytes, type_size);
} }
bool is_small() { bool is_small() { return m_begin_ptr == static_cast<const void*>(&m_first_elm); }
return m_begin_ptr == static_cast<const void*>(&m_first_elm);
}
void reset_to_small() { void reset_to_small() { m_begin_ptr = m_end_ptr = m_capacity_ptr = &m_first_elm; }
m_begin_ptr = m_end_ptr = m_capacity_ptr = &m_first_elm;
}
void set_end(T* p) { m_end_ptr = p; } void set_end(T* p) { m_end_ptr = p; }
...@@ -128,20 +121,12 @@ protected: ...@@ -128,20 +121,12 @@ protected:
public: public:
// forwarding iterator creation // forwarding iterator creation
iterator begin() { return static_cast<iterator>(m_begin_ptr); } iterator begin() { return static_cast<iterator>(m_begin_ptr); }
const_iterator begin() const { const_iterator begin() const { return static_cast<const_iterator>(m_begin_ptr); }
return static_cast<const_iterator>(m_begin_ptr); const_iterator cbegin() const { return static_cast<const_iterator>(m_begin_ptr); }
}
const_iterator cbegin() const {
return static_cast<const_iterator>(m_begin_ptr);
}
iterator end() { return static_cast<iterator>(m_end_ptr); } iterator end() { return static_cast<iterator>(m_end_ptr); }
const_iterator end() const { const_iterator end() const { return static_cast<const_iterator>(m_end_ptr); }
return static_cast<const_iterator>(m_end_ptr); const_iterator cend() const { return static_cast<const_iterator>(m_end_ptr); }
}
const_iterator cend() const {
return static_cast<const_iterator>(m_end_ptr);
}
reference at(size_type idx) { reference at(size_type idx) {
if (idx >= size()) { if (idx >= size()) {
...@@ -167,13 +152,9 @@ public: ...@@ -167,13 +152,9 @@ public:
// reverse iterator creation method. // reverse iterator creation method.
reverse_iterator rbegin() { return reverse_iterator(end()); } reverse_iterator rbegin() { return reverse_iterator(end()); }
const_reverse_iterator rbegin() const { const_reverse_iterator rbegin() const { return const_reverse_iterator(end()); }
return const_reverse_iterator(end());
}
reverse_iterator rend() { return reverse_iterator(begin()); } reverse_iterator rend() { return reverse_iterator(begin()); }
const_reverse_iterator rend() const { const_reverse_iterator rend() const { return const_reverse_iterator(begin()); }
return const_reverse_iterator(begin());
}
pointer data() { return pointer(begin()); } pointer data() { return pointer(begin()); }
const_pointer data() const { return const_pointer(begin()); } const_pointer data() const { return const_pointer(begin()); }
...@@ -207,8 +188,8 @@ protected: ...@@ -207,8 +188,8 @@ protected:
template <typename It1, typename It2> template <typename It1, typename It2>
static void uninitialized_move(It1 first, It1 last, It2 dest) { static void uninitialized_move(It1 first, It1 last, It2 dest) {
std::uninitialized_copy(std::make_move_iterator(first), std::uninitialized_copy(
std::make_move_iterator(last), dest); std::make_move_iterator(first), std::make_move_iterator(last), dest);
} }
template <typename It1, typename It2> template <typename It1, typename It2>
...@@ -293,9 +274,7 @@ protected: ...@@ -293,9 +274,7 @@ protected:
memcpy(dest, first, (last - first) * sizeof(T)); memcpy(dest, first, (last - first) * sizeof(T));
} }
void grow(size_t min_sz = 0) { void grow(size_t min_sz = 0) { this->grow_pod(min_sz * sizeof(T), sizeof(T)); }
this->grow_pod(min_sz * sizeof(T), sizeof(T));
}
public: public:
void push_back(const T& _elm) { void push_back(const T& _elm) {
...@@ -318,8 +297,7 @@ public: ...@@ -318,8 +297,7 @@ public:
* SmallVector<T, N> can be converted to SmallVectorImpl<T> to erase N * SmallVector<T, N> can be converted to SmallVectorImpl<T> to erase N
*/ */
template <typename T> template <typename T>
class SmallVectorImpl class SmallVectorImpl : public SmallVectorTemplateBase<T, std::is_pod<T>::value> {
: public SmallVectorTemplateBase<T, std::is_pod<T>::value> {
using SuperClass = SmallVectorTemplateBase<T, std::is_pod<T>::value>; using SuperClass = SmallVectorTemplateBase<T, std::is_pod<T>::value>;
public: public:
...@@ -329,8 +307,7 @@ public: ...@@ -329,8 +307,7 @@ public:
protected: protected:
explicit SmallVectorImpl(unsigned n) explicit SmallVectorImpl(unsigned n)
: SmallVectorTemplateBase<T, std::is_pod<T>::value>(n * sizeof(T)) { : SmallVectorTemplateBase<T, std::is_pod<T>::value>(n * sizeof(T)) {}
}
public: public:
SmallVectorImpl(const SmallVectorImpl&) = delete; SmallVectorImpl(const SmallVectorImpl&) = delete;
...@@ -354,8 +331,7 @@ public: ...@@ -354,8 +331,7 @@ public:
} else if (n > this->size()) { } else if (n > this->size()) {
if (this->capacity() < n) if (this->capacity() < n)
this->grow(n); this->grow(n);
for (auto it = this->end(), end = this->begin() + n; it != end; for (auto it = this->end(), end = this->begin() + n; it != end; ++it)
++it)
new (&*it) T(); new (&*it) T();
this->set_end(this->begin() + n); this->set_end(this->begin() + n);
} }
...@@ -389,10 +365,11 @@ public: ...@@ -389,10 +365,11 @@ public:
void swap(SmallVectorImpl<T>& rhs); void swap(SmallVectorImpl<T>& rhs);
/// Add the specified range to the end of the SmallVector. /// Add the specified range to the end of the SmallVector.
template <typename in_iter, template <
typename = typename std::enable_if<std::is_convertible< typename in_iter,
typename std::iterator_traits<in_iter>::iterator_category, typename = typename std::enable_if<std::is_convertible<
std::input_iterator_tag>::value>::type> typename std::iterator_traits<in_iter>::iterator_category,
std::input_iterator_tag>::value>::type>
void append(in_iter in_start, in_iter in_end) { void append(in_iter in_start, in_iter in_end) {
size_type num_inputs = std::distance(in_start, in_end); size_type num_inputs = std::distance(in_start, in_end);
// Grow allocated space if needed. // Grow allocated space if needed.
...@@ -432,10 +409,11 @@ public: ...@@ -432,10 +409,11 @@ public:
std::uninitialized_fill(this->begin(), this->end(), elm); std::uninitialized_fill(this->begin(), this->end(), elm);
} }
template <typename in_iter, template <
typename = typename std::enable_if<std::is_convertible< typename in_iter,
typename std::iterator_traits<in_iter>::iterator_category, typename = typename std::enable_if<std::is_convertible<
std::input_iterator_tag>::value>::type> typename std::iterator_traits<in_iter>::iterator_category,
std::input_iterator_tag>::value>::type>
void assign(in_iter in_start, in_iter in_end) { void assign(in_iter in_start, in_iter in_end) {
clear(); clear();
append(in_start, in_end); append(in_start, in_end);
...@@ -571,8 +549,7 @@ public: ...@@ -571,8 +549,7 @@ public:
std::fill_n(it, num_overwritten, elm); std::fill_n(it, num_overwritten, elm);
// Insert the non-overwritten middle part. // Insert the non-overwritten middle part.
std::uninitialized_fill_n(old_end, num_to_insert - num_overwritten, std::uninitialized_fill_n(old_end, num_to_insert - num_overwritten, elm);
elm);
return it; return it;
} }
...@@ -646,8 +623,7 @@ public: ...@@ -646,8 +623,7 @@ public:
if (megdnn_unlikely(this->m_end_ptr >= this->m_capacity_ptr)) { if (megdnn_unlikely(this->m_end_ptr >= this->m_capacity_ptr)) {
this->grow(); this->grow();
} }
new (static_cast<void*>(this->end())) new (static_cast<void*>(this->end())) T(std::forward<ArgTypes>(args)...);
T(std::forward<ArgTypes>(args)...);
this->set_end(this->end() + 1); this->set_end(this->end() + 1);
} }
...@@ -661,13 +637,11 @@ public: ...@@ -661,13 +637,11 @@ public:
return std::equal(this->begin(), this->end(), rhs.begin()); return std::equal(this->begin(), this->end(), rhs.begin());
} }
bool operator!=(const SmallVectorImpl<T>& rhs) const { bool operator!=(const SmallVectorImpl<T>& rhs) const { return !(*this == rhs); }
return !(*this == rhs);
}
bool operator<(const SmallVectorImpl<T>& rhs) const { bool operator<(const SmallVectorImpl<T>& rhs) const {
return std::lexicographical_compare(this->begin(), this->end(), return std::lexicographical_compare(
rhs.begin(), rhs.end()); this->begin(), this->end(), rhs.begin(), rhs.end());
} }
}; };
...@@ -698,15 +672,13 @@ void SmallVectorImpl<T>::swap(SmallVectorImpl<T>& rhs) { ...@@ -698,15 +672,13 @@ void SmallVectorImpl<T>::swap(SmallVectorImpl<T>& rhs) {
// Copy over the extra elms. // Copy over the extra elms.
if (this->size() > rhs.size()) { if (this->size() > rhs.size()) {
size_t elm_diff = this->size() - rhs.size(); size_t elm_diff = this->size() - rhs.size();
this->uninitialized_move(this->begin() + num_shared, this->end(), this->uninitialized_move(this->begin() + num_shared, this->end(), rhs.end());
rhs.end());
rhs.set_end(rhs.end() + elm_diff); rhs.set_end(rhs.end() + elm_diff);
this->destroy_range(this->begin() + num_shared, this->end()); this->destroy_range(this->begin() + num_shared, this->end());
this->set_end(this->begin() + num_shared); this->set_end(this->begin() + num_shared);
} else if (rhs.size() > this->size()) { } else if (rhs.size() > this->size()) {
size_t elm_diff = rhs.size() - this->size(); size_t elm_diff = rhs.size() - this->size();
this->uninitialized_move(rhs.begin() + num_shared, rhs.end(), this->uninitialized_move(rhs.begin() + num_shared, rhs.end(), this->end());
this->end());
this->set_end(this->end() + elm_diff); this->set_end(this->end() + elm_diff);
this->destroy_range(rhs.begin() + num_shared, rhs.end()); this->destroy_range(rhs.begin() + num_shared, rhs.end());
rhs.set_end(rhs.begin() + num_shared); rhs.set_end(rhs.begin() + num_shared);
...@@ -714,8 +686,7 @@ void SmallVectorImpl<T>::swap(SmallVectorImpl<T>& rhs) { ...@@ -714,8 +686,7 @@ void SmallVectorImpl<T>::swap(SmallVectorImpl<T>& rhs) {
} }
template <typename T> template <typename T>
SmallVectorImpl<T>& SmallVectorImpl<T>::operator=( SmallVectorImpl<T>& SmallVectorImpl<T>::operator=(const SmallVectorImpl<T>& rhs) {
const SmallVectorImpl<T>& rhs) {
if (this == &rhs) if (this == &rhs)
return *this; return *this;
size_t rhs_sz = rhs.size(); size_t rhs_sz = rhs.size();
...@@ -740,8 +711,7 @@ SmallVectorImpl<T>& SmallVectorImpl<T>::operator=( ...@@ -740,8 +711,7 @@ SmallVectorImpl<T>& SmallVectorImpl<T>::operator=(
} else if (cur_sz) { } else if (cur_sz) {
std::copy(rhs.begin(), rhs.begin() + cur_sz, this->begin()); std::copy(rhs.begin(), rhs.begin() + cur_sz, this->begin());
} }
std::uninitialized_copy(rhs.begin() + cur_sz, rhs.end(), std::uninitialized_copy(rhs.begin() + cur_sz, rhs.end(), this->begin() + cur_sz);
this->begin() + cur_sz);
this->set_end(this->begin() + rhs_sz); this->set_end(this->begin() + rhs_sz);
return *this; return *this;
} }
...@@ -785,8 +755,7 @@ SmallVectorImpl<T>& SmallVectorImpl<T>::operator=(SmallVectorImpl<T>&& rhs) { ...@@ -785,8 +755,7 @@ SmallVectorImpl<T>& SmallVectorImpl<T>::operator=(SmallVectorImpl<T>&& rhs) {
std::move(rhs.begin(), rhs.begin() + cur_sz, this->begin()); std::move(rhs.begin(), rhs.begin() + cur_sz, this->begin());
} }
this->uninitialized_move(rhs.begin() + cur_sz, rhs.end(), this->uninitialized_move(rhs.begin() + cur_sz, rhs.end(), this->begin() + cur_sz);
this->begin() + cur_sz);
this->set_end(this->begin() + rhs_sz); this->set_end(this->begin() + rhs_sz);
...@@ -826,8 +795,7 @@ class SmallVector : public SmallVectorImpl<T> { ...@@ -826,8 +795,7 @@ class SmallVector : public SmallVectorImpl<T> {
public: public:
SmallVector() : SmallVectorImpl<T>(N) {} SmallVector() : SmallVectorImpl<T>(N) {}
explicit SmallVector(size_t size, const T& value = T()) explicit SmallVector(size_t size, const T& value = T()) : SmallVectorImpl<T>(N) {
: SmallVectorImpl<T>(N) {
this->assign(size, value); this->assign(size, value);
} }
...@@ -901,15 +869,13 @@ namespace std { ...@@ -901,15 +869,13 @@ namespace std {
/// Implement std::swap in terms of SmallVector swap. /// Implement std::swap in terms of SmallVector swap.
template <typename T> template <typename T>
inline void swap(megdnn::SmallVectorImpl<T>& lhs, inline void swap(megdnn::SmallVectorImpl<T>& lhs, megdnn::SmallVectorImpl<T>& rhs) {
megdnn::SmallVectorImpl<T>& rhs) {
lhs.swap(rhs); lhs.swap(rhs);
} }
/// Implement std::swap in terms of SmallVector swap. /// Implement std::swap in terms of SmallVector swap.
template <typename T, unsigned N> template <typename T, unsigned N>
inline void swap(megdnn::SmallVector<T, N>& lhs, inline void swap(megdnn::SmallVector<T, N>& lhs, megdnn::SmallVector<T, N>& rhs) {
megdnn::SmallVector<T, N>& rhs) {
lhs.swap(rhs); lhs.swap(rhs);
} }
} // end namespace std } // end namespace std
......
...@@ -17,13 +17,13 @@ ...@@ -17,13 +17,13 @@
#include "megdnn/internal/visibility_prologue.h" #include "megdnn/internal/visibility_prologue.h"
namespace megdnn { namespace megdnn {
struct Version { struct Version {
int major, minor, patch; int major, minor, patch;
}; };
//! get megdnn version of the binary //! get megdnn version of the binary
Version get_version(); Version get_version();
} } // namespace megdnn
#include "megdnn/internal/visibility_epilogue.h" #include "megdnn/internal/visibility_epilogue.h"
......
...@@ -22,18 +22,17 @@ using namespace aarch64; ...@@ -22,18 +22,17 @@ using namespace aarch64;
/* ===================== stride-2 algo ===================== */ /* ===================== stride-2 algo ===================== */
MIDOUT_DECL(megdnn_aarch64_conv_bias_stride2_conv2357_fp16) MIDOUT_DECL(megdnn_aarch64_conv_bias_stride2_conv2357_fp16)
bool ConvBiasImpl::AlgoF16DirectStride2::usable(const NCBKernSizeParam& param, bool ConvBiasImpl::AlgoF16DirectStride2::usable(
AlgoSelectionStrategy) const { const NCBKernSizeParam& param, AlgoSelectionStrategy) const {
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 0) { MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 0) {
auto&& fm = param.filter_meta; auto&& fm = param.filter_meta;
auto FH = fm.spatial[0]; auto FH = fm.spatial[0];
return param.filter_meta.format == param::Convolution::Format::NCHW && return param.filter_meta.format == param::Convolution::Format::NCHW &&
param.src_type.enumv() == DTypeEnum::Float16 && param.src_type.enumv() == DTypeEnum::Float16 &&
param.filter_type.enumv() == DTypeEnum::Float16 && param.filter_type.enumv() == DTypeEnum::Float16 &&
param.dst_type.enumv() == DTypeEnum::Float16 && param.dst_type.enumv() == DTypeEnum::Float16 && !fm.should_flip &&
!fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] &&
FH == fm.spatial[1] &&
(FH == 2 || FH == 3 || FH == 5 || FH == 7); (FH == 2 || FH == 3 || FH == 5 || FH == 7);
} }
MIDOUT_END(); MIDOUT_END();
...@@ -52,8 +51,7 @@ size_t ConvBiasImpl::AlgoF16DirectStride2::get_workspace( ...@@ -52,8 +51,7 @@ size_t ConvBiasImpl::AlgoF16DirectStride2::get_workspace(
return 0; return 0;
} }
SmallVector<ConvBiasImpl::NCBKern> SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16DirectStride2::dispatch_kerns(
ConvBiasImpl::AlgoF16DirectStride2::dispatch_kerns(
const NCBKernSizeParam& param) const { const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 2) { MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 2) {
return get_kimpls(param); return get_kimpls(param);
...@@ -62,8 +60,7 @@ ConvBiasImpl::AlgoF16DirectStride2::dispatch_kerns( ...@@ -62,8 +60,7 @@ ConvBiasImpl::AlgoF16DirectStride2::dispatch_kerns(
return {}; return {};
} }
SmallVector<ConvBiasImpl::NCBKern> SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16DirectStride2::get_kimpls(
ConvBiasImpl::AlgoF16DirectStride2::get_kimpls(
const NCBKernSizeParam& param) const { const NCBKernSizeParam& param) const {
auto fm = param.filter_meta; auto fm = param.filter_meta;
auto FH = fm.spatial[0]; auto FH = fm.spatial[0];
...@@ -72,8 +69,9 @@ ConvBiasImpl::AlgoF16DirectStride2::get_kimpls( ...@@ -72,8 +69,9 @@ ConvBiasImpl::AlgoF16DirectStride2::get_kimpls(
size_t OC = param.filter_meta.ocpg; size_t OC = param.filter_meta.ocpg;
size_t group = fm.group; size_t group = fm.group;
bool large_group = group >= param.nr_threads; bool large_group = group >= param.nr_threads;
using Func = std::function<void(const __fp16*, const __fp16*, __fp16*, using Func = std::function<void(
size_t, size_t, size_t, size_t, size_t)>; const __fp16*, const __fp16*, __fp16*, size_t, size_t, size_t, size_t,
size_t)>;
Func conv = nullptr; Func conv = nullptr;
if (FH == 2) { if (FH == 2) {
conv = fp16::conv_stride2::do_conv_2x2_stride2; conv = fp16::conv_stride2::do_conv_2x2_stride2;
...@@ -101,31 +99,35 @@ ConvBiasImpl::AlgoF16DirectStride2::get_kimpls( ...@@ -101,31 +99,35 @@ ConvBiasImpl::AlgoF16DirectStride2::get_kimpls(
bundle.set(kern_param.workspace_ptr); bundle.set(kern_param.workspace_ptr);
for (size_t ic = 0; ic < IC; ic++) { for (size_t ic = 0; ic < IC; ic++) {
arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>:: arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>::
copy_padding_kern_stride(bundle, kern_param, ncb_index, copy_padding_kern_stride(
{ncb_index.thread_id, 0, ic}); bundle, kern_param, ncb_index,
{ncb_index.thread_id, 0, ic});
} }
for (size_t oc = 0; oc < OC; oc++) { for (size_t oc = 0; oc < OC; oc++) {
arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>:: arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>::
do_conv_kern_stride(bundle, kern_param, ncb_index, conv, do_conv_kern_stride(
{ncb_index.thread_id, 0, oc}); bundle, kern_param, ncb_index, conv,
{ncb_index.thread_id, 0, oc});
} }
}; };
ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); ret_kerns.push_back({exec_one_group, {group, N, 1_z}});
} else { } else {
auto copy_padding = [bundle](const NCBKernParam& kern_param, auto copy_padding = [bundle](
const NCBKernIndex& ncb_index) mutable { const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable {
bundle.set(kern_param.workspace_ptr); bundle.set(kern_param.workspace_ptr);
arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>:: arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>::
copy_padding_kern_stride(bundle, kern_param, ncb_index, copy_padding_kern_stride(
ncb_index.ndrange_id); bundle, kern_param, ncb_index, ncb_index.ndrange_id);
}; };
ret_kerns.push_back({copy_padding, {group, N, IC}}); ret_kerns.push_back({copy_padding, {group, N, IC}});
auto do_conv = [bundle, conv](const NCBKernParam& kern_param, auto do_conv = [bundle, conv](
const NCBKernIndex& ncb_index) mutable { const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable {
bundle.set(kern_param.workspace_ptr); bundle.set(kern_param.workspace_ptr);
arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>:: arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>::
do_conv_kern_stride(bundle, kern_param, ncb_index, conv, do_conv_kern_stride(
ncb_index.ndrange_id); bundle, kern_param, ncb_index, conv, ncb_index.ndrange_id);
}; };
ret_kerns.push_back({do_conv, {group, N, OC}}); ret_kerns.push_back({do_conv, {group, N, OC}});
} }
......
...@@ -18,13 +18,13 @@ namespace aarch64 { ...@@ -18,13 +18,13 @@ namespace aarch64 {
/* ===================== stride-2 algo ===================== */ /* ===================== stride-2 algo ===================== */
class ConvBiasImpl::AlgoF16DirectStride2 final : public AlgoBase { class ConvBiasImpl::AlgoF16DirectStride2 final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
public: public:
AlgoAttribute attribute() const override { AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "ARMV8F16STRD2"; } const char* name() const override { return "ARMV8F16STRD2"; }
bool usable(const NCBKernSizeParam& param, bool usable(
AlgoSelectionStrategy algo_selection_strategy) const override; const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(const NCBKernSizeParam& param) const override; size_t get_workspace(const NCBKernSizeParam& param) const override;
......
...@@ -22,14 +22,14 @@ using FallbackConvBiasImpl = fallback::ConvBiasImpl; ...@@ -22,14 +22,14 @@ using FallbackConvBiasImpl = fallback::ConvBiasImpl;
class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase { class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
public: public:
AlgoAttribute attribute() const override { AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "ARMV8F32STRD2"; } const char* name() const override { return "ARMV8F32STRD2"; }
bool usable(const NCBKernSizeParam& param, bool usable(
AlgoSelectionStrategy algo_selection_strategy) const override; const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(const NCBKernSizeParam& param) const override; size_t get_workspace(const NCBKernSizeParam& param) const override;
......
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
#pragma once #pragma once
#include "src/aarch64/conv_bias/opr_impl.h" #include "src/aarch64/conv_bias/opr_impl.h"
#include "src/fallback/conv_bias/opr_impl.h"
#include "src/common/opr_delegate.h" #include "src/common/opr_delegate.h"
#include "src/fallback/conv_bias/opr_impl.h"
namespace megdnn { namespace megdnn {
namespace aarch64 { namespace aarch64 {
...@@ -25,18 +25,16 @@ class ConvBiasImpl::AlgoS8MatrixMul final : public AlgoBase { ...@@ -25,18 +25,16 @@ class ConvBiasImpl::AlgoS8MatrixMul final : public AlgoBase {
static void kimpl(const NCBKernParam& param, const NCBKernIndex& ncb_index); static void kimpl(const NCBKernParam& param, const NCBKernIndex& ncb_index);
public: public:
AlgoAttribute attribute() const override { AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "S8MATMUL"; } const char* name() const override { return "S8MATMUL"; }
bool usable(const NCBKernSizeParam& param, bool usable(
AlgoSelectionStrategy algo_selection_strategy) const override; const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(const NCBKernSizeParam& param) const override { size_t get_workspace(const NCBKernSizeParam& param) const override {
return get_bundle(param).total_size_in_bytes(); return get_bundle(param).total_size_in_bytes();
} }
SmallVector<NCBKern> dispatch_kerns( SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam& param) const override {
const NCBKernSizeParam& param) const override {
size_t group = param.filter_meta.group; size_t group = param.filter_meta.group;
return {{kimpl, {group, 1_z, 1_z}}}; return {{kimpl, {group, 1_z, 1_z}}};
} }
......
...@@ -9,8 +9,8 @@ ...@@ -9,8 +9,8 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */
#pragma once #pragma once
#include "src/common/utils.h"
#include "src/arm_common/conv_bias/opr_impl.h" #include "src/arm_common/conv_bias/opr_impl.h"
#include "src/common/utils.h"
namespace megdnn { namespace megdnn {
namespace aarch64 { namespace aarch64 {
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
...@@ -17,8 +17,8 @@ namespace megdnn { ...@@ -17,8 +17,8 @@ namespace megdnn {
namespace aarch64 { namespace aarch64 {
namespace matmul { namespace matmul {
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 8, 8, 8, false, true, MEGDNN_REG_GEMM_STRATEGY(
gemm_s4x4x16_s4_8x8x8); dt_int8, dt_int16, dt_int16, 8, 8, 8, false, true, gemm_s4x4x16_s4_8x8x8);
} // namespace matmul } // namespace matmul
} // namespace aarch64 } // namespace aarch64
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册