提交 2a2c1275 编写于 作者: M Megvii Engine Team

fix(dnn): correctly using MEGDNN_DISABLE_FLOAT16 directives

GitOrigin-RevId: c6b124f195c9fc3a830bb058797d7d5619aad72d
上级 4a863160
......@@ -15,6 +15,7 @@ using conv_fun = std::function<void(
const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids)>;
MIDOUT_DECL(megdnn_arm_common_conv_bias_fp16_nchw88)
MIDOUT_DECL(megdnn_arm_common_conv_bias_fp16_nchw88_stride1)
namespace {
static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) {
......
......@@ -14,7 +14,7 @@ struct RoundingConverter<float> {
}
};
#ifndef MEGDNN_DISABLE_FLOAT16
#if !MEGDNN_DISABLE_FLOAT16
template <>
struct RoundingConverter<half_float::half> {
......@@ -32,7 +32,7 @@ struct RoundingConverter<half_bfloat16::bfloat16> {
}
};
#endif // #ifdef MEGDNN_DISABLE_FLOAT16
#endif // #if !MEGDNN_DISABLE_FLOAT16
template <>
struct RoundingConverter<int8_t> {
......
......@@ -295,7 +295,7 @@ void WarpPerspectiveForwardImpl::exec(
m_error_tracker, stream);
} else if (DNN_FLOAT16_SELECT(
src.layout.dtype == dtype::Float16(), false)) {
#ifndef MEGDNN_DISABLE_FLOAT16
#if !MEGDNN_DISABLE_FLOAT16
warp_perspective::forward_proxy(
is_nhwc, src.ptr<dt_float16>(), mat.ptr<dt_float32>(),
mat_idx.raw_ptr() ? mat_idx.ptr<int>() : nullptr,
......@@ -563,7 +563,7 @@ void WarpPerspectiveForwardImpl::exec(
m_error_tracker, stream);
} else if (DNN_FLOAT16_SELECT(
src.layout.dtype == dtype::Float16(), false)) {
#ifndef MEGDNN_DISABLE_FLOAT16
#if !MEGDNN_DISABLE_FLOAT16
SmallVector<size_t> workspace_sizes{sizeof(dt_float16*) * srcs.size()};
WorkspaceBundle workspace_cpu(nullptr, workspace_sizes);
auto total_workspace_size = workspace_cpu.total_size_in_bytes();
......
......@@ -1924,7 +1924,7 @@ void forward_proxy_nchw64(
cudaStream_t);
INST(float)
INST(uint8_t)
#ifndef MEGDNN_DISABLE_FLOAT16
#if !MEGDNN_DISABLE_FLOAT16
INST(dt_float16)
#endif
INST(int8_t)
......@@ -1936,7 +1936,7 @@ INST(int8_t)
int, int, int, ctype, BorderMode, megcore::AsyncErrorInfo*, void*, \
cudaStream_t);
INST(float)
#ifndef MEGDNN_DISABLE_FLOAT16
#if !MEGDNN_DISABLE_FLOAT16
INST(dt_float16)
#endif
#undef INST
......
......@@ -73,7 +73,7 @@ struct powci_general_even {
template <size_t size>
struct float_itype;
#ifndef MEGDNN_DISABLE_FLOAT16
#if !MEGDNN_DISABLE_FLOAT16
template <>
struct float_itype<2> {
using type = uint16_t;
......
......@@ -84,7 +84,7 @@ ResizeImpl::KernParam<ctype> ResizeImpl::KernParam<ctype>::from_tensors(
#define INST(_dtype) template struct ResizeImpl::KernParam<_dtype>;
INST(dt_float32);
#ifndef MEGDNN_DISABLE_FLOAT16
#if !MEGDNN_DISABLE_FLOAT16
INST(dt_float16);
#endif
INST(dt_int8);
......
......@@ -15,7 +15,7 @@ MegRay::DType mgb::opr::get_megray_dtype(megdnn::DType dtype) {
return MegRay::DType::MEGRAY_INT32;
case DTypeEnum::Float32:
return MegRay::DType::MEGRAY_FLOAT32;
#ifndef MEGDNN_DISABLE_FLOAT16
#if !MEGDNN_DISABLE_FLOAT16
case DTypeEnum::Float16:
return MegRay::DType::MEGRAY_FLOAT16;
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册