提交 ed7fa104 编写于 作者: M Megvii Engine Team

feat(fallback): move direct multi_thread_common helper to fallback

GitOrigin-RevId: 27ed93e4c1d56d550c006a470bb4c95ee5ff2032
上级 8871ad74
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
#include "src/aarch64/conv_bias/fp16/algos.h" #include "src/aarch64/conv_bias/fp16/algos.h"
#include "src/aarch64/conv_bias/fp16/stride2_kern.h" #include "src/aarch64/conv_bias/fp16/stride2_kern.h"
#include "src/arm_common/conv_bias/direct/multi_thread_common.h"
#include "src/arm_common/conv_bias/postprocess_helper.h" #include "src/arm_common/conv_bias/postprocess_helper.h"
#include "src/fallback/conv_bias/direct/multi_thread_common.h"
using namespace megdnn; using namespace megdnn;
using namespace aarch64; using namespace aarch64;
...@@ -43,7 +43,7 @@ size_t ConvBiasImpl::AlgoF16DirectStride2::get_workspace( ...@@ -43,7 +43,7 @@ size_t ConvBiasImpl::AlgoF16DirectStride2::get_workspace(
const NCBKernSizeParam& param) const { const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 1) { MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 1) {
bool large_group = param.filter_meta.group >= param.nr_threads; bool large_group = param.filter_meta.group >= param.nr_threads;
auto wbundle = arm_common::MultithreadDirectConvCommon< auto wbundle = fallback::MultithreadDirectConvCommon<
dt_float16, __fp16>::get_bundle_stride(param, large_group); dt_float16, __fp16>::get_bundle_stride(param, large_group);
return wbundle.total_size_in_bytes(); return wbundle.total_size_in_bytes();
} }
...@@ -83,7 +83,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16DirectStride2::get_kimpl ...@@ -83,7 +83,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16DirectStride2::get_kimpl
conv = fp16::conv_stride2::do_conv_7x7_stride2; conv = fp16::conv_stride2::do_conv_7x7_stride2;
} }
WorkspaceBundle bundle = arm_common::MultithreadDirectConvCommon< WorkspaceBundle bundle = fallback::MultithreadDirectConvCommon<
dt_float16, __fp16>::get_bundle_stride(param, large_group); dt_float16, __fp16>::get_bundle_stride(param, large_group);
SmallVector<NCBKern> ret_kerns; SmallVector<NCBKern> ret_kerns;
...@@ -98,13 +98,13 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16DirectStride2::get_kimpl ...@@ -98,13 +98,13 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16DirectStride2::get_kimpl
size_t OC = fm.ocpg; size_t OC = fm.ocpg;
bundle.set(kern_param.workspace_ptr); bundle.set(kern_param.workspace_ptr);
for (size_t ic = 0; ic < IC; ic++) { for (size_t ic = 0; ic < IC; ic++) {
arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>:: fallback::MultithreadDirectConvCommon<dt_float16, __fp16>::
copy_padding_kern_stride( copy_padding_kern_stride(
bundle, kern_param, ncb_index, bundle, kern_param, ncb_index,
{ncb_index.thread_id, 0, ic}); {ncb_index.thread_id, 0, ic});
} }
for (size_t oc = 0; oc < OC; oc++) { for (size_t oc = 0; oc < OC; oc++) {
arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>:: fallback::MultithreadDirectConvCommon<dt_float16, __fp16>::
do_conv_kern_stride( do_conv_kern_stride(
bundle, kern_param, ncb_index, conv, bundle, kern_param, ncb_index, conv,
{ncb_index.thread_id, 0, oc}); {ncb_index.thread_id, 0, oc});
...@@ -116,7 +116,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16DirectStride2::get_kimpl ...@@ -116,7 +116,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16DirectStride2::get_kimpl
const NCBKernParam& kern_param, const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable { const NCBKernIndex& ncb_index) mutable {
bundle.set(kern_param.workspace_ptr); bundle.set(kern_param.workspace_ptr);
arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>:: fallback::MultithreadDirectConvCommon<dt_float16, __fp16>::
copy_padding_kern_stride( copy_padding_kern_stride(
bundle, kern_param, ncb_index, ncb_index.ndrange_id); bundle, kern_param, ncb_index, ncb_index.ndrange_id);
}; };
...@@ -125,7 +125,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16DirectStride2::get_kimpl ...@@ -125,7 +125,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16DirectStride2::get_kimpl
const NCBKernParam& kern_param, const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable { const NCBKernIndex& ncb_index) mutable {
bundle.set(kern_param.workspace_ptr); bundle.set(kern_param.workspace_ptr);
arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>:: fallback::MultithreadDirectConvCommon<dt_float16, __fp16>::
do_conv_kern_stride( do_conv_kern_stride(
bundle, kern_param, ncb_index, conv, ncb_index.ndrange_id); bundle, kern_param, ncb_index, conv, ncb_index.ndrange_id);
}; };
......
...@@ -11,9 +11,9 @@ ...@@ -11,9 +11,9 @@
#include "src/aarch64/conv_bias/fp32/algos.h" #include "src/aarch64/conv_bias/fp32/algos.h"
#include "src/aarch64/conv_bias/fp32/stride2_kern.h" #include "src/aarch64/conv_bias/fp32/stride2_kern.h"
#include "src/arm_common/conv_bias/direct/multi_thread_common.h"
#include "src/arm_common/conv_bias/postprocess_helper.h" #include "src/arm_common/conv_bias/postprocess_helper.h"
#include "src/fallback/conv_bias/common.h" #include "src/fallback/conv_bias/common.h"
#include "src/fallback/conv_bias/direct/multi_thread_common.h"
#include "midout.h" #include "midout.h"
...@@ -42,8 +42,9 @@ size_t ConvBiasImpl::AlgoF32DirectStride2::get_workspace( ...@@ -42,8 +42,9 @@ size_t ConvBiasImpl::AlgoF32DirectStride2::get_workspace(
const NCBKernSizeParam& param) const { const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 1) { MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 1) {
bool large_group = param.filter_meta.group >= param.nr_threads; bool large_group = param.filter_meta.group >= param.nr_threads;
auto wbundle = arm_common::MultithreadDirectConvCommon< auto wbundle =
float, float>::get_bundle_stride(param, large_group); fallback::MultithreadDirectConvCommon<float, float>::get_bundle_stride(
param, large_group);
return wbundle.total_size_in_bytes(); return wbundle.total_size_in_bytes();
} }
MIDOUT_END(); MIDOUT_END();
...@@ -82,7 +83,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride2::get_kimpl ...@@ -82,7 +83,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride2::get_kimpl
} }
WorkspaceBundle bundle = WorkspaceBundle bundle =
arm_common::MultithreadDirectConvCommon<float, float>::get_bundle_stride( fallback::MultithreadDirectConvCommon<float, float>::get_bundle_stride(
param, large_group); param, large_group);
SmallVector<NCBKern> ret_kerns; SmallVector<NCBKern> ret_kerns;
...@@ -97,13 +98,13 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride2::get_kimpl ...@@ -97,13 +98,13 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride2::get_kimpl
size_t OC = fm.ocpg; size_t OC = fm.ocpg;
bundle.set(kern_param.workspace_ptr); bundle.set(kern_param.workspace_ptr);
for (size_t ic = 0; ic < IC; ic++) { for (size_t ic = 0; ic < IC; ic++) {
arm_common::MultithreadDirectConvCommon<float, float>:: fallback::MultithreadDirectConvCommon<float, float>::
copy_padding_kern_stride( copy_padding_kern_stride(
bundle, kern_param, ncb_index, bundle, kern_param, ncb_index,
{ncb_index.thread_id, 0, ic}); {ncb_index.thread_id, 0, ic});
} }
for (size_t oc = 0; oc < OC; oc++) { for (size_t oc = 0; oc < OC; oc++) {
arm_common::MultithreadDirectConvCommon<float, float>:: fallback::MultithreadDirectConvCommon<float, float>::
do_conv_kern_stride( do_conv_kern_stride(
bundle, kern_param, ncb_index, conv, bundle, kern_param, ncb_index, conv,
{ncb_index.thread_id, 0, oc}); {ncb_index.thread_id, 0, oc});
...@@ -115,7 +116,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride2::get_kimpl ...@@ -115,7 +116,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride2::get_kimpl
const NCBKernParam& kern_param, const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable { const NCBKernIndex& ncb_index) mutable {
bundle.set(kern_param.workspace_ptr); bundle.set(kern_param.workspace_ptr);
arm_common::MultithreadDirectConvCommon<float, float>:: fallback::MultithreadDirectConvCommon<float, float>::
copy_padding_kern_stride( copy_padding_kern_stride(
bundle, kern_param, ncb_index, ncb_index.ndrange_id); bundle, kern_param, ncb_index, ncb_index.ndrange_id);
}; };
...@@ -124,7 +125,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride2::get_kimpl ...@@ -124,7 +125,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride2::get_kimpl
const NCBKernParam& kern_param, const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable { const NCBKernIndex& ncb_index) mutable {
bundle.set(kern_param.workspace_ptr); bundle.set(kern_param.workspace_ptr);
arm_common::MultithreadDirectConvCommon<float, float>::do_conv_kern_stride( fallback::MultithreadDirectConvCommon<float, float>::do_conv_kern_stride(
bundle, kern_param, ncb_index, conv, ncb_index.ndrange_id); bundle, kern_param, ncb_index, conv, ncb_index.ndrange_id);
}; };
ret_kerns.push_back({do_conv, {group, N, OC}}); ret_kerns.push_back({do_conv, {group, N, OC}});
......
...@@ -10,7 +10,6 @@ ...@@ -10,7 +10,6 @@
*/ */
#include "src/arm_common/conv_bias/f16/algos.h" #include "src/arm_common/conv_bias/f16/algos.h"
#include "src/arm_common/conv_bias/direct/multi_thread_common.h"
#include "src/arm_common/conv_bias/f16/direct.h" #include "src/arm_common/conv_bias/f16/direct.h"
#include "src/arm_common/conv_bias/f16/do_conv_stride1.h" #include "src/arm_common/conv_bias/f16/do_conv_stride1.h"
#include "src/arm_common/conv_bias/f16/strategy.h" #include "src/arm_common/conv_bias/f16/strategy.h"
...@@ -18,6 +17,7 @@ ...@@ -18,6 +17,7 @@
#include "src/arm_common/conv_bias/postprocess_helper.h" #include "src/arm_common/conv_bias/postprocess_helper.h"
#include "src/common/opr_delegate.h" #include "src/common/opr_delegate.h"
#include "src/fallback/conv_bias/common.h" #include "src/fallback/conv_bias/common.h"
#include "src/fallback/conv_bias/direct/multi_thread_common.h"
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#include "midout.h" #include "midout.h"
MIDOUT_DECL(megdnn_arm_common_winograd_fp16) MIDOUT_DECL(megdnn_arm_common_winograd_fp16)
...@@ -187,8 +187,9 @@ bool ConvBiasImpl::AlgoF16Direct::usable( ...@@ -187,8 +187,9 @@ bool ConvBiasImpl::AlgoF16Direct::usable(
size_t ConvBiasImpl::AlgoF16Direct::get_workspace(const NCBKernSizeParam& param) const { size_t ConvBiasImpl::AlgoF16Direct::get_workspace(const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 0, 1) { MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 0, 1) {
bool large_group = param.filter_meta.group >= param.nr_threads; bool large_group = param.filter_meta.group >= param.nr_threads;
auto wbundle = MultithreadDirectConvCommon<dt_float16, __fp16>::get_bundle( auto wbundle =
param, large_group); fallback::MultithreadDirectConvCommon<dt_float16, __fp16>::get_bundle(
param, large_group);
return wbundle.total_size_in_bytes(); return wbundle.total_size_in_bytes();
} }
MIDOUT_END(); MIDOUT_END();
...@@ -204,7 +205,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16Direct::get_kimpls( ...@@ -204,7 +205,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16Direct::get_kimpls(
size_t group = fm.group; size_t group = fm.group;
bool large_group = group >= param.nr_threads; bool large_group = group >= param.nr_threads;
WorkspaceBundle bundle = WorkspaceBundle bundle =
MultithreadDirectConvCommon<dt_float16, __fp16>::get_bundle( fallback::MultithreadDirectConvCommon<dt_float16, __fp16>::get_bundle(
param, large_group); param, large_group);
SmallVector<NCBKern> ret_kerns; SmallVector<NCBKern> ret_kerns;
//! When group >= nr_threads, treat it as large_group, each thread process //! When group >= nr_threads, treat it as large_group, each thread process
...@@ -220,17 +221,20 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16Direct::get_kimpls( ...@@ -220,17 +221,20 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16Direct::get_kimpls(
bundle.set(kern_param.workspace_ptr); bundle.set(kern_param.workspace_ptr);
if (fm.should_flip) { if (fm.should_flip) {
for (size_t oc = 0; oc < OC; oc++) { for (size_t oc = 0; oc < OC; oc++) {
MultithreadDirectConvCommon<dt_float16, __fp16>::weight_flip_kern( fallback::MultithreadDirectConvCommon<dt_float16, __fp16>::
bundle, kern_param, ncb_index, weight_flip_kern(
{ncb_index.thread_id, 0, oc}); bundle, kern_param, ncb_index,
{ncb_index.thread_id, 0, oc});
} }
} }
for (size_t ic = 0; ic < IC; ic++) { for (size_t ic = 0; ic < IC; ic++) {
MultithreadDirectConvCommon<dt_float16, __fp16>::copy_padding_kern( fallback::MultithreadDirectConvCommon<dt_float16, __fp16>::
bundle, kern_param, ncb_index, {ncb_index.thread_id, 0, ic}); copy_padding_kern(
bundle, kern_param, ncb_index,
{ncb_index.thread_id, 0, ic});
} }
for (size_t oc = 0; oc < OC; oc++) { for (size_t oc = 0; oc < OC; oc++) {
MultithreadDirectConvCommon<dt_float16, __fp16>::do_conv_kern( fallback::MultithreadDirectConvCommon<dt_float16, __fp16>::do_conv_kern(
bundle, kern_param, ncb_index, fp16::conv_bias::kern_direct_f16, bundle, kern_param, ncb_index, fp16::conv_bias::kern_direct_f16,
{ncb_index.thread_id, 0, oc}); {ncb_index.thread_id, 0, oc});
} }
...@@ -242,8 +246,9 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16Direct::get_kimpls( ...@@ -242,8 +246,9 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16Direct::get_kimpls(
const NCBKernParam& kern_param, const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable { const NCBKernIndex& ncb_index) mutable {
bundle.set(kern_param.workspace_ptr); bundle.set(kern_param.workspace_ptr);
MultithreadDirectConvCommon<dt_float16, __fp16>::weight_flip_kern( fallback::MultithreadDirectConvCommon<dt_float16, __fp16>::
bundle, kern_param, ncb_index, ncb_index.ndrange_id); weight_flip_kern(
bundle, kern_param, ncb_index, ncb_index.ndrange_id);
}; };
ret_kerns.push_back({weight_flip, {group, 1_z, OC}}); ret_kerns.push_back({weight_flip, {group, 1_z, OC}});
} }
...@@ -251,15 +256,16 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16Direct::get_kimpls( ...@@ -251,15 +256,16 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16Direct::get_kimpls(
const NCBKernParam& kern_param, const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable { const NCBKernIndex& ncb_index) mutable {
bundle.set(kern_param.workspace_ptr); bundle.set(kern_param.workspace_ptr);
MultithreadDirectConvCommon<dt_float16, __fp16>::copy_padding_kern( fallback::MultithreadDirectConvCommon<dt_float16, __fp16>::
bundle, kern_param, ncb_index, ncb_index.ndrange_id); copy_padding_kern(
bundle, kern_param, ncb_index, ncb_index.ndrange_id);
}; };
ret_kerns.push_back({copy_padding, {group, N, IC}}); ret_kerns.push_back({copy_padding, {group, N, IC}});
auto do_conv = [bundle]( auto do_conv = [bundle](
const NCBKernParam& kern_param, const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable { const NCBKernIndex& ncb_index) mutable {
bundle.set(kern_param.workspace_ptr); bundle.set(kern_param.workspace_ptr);
MultithreadDirectConvCommon<dt_float16, __fp16>::do_conv_kern( fallback::MultithreadDirectConvCommon<dt_float16, __fp16>::do_conv_kern(
bundle, kern_param, ncb_index, fp16::conv_bias::kern_direct_f16, bundle, kern_param, ncb_index, fp16::conv_bias::kern_direct_f16,
ncb_index.ndrange_id); ncb_index.ndrange_id);
}; };
...@@ -324,9 +330,8 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16DirectStride1::get_kimpl ...@@ -324,9 +330,8 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16DirectStride1::get_kimpl
} }
SWITCH_KERN(); SWITCH_KERN();
WorkspaceBundle bundle = WorkspaceBundle bundle = fallback::MultithreadDirectConvCommon<
MultithreadDirectConvCommon<dt_float16, __fp16>::get_bundle_stride( dt_float16, __fp16>::get_bundle_stride(param, large_group);
param, large_group);
SmallVector<NCBKern> ret_kerns; SmallVector<NCBKern> ret_kerns;
//! When group >= nr_threads, treat it as large_group, each thread process //! When group >= nr_threads, treat it as large_group, each thread process
//! one group for better performance //! one group for better performance
...@@ -340,15 +345,16 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16DirectStride1::get_kimpl ...@@ -340,15 +345,16 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16DirectStride1::get_kimpl
size_t OC = fm.ocpg; size_t OC = fm.ocpg;
bundle.set(kern_param.workspace_ptr); bundle.set(kern_param.workspace_ptr);
for (size_t ic = 0; ic < IC; ic++) { for (size_t ic = 0; ic < IC; ic++) {
MultithreadDirectConvCommon<dt_float16, __fp16>:: fallback::MultithreadDirectConvCommon<dt_float16, __fp16>::
copy_padding_kern_stride( copy_padding_kern_stride(
bundle, kern_param, ncb_index, bundle, kern_param, ncb_index,
{ncb_index.thread_id, 0, ic}); {ncb_index.thread_id, 0, ic});
} }
for (size_t oc = 0; oc < OC; oc++) { for (size_t oc = 0; oc < OC; oc++) {
MultithreadDirectConvCommon<dt_float16, __fp16>::do_conv_kern_stride( fallback::MultithreadDirectConvCommon<dt_float16, __fp16>::
bundle, kern_param, ncb_index, conv_kern_function, do_conv_kern_stride(
{ncb_index.thread_id, 0, oc}); bundle, kern_param, ncb_index, conv_kern_function,
{ncb_index.thread_id, 0, oc});
} }
}; };
ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); ret_kerns.push_back({exec_one_group, {group, N, 1_z}});
...@@ -357,17 +363,19 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16DirectStride1::get_kimpl ...@@ -357,17 +363,19 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16DirectStride1::get_kimpl
const NCBKernParam& kern_param, const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable { const NCBKernIndex& ncb_index) mutable {
bundle.set(kern_param.workspace_ptr); bundle.set(kern_param.workspace_ptr);
MultithreadDirectConvCommon<dt_float16, __fp16>::copy_padding_kern_stride( fallback::MultithreadDirectConvCommon<dt_float16, __fp16>::
bundle, kern_param, ncb_index, ncb_index.ndrange_id); copy_padding_kern_stride(
bundle, kern_param, ncb_index, ncb_index.ndrange_id);
}; };
ret_kerns.push_back({copy_padding, {group, N, IC}}); ret_kerns.push_back({copy_padding, {group, N, IC}});
auto do_conv = [bundle, conv_kern_function]( auto do_conv = [bundle, conv_kern_function](
const NCBKernParam& kern_param, const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable { const NCBKernIndex& ncb_index) mutable {
bundle.set(kern_param.workspace_ptr); bundle.set(kern_param.workspace_ptr);
MultithreadDirectConvCommon<dt_float16, __fp16>::do_conv_kern_stride( fallback::MultithreadDirectConvCommon<dt_float16, __fp16>::
bundle, kern_param, ncb_index, conv_kern_function, do_conv_kern_stride(
ncb_index.ndrange_id); bundle, kern_param, ncb_index, conv_kern_function,
ncb_index.ndrange_id);
}; };
ret_kerns.push_back({do_conv, {group, N, OC}}); ret_kerns.push_back({do_conv, {group, N, OC}});
} }
...@@ -378,9 +386,8 @@ size_t ConvBiasImpl::AlgoF16DirectStride1::get_workspace( ...@@ -378,9 +386,8 @@ size_t ConvBiasImpl::AlgoF16DirectStride1::get_workspace(
const NCBKernSizeParam& param) const { const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 1, 1) { MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 1, 1) {
bool large_group = param.filter_meta.group >= param.nr_threads; bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = auto bundle = fallback::MultithreadDirectConvCommon<
MultithreadDirectConvCommon<dt_float16, __fp16>::get_bundle_stride( dt_float16, __fp16>::get_bundle_stride(param, large_group);
param, large_group);
return bundle.total_size_in_bytes(); return bundle.total_size_in_bytes();
} }
MIDOUT_END(); MIDOUT_END();
......
...@@ -11,7 +11,6 @@ ...@@ -11,7 +11,6 @@
*/ */
#include "src/arm_common/conv_bias/fp32/algos.h" #include "src/arm_common/conv_bias/fp32/algos.h"
#include "src/arm_common/conv_bias/direct/multi_thread_common.h"
#include "src/arm_common/conv_bias/fp32/direct.h" #include "src/arm_common/conv_bias/fp32/direct.h"
#include "src/arm_common/conv_bias/fp32/do_conv_stride1.h" #include "src/arm_common/conv_bias/fp32/do_conv_stride1.h"
#include "src/arm_common/conv_bias/fp32/do_conv_stride2.h" #include "src/arm_common/conv_bias/fp32/do_conv_stride2.h"
...@@ -20,6 +19,7 @@ ...@@ -20,6 +19,7 @@
#include "src/arm_common/conv_bias/postprocess_helper.h" #include "src/arm_common/conv_bias/postprocess_helper.h"
#include "src/common/opr_delegate.h" #include "src/common/opr_delegate.h"
#include "src/fallback/conv_bias/common.h" #include "src/fallback/conv_bias/common.h"
#include "src/fallback/conv_bias/direct/multi_thread_common.h"
#include "midout.h" #include "midout.h"
...@@ -343,7 +343,7 @@ bool ConvBiasImpl::AlgoF32Direct::usable( ...@@ -343,7 +343,7 @@ bool ConvBiasImpl::AlgoF32Direct::usable(
size_t ConvBiasImpl::AlgoF32Direct::get_workspace(const NCBKernSizeParam& param) const { size_t ConvBiasImpl::AlgoF32Direct::get_workspace(const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 0, 1) { MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 0, 1) {
bool large_group = param.filter_meta.group >= param.nr_threads; bool large_group = param.filter_meta.group >= param.nr_threads;
auto wbundle = MultithreadDirectConvCommon<float, float>::get_bundle( auto wbundle = fallback::MultithreadDirectConvCommon<float, float>::get_bundle(
param, large_group); param, large_group);
return wbundle.total_size_in_bytes(); return wbundle.total_size_in_bytes();
} }
...@@ -359,7 +359,8 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32Direct::get_kimpls( ...@@ -359,7 +359,8 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32Direct::get_kimpls(
size_t group = fm.group; size_t group = fm.group;
bool large_group = group >= param.nr_threads; bool large_group = group >= param.nr_threads;
WorkspaceBundle bundle = WorkspaceBundle bundle =
MultithreadDirectConvCommon<float, float>::get_bundle(param, large_group); fallback::MultithreadDirectConvCommon<float, float>::get_bundle(
param, large_group);
SmallVector<NCBKern> ret_kerns; SmallVector<NCBKern> ret_kerns;
//! When group >= nr_threads, treat it as large_group, each thread process //! When group >= nr_threads, treat it as large_group, each thread process
//! one group for better performance //! one group for better performance
...@@ -374,17 +375,18 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32Direct::get_kimpls( ...@@ -374,17 +375,18 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32Direct::get_kimpls(
bundle.set(kern_param.workspace_ptr); bundle.set(kern_param.workspace_ptr);
if (fm.should_flip) { if (fm.should_flip) {
for (size_t oc = 0; oc < OC; oc++) { for (size_t oc = 0; oc < OC; oc++) {
MultithreadDirectConvCommon<float, float>::weight_flip_kern( fallback::MultithreadDirectConvCommon<float, float>::
bundle, kern_param, ncb_index, weight_flip_kern(
{ncb_index.thread_id, 0, oc}); bundle, kern_param, ncb_index,
{ncb_index.thread_id, 0, oc});
} }
} }
for (size_t ic = 0; ic < IC; ic++) { for (size_t ic = 0; ic < IC; ic++) {
MultithreadDirectConvCommon<float, float>::copy_padding_kern( fallback::MultithreadDirectConvCommon<float, float>::copy_padding_kern(
bundle, kern_param, ncb_index, {ncb_index.thread_id, 0, ic}); bundle, kern_param, ncb_index, {ncb_index.thread_id, 0, ic});
} }
for (size_t oc = 0; oc < OC; oc++) { for (size_t oc = 0; oc < OC; oc++) {
MultithreadDirectConvCommon<float, float>::do_conv_kern( fallback::MultithreadDirectConvCommon<float, float>::do_conv_kern(
bundle, kern_param, ncb_index, fp32::conv_bias::kern_direct, bundle, kern_param, ncb_index, fp32::conv_bias::kern_direct,
{ncb_index.thread_id, 0, oc}); {ncb_index.thread_id, 0, oc});
} }
...@@ -396,7 +398,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32Direct::get_kimpls( ...@@ -396,7 +398,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32Direct::get_kimpls(
const NCBKernParam& kern_param, const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable { const NCBKernIndex& ncb_index) mutable {
bundle.set(kern_param.workspace_ptr); bundle.set(kern_param.workspace_ptr);
MultithreadDirectConvCommon<float, float>::weight_flip_kern( fallback::MultithreadDirectConvCommon<float, float>::weight_flip_kern(
bundle, kern_param, ncb_index, ncb_index.ndrange_id); bundle, kern_param, ncb_index, ncb_index.ndrange_id);
}; };
ret_kerns.push_back({weight_flip, {group, 1_z, OC}}); ret_kerns.push_back({weight_flip, {group, 1_z, OC}});
...@@ -405,7 +407,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32Direct::get_kimpls( ...@@ -405,7 +407,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32Direct::get_kimpls(
const NCBKernParam& kern_param, const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable { const NCBKernIndex& ncb_index) mutable {
bundle.set(kern_param.workspace_ptr); bundle.set(kern_param.workspace_ptr);
MultithreadDirectConvCommon<float, float>::copy_padding_kern( fallback::MultithreadDirectConvCommon<float, float>::copy_padding_kern(
bundle, kern_param, ncb_index, ncb_index.ndrange_id); bundle, kern_param, ncb_index, ncb_index.ndrange_id);
}; };
ret_kerns.push_back({copy_padding, {group, N, IC}}); ret_kerns.push_back({copy_padding, {group, N, IC}});
...@@ -413,7 +415,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32Direct::get_kimpls( ...@@ -413,7 +415,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32Direct::get_kimpls(
const NCBKernParam& kern_param, const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable { const NCBKernIndex& ncb_index) mutable {
bundle.set(kern_param.workspace_ptr); bundle.set(kern_param.workspace_ptr);
MultithreadDirectConvCommon<float, float>::do_conv_kern( fallback::MultithreadDirectConvCommon<float, float>::do_conv_kern(
bundle, kern_param, ncb_index, fp32::conv_bias::kern_direct, bundle, kern_param, ncb_index, fp32::conv_bias::kern_direct,
ncb_index.ndrange_id); ncb_index.ndrange_id);
}; };
...@@ -452,8 +454,9 @@ size_t ConvBiasImpl::AlgoF32DirectStride1::get_workspace( ...@@ -452,8 +454,9 @@ size_t ConvBiasImpl::AlgoF32DirectStride1::get_workspace(
const NCBKernSizeParam& param) const { const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 1, 1) { MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 1, 1) {
bool large_group = param.filter_meta.group >= param.nr_threads; bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = MultithreadDirectConvCommon<float, float>::get_bundle_stride( auto bundle =
param, large_group); fallback::MultithreadDirectConvCommon<float, float>::get_bundle_stride(
param, large_group);
return bundle.total_size_in_bytes(); return bundle.total_size_in_bytes();
} }
MIDOUT_END(); MIDOUT_END();
...@@ -492,7 +495,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride1::get_kimpl ...@@ -492,7 +495,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride1::get_kimpl
SWITCH_KERN_STR1(); SWITCH_KERN_STR1();
WorkspaceBundle bundle = WorkspaceBundle bundle =
MultithreadDirectConvCommon<float, float>::get_bundle_stride( fallback::MultithreadDirectConvCommon<float, float>::get_bundle_stride(
param, large_group); param, large_group);
SmallVector<NCBKern> ret_kerns; SmallVector<NCBKern> ret_kerns;
//! When group >= nr_threads, treat it as large_group, each thread process //! When group >= nr_threads, treat it as large_group, each thread process
...@@ -507,13 +510,16 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride1::get_kimpl ...@@ -507,13 +510,16 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride1::get_kimpl
size_t OC = fm.ocpg; size_t OC = fm.ocpg;
bundle.set(kern_param.workspace_ptr); bundle.set(kern_param.workspace_ptr);
for (size_t ic = 0; ic < IC; ic++) { for (size_t ic = 0; ic < IC; ic++) {
MultithreadDirectConvCommon<float, float>::copy_padding_kern_stride( fallback::MultithreadDirectConvCommon<float, float>::
bundle, kern_param, ncb_index, {ncb_index.thread_id, 0, ic}); copy_padding_kern_stride(
bundle, kern_param, ncb_index,
{ncb_index.thread_id, 0, ic});
} }
for (size_t oc = 0; oc < OC; oc++) { for (size_t oc = 0; oc < OC; oc++) {
MultithreadDirectConvCommon<float, float>::do_conv_kern_stride( fallback::MultithreadDirectConvCommon<float, float>::
bundle, kern_param, ncb_index, conv_kern_function, do_conv_kern_stride(
{ncb_index.thread_id, 0, oc}); bundle, kern_param, ncb_index, conv_kern_function,
{ncb_index.thread_id, 0, oc});
} }
}; };
ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); ret_kerns.push_back({exec_one_group, {group, N, 1_z}});
...@@ -522,15 +528,16 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride1::get_kimpl ...@@ -522,15 +528,16 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride1::get_kimpl
const NCBKernParam& kern_param, const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable { const NCBKernIndex& ncb_index) mutable {
bundle.set(kern_param.workspace_ptr); bundle.set(kern_param.workspace_ptr);
MultithreadDirectConvCommon<float, float>::copy_padding_kern_stride( fallback::MultithreadDirectConvCommon<float, float>::
bundle, kern_param, ncb_index, ncb_index.ndrange_id); copy_padding_kern_stride(
bundle, kern_param, ncb_index, ncb_index.ndrange_id);
}; };
ret_kerns.push_back({copy_padding, {group, N, IC}}); ret_kerns.push_back({copy_padding, {group, N, IC}});
auto do_conv = [bundle, conv_kern_function]( auto do_conv = [bundle, conv_kern_function](
const NCBKernParam& kern_param, const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable { const NCBKernIndex& ncb_index) mutable {
bundle.set(kern_param.workspace_ptr); bundle.set(kern_param.workspace_ptr);
MultithreadDirectConvCommon<float, float>::do_conv_kern_stride( fallback::MultithreadDirectConvCommon<float, float>::do_conv_kern_stride(
bundle, kern_param, ncb_index, conv_kern_function, bundle, kern_param, ncb_index, conv_kern_function,
ncb_index.ndrange_id); ncb_index.ndrange_id);
}; };
...@@ -570,8 +577,9 @@ size_t ConvBiasImpl::AlgoF32DirectStride2::get_workspace( ...@@ -570,8 +577,9 @@ size_t ConvBiasImpl::AlgoF32DirectStride2::get_workspace(
const NCBKernSizeParam& param) const { const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 2, 1) { MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 2, 1) {
bool large_group = param.filter_meta.group >= param.nr_threads; bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = MultithreadDirectConvCommon<float, float>::get_bundle_stride( auto bundle =
param, large_group); fallback::MultithreadDirectConvCommon<float, float>::get_bundle_stride(
param, large_group);
return bundle.total_size_in_bytes(); return bundle.total_size_in_bytes();
} }
MIDOUT_END(); MIDOUT_END();
...@@ -609,7 +617,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride2::get_kimpl ...@@ -609,7 +617,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride2::get_kimpl
SWITCH_KERN_STR2(); SWITCH_KERN_STR2();
WorkspaceBundle bundle = WorkspaceBundle bundle =
MultithreadDirectConvCommon<float, float>::get_bundle_stride( fallback::MultithreadDirectConvCommon<float, float>::get_bundle_stride(
param, large_group); param, large_group);
SmallVector<NCBKern> ret_kerns; SmallVector<NCBKern> ret_kerns;
//! When group >= nr_threads, treat it as large_group, each thread process //! When group >= nr_threads, treat it as large_group, each thread process
...@@ -624,13 +632,16 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride2::get_kimpl ...@@ -624,13 +632,16 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride2::get_kimpl
size_t OC = fm.ocpg; size_t OC = fm.ocpg;
bundle.set(kern_param.workspace_ptr); bundle.set(kern_param.workspace_ptr);
for (size_t ic = 0; ic < IC; ic++) { for (size_t ic = 0; ic < IC; ic++) {
MultithreadDirectConvCommon<float, float>::copy_padding_kern_stride( fallback::MultithreadDirectConvCommon<float, float>::
bundle, kern_param, ncb_index, {ncb_index.thread_id, 0, ic}); copy_padding_kern_stride(
bundle, kern_param, ncb_index,
{ncb_index.thread_id, 0, ic});
} }
for (size_t oc = 0; oc < OC; oc++) { for (size_t oc = 0; oc < OC; oc++) {
MultithreadDirectConvCommon<float, float>::do_conv_kern_stride( fallback::MultithreadDirectConvCommon<float, float>::
bundle, kern_param, ncb_index, conv_kern_function, do_conv_kern_stride(
{ncb_index.thread_id, 0, oc}); bundle, kern_param, ncb_index, conv_kern_function,
{ncb_index.thread_id, 0, oc});
} }
}; };
ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); ret_kerns.push_back({exec_one_group, {group, N, 1_z}});
...@@ -639,15 +650,16 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride2::get_kimpl ...@@ -639,15 +650,16 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride2::get_kimpl
const NCBKernParam& kern_param, const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable { const NCBKernIndex& ncb_index) mutable {
bundle.set(kern_param.workspace_ptr); bundle.set(kern_param.workspace_ptr);
MultithreadDirectConvCommon<float, float>::copy_padding_kern_stride( fallback::MultithreadDirectConvCommon<float, float>::
bundle, kern_param, ncb_index, ncb_index.ndrange_id); copy_padding_kern_stride(
bundle, kern_param, ncb_index, ncb_index.ndrange_id);
}; };
ret_kerns.push_back({copy_padding, {group, N, IC}}); ret_kerns.push_back({copy_padding, {group, N, IC}});
auto do_conv = [bundle, conv_kern_function]( auto do_conv = [bundle, conv_kern_function](
const NCBKernParam& kern_param, const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable { const NCBKernIndex& ncb_index) mutable {
bundle.set(kern_param.workspace_ptr); bundle.set(kern_param.workspace_ptr);
MultithreadDirectConvCommon<float, float>::do_conv_kern_stride( fallback::MultithreadDirectConvCommon<float, float>::do_conv_kern_stride(
bundle, kern_param, ncb_index, conv_kern_function, bundle, kern_param, ncb_index, conv_kern_function,
ncb_index.ndrange_id); ncb_index.ndrange_id);
}; };
......
/** /**
* \file dnn/src/arm_common/conv_bias/direct/multi_thread_common.cpp * \file dnn/src/fallback/conv_bias/direct/multi_thread_common.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
* *
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
...@@ -9,12 +9,14 @@ ...@@ -9,12 +9,14 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */
#include "src/arm_common/conv_bias/direct/multi_thread_common.h" #include "multi_thread_common.h"
#include "src/arm_common/conv_bias/postprocess_helper.h"
#include "src/fallback/matrix_mul/opr_impl.h" #include "src/fallback/matrix_mul/opr_impl.h"
using namespace megdnn; using namespace megdnn;
using namespace arm_common; using namespace fallback;
#if MEGDNN_X86
using namespace x86;
#endif
namespace { namespace {
bool need_dst_copy(const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { bool need_dst_copy(const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) {
...@@ -354,8 +356,8 @@ void MultithreadDirectConvCommon<io_ctype, compute_ctype>::do_conv_kern_stride( ...@@ -354,8 +356,8 @@ void MultithreadDirectConvCommon<io_ctype, compute_ctype>::do_conv_kern_stride(
kern_param.nonlineMode, kern_param.bias_type, kern_param.dst_type, 1_z, 1_z, kern_param.nonlineMode, kern_param.bias_type, kern_param.dst_type, 1_z, 1_z,
OH, OW); OH, OW);
}; };
template class megdnn::arm_common::MultithreadDirectConvCommon<float, float>; template class megdnn::fallback::MultithreadDirectConvCommon<float, float>;
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
template class megdnn::arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>; template class megdnn::fallback::MultithreadDirectConvCommon<dt_float16, __fp16>;
#endif #endif
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
/** /**
* \file dnn/src/arm_common/conv_bias/direct/multi_thread_common.h * \file dnn/src/fallback/conv_bias/direct/multi_thread_common.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
* *
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
...@@ -10,11 +10,20 @@ ...@@ -10,11 +10,20 @@
*/ */
#pragma once #pragma once
#include "src/arm_common/conv_bias/opr_impl.h" #include "src/fallback/conv_bias/opr_impl.h"
#include "src/fallback/matrix_mul/opr_impl.h" #include "src/fallback/matrix_mul/opr_impl.h"
#if MEGDNN_X86
#include "src/x86/conv_bias/postprocess_helper.h"
#elif (MEGDNN_ARMV7 || MEGDNN_AARCH64)
#include "src/arm_common/conv_bias/postprocess_helper.h"
#else
//! TODO: optimize common postprocess_helper with general intrinsic
#include "src/common/postprocess_helper.h"
#endif
namespace megdnn { namespace megdnn {
namespace arm_common { namespace fallback {
template <class io_ctype, class compute_ctype> template <class io_ctype, class compute_ctype>
class MultithreadDirectConvCommon { class MultithreadDirectConvCommon {
...@@ -53,7 +62,7 @@ public: ...@@ -53,7 +62,7 @@ public:
const CpuNDRange& workspace_ids); const CpuNDRange& workspace_ids);
}; };
} // namespace arm_common } // namespace fallback
} // namespace megdnn } // namespace megdnn
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -42,7 +42,7 @@ GI_FLOAT32_t GiReintInt32ToFloat32(GI_INT32_t Vector) { ...@@ -42,7 +42,7 @@ GI_FLOAT32_t GiReintInt32ToFloat32(GI_INT32_t Vector) {
#elif defined(GI_SSE2_INTRINSICS) #elif defined(GI_SSE2_INTRINSICS)
return _mm_castsi128_ps(Vector); return _mm_castsi128_ps(Vector);
#else #else
return (GI_FLOAT32_t)In; return (GI_FLOAT32_t)Vector;
#endif #endif
} }
...@@ -53,7 +53,7 @@ GI_FLOAT32_t GiReintUint32ToFloat32(GI_UINT32_t Vector) { ...@@ -53,7 +53,7 @@ GI_FLOAT32_t GiReintUint32ToFloat32(GI_UINT32_t Vector) {
#elif defined(GI_SSE2_INTRINSICS) #elif defined(GI_SSE2_INTRINSICS)
return _mm_castsi128_ps(Vector); return _mm_castsi128_ps(Vector);
#else #else
return (GI_FLOAT32_t)In; return (GI_FLOAT32_t)Vector;
#endif #endif
} }
......
/** /**
* \file dnn/src/fallback/general_intrinsic/gi_float.h * \file dnn/src/fallback/general_intrinsic/gi_int.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
* *
* Copyright (c) 2014-2022 Megvii Inc. All rights reserved. * Copyright (c) 2014-2022 Megvii Inc. All rights reserved.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册