提交 4c13bc7e 编写于 作者: M Megvii Engine Team

feat(dnn/cuda): add nhwc int8 deconv

GitOrigin-RevId: ad361a0f812458115f9a7f809345bef6fe693962
上级 11f022ff
...@@ -9,6 +9,7 @@ genrule( ...@@ -9,6 +9,7 @@ genrule(
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemm --type tensorop1688 $(@D) CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemm --type tensorop1688 $(@D)
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemv --type simt $(@D) CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemv --type simt $(@D)
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations deconv --type simt $(@D) CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations deconv --type simt $(@D)
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations deconv --type tensorop8816 $(@D)
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations conv2d --type simt $(@D) CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations conv2d --type simt $(@D)
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations conv2d --type tensorop8816 $(@D) CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations conv2d --type tensorop8816 $(@D)
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations conv2d --type tensorop8832 $(@D) CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations conv2d --type tensorop8832 $(@D)
......
...@@ -336,6 +336,9 @@ def GenerateConv2d(conv_kind, tile_descriptions, src_layout, flt_layout, dst_lay ...@@ -336,6 +336,9 @@ def GenerateConv2d(conv_kind, tile_descriptions, src_layout, flt_layout, dst_lay
swizzling_functor = SwizzlingFunctor.ConvFpropTrans swizzling_functor = SwizzlingFunctor.ConvFpropTrans
else: else:
swizzling_functor = SwizzlingFunctor.ConvFpropNCxHWx swizzling_functor = SwizzlingFunctor.ConvFpropNCxHWx
else:
if implicit_gemm_mode == ImplicitGemmMode.GemmTN:
swizzling_functor = SwizzlingFunctor.ConvDgradTrans
else: else:
swizzling_functor = SwizzlingFunctor.ConvDgradNCxHWx swizzling_functor = SwizzlingFunctor.ConvDgradNCxHWx
......
...@@ -36,6 +36,7 @@ if __name__ == "__main__": ...@@ -36,6 +36,7 @@ if __name__ == "__main__":
write_op_list(f, "gemm", "tensorop884") write_op_list(f, "gemm", "tensorop884")
write_op_list(f, "gemv", "simt") write_op_list(f, "gemv", "simt")
write_op_list(f, "deconv", "simt") write_op_list(f, "deconv", "simt")
write_op_list(f, "deconv", "tensorop8816")
write_op_list(f, "conv2d", "simt") write_op_list(f, "conv2d", "simt")
write_op_list(f, "conv2d", "tensorop8816") write_op_list(f, "conv2d", "tensorop8816")
write_op_list(f, "conv2d", "tensorop8832") write_op_list(f, "conv2d", "tensorop8832")
......
...@@ -445,6 +445,53 @@ def GenerateDeconv_Simt(args): ...@@ -445,6 +445,53 @@ def GenerateDeconv_Simt(args):
use_special_optimization) use_special_optimization)
return operations return operations
def GenerateDeconv_TensorOp_8816(args):
operations = []
layouts = [
(LayoutType.TensorNHWC, LayoutType.TensorCK4RS4, 32),
(LayoutType.TensorNHWC, LayoutType.TensorCK8RS8, 64),
(LayoutType.TensorNHWC, LayoutType.TensorCK16RS16, 128),
]
math_instructions = [
MathInstruction( \
[8, 8, 16], \
DataType.s8, DataType.s8, DataType.s32, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_saturate),
]
dst_layouts = [
LayoutType.TensorNHWC,
]
dst_types = [
DataType.s8,
]
use_special_optimization = SpecialOptimizeDesc.DeconvDoubleUpsampling
min_cc = 75
max_cc = 1024
cuda_major = 10
cuda_minor = 2
for math_inst in math_instructions:
for layout in layouts:
for dst_type, dst_layout in zip(dst_types, dst_layouts):
tile_descriptions = [
TileDescription([128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc),
TileDescription([64, 16, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc),
]
for tile in tile_descriptions:
dst_align = 32 if tile.threadblock_shape[1] == 16 else 64
operations += GenerateConv2d(ConvKind.Dgrad, [tile], layout[0], layout[1], dst_layout, dst_type,
min_cc, layout[2], layout[2], dst_align, use_special_optimization,
ImplicitGemmMode.GemmTN, False, cuda_major, cuda_minor)
return operations
################################################################################ ################################################################################
# parameters # parameters
# Edge - for tiles, the edges represent the length of one side # Edge - for tiles, the edges represent the length of one side
...@@ -820,9 +867,12 @@ def GenerateConv2dOperations(args): ...@@ -820,9 +867,12 @@ def GenerateConv2dOperations(args):
return GenerateConv2d_TensorOp_8832(args) return GenerateConv2d_TensorOp_8832(args)
def GenerateDeconvOperations(args): def GenerateDeconvOperations(args):
assert args.type == "simt", "operation deconv only support" \ if args.type == "simt":
"simt. (got:{})".format(args.type)
return GenerateDeconv_Simt(args) return GenerateDeconv_Simt(args)
else:
assert args.type == "tensorop8816", "operation deconv only support" \
"simt and tensorop8816. (got:{})".format(args.type)
return GenerateDeconv_TensorOp_8816(args)
def GenerateGemmOperations(args): def GenerateGemmOperations(args):
if args.type == "tensorop884": if args.type == "tensorop884":
......
...@@ -280,6 +280,9 @@ class LayoutType(enum.Enum): ...@@ -280,6 +280,9 @@ class LayoutType(enum.Enum):
TensorC32RSK32 = enum_auto() TensorC32RSK32 = enum_auto()
TensorC64RSK64 = enum_auto() TensorC64RSK64 = enum_auto()
TensorK4RSC4 = enum_auto() TensorK4RSC4 = enum_auto()
TensorCK4RS4 = enum_auto()
TensorCK8RS8 = enum_auto()
TensorCK16RS16 = enum_auto()
# #
LayoutTag = { LayoutTag = {
...@@ -304,6 +307,9 @@ LayoutTag = { ...@@ -304,6 +307,9 @@ LayoutTag = {
LayoutType.TensorNC64HW64: 'cutlass::layout::TensorNCxHWx<64>', LayoutType.TensorNC64HW64: 'cutlass::layout::TensorNCxHWx<64>',
LayoutType.TensorC64RSK64: 'cutlass::layout::TensorCxRSKx<64>', LayoutType.TensorC64RSK64: 'cutlass::layout::TensorCxRSKx<64>',
LayoutType.TensorK4RSC4: 'cutlass::layout::TensorKxRSCx<4>', LayoutType.TensorK4RSC4: 'cutlass::layout::TensorKxRSCx<4>',
LayoutType.TensorCK4RS4: 'cutlass::layout::TensorCKxRSx<4>',
LayoutType.TensorCK8RS8: 'cutlass::layout::TensorCKxRSx<8>',
LayoutType.TensorCK16RS16: 'cutlass::layout::TensorCKxRSx<16>',
} }
# #
...@@ -342,6 +348,9 @@ ShortLayoutTypeNames = { ...@@ -342,6 +348,9 @@ ShortLayoutTypeNames = {
LayoutType.TensorC32RSK32: 'c32rsk32', LayoutType.TensorC32RSK32: 'c32rsk32',
LayoutType.TensorC64RSK64: 'c64rsk64', LayoutType.TensorC64RSK64: 'c64rsk64',
LayoutType.TensorK4RSC4: 'k4rsc4', LayoutType.TensorK4RSC4: 'k4rsc4',
LayoutType.TensorCK4RS4: 'ck4rs4',
LayoutType.TensorCK8RS8: 'ck8rs8',
LayoutType.TensorCK16RS16: 'ck16rs16',
} }
# #
...@@ -484,6 +493,7 @@ class SwizzlingFunctor(enum.Enum): ...@@ -484,6 +493,7 @@ class SwizzlingFunctor(enum.Enum):
ConvFpropNCxHWx = enum_auto() ConvFpropNCxHWx = enum_auto()
ConvFpropTrans = enum_auto() ConvFpropTrans = enum_auto()
ConvDgradNCxHWx = enum_auto() ConvDgradNCxHWx = enum_auto()
ConvDgradTrans = enum_auto()
# #
SwizzlingFunctorTag = { SwizzlingFunctorTag = {
...@@ -494,6 +504,7 @@ SwizzlingFunctorTag = { ...@@ -494,6 +504,7 @@ SwizzlingFunctorTag = {
SwizzlingFunctor.ConvFpropNCxHWx: 'cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle', SwizzlingFunctor.ConvFpropNCxHWx: 'cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle',
SwizzlingFunctor.ConvFpropTrans: 'cutlass::conv::threadblock::ConvolutionFpropTransThreadblockSwizzle', SwizzlingFunctor.ConvFpropTrans: 'cutlass::conv::threadblock::ConvolutionFpropTransThreadblockSwizzle',
SwizzlingFunctor.ConvDgradNCxHWx: 'cutlass::conv::threadblock::ConvolutionDgradNCxHWxThreadblockSwizzle', SwizzlingFunctor.ConvDgradNCxHWx: 'cutlass::conv::threadblock::ConvolutionDgradNCxHWxThreadblockSwizzle',
SwizzlingFunctor.ConvDgradTrans: 'cutlass::conv::threadblock::ConvolutionDgradTransThreadblockSwizzle',
} }
################################################################################################### ###################################################################################################
......
...@@ -155,6 +155,7 @@ if(MGE_WITH_CUDA) ...@@ -155,6 +155,7 @@ if(MGE_WITH_CUDA)
gen_cutlass_kimpl(gemm tensorop1688 CUTLASS_SOURCES) gen_cutlass_kimpl(gemm tensorop1688 CUTLASS_SOURCES)
gen_cutlass_kimpl(gemv simt CUTLASS_SOURCES) gen_cutlass_kimpl(gemv simt CUTLASS_SOURCES)
gen_cutlass_kimpl(deconv simt CUTLASS_SOURCES) gen_cutlass_kimpl(deconv simt CUTLASS_SOURCES)
gen_cutlass_kimpl(deconv tensorop8816 CUTLASS_SOURCES)
gen_cutlass_kimpl(conv2d simt CUTLASS_SOURCES) gen_cutlass_kimpl(conv2d simt CUTLASS_SOURCES)
gen_cutlass_kimpl(conv2d tensorop8816 CUTLASS_SOURCES) gen_cutlass_kimpl(conv2d tensorop8816 CUTLASS_SOURCES)
gen_cutlass_kimpl(conv2d tensorop8832 CUTLASS_SOURCES) gen_cutlass_kimpl(conv2d tensorop8832 CUTLASS_SOURCES)
......
...@@ -36,6 +36,12 @@ ConvolutionBackwardDataImpl::AlgoPack::AlgoPack() { ...@@ -36,6 +36,12 @@ ConvolutionBackwardDataImpl::AlgoPack::AlgoPack() {
int8_algos.push_back(&algo); int8_algos.push_back(&algo);
} }
fill_int8_imma_algos();
for (auto&& algo : int8_nhwc_imma) {
all_algos.push_back(&algo);
int8_algos.push_back(&algo);
}
int8_algos.push_back(&int8_nchw_dotprod); int8_algos.push_back(&int8_nchw_dotprod);
all_algos.push_back(&int8_nchw_dotprod); all_algos.push_back(&int8_nchw_dotprod);
......
...@@ -40,7 +40,8 @@ public: ...@@ -40,7 +40,8 @@ public:
CUDA_BFLOAT16, CUDA_BFLOAT16,
CUDA_GROUP_CONV_GENERAL, CUDA_GROUP_CONV_GENERAL,
CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8, CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8,
CUDA_IMPLICIT_GEMM_NCHW_DOTPROD_INT8 CUDA_IMPLICIT_GEMM_NCHW_DOTPROD_INT8,
CUDA_IMPLICIT_GEMM_NHWC_IMMA_INT8
}; };
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;
...@@ -299,11 +300,53 @@ private: ...@@ -299,11 +300,53 @@ private:
const void* get_available_op(const SizeArgs& args) const; const void* get_available_op(const SizeArgs& args) const;
}; };
class ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm final
: public AlgoBase {
public:
struct AlgoParam {
int threadblock_m;
int threadblock_n;
int threadblock_k;
int warp_m;
int warp_n;
int warp_k;
int stage;
int access_size;
std::string to_string() {
return ssprintf("_%dX%dX%d_%dX%dX%d_%dstage_%d", threadblock_m,
threadblock_n, threadblock_k, warp_m, warp_n,
warp_k, stage, access_size);
}
};
AlgoInt8NHWCIMMAImplicitGemm(AlgoParam algo_param)
: m_algo_param{algo_param},
m_name{ssprintf("INT8_NHWC_IMMA_IMPLICIT_GEMM%s",
m_algo_param.to_string().c_str())} {}
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
const char* name() const override { return m_name.c_str(); }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NHWC_IMMA_INT8)
private:
WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr,
const SizeArgs& args) const;
const void* get_available_op(const SizeArgs& args) const;
void reorder_filter(const ExecArgs& args, const int iterleaved,
int8_t* reordered_filter) const;
AlgoParam m_algo_param;
std::string m_name;
};
class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj { class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj {
// defined in cudnn.cpp // defined in cudnn.cpp
void fill_cudnn_algos(); void fill_cudnn_algos();
// defined in implicit_gemm_int8_nchw4_dp4a.cpp // defined in implicit_gemm_int8_nchw4_dp4a.cpp
void fill_int8_dp4a_algos(); void fill_int8_dp4a_algos();
// defined in implicit_gemm_int8_nhwc_imma.cpp
void fill_int8_imma_algos();
AlgoBase::Mapper m_all_algos_map; AlgoBase::Mapper m_all_algos_map;
...@@ -318,6 +361,7 @@ public: ...@@ -318,6 +361,7 @@ public:
AlgoGroupConvGeneral group; AlgoGroupConvGeneral group;
std::vector<AlgoInt8NCHW4DotProdImplicitGemm> int8_nchw4_dotprod; std::vector<AlgoInt8NCHW4DotProdImplicitGemm> int8_nchw4_dotprod;
AlgoInt8NCHWDotProdImplicitGemm int8_nchw_dotprod; AlgoInt8NCHWDotProdImplicitGemm int8_nchw_dotprod;
std::vector<AlgoInt8NHWCIMMAImplicitGemm> int8_nhwc_imma;
std::vector<AlgoBase*> std::vector<AlgoBase*>
//! all algorithms //! all algorithms
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
*/ */
#include "src/cuda/convolution/backward_data/deconv_int8_helper.cuh" #include "src/cuda/convolution/backward_data/deconv_int8_helper.cuh"
#include "src/cuda/transpose_utils.cuh"
using namespace megdnn; using namespace megdnn;
using namespace cuda; using namespace cuda;
...@@ -21,7 +22,6 @@ using namespace deconv; ...@@ -21,7 +22,6 @@ using namespace deconv;
namespace { namespace {
//
__global__ void reorder_filter_nc4hw4_to_n4hwc4_kernel( __global__ void reorder_filter_nc4hw4_to_n4hwc4_kernel(
int8_t* __restrict__ dst, const int8_t* __restrict__ src, uint32_t OC, int8_t* __restrict__ dst, const int8_t* __restrict__ src, uint32_t OC,
uint32_t IC, uint32_t FHFW) { uint32_t IC, uint32_t FHFW) {
...@@ -30,32 +30,55 @@ __global__ void reorder_filter_nc4hw4_to_n4hwc4_kernel( ...@@ -30,32 +30,55 @@ __global__ void reorder_filter_nc4hw4_to_n4hwc4_kernel(
const int32_t fhfw = blockIdx.x * BLOCKSIZE_Y + threadIdx.x; const int32_t fhfw = blockIdx.x * BLOCKSIZE_Y + threadIdx.x;
if (fhfw < FHFW && icb < IC / 4) { if (fhfw < FHFW && icb < IC / 4) {
int src0 = *reinterpret_cast<const int*>( int src_value[4], dst_value[4];
src + (ocb * 4 + 0) * IC * FHFW + (icb * FHFW + fhfw) * 4); #pragma unroll
int src1 = *reinterpret_cast<const int*>( for (int i = 0; i < 4; i++) {
src + (ocb * 4 + 1) * IC * FHFW + (icb * FHFW + fhfw) * 4); src_value[i] = *reinterpret_cast<const int*>(
int src2 = *reinterpret_cast<const int*>( src + (ocb * 4 + i) * IC * FHFW + (icb * FHFW + fhfw) * 4);
src + (ocb * 4 + 2) * IC * FHFW + (icb * FHFW + fhfw) * 4); }
int src3 = *reinterpret_cast<const int*>(
src + (ocb * 4 + 3) * IC * FHFW + (icb * FHFW + fhfw) * 4);
// transpose 4x4 // transpose 4x4
int dst01_lo = __byte_perm(src0, src1, 0x5140); transpose_int8_interleavedx4<4, int>(src_value, dst_value);
int dst01_hi = __byte_perm(src0, src1, 0x7362);
int dst23_lo = __byte_perm(src2, src3, 0x5140);
int dst23_hi = __byte_perm(src2, src3, 0x7362);
int dst0 = __byte_perm(dst01_lo, dst23_lo, 0x5410);
int dst1 = __byte_perm(dst01_lo, dst23_lo, 0x7632);
int dst2 = __byte_perm(dst01_hi, dst23_hi, 0x5410);
int dst3 = __byte_perm(dst01_hi, dst23_hi, 0x7632);
#pragma unroll
for (int i = 0; i < 4; i++) {
*reinterpret_cast<int*>( *reinterpret_cast<int*>(
dst + (ocb * FHFW * IC + fhfw * IC + icb * 4 + 0) * 4) = dst0; dst + (ocb * FHFW * IC + fhfw * IC + icb * 4 + i) * 4) =
*reinterpret_cast<int*>( dst_value[i];
dst + (ocb * FHFW * IC + fhfw * IC + icb * 4 + 1) * 4) = dst1; }
*reinterpret_cast<int*>( }
dst + (ocb * FHFW * IC + fhfw * IC + icb * 4 + 2) * 4) = dst2; }
*reinterpret_cast<int*>(
dst + (ocb * FHFW * IC + fhfw * IC + icb * 4 + 3) * 4) = dst3; template <uint32_t interleaved, typename vec_type>
__global__ void reorder_filter_nhwc_to_cnxhwx_kernel(
int8_t* __restrict__ dst, const int8_t* __restrict__ src, uint32_t OC,
uint32_t IC, uint32_t FHFW) {
uint32_t lane = threadIdx.x + blockIdx.x * blockDim.x;
const int32_t ocb = lane / (FHFW * IC / 4);
const int32_t fhfw_icb = lane % (FHFW * IC / 4);
const int32_t fhfw = fhfw_icb / (IC / 4);
const int32_t icb = fhfw_icb % (IC / 4);
if (ocb < OC / interleaved && fhfw < FHFW) {
int src_value[interleaved];
vec_type dst_value[4];
#pragma unroll
for (int i = 0; i < interleaved; i++) {
src_value[i] = *reinterpret_cast<const int*>(
src + (ocb * interleaved + i) * FHFW * IC + fhfw * IC +
icb * 4);
}
transpose_int8_interleavedx4<interleaved, vec_type>(src_value,
dst_value);
#pragma unroll
for (int i = 0; i < 4; i++) {
*reinterpret_cast<vec_type*>(dst + (icb * 4 + i) * FHFW * OC +
(ocb * FHFW + fhfw) * interleaved) =
dst_value[i];
}
} }
} }
...@@ -73,4 +96,27 @@ void megdnn::cuda::deconv::reorder_filter_nc4hw4_to_n4hwc4( ...@@ -73,4 +96,27 @@ void megdnn::cuda::deconv::reorder_filter_nc4hw4_to_n4hwc4(
after_kernel_launch(); after_kernel_launch();
} }
void megdnn::cuda::deconv::reorder_filter_nhwc_to_cnxhwx(
int8_t* dst, const int8_t* src, uint32_t OC, uint32_t IC, uint32_t FH,
uint32_t FW, uint32_t interleaved, cudaStream_t stream) {
int32_t vthreads = OC / interleaved * IC / 4 * FH * FW;
int32_t nr_threads = std::min(256, vthreads);
int32_t nr_blocks = DIVUP(vthreads, nr_threads);
if (interleaved == 4) {
reorder_filter_nhwc_to_cnxhwx_kernel<4, int>
<<<nr_blocks, nr_threads, 0, stream>>>(dst, src, OC, IC,
FH * FW);
} else if (interleaved == 8) {
reorder_filter_nhwc_to_cnxhwx_kernel<8, int2>
<<<nr_blocks, nr_threads, 0, stream>>>(dst, src, OC, IC,
FH * FW);
} else {
reorder_filter_nhwc_to_cnxhwx_kernel<16, int4>
<<<nr_blocks, nr_threads, 0, stream>>>(dst, src, OC, IC,
FH * FW);
}
after_kernel_launch();
}
// vim: syntax=cuda.doxygen // vim: syntax=cuda.doxygen
...@@ -20,6 +20,10 @@ void reorder_filter_nc4hw4_to_n4hwc4(int8_t* dst, const int8_t* src, ...@@ -20,6 +20,10 @@ void reorder_filter_nc4hw4_to_n4hwc4(int8_t* dst, const int8_t* src,
uint32_t OC, uint32_t IC, uint32_t FH, uint32_t OC, uint32_t IC, uint32_t FH,
uint32_t FW, cudaStream_t stream); uint32_t FW, cudaStream_t stream);
void reorder_filter_nhwc_to_cnxhwx(int8_t* dst, const int8_t* src, uint32_t OC,
uint32_t IC, uint32_t FH, uint32_t FW,
uint32_t interleaved, cudaStream_t stream);
} // namespace deconv } // namespace deconv
} // namespace cuda } // namespace cuda
} // namespace megdnn } // namespace megdnn
......
/**
* \file
* dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nchw4_dp4a.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/cuda/convolution/backward_data/algo.h"
#include "src/cuda/convolution/backward_data/deconv_int8_helper.cuh"
#include "src/cuda/convolution_helper/parameter.cuh"
#include "src/cuda/cutlass/singleton.h"
#include "src/cuda/utils.h"
using namespace megdnn;
using namespace cuda;
const void*
ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm::get_available_op(
const SizeArgs& args) const {
using namespace cutlass::library;
auto&& fm = args.filter_meta;
size_t sh = fm.stride[0], sw = fm.stride[1];
cutlass::conv::SpecialOptimizeDesc special_optimization =
(sh == 2 && sw == 2) ? cutlass::conv::SpecialOptimizeDesc::
DECONV_DOUBLE_UPSAMPLING
: cutlass::conv::SpecialOptimizeDesc::NONE;
LayoutTypeID filter_layout;
if (m_algo_param.access_size == 16) {
filter_layout = LayoutTypeID::kTensorCK16RS16;
} else if (m_algo_param.access_size == 8) {
filter_layout = LayoutTypeID::kTensorCK8RS8;
} else {
megdnn_assert(m_algo_param.access_size == 4, "invalid access_size: %d",
m_algo_param.access_size);
filter_layout = LayoutTypeID::kTensorCK4RS4;
}
ConvolutionKey key{
cutlass::conv::Operator::kDgrad,
NumericTypeID::kS8,
LayoutTypeID::kTensorNHWC,
NumericTypeID::kS8,
filter_layout,
NumericTypeID::kS8,
LayoutTypeID::kTensorNHWC,
NumericTypeID::kS32,
LayoutTypeID::kTensorNHWC,
cutlass::conv::ConvType::kConvolution,
m_algo_param.threadblock_m,
m_algo_param.threadblock_n,
m_algo_param.threadblock_k,
m_algo_param.warp_m,
m_algo_param.warp_n,
m_algo_param.warp_k,
8,
8,
16,
cutlass::epilogue::EpilogueType::kBiasAddLinearCombinationClamp,
m_algo_param.stage,
special_optimization,
false};
return (void*)Singleton::get().operation_table.find_op(key);
}
bool ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm::is_available(
const SizeArgs& args) const {
auto&& fm = args.filter_meta;
if (fm.format != Param::Format::NHWC)
return false;
if (!args.grad_layout->is_contiguous() ||
!args.diff_layout->is_contiguous()) {
return false;
}
bool available = true;
auto src_dtype = args.diff_layout->dtype,
filter_dtype = args.filter_layout->dtype,
dst_dtype = args.grad_layout->dtype;
size_t co = args.diff_layout->operator[](3);
size_t ci = args.grad_layout->operator[](3);
available &= (src_dtype.enumv() == DTypeEnum::QuantizedS8 &&
filter_dtype.enumv() == DTypeEnum::QuantizedS8 &&
dst_dtype.enumv() == DTypeEnum::QuantizedS8);
// TODO support group deconv int8
available &= (fm.group == 1);
// mode must be cross correlation
available &= !fm.should_flip;
// mode must be 2D
available &= fm.spatial_ndim == 2;
// TODO: support dialtion
available &= (fm.dilation[0] == 1 && fm.dilation[1] == 1);
// FIXME: too large filter size is not supported now
size_t kMaxFilterPixels =
848 / (m_algo_param.warp_k / m_algo_param.access_size) - 1;
available &= fm.spatial[0] * fm.spatial[1] <= kMaxFilterPixels;
// ci should be aligned with 4, and co should be aligned with
// algo_param.access_size
available &= ((ci % 4 == 0) && (co % m_algo_param.access_size == 0));
available &= (get_available_op(args) != nullptr);
// only support sm_75 or later, platform should have imma int8 support
available &= is_compute_capability_required(7, 5);
return available;
}
WorkspaceBundle
ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm::get_workspace_bundle(
dt_byte* raw_ptr, const SizeArgs& args) const {
size_t ws_filter = args.filter_layout->span().dist_byte();
return WorkspaceBundle{raw_ptr, {ws_filter}};
}
size_t ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm::
get_workspace_in_bytes(const SizeArgs& args) const {
return get_workspace_bundle(nullptr, args).total_size_in_bytes();
}
void ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm::exec(
const ExecArgs& args) const {
auto&& param = args.opr->param();
auto&& fm = args.filter_meta;
size_t n = args.diff_layout->operator[](0),
co = args.diff_layout->operator[](3),
ho = args.diff_layout->operator[](1),
wo = args.diff_layout->operator[](2);
size_t ci = args.grad_layout->operator[](3),
hi = args.grad_layout->operator[](1),
wi = args.grad_layout->operator[](2);
size_t fh = fm.spatial[0], fw = fm.spatial[1];
size_t sh = fm.stride[0], sw = fm.stride[1];
size_t ph = fm.padding[0], pw = fm.padding[1];
size_t dh = param.dilate_h, dw = param.dilate_w;
auto&& stream = cuda_stream(args.opr->handle());
int8_t* filter_ptr = nullptr;
// TODO: weight preprocess
{
filter_ptr = reinterpret_cast<int8_t*>(args.workspace.raw_ptr);
// reformat filter from nc4hw4 to n4hwc4
reorder_filter(args, m_algo_param.access_size, filter_ptr);
}
float diff_scale =
args.diff_layout->dtype.param<dtype::QuantizedS8>().scale,
filter_scale =
args.filter_layout->dtype.param<dtype::QuantizedS8>().scale,
grad_scale =
args.grad_layout->dtype.param<dtype::QuantizedS8>().scale;
// \note these constants of cutlass epilogue will be passed to struct
// `ConvolutionArguments` by pointer and interpreted as ElementCompute*,
// a different dtype here results in undefined epilogue behaviors
float alpha = diff_scale * filter_scale / grad_scale, beta = 0.f,
gamma = 0.f, delta = 0.f;
using namespace cutlass::library;
const Operation* op = (const Operation*)get_available_op(args);
// gcc prints warnings when size_t values are implicitly narrowed to int
cutlass::conv::Conv2dProblemSize problem_size{
int(n), int(hi), int(wi), int(ci),
int(co), int(fh), int(fw), int(ho),
int(wo), int(ph), int(pw), int(sh),
int(sw), int(dh), int(dw), cutlass::conv::Mode::kCrossCorrelation};
cutlass::library::ConvolutionArguments conv_args{
problem_size, args.diff_tensor->compatible_ptr<int8_t>(),
filter_ptr, nullptr,
nullptr, args.grad_tensor->compatible_ptr<int8_t>(),
&alpha, &beta,
&gamma, &delta,
nullptr, nullptr,
nullptr, nullptr};
cutlass_check(op->run(&conv_args, nullptr, stream));
after_kernel_launch();
}
void ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm::reorder_filter(
const ExecArgs& args, const int interleaved,
int8_t* reordered_filter) const {
auto&& fm = args.filter_meta;
size_t co = args.diff_layout->operator[](3);
size_t ci = args.grad_layout->operator[](3);
size_t fh = fm.spatial[0], fw = fm.spatial[1];
auto&& stream = cuda_stream(args.opr->handle());
megdnn::cuda::deconv::reorder_filter_nhwc_to_cnxhwx(
reordered_filter, args.filter_tensor->compatible_ptr<int8_t>(), co,
ci, fh, fw, interleaved, stream);
}
void ConvolutionBackwardDataImpl::AlgoPack::fill_int8_imma_algos() {
using AlgoParam = AlgoInt8NHWCIMMAImplicitGemm::AlgoParam;
int8_nhwc_imma.emplace_back(AlgoParam{64, 16, 32, 64, 16, 32, 2, 4});
int8_nhwc_imma.emplace_back(AlgoParam{64, 16, 32, 64, 16, 32, 2, 8});
int8_nhwc_imma.emplace_back(AlgoParam{64, 16, 32, 64, 16, 32, 2, 16});
int8_nhwc_imma.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 1, 4});
int8_nhwc_imma.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 1, 8});
int8_nhwc_imma.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 1, 16});
}
// vim: syntax=cpp.doxygen
...@@ -99,6 +99,7 @@ public: ...@@ -99,6 +99,7 @@ public:
class AlgoBFloat16; class AlgoBFloat16;
class AlgoInt8NCHW4DotProdImplicitGemm; class AlgoInt8NCHW4DotProdImplicitGemm;
class AlgoInt8NCHWDotProdImplicitGemm; class AlgoInt8NCHWDotProdImplicitGemm;
class AlgoInt8NHWCIMMAImplicitGemm;
class AlgoPack; class AlgoPack;
......
...@@ -60,6 +60,7 @@ void initialize_all_gemm_tensorop884_operations(Manifest& manifest); ...@@ -60,6 +60,7 @@ void initialize_all_gemm_tensorop884_operations(Manifest& manifest);
void initialize_all_gemm_tensorop1688_operations(Manifest& manifest); void initialize_all_gemm_tensorop1688_operations(Manifest& manifest);
void initialize_all_conv2d_tensorop8816_operations(Manifest& manifest); void initialize_all_conv2d_tensorop8816_operations(Manifest& manifest);
void initialize_all_conv2d_tensorop8832_operations(Manifest& manifest); void initialize_all_conv2d_tensorop8832_operations(Manifest& manifest);
void initialize_all_deconv_tensorop8816_operations(Manifest& manifest);
#endif #endif
void initialize_all(Manifest& manifest) { void initialize_all(Manifest& manifest) {
...@@ -71,6 +72,7 @@ void initialize_all(Manifest& manifest) { ...@@ -71,6 +72,7 @@ void initialize_all(Manifest& manifest) {
initialize_all_gemm_tensorop1688_operations(manifest); initialize_all_gemm_tensorop1688_operations(manifest);
initialize_all_conv2d_tensorop8816_operations(manifest); initialize_all_conv2d_tensorop8816_operations(manifest);
initialize_all_conv2d_tensorop8832_operations(manifest); initialize_all_conv2d_tensorop8832_operations(manifest);
initialize_all_deconv_tensorop8816_operations(manifest);
#endif #endif
} }
......
...@@ -100,6 +100,9 @@ enum class LayoutTypeID { ...@@ -100,6 +100,9 @@ enum class LayoutTypeID {
kTensorNC64HW64, kTensorNC64HW64,
kTensorC64RSK64, kTensorC64RSK64,
kTensorK4RSC4, kTensorK4RSC4,
kTensorCK4RS4,
kTensorCK8RS8,
kTensorCK16RS16,
kInvalid kInvalid
}; };
...@@ -225,6 +228,7 @@ enum class ThreadblockSwizzleID { ...@@ -225,6 +228,7 @@ enum class ThreadblockSwizzleID {
kConvolutionFpropNCxHWx, kConvolutionFpropNCxHWx,
kConvolutionFpropTrans, kConvolutionFpropTrans,
kConvolutionDgradNCxHWx, kConvolutionDgradNCxHWx,
kConvolutionDgradTrans,
kInvalid kInvalid
}; };
......
...@@ -340,6 +340,21 @@ struct LayoutMap<cutlass::layout::TensorKxRSCx<4>> { ...@@ -340,6 +340,21 @@ struct LayoutMap<cutlass::layout::TensorKxRSCx<4>> {
static LayoutTypeID const kId = LayoutTypeID::kTensorK4RSC4; static LayoutTypeID const kId = LayoutTypeID::kTensorK4RSC4;
}; };
template <>
struct LayoutMap<cutlass::layout::TensorCKxRSx<4>> {
static LayoutTypeID const kId = LayoutTypeID::kTensorCK4RS4;
};
template <>
struct LayoutMap<cutlass::layout::TensorCKxRSx<8>> {
static LayoutTypeID const kId = LayoutTypeID::kTensorCK8RS8;
};
template <>
struct LayoutMap<cutlass::layout::TensorCKxRSx<16>> {
static LayoutTypeID const kId = LayoutTypeID::kTensorCK16RS16;
};
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T> template <typename T>
...@@ -556,6 +571,13 @@ struct ThreadblockSwizzleMap< ...@@ -556,6 +571,13 @@ struct ThreadblockSwizzleMap<
ThreadblockSwizzleID::kConvolutionDgradNCxHWx; ThreadblockSwizzleID::kConvolutionDgradNCxHWx;
}; };
template <>
struct ThreadblockSwizzleMap<
conv::threadblock::ConvolutionDgradTransThreadblockSwizzle> {
static ThreadblockSwizzleID const kId =
ThreadblockSwizzleID::kConvolutionDgradTrans;
};
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Element, typename Layout> template <typename Element, typename Layout>
......
...@@ -533,7 +533,10 @@ static struct { ...@@ -533,7 +533,10 @@ static struct {
{LayoutTypeID::kTensorC16RSK16, "c16rsk16"}, {LayoutTypeID::kTensorC16RSK16, "c16rsk16"},
{LayoutTypeID::kTensorC32RSK32, "c32rsk32"}, {LayoutTypeID::kTensorC32RSK32, "c32rsk32"},
{LayoutTypeID::kTensorC64RSK64, "c64rsk64"}, {LayoutTypeID::kTensorC64RSK64, "c64rsk64"},
{LayoutTypeID::kTensorK4RSC4, "k4rsC4"}, {LayoutTypeID::kTensorK4RSC4, "k4rsc4"},
{LayoutTypeID::kTensorCK4RS4, "ck4rs4"},
{LayoutTypeID::kTensorCK8RS8, "ck8rs8"},
{LayoutTypeID::kTensorCK16RS16, "ck16rs16"},
{LayoutTypeID::kUnknown, "*"}, {LayoutTypeID::kUnknown, "*"},
{LayoutTypeID::kInvalid, nullptr}}; {LayoutTypeID::kInvalid, nullptr}};
...@@ -1499,6 +1502,8 @@ static struct { ...@@ -1499,6 +1502,8 @@ static struct {
ThreadblockSwizzleID::kConvolutionFpropTrans}, ThreadblockSwizzleID::kConvolutionFpropTrans},
{"convolution_dgrad_ncxhwx", "ConvolutionDgradNCxHWxThreadblockSwizzle", {"convolution_dgrad_ncxhwx", "ConvolutionDgradNCxHWxThreadblockSwizzle",
ThreadblockSwizzleID::kConvolutionDgradNCxHWx}, ThreadblockSwizzleID::kConvolutionDgradNCxHWx},
{"convolution_dgrad_ncxhwx", "ConvolutionDgradTransThreadblockSwizzle",
ThreadblockSwizzleID::kConvolutionDgradTrans},
}; };
/// Converts a ThreadblockSwizzleID enumerant to a string /// Converts a ThreadblockSwizzleID enumerant to a string
......
/**
* \file dnn/src/cuda/memory_utils.cuh
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#if MEGDNN_CC_CUDA
#pragma once
#include "src/cuda/utils.cuh"
namespace megdnn {
namespace cuda {
MEGDNN_DEVICE __forceinline__ void transpose_int8_4x4_impl(
const int src0, const int src1, const int src2, const int src3,
int& dst0, int& dst1, int& dst2, int& dst3) {
int dst01_lo = __byte_perm(src0, src1, 0x5140);
int dst01_hi = __byte_perm(src0, src1, 0x7362);
int dst23_lo = __byte_perm(src2, src3, 0x5140);
int dst23_hi = __byte_perm(src2, src3, 0x7362);
dst0 = __byte_perm(dst01_lo, dst23_lo, 0x5410);
dst1 = __byte_perm(dst01_lo, dst23_lo, 0x7632);
dst2 = __byte_perm(dst01_hi, dst23_hi, 0x5410);
dst3 = __byte_perm(dst01_hi, dst23_hi, 0x7632);
}
template <uint32_t interleaved, typename vec_type>
MEGDNN_DEVICE __forceinline__ void transpose_int8_interleavedx4(
const int src[interleaved], vec_type (&dst)[4]);
template <>
MEGDNN_DEVICE __forceinline__ void transpose_int8_interleavedx4<4, int>(
const int src[4], int (&dst)[4]) {
transpose_int8_4x4_impl(src[0], src[1], src[2], src[3], dst[0], dst[1],
dst[2], dst[3]);
}
template <>
MEGDNN_DEVICE __forceinline__ void transpose_int8_interleavedx4<8, int2>(
const int src[8], int2 (&dst)[4]) {
transpose_int8_4x4_impl(src[0], src[1], src[2], src[3], dst[0].x, dst[1].x,
dst[2].x, dst[3].x);
transpose_int8_4x4_impl(src[4], src[5], src[6], src[7], dst[0].y, dst[1].y,
dst[2].y, dst[3].y);
}
template <>
MEGDNN_DEVICE __forceinline__ void transpose_int8_interleavedx4<16, int4>(
const int src[16], int4 (&dst)[4]) {
transpose_int8_4x4_impl(src[0], src[1], src[2], src[3], dst[0].x, dst[1].x,
dst[2].x, dst[3].x);
transpose_int8_4x4_impl(src[4], src[5], src[6], src[7], dst[0].y, dst[1].y,
dst[2].y, dst[3].y);
transpose_int8_4x4_impl(src[8], src[9], src[10], src[11], dst[0].z,
dst[1].z, dst[2].z, dst[3].z);
transpose_int8_4x4_impl(src[12], src[13], src[14], src[15], dst[0].w,
dst[1].w, dst[2].w, dst[3].w);
}
} // namespace cuda
} // namespace megdnn
#endif
// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -469,7 +469,6 @@ std::vector<TestArg> convolution::get_args_int8_nchw4_conv_bwd_data() { ...@@ -469,7 +469,6 @@ std::vector<TestArg> convolution::get_args_int8_nchw4_conv_bwd_data() {
return args; return args;
} }
std::vector<TestArg> convolution::get_args_int8_nchw_conv_bwd_data() { std::vector<TestArg> convolution::get_args_int8_nchw_conv_bwd_data() {
std::vector<TestArg> args; std::vector<TestArg> args;
param::Convolution cur_param; param::Convolution cur_param;
...@@ -511,6 +510,46 @@ std::vector<TestArg> convolution::get_args_int8_nchw_conv_bwd_data() { ...@@ -511,6 +510,46 @@ std::vector<TestArg> convolution::get_args_int8_nchw_conv_bwd_data() {
return args; return args;
} }
std::vector<TestArg> convolution::get_args_int8_nhwc_conv_bwd_data() {
std::vector<TestArg> args;
param::Convolution cur_param;
// clang-format off
for (auto mode : {param::Convolution::Mode::CROSS_CORRELATION}) {
for (size_t b : {64, 16}) {
for (size_t ic : {16, 32}) {
for (size_t oc : {16, 32}) {
for (size_t h : {8}) {
for (size_t w : {8, 11}) {
for (size_t kernel_size : {3, 4, 5, 7}) {
for (int p : {0, static_cast<int>(kernel_size / 2)}) {
for (size_t s : {2}) {
if (kernel_size >= 7) {
b = std::min(b, 32_z);
}
size_t f = kernel_size;
cur_param.mode = mode;
cur_param.format = param::Convolution::Format::NHWC;
cur_param.sparse = param::Convolution::Sparse::DENSE;
cur_param.pad_h = cur_param.pad_w = p;
cur_param.stride_h = cur_param.stride_w = s;
//! bias channel
args.emplace_back(cur_param, TensorShape{b, h, w, ic},
TensorShape{oc, f, f, ic});
} } } } } } } } }
// clang-format on
cur_param.pad_h = cur_param.pad_w = 1;
cur_param.stride_h = cur_param.stride_w = 1;
args.emplace_back(cur_param, TensorShape{16, 8, 11, 16},
TensorShape{16, 3, 3, 16});
return args;
}
void convolution::test_conv_config_combinations( void convolution::test_conv_config_combinations(
int k_size, Handle* handle, bool test_int8, bool test_backward, int k_size, Handle* handle, bool test_int8, bool test_backward,
bool is_cuda, ConvEPSGetter eps_getter, bool use_io16xc32) { bool is_cuda, ConvEPSGetter eps_getter, bool use_io16xc32) {
......
...@@ -50,6 +50,7 @@ std::vector<TestArg> get_dilated_args(); ...@@ -50,6 +50,7 @@ std::vector<TestArg> get_dilated_args();
std::vector<TestArg> get_chanwise_args(); std::vector<TestArg> get_chanwise_args();
std::vector<TestArg> get_args_int8_nchw4_conv_bwd_data(); std::vector<TestArg> get_args_int8_nchw4_conv_bwd_data();
std::vector<TestArg> get_args_int8_nchw_conv_bwd_data(); std::vector<TestArg> get_args_int8_nchw_conv_bwd_data();
std::vector<TestArg> get_args_int8_nhwc_conv_bwd_data();
//! \param stage 0 for fwd, 1 for bwd data, 2 for bwd filter //! \param stage 0 for fwd, 1 for bwd data, 2 for bwd filter
using ConvEPSGetter = using ConvEPSGetter =
......
...@@ -386,6 +386,69 @@ TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA_INT8_NCHW_DP4A) { ...@@ -386,6 +386,69 @@ TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA_INT8_NCHW_DP4A) {
} }
} }
#if CUDA_VERSION >= 10020
TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA_INT8_NHWC_IMMA) {
if (!cuda::is_compute_capability_required(7, 5)) {
printf("Skip CUDA.CONVOLUTION_BACKWARD_DATA_INT8_NHWC_IMMA test as "
"current device doesn't support\n");
return;
}
using namespace convolution;
std::vector<TestArg> args = get_args_int8_nhwc_conv_bwd_data();
struct AlgoParam {
int threadblock_m;
int threadblock_n;
int threadblock_k;
int warp_m;
int warp_n;
int warp_k;
int stage;
int access_size;
std::string to_string() {
return ssprintf("_%dX%dX%d_%dX%dX%d_%dstage_%d", threadblock_m,
threadblock_n, threadblock_k, warp_m, warp_n,
warp_k, stage, access_size);
}
};
std::vector<AlgoParam> all_params;
all_params.emplace_back(AlgoParam{64, 16, 32, 64, 16, 32, 2, 4});
all_params.emplace_back(AlgoParam{64, 16, 32, 64, 16, 32, 2, 8});
all_params.emplace_back(AlgoParam{64, 16, 32, 64, 16, 32, 2, 16});
all_params.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 1, 4});
all_params.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 1, 8});
all_params.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 1, 16});
for (auto algo_param : all_params) {
Checker<ConvolutionBackwardData> checker(handle_cuda());
std::string algo_name(ssprintf("INT8_NHWC_IMMA_IMPLICIT_GEMM%s",
algo_param.to_string().c_str()));
checker.set_before_exec_callback(
AlgoChecker<ConvolutionBackwardData>(algo_name.c_str()));
checker.set_epsilon(1 + 1e-3).set_max_avg_error(1e-1);
for (auto&& arg : args) {
UniformIntRNG rng(-3, 3);
auto src = TensorLayout(arg.src, dtype::QuantizedS8{1.2f});
auto filter = TensorLayout(arg.filter, dtype::QuantizedS8{1.3f});
TensorLayout dst;
dst.dtype = dtype::QuantizedS8{1.2f};
{
auto opr = handle_cuda()->create_operator<Convolution>();
opr->param() = arg.param;
opr->deduce_layout(src, filter, dst);
}
checker.set_rng(0, &rng).set_rng(1, &rng).set_param(arg.param).exec(
TensorLayoutArray{filter, dst, src});
}
}
}
#endif
TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA_FAILED_CUDNN7_5) { TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA_FAILED_CUDNN7_5) {
// BRAIN-481 failed on architectures 7.0, remove the following if statement, // BRAIN-481 failed on architectures 7.0, remove the following if statement,
// when cudnn fixed the problem. // when cudnn fixed the problem.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册