From a5fad7d07ca4fa916d9d93aaaee9ce85c2bf56c3 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 10 Sep 2020 13:31:57 +0800 Subject: [PATCH] feat(dnn): add compile for riscv64 GitOrigin-RevId: fa0c1635273339702cabcf1fbbe8a53636cc56ab --- CMakeLists.txt | 7 ++ dnn/src/common/postprocess.h | 21 +++++ dnn/src/common/postprocess_helper.h | 80 +++++++++++++++++++ dnn/src/common/relayout_helper.h | 9 ++- dnn/src/fallback/conv_bias/common.h | 11 +-- dnn/src/fallback/conv_bias/conv1x1/algos.cpp | 5 +- .../conv_bias/conv1x1/algos_conv1x1_gemv.cpp | 13 +-- .../conv_bias/conv1x1/conv1x1_strategy.h | 2 + .../fallback/conv_bias/im2col/strategy_base.h | 2 + .../im2col/strategy_default_nchw44.cpp | 6 +- dnn/src/fallback/conv_bias/opr_impl.cpp | 6 ++ dnn/src/fallback/conv_bias/opr_impl.h | 1 + dnn/test/common/mask_conv.h | 3 +- dnn/test/cpu/mask_conv.cpp | 2 + dnn/test/cpu/matrix_mul.cpp | 2 + dnn/test/cpu/relayout.cpp | 3 +- dnn/test/cuda/mask_conv.cpp | 3 +- dnn/test/fallback/elemwise.cpp | 3 +- dnn/test/fallback/elemwise_multi_type.cpp | 4 +- dnn/test/fallback/relayout.cpp | 3 +- dnn/test/fallback/roi_copy.cpp | 3 +- toolchains/riscv64-linux-gnu.toolchain.cmake | 18 +++++ 22 files changed, 179 insertions(+), 28 deletions(-) create mode 100644 dnn/src/common/postprocess.h create mode 100644 dnn/src/common/postprocess_helper.h create mode 100644 toolchains/riscv64-linux-gnu.toolchain.cmake diff --git a/CMakeLists.txt b/CMakeLists.txt index 65a00af2..1404220a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -117,6 +117,8 @@ if(CMAKE_TOOLCHAIN_FILE) else() message(FATAL_ERROR "Unsupported IOS_ARCH.") endif() + elseif(RISCV_TOOLCHAIN_ROOT) + set(MGE_ARCH "riscv64") elseif(NOT "${ARM_CROSS_BUILD_ARCH}" STREQUAL "") set(MGE_ARCH ${ARM_CROSS_BUILD_ARCH}) else() @@ -664,6 +666,11 @@ if(MGE_ARCH STREQUAL "aarch64") endif() +if(MGE_ARCH STREQUAL "riscv64") + set(MEGDNN_RISCV64 1) + set(MEGDNN_64_BIT 1) +endif() + set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${MARCH}") set(MGB_ENABLE_IMPERATIVE ${MGE_BUILD_IMPERATIVE_RT}) diff --git a/dnn/src/common/postprocess.h b/dnn/src/common/postprocess.h new file mode 100644 index 00000000..6438a4d6 --- /dev/null +++ b/dnn/src/common/postprocess.h @@ -0,0 +1,21 @@ +/** + * \file dnn/src/common/postprocess.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +namespace megdnn { +enum class PostprocessMode : uint8_t { + FLOAT = 0, ///< support all biasmode and no_nonlinemode + NO_PROCESS, ///< support non bias and identity + QUANTIZED, ///< support NOBIAS ,BROADCAST_CHANNEL_BIAS and relu hswish + ///< identify nonline mode + ADD_BIAS, ///< only add bias +}; +} \ No newline at end of file diff --git a/dnn/src/common/postprocess_helper.h b/dnn/src/common/postprocess_helper.h new file mode 100644 index 00000000..70725083 --- /dev/null +++ b/dnn/src/common/postprocess_helper.h @@ -0,0 +1,80 @@ +/** + * \file dnn/src/common/postprocess_helper.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once + +#include "megdnn/basic_types.h" +#include "midout.h" +#include "src/common/postprocess.h" + +namespace { +#define POST_PROCESS_UNUSED_VAR() \ + MEGDNN_MARK_USED_VAR(conv_dst_ptr); \ + MEGDNN_MARK_USED_VAR(bias_ptr); \ + MEGDNN_MARK_USED_VAR(dst_ptr); \ + MEGDNN_MARK_USED_VAR(bias_mode); \ + MEGDNN_MARK_USED_VAR(nonlineMode); \ + MEGDNN_MARK_USED_VAR(bias_type); \ + MEGDNN_MARK_USED_VAR(dst_type); \ + MEGDNN_MARK_USED_VAR(N); \ + MEGDNN_MARK_USED_VAR(OC); \ + MEGDNN_MARK_USED_VAR(OH); \ + MEGDNN_MARK_USED_VAR(OW); \ + MEGDNN_MARK_USED_VAR(pack_oc_size) + +template +struct PostProcess { + static void run(void* conv_dst_ptr, const void* bias_ptr, void* dst_ptr, + megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, + megdnn::DType bias_type, megdnn::DType dst_type, size_t N, + size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) { + POST_PROCESS_UNUSED_VAR(); + megdnn_throw("not impl PostProcess"); + } +}; + +template +struct PostProcess { + static void run(void* conv_dst_ptr, void* bias_ptr, void* dst_ptr, + megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, + megdnn::DType bias_type, megdnn::DType dst_type, size_t N, + size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) { + POST_PROCESS_UNUSED_VAR(); + megdnn_throw("not impl PostProcess"); + } +}; + +template +struct PostProcess { + static void run(void* conv_dst_ptr, const void* bias_ptr, void* dst_ptr, + megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, + megdnn::DType bias_type, megdnn::DType dst_type, size_t N, + size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) { + POST_PROCESS_UNUSED_VAR(); + megdnn_throw("not impl PostProcess"); + } +}; + +template +struct PostProcess { + static void run(void* conv_dst_ptr, void* bias_ptr, void* dst_ptr, + megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, + megdnn::DType bias_type, megdnn::DType dst_type, size_t N, + size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) { + POST_PROCESS_UNUSED_VAR(); + megdnn_throw("not impl PostProcess"); + } +}; + +} // namespace diff --git a/dnn/src/common/relayout_helper.h b/dnn/src/common/relayout_helper.h index 129a923b..f00b7bae 100644 --- a/dnn/src/common/relayout_helper.h +++ b/dnn/src/common/relayout_helper.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once @@ -42,8 +43,12 @@ namespace transpose_fallback { #if MEGDNN_X86 constexpr size_t BLOCK_LINE_SIZE_BYTES = 64; -#elif MEGDNN_AARCH64 || MEGDNN_ARMV7 +#elif MEGDNN_AARCH64 || MEGDNN_ARMV7 /*BEGIN-INLINE-INTERNAL*/ || \ + MEGDNN_MIPS /*END-INLINE-INTERNAL*/ constexpr size_t BLOCK_LINE_SIZE_BYTES = 32; +#elif MEGDNN_RISCV64 +//! ref U54-MC arch +constexpr size_t BLOCK_LINE_SIZE_BYTES = 64; #else #error "unknown megdnn arch" #endif diff --git a/dnn/src/fallback/conv_bias/common.h b/dnn/src/fallback/conv_bias/common.h index 60f18a6d..75cc155c 100644 --- a/dnn/src/fallback/conv_bias/common.h +++ b/dnn/src/fallback/conv_bias/common.h @@ -6,12 +6,14 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once #include #include "megdnn/oprs.h" +#include "src/common/postprocess.h" #include "src/common/utils.h" namespace megdnn { @@ -157,13 +159,6 @@ private: \ mutable std::string m_name; \ uint32_t m_tile_size; -enum class PostprocessMode : uint8_t { - FLOAT = 0, ///< support all biasmode and no_nonlinemode - NO_PROCESS, ///< support non bias and identity - QUANTIZED, ///< support NOBIAS ,BROADCAST_CHANNEL_BIAS and relu hswish - ///< identify nonline mode - ADD_BIAS, ///< only add bias -}; } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/conv_bias/conv1x1/algos.cpp b/dnn/src/fallback/conv_bias/conv1x1/algos.cpp index 8dea1707..0ff7852b 100644 --- a/dnn/src/fallback/conv_bias/conv1x1/algos.cpp +++ b/dnn/src/fallback/conv_bias/conv1x1/algos.cpp @@ -24,6 +24,8 @@ #include "src/x86/conv_bias/postprocess_helper.h" #elif (MEGDNN_ARMV7 || MEGDNN_AARCH64) #include "src/arm_common/conv_bias/postprocess_helper.h" +#else +#include "src/common/postprocess_helper.h" #endif #include "midout.h" @@ -106,7 +108,7 @@ ConvBiasImpl::AlgoConv1x1::get_kerns_according_packmode( WorkspaceBundle whole_bundle = get_bundle_according_packmode(param); //! NO_PACK not implement get_bundle - WorkspaceBundle matmul_bundle ={nullptr,{}}; + WorkspaceBundle matmul_bundle = {nullptr, {}}; if (pack_mode == MatrixMulImpl::AlgoBase::PackMode::NO_PACK) { matmul_bundle = {nullptr, {0, 0, m_matmul_algo->get_workspace(matmul_param)}}; @@ -281,7 +283,6 @@ bool ConvBiasImpl::AlgoConv1x1::usable(const NCBKernSizeParam& param, return false; } - bool ConvBiasImpl::AlgoConv1x1::is_preferred( const NCBKernSizeParam& param) const { size_t OH = param.osz[0]; diff --git a/dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.cpp b/dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.cpp index 70ec3abe..bd40dd24 100644 --- a/dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.cpp +++ b/dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.cpp @@ -25,9 +25,11 @@ #include "src/x86/conv_bias/postprocess_helper.h" #elif (MEGDNN_ARMV7 || MEGDNN_AARCH64) #include "src/arm_common/conv_bias/postprocess_helper.h" -#include "src/arm_common/matrix_mul/fp32/exec_sgemv.h" #include "src/arm_common/matrix_mul/fp16/hgemv.h" +#include "src/arm_common/matrix_mul/fp32/exec_sgemv.h" #include "src/arm_common/matrix_mul/int8/gemv.h" +#else +#include "src/common/postprocess_helper.h" #endif #include "midout.h" @@ -249,7 +251,7 @@ size_t ConvBiasImpl::AlgoConv1x1Gemv::get_oc_tile_size_heuristic( } size_t ConvBiasImpl::AlgoConv1x1Gemv::get_workspace( - const NCBKernSizeParam& param) const { + const NCBKernSizeParam& param) const { MIDOUT_BEGIN(megdnn_fallback_conv1x1_gemv, midout_iv("AlgoConv1x1Gemv::get_workspace"_hash)) { size_t compt_oc_block_size = get_oc_tile_size_heuristic(param); @@ -335,7 +337,8 @@ ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns( #else #if !MEGDNN_DISABLE_FLOAT16 cb1(param::ConvBias::Format::NCHW, dt_float16, dt_float16, - PostprocessMode::NO_PROCESS, "NCHW::GEMV::FLOAT16_FLOAT16"_hash); + PostprocessMode::NO_PROCESS, + "NCHW::GEMV::FLOAT16_FLOAT16"_hash); #endif #endif cb3(param::ConvBias::Format::NCHW, dt_int8, dt_int32, dt_int32, @@ -361,7 +364,7 @@ ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns( dt_uint8, PostprocessMode::QUANTIZED, "NCHW::GEMV::QUINT8x8x32_QUINT8"_hash); break; - //!no support nchw44 8x8x16 + //! no support nchw44 8x8x16 case param::ConvBias::Format::NCHW44: cb1(param::ConvBias::Format::NCHW44, dt_float32, dt_float32, PostprocessMode::FLOAT, "NCHW44::GEMV::FLOAT"_hash); @@ -377,7 +380,7 @@ ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns( dt_int8, PostprocessMode::QUANTIZED, "NCHW44::GEMV::QINT8x8x32_QINT8"_hash); break; - //!no support nchw44-dot 8x8x16 + //! no support nchw44-dot 8x8x16 case param::ConvBias::Format::NCHW44_DOT: cb3(param::ConvBias::Format::NCHW44_DOT, dt_int8, dt_int32, dt_int32, dt_int8, dt_int32, dt_int32, diff --git a/dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.h b/dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.h index eaf62551..54cbe6e2 100644 --- a/dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.h +++ b/dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.h @@ -19,6 +19,8 @@ #include "src/x86/conv_bias/postprocess_helper.h" #elif (MEGDNN_ARMV7 || MEGDNN_AARCH64) #include "src/arm_common/conv_bias/postprocess_helper.h" +#else +#include "src/common/postprocess_helper.h" #endif namespace megdnn { diff --git a/dnn/src/fallback/conv_bias/im2col/strategy_base.h b/dnn/src/fallback/conv_bias/im2col/strategy_base.h index 7c7279b2..816106b4 100644 --- a/dnn/src/fallback/conv_bias/im2col/strategy_base.h +++ b/dnn/src/fallback/conv_bias/im2col/strategy_base.h @@ -16,6 +16,8 @@ #include "src/x86/conv_bias/postprocess_helper.h" #elif (MEGDNN_ARMV7 || MEGDNN_AARCH64) #include "src/arm_common/conv_bias/postprocess_helper.h" +#else +#include "src/common/postprocess_helper.h" #endif using namespace megdnn; #if MEGDNN_X86 diff --git a/dnn/src/fallback/conv_bias/im2col/strategy_default_nchw44.cpp b/dnn/src/fallback/conv_bias/im2col/strategy_default_nchw44.cpp index ff4eab52..699d6661 100644 --- a/dnn/src/fallback/conv_bias/im2col/strategy_default_nchw44.cpp +++ b/dnn/src/fallback/conv_bias/im2col/strategy_default_nchw44.cpp @@ -12,10 +12,10 @@ #include "src/fallback/convolution/img2col_helper.h" #if MEGDNN_X86 #include "src/x86/conv_bias/postprocess_helper.h" -#endif - -#if (MEGDNN_ARMV7 || MEGDNN_AARCH64) +#elif (MEGDNN_ARMV7 || MEGDNN_AARCH64) #include "src/arm_common/conv_bias/postprocess_helper.h" +#else +#include "src/common/postprocess_helper.h" #endif using namespace megdnn; diff --git a/dnn/src/fallback/conv_bias/opr_impl.cpp b/dnn/src/fallback/conv_bias/opr_impl.cpp index 7f7b72f0..45b09bc4 100644 --- a/dnn/src/fallback/conv_bias/opr_impl.cpp +++ b/dnn/src/fallback/conv_bias/opr_impl.cpp @@ -74,6 +74,10 @@ public: } #endif +//! As we haven't riscv64 postprocess yet, im2col and conv1x1 can not pass ci +//! test. so we just disable all im2col and conv1x1 in riscv64 +//! FIXME: remove it when impl postprocess for riscv64 +#if !MEGDNN_RISCV64 for (size_t ohw_tile_size : {192, 384, 96, 48, 24}) { refhold.emplace_back(new AlgoIm2col( static_cast(algo), @@ -86,6 +90,8 @@ public: oc_tile_size)); all_algos.emplace_back(refhold.back().get()); } +#endif + #if 0 //! As these algos maybe very slow, it will make fastrun search slow, so //! we disable it, but for the test of strategyhelper, we just keep it. diff --git a/dnn/src/fallback/conv_bias/opr_impl.h b/dnn/src/fallback/conv_bias/opr_impl.h index b228b5ab..082bb2e5 100644 --- a/dnn/src/fallback/conv_bias/opr_impl.h +++ b/dnn/src/fallback/conv_bias/opr_impl.h @@ -50,6 +50,7 @@ public: _megdnn_tensor_in bias, _megdnn_tensor_in z, _megdnn_tensor_out dst, const PreprocessedFilter*, _megdnn_workspace workspace) override; + bool is_thread_safe() const override { return true; } void exec_preprocess(const TensorLayout& src_layout, _megdnn_tensor_in filter, diff --git a/dnn/test/common/mask_conv.h b/dnn/test/common/mask_conv.h index 4bd9c4bb..0ab37bd9 100644 --- a/dnn/test/common/mask_conv.h +++ b/dnn/test/common/mask_conv.h @@ -74,7 +74,7 @@ void mask_conv_test(Handle* handle) { arg[8], arg[9], arg[10], arg[11], arg[12]); } } - +#if MEGDNN_WITH_BENCHMARK void mask_conv_benchmark(Handle* handle) { auto benchmark = [&](size_t N, size_t IC, size_t OC, size_t IH, size_t IW, size_t FH, size_t FW, size_t SH, size_t SW, size_t PH, @@ -113,5 +113,6 @@ void mask_conv_benchmark(Handle* handle) { arg[7], arg[8], arg[9], arg[10], arg[11], arg[12]); } } +#endif } // namespace diff --git a/dnn/test/cpu/mask_conv.cpp b/dnn/test/cpu/mask_conv.cpp index 16e07c98..c215d515 100644 --- a/dnn/test/cpu/mask_conv.cpp +++ b/dnn/test/cpu/mask_conv.cpp @@ -25,9 +25,11 @@ TEST_F(CPU, MASK_CONV) { mask_conv_test(handle()); } +#if MEGDNN_WITH_BENCHMARK TEST_F(CPU, MASK_CONV_BENCHMARK) { mask_conv_benchmark(handle()); } +#endif TEST_F(CPU, MASK_PROPAGATE) { param::MaskPropagate mask_param; diff --git a/dnn/test/cpu/matrix_mul.cpp b/dnn/test/cpu/matrix_mul.cpp index 11d6db66..cdec9092 100644 --- a/dnn/test/cpu/matrix_mul.cpp +++ b/dnn/test/cpu/matrix_mul.cpp @@ -17,6 +17,7 @@ using namespace megdnn; using namespace test; +#if MEGDNN_WITH_BENCHMARK namespace { void sgemm_sgemv_like(const float* __restrict A, const float* __restrict B, @@ -70,6 +71,7 @@ TEST_F(CPU, BENCHMARK_MATRIX_MUL) { run(m, nk, nk); } } +#endif TEST_F(CPU, MATRIX_MUL) { matrix_mul::check_matrix_mul(dtype::Float32{}, dtype::Float32{}, diff --git a/dnn/test/cpu/relayout.cpp b/dnn/test/cpu/relayout.cpp index 6bc16c79..05087428 100644 --- a/dnn/test/cpu/relayout.cpp +++ b/dnn/test/cpu/relayout.cpp @@ -31,6 +31,7 @@ TYPED_TEST(CPU_RELAYOUT, run) { } } +#if MEGDNN_WITH_BENCHMARK TEST_F(CPU, BENCHMARK_RELAYOUT_CV) { relayout::run_cv_benchmark(handle()); } @@ -55,6 +56,6 @@ TEST_F(CPU, BENCHMARK_RELAYOUT) { ASSERT_LE(cpu_time * 5, naive_time); } } - +#endif // vim: syntax=cpp.doxygen diff --git a/dnn/test/cuda/mask_conv.cpp b/dnn/test/cuda/mask_conv.cpp index 38277ee0..468a7e80 100644 --- a/dnn/test/cuda/mask_conv.cpp +++ b/dnn/test/cuda/mask_conv.cpp @@ -22,10 +22,11 @@ using namespace test; TEST_F(CUDA, MASK_CONV) { mask_conv_test(handle_cuda()); } - +#if MEGDNN_WITH_BENCHMARK TEST_F(CUDA, MASK_CONV_BENCHMARK) { mask_conv_benchmark(handle_cuda()); } +#endif TEST_F(CUDA, MASK_PROPAGATE) { Checker checker(handle_cuda()); diff --git a/dnn/test/fallback/elemwise.cpp b/dnn/test/fallback/elemwise.cpp index 47f47ce7..a4bb00a9 100644 --- a/dnn/test/fallback/elemwise.cpp +++ b/dnn/test/fallback/elemwise.cpp @@ -27,7 +27,7 @@ TYPED_TEST_CASE(FALLBACK_ELEMWISE, elemwise::test_types); TYPED_TEST(FALLBACK_ELEMWISE, run) { elemwise::run_test(this->handle()); } - +#if MEGDNN_WITH_BENCHMARK TEST_F(FALLBACK, BENCHMARK_ELEMWISE) { auto naive_handle = create_cpu_handle(2); auto run = [&](const TensorShape &shp0, const TensorShape &shp1) { @@ -72,6 +72,7 @@ TEST_F(FALLBACK, BENCHMARK_ELEMWISE) { // non-contig, fallback to naive run({1024, 1024, 32}, {1024, 1, 32}); } +#endif // vim: syntax=cpp.doxygen diff --git a/dnn/test/fallback/elemwise_multi_type.cpp b/dnn/test/fallback/elemwise_multi_type.cpp index 0d6ac7b0..4fb18606 100644 --- a/dnn/test/fallback/elemwise_multi_type.cpp +++ b/dnn/test/fallback/elemwise_multi_type.cpp @@ -25,7 +25,7 @@ TYPED_TEST_CASE(FALLBACK_ELEMWISE_MULTI_TYPE, elemwise_multi_type::test_types); TYPED_TEST(FALLBACK_ELEMWISE_MULTI_TYPE, run) { elemwise_multi_type::run_test(this->handle()); } - +#if MEGDNN_WITH_BENCHMARK TEST_F(FALLBACK, ELEMWISE_MULTI_TYPE_BENCHMARK_FMA3_INT16x32x32x32) { Benchmarker bench{handle()}; bench.set_param({ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16x32x32x32}); @@ -64,5 +64,5 @@ TEST_F(FALLBACK, ELEMWISE_MULTI_TYPE_BENCHMARK_FMA3_IXxf32xf32xI8) { (1024.0 * 1024.0 * 1024.0)); } } - +#endif // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/test/fallback/relayout.cpp b/dnn/test/fallback/relayout.cpp index c09683bf..ac0990a4 100644 --- a/dnn/test/fallback/relayout.cpp +++ b/dnn/test/fallback/relayout.cpp @@ -31,7 +31,7 @@ TYPED_TEST(FALLBACK_RELAYOUT, run) { relayout::run_test(this->handle()); } } - +#if MEGDNN_WITH_BENCHMARK TEST_F(FALLBACK, BENCHMARK_RELAYOUT_CV) { relayout::run_cv_benchmark(handle()); } @@ -160,5 +160,6 @@ TEST_F(FALLBACK, BENCHMARK_RELAYOUT) { } } } +#endif // vim: syntax=cpp.doxygen diff --git a/dnn/test/fallback/roi_copy.cpp b/dnn/test/fallback/roi_copy.cpp index 63055db7..e383d86c 100644 --- a/dnn/test/fallback/roi_copy.cpp +++ b/dnn/test/fallback/roi_copy.cpp @@ -34,7 +34,7 @@ TEST_F(FALLBACK, ROICOPY) { } } - +#if MEGDNN_WITH_BENCHMARK TEST_F(FALLBACK, BENCHMARK_ROICOPY) { auto run = [&](const TensorShapeArray& shapes) { Benchmarker benchmarker(handle()); @@ -62,6 +62,7 @@ TEST_F(FALLBACK, BENCHMARK_ROICOPY) { run(shapes); } +#endif } // namespace test diff --git a/toolchains/riscv64-linux-gnu.toolchain.cmake b/toolchains/riscv64-linux-gnu.toolchain.cmake new file mode 100644 index 00000000..d90ad0c4 --- /dev/null +++ b/toolchains/riscv64-linux-gnu.toolchain.cmake @@ -0,0 +1,18 @@ +set(CMAKE_SYSTEM_NAME Linux) +set(CMAKE_SYSTEM_PROCESSOR riscv64) +set(RISCV_CROSS_BUILD_ARCH riscv64) + +if(DEFINED ENV{RISCV_TOOLCHAIN_ROOT}) + file(TO_CMAKE_PATH $ENV{RISCV_TOOLCHAIN_ROOT} RISCV_TOOLCHAIN_ROOT) +else() + message(FATAL_ERROR "RISCV_TOOLCHAIN_ROOT env must be defined") +endif() + +set(RISCV_TOOLCHAIN_ROOT ${RISCV_TOOLCHAIN_ROOT} CACHE STRING "root path to riscv toolchain") + +set(CMAKE_C_COMPILER "${RISCV_TOOLCHAIN_ROOT}/bin/riscv64-unknown-linux-gnu-gcc") +set(CMAKE_CXX_COMPILER "${RISCV_TOOLCHAIN_ROOT}/bin/riscv64-unknown-linux-gnu-g++") +set(CMAKE_FIND_ROOT_PATH "${RISCV_TOOLCHAIN_ROOT}/riscv64-unknown-linux-gnu") +set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) +set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY) \ No newline at end of file -- GitLab