提交 14b65e4d 编写于 作者: M Megvii Engine Team

feat(dnn/cuda): add reduce_filter_and_update_bias

GitOrigin-RevId: 31b6e6b0abe2790029e63c9f91c64290a1801958
上级 2d4e62ef
......@@ -15,7 +15,7 @@
#include "./quint4x4x32_wmma/activation_u4.cuh"
#include "./quint4x4x32_wmma/reduce_with_scale_data.cuh"
#include "./reduce_with_scale_filter.cuh"
#include "./reduce_filter.cuh"
#include "./quint4x4x32_wmma/wmma_conv_integer_u4.cuh"
using namespace megdnn;
......
......@@ -25,7 +25,7 @@
*
**************************************************************************************************/
/**
* \file dnn/src/cuda/conv_bias/reduce_with_scale_filter.cu
* \file dnn/src/cuda/conv_bias/reduce_filter.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
......@@ -36,9 +36,11 @@
* implied.
*/
#include "./reduce_with_scale_filter.cuh"
#include "src/cuda/reduce_helper.cuh"
#include "./reduce_filter.cuh"
#include "src/cuda/utils.cuh"
#include "src/cuda/integer_subbyte_utils.cuh"
#include "src/cuda/reduce_helper.cuh"
using namespace megdnn;
using namespace cuda;
......@@ -76,6 +78,38 @@ struct ReduceWithScaleInt4Op {
#endif
};
template <bool signedness>
struct ReduceUpdateBiasInt4Op {
typedef int32_t wtype;
const uint8_t* filter;
const int32_t* src_bias;
int32_t* dst_bias;
int32_t zero_point;
static const wtype INIT = 0;
#if MEGDNN_CC_CUDA
__host__ __device__ void write(uint32_t idx, wtype val) {
dst_bias[idx] = src_bias[idx] - val * zero_point;
}
__host__ __device__ static wtype apply(wtype a, wtype b) { return a + b; }
__device__ wtype read(uint32_t idx) {
constexpr uint32_t subbytes_per_pixel = 8;
const uint32_t* fptr =
(const uint32_t*)(filter + subbytes_per_pixel * idx / 2);
uint32_t val = *fptr;
int32_t ret = 0;
#pragma unroll
for (int j = 0; j < 8; j++) {
ret += integer_subbyte::unpack_integer_4bits<signedness>(val,
(j << 2));
}
return ret;
}
#endif
};
} // namespace
template <bool signedness>
......@@ -106,6 +140,31 @@ INST(false);
INST(true);
#undef INST
template <bool signedness>
void megdnn::cuda::do_dispatch_reduce_filter_and_update_bias_4bit(
const uint8_t* filter, const int32_t* src_bias, uint32_t rows,
uint32_t cols, int32_t* dst_bias, int32_t* workspace,
int32_t zero_point, cudaStream_t stream) {
ReduceUpdateBiasInt4Op<signedness> op;
op.filter = filter;
op.src_bias = src_bias;
op.dst_bias = dst_bias;
op.zero_point = zero_point;
run_reduce<ReduceUpdateBiasInt4Op<signedness>, false>(workspace, rows, cols,
1, stream, op);
}
#define INST(signedness) \
template void \
megdnn::cuda::do_dispatch_reduce_filter_and_update_bias_4bit<signedness>( \
const uint8_t* filter, const int32_t* src_bias, uint32_t rows, \
uint32_t cols, int32_t* dst_bias, int32_t* workspace, \
int32_t zero_point, cudaStream_t stream)
INST(false);
INST(true);
#undef INST
size_t megdnn::cuda::do_dispatch_reduce_workspace_in_bytes(size_t A, size_t B,
size_t C) {
return get_reduce_workspace_in_bytes<ReduceWithScaleInt4Op<false>>(A, B, C);
......
......@@ -25,7 +25,7 @@
*
**************************************************************************************************/
/**
* \file dnn/src/cuda/conv_bias/reduce_with_scale_filter.cuh
* \file dnn/src/cuda/conv_bias/reduce_filter.cuh
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
......@@ -36,16 +36,28 @@
* implied.
*/
#include "src/cuda/utils.cuh"
#include <stddef.h>
#include <stdint.h>
#include <cuda_runtime.h>
namespace megdnn {
namespace cuda {
template <bool signedness>
void do_dispatch_reduce_with_scale_filter_4bit(const uint8_t* src,
int32_t scale, uint32_t rows,
uint32_t cols, int32_t* dst,
cudaStream_t stream);
template <bool signedness>
void do_dispatch_reduce_filter_and_update_bias_4bit(
const uint8_t* filter, const int32_t* src_bias, uint32_t rows,
uint32_t cols, int32_t* dst_bias, int32_t* workspace, int zero_point,
cudaStream_t stream);
size_t do_dispatch_reduce_workspace_in_bytes(size_t A, size_t B, size_t C);
} // namespace cuda
} // namespace megdnn
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册