提交 4e9b1c4e 编写于 作者: M Megvii Engine Team

feat(dnn): add rrconv wgrad, support int32 and uint8 region mask

GitOrigin-RevId: 0da9b3bca86ba6ae11289e451b224aca348f647b
上级 9e020d23
......@@ -19,6 +19,7 @@ genrule(
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations dwconv2d_dgrad --type tensorop884 $(@D)
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations dwconv2d_wgrad --type simt $(@D)
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations dwconv2d_wgrad --type tensorop884 $(@D)
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations rrconv2d_wgrad --type simt $(@D)
""",
tools = ["//brain/megbrain/dnn/scripts/cutlass_generator:generator.py"],
visibility = ["//visibility:public"],
......
......@@ -35,6 +35,8 @@ class Conv2dOperation:
without_shared_load=False,
required_cuda_ver_major=9,
required_cuda_ver_minor=2,
rin=None,
rout=None,
):
self.operation_kind = OperationKind.Conv2d
......@@ -54,6 +56,8 @@ class Conv2dOperation:
self.without_shared_load = without_shared_load
self.required_cuda_ver_major = required_cuda_ver_major
self.required_cuda_ver_minor = required_cuda_ver_minor
self.rin = rin
self.rout = rout
#
def accumulator_type(self):
......@@ -95,6 +99,8 @@ class Conv2dOperation:
conv_type_name = ""
if self.conv_type == ConvType.DepthwiseConvolution:
conv_type_name = "dw"
elif self.conv_type == ConvType.RegionRestrictedConvolution:
conv_type_name = "rr"
return "%s%s%s%s%s%s%s_%s" % (
ShortDataTypeNames[self.accumulator_type()],
......@@ -125,6 +131,9 @@ class Conv2dOperation:
elif self.src.element == self.flt.element:
extended_name = "${core_name}_${element_src}"
if self.rin != None:
extended_name += "_${element_rin}"
extended_name = SubstituteTemplate(
extended_name,
{
......@@ -132,6 +141,7 @@ class Conv2dOperation:
"element_flt": DataTypeNames[self.flt.element],
"element_dst": DataTypeNames[self.dst.element],
"core_name": self.core_name(),
"element_rin": DataTypeNames[self.rin.element],
},
)
......@@ -512,6 +522,115 @@ using Convolution_${operation_name} =
return SubstituteTemplate(self.template, values)
class EmitRegionRestrictedConvolutionBackwardFilterInstance:
def __init__(self):
self.template = """
// kernel instance "${operation_name}" generated by cutlass generator
using Convolution_${operation_name} =
typename cutlass::conv::device::RegionRestrictedConvolutionBackwardFilter<
${element_src},
${layout_src},
${element_diff},
${layout_diff},
${element_src_mask},
${layout_src_mask},
${element_output_mask},
${layout_output_mask},
${element_grad},
${layout_grad},
${element_accumulator},
${conv_type},
${opcode_class},
${arch},
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
${epilogue_functor}<
${element_grad},
${epilogue_vector_length},
${element_accumulator},
${element_epilogue}
>,
${swizzling_functor},
${stages},
${alignment_src},
${alignment_diff},
${alignment_src_mask},
${alignment_output_mask},
${special_optimization},
${math_operator},
${implicit_gemm_mode}>;
"""
def emit(self, operation):
warp_shape = [
int(
operation.tile_description.threadblock_shape[idx]
/ operation.tile_description.warp_count[idx]
)
for idx in range(3)
]
epilogue_vector_length = int(
min(operation.dst.alignment * DataTypeSize[operation.dst.element], 128)
/ DataTypeSize[operation.dst.element]
)
values = {
"operation_name": operation.procedural_name(),
"conv_type": ConvTypeTag[operation.conv_type],
"element_src": DataTypeTag[operation.src.element],
"layout_src": LayoutTag[operation.src.layout],
"element_diff": DataTypeTag[operation.flt.element],
"layout_diff": LayoutTag[operation.flt.layout],
"element_src_mask": DataTypeTag[operation.rin.element],
"layout_src_mask": LayoutTag[operation.rin.layout],
"element_output_mask": DataTypeTag[operation.rout.element],
"layout_output_mask": LayoutTag[operation.rout.layout],
"element_grad": DataTypeTag[operation.dst.element],
"layout_grad": LayoutTag[operation.dst.layout],
"element_accumulator": DataTypeTag[operation.accumulator_type()],
"opcode_class": OpcodeClassTag[
operation.tile_description.math_instruction.opcode_class
],
"arch": "cutlass::arch::Sm%d" % operation.arch,
"threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]),
"threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]),
"threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]),
"warp_shape_m": str(warp_shape[0]),
"warp_shape_n": str(warp_shape[1]),
"warp_shape_k": str(warp_shape[2]),
"instruction_shape_m": str(
operation.tile_description.math_instruction.instruction_shape[0]
),
"instruction_shape_n": str(
operation.tile_description.math_instruction.instruction_shape[1]
),
"instruction_shape_k": str(
operation.tile_description.math_instruction.instruction_shape[2]
),
"epilogue_vector_length": str(epilogue_vector_length),
"epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor],
"element_epilogue": str(DataTypeTag[operation.element_epilogue]),
"swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor],
"stages": str(operation.tile_description.stages),
"alignment_src": str(operation.src.alignment),
"alignment_diff": str(operation.flt.alignment),
"alignment_src_mask": str(operation.rin.alignment),
"alignment_output_mask": str(operation.rout.alignment),
"special_optimization": SpecialOptimizeDescTag[
operation.special_optimization
],
"math_operator": MathOperationTag[
operation.tile_description.math_instruction.math_operation
],
"implicit_gemm_mode": ImplicitGemmModeTag[operation.implicit_gemm_mode],
}
return SubstituteTemplate(self.template, values)
###################################################################################################
#
# Generator functions for all layouts
......@@ -540,7 +659,10 @@ def GenerateConv2d(
operations = []
element_epilogue = DataType.f32
if conv_type == ConvType.DepthwiseConvolution:
if (
conv_type == ConvType.DepthwiseConvolution
or conv_type == ConvType.RegionRestrictedConvolution
):
if conv_kind == ConvKind.Fprop:
swizzling_functor = SwizzlingFunctor.DepthwiseConvolutionFprop
elif conv_kind == ConvKind.Dgrad:
......@@ -680,6 +802,16 @@ def GenerateConv2d(
flt_layout,
int(flt_align / DataTypeSize[tile.math_instruction.element_a]),
)
rin = TensorDescription(
tile.math_instruction.element_rin,
src_layout,
int(src_align / DataTypeSize[tile.math_instruction.element_rin]),
)
rout = TensorDescription(
tile.math_instruction.element_rout,
dst_layout,
int(dst_align / DataTypeSize[tile.math_instruction.element_rout]),
)
bias = TensorDescription(
bias_type, dst_layout, max(1, int(32 / DataTypeSize[bias_type]))
)
......@@ -704,6 +836,8 @@ def GenerateConv2d(
without_shared_load,
required_cuda_ver_major,
required_cuda_ver_minor,
rin,
rout,
)
operations.append(new_operation)
if use_special_optimization != SpecialOptimizeDesc.NoneSpecialOpt:
......@@ -724,6 +858,8 @@ def GenerateConv2d(
without_shared_load,
required_cuda_ver_major,
required_cuda_ver_minor,
rin,
rout,
)
operations.append(new_operation)
return operations
......@@ -955,5 +1091,89 @@ void initialize_${operation_name}(Manifest &manifest) {
self.kernel_file.close()
class EmitRegionRestrictedConvSingleKernelWrapper:
def __init__(self, kernel_path, operation, short_path=False):
self.kernel_path = kernel_path
self.operation = operation
self.short_path = short_path
# Now only support wgrad
assert self.operation.conv_kind == ConvKind.Wgrad
self.instance_emitter = EmitRegionRestrictedConvolutionBackwardFilterInstance()
self.convolution_name = "RegionRestrictedConvolutionBackwardFilterOperation"
self.header_template = """
#if __CUDACC_VER_MAJOR__ > ${required_cuda_ver_major} || (__CUDACC_VER_MAJOR__ == ${required_cuda_ver_major} && __CUDACC_VER_MINOR__ >= ${required_cuda_ver_minor})
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#pragma GCC diagnostic ignored "-Wuninitialized"
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#include "cutlass/convolution/device/convolution.h"
#include "src/cuda/cutlass/manifest.h"
#include "src/cuda/cutlass/convolution_operation.h"
"""
self.instance_template = """
${operation_instance}
"""
self.manifest_template = """
namespace cutlass {
namespace library {
void initialize_${operation_name}(Manifest &manifest) {
manifest.append(new ${convolution_name}<Convolution_${operation_name}>(
"${operation_name}"
));
}
} // namespace library
} // namespace cutlass
"""
self.epilogue_template = """
#pragma GCC diagnostic pop
#endif
"""
#
def __enter__(self):
if self.short_path:
self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % GlobalCnt.cnt)
GlobalCnt.cnt += 1
else:
self.kernel_path = os.path.join(
self.kernel_path, "%s.cu" % self.operation.procedural_name()
)
self.kernel_file = open(self.kernel_path, "w")
return self
#
def emit(self):
self.kernel_file.write(
SubstituteTemplate(
self.instance_template,
{"operation_instance": self.instance_emitter.emit(self.operation)},
)
)
# emit manifest helper
manifest = SubstituteTemplate(
self.manifest_template,
{
"operation_name": self.operation.procedural_name(),
"convolution_name": self.convolution_name,
},
)
self.kernel_file.write(manifest)
#
def __exit__(self, exception_type, exception_value, traceback):
self.kernel_file.close()
###################################################################################################
###################################################################################################
......@@ -64,4 +64,5 @@ if __name__ == "__main__":
write_merge_file_name(f, "dwconv2d_dgrad", "tensorop884", 4)
write_merge_file_name(f, "dwconv2d_wgrad", "simt", 2)
write_merge_file_name(f, "dwconv2d_wgrad", "tensorop884", 4)
write_merge_file_name(f, "rrconv2d_wgrad", "simt", 2)
f.write("]")
......@@ -1260,6 +1260,218 @@ def GenerateDwconv2d_Simt(args, conv_kind):
return operations
def GenerateRegionRestrictedconv2d_Simt(args, conv_kind):
################################################################################
# warps per threadblock
################################################################################
warpsPerThreadblocks = []
for warpsPerThreadblock0 in warpsPerThreadblockEdge:
for warpsPerThreadblock1 in warpsPerThreadblockEdge:
if (
warpsPerThreadblock0 / warpsPerThreadblock1 <= warpsPerThreadblockRatio
and warpsPerThreadblock1 / warpsPerThreadblock0
<= warpsPerThreadblockRatio
and warpsPerThreadblock0 * warpsPerThreadblock1
<= warpsPerThreadblockMax
):
warpsPerThreadblocks.append(
[warpsPerThreadblock0, warpsPerThreadblock1]
)
################################################################################
# warp shapes
################################################################################
warpNumThreads = 32
warpShapes = []
for warp0 in warpShapeEdges:
for warp1 in warpShapeEdges:
if (
warp0 / warp1 <= warpShapeRatio
and warp1 / warp0 <= warpShapeRatio
and warp0 * warp1 <= warpShapeMax
and warp0 * warp1 > warpShapeMin
):
warpShapes.append([warp0, warp1])
# sgemm
(
precisionType,
precisionBits,
threadblockMaxElements,
threadblockTilesL0,
) = precisions["s"]
layouts = [(LayoutType.TensorNCHW, LayoutType.TensorNCHW)]
math_instructions = [
MathInstruction(
[1, 1, 1],
DataType.f32,
DataType.f32,
DataType.f32,
OpcodeClass.Simt,
MathOperation.multiply_add,
DataType.s32,
DataType.s32,
),
MathInstruction(
[1, 1, 1],
DataType.f32,
DataType.f32,
DataType.f32,
OpcodeClass.Simt,
MathOperation.multiply_add,
DataType.s8,
DataType.s8,
),
]
min_cc = 50
max_cc = 1024
dst_layouts = [LayoutType.TensorNCHW]
dst_types = [DataType.f32]
if conv_kind == ConvKind.Wgrad:
alignment_constraints = [32]
else:
alignment_constraints = [128, 32]
operations = []
for math_inst in math_instructions:
tile_descriptions = [
TileDescription([128, 128, 8], 1, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 8], 1, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 128, 8], 1, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 32, 8], 1, [2, 1, 1], math_inst, min_cc, max_cc),
TileDescription([32, 128, 8], 1, [1, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 64, 8], 1, [2, 1, 1], math_inst, min_cc, max_cc),
TileDescription([32, 64, 8], 1, [1, 1, 1], math_inst, min_cc, max_cc),
TileDescription([64, 32, 8], 1, [1, 1, 1], math_inst, min_cc, max_cc),
TileDescription([32, 32, 8], 1, [1, 1, 1], math_inst, min_cc, max_cc),
]
for warpsPerThreadblock in warpsPerThreadblocks:
for warpShape in warpShapes:
warpThreadsM = 0
if warpShape[0] > warpShape[1]:
warpThreadsM = 8
else:
warpThreadsM = 4
warpThreadsN = warpNumThreads / warpThreadsM
# skip shapes with conflicting rectangularity
# they are unlikely to be fastest
blockG = warpsPerThreadblock[0] > warpsPerThreadblock[1]
blockL = warpsPerThreadblock[0] < warpsPerThreadblock[1]
warpG = warpShape[0] > warpShape[1]
warpL = warpShape[0] < warpShape[1]
blockG2 = warpsPerThreadblock[0] > warpsPerThreadblock[1] * 2
blockL2 = warpsPerThreadblock[0] * 2 < warpsPerThreadblock[1]
warpG2 = warpShape[0] > warpShape[1] * 2
warpL2 = warpShape[0] * 2 < warpShape[1]
if blockG2 and warpL:
continue
if blockL2 and warpG:
continue
if warpG2 and blockL:
continue
if warpL2 and blockG:
continue
# check threadblock ratios and max
threadblockTile = [
warpShape[0] * warpsPerThreadblock[0],
warpShape[1] * warpsPerThreadblock[1],
]
if threadblockTile[0] * threadblockTile[1] > threadblockMaxElements:
continue
if threadblockTile[0] > threadblockEdgeMax:
continue
if threadblockTile[1] > threadblockEdgeMax:
continue
totalThreads = (
warpNumThreads * warpsPerThreadblock[0] * warpsPerThreadblock[1]
)
# calculate unroll
# ensure that every iteration at least a full load of A,B are done
unrollMin = 8
unrollMin0 = totalThreads // threadblockTile[0]
unrollMin1 = totalThreads // threadblockTile[1]
unroll = max(unrollMin, unrollMin0, unrollMin1)
threadTileM = warpShape[0] // warpThreadsM
threadTileN = warpShape[1] // warpThreadsN
if threadTileM < 2 or threadTileN < 2:
continue
if threadTileM * threadTileN * precisionBits > 8 * 8 * 32:
continue
# epilogue currently only supports N < WarpNumThreads
if threadblockTile[1] < warpNumThreads:
continue
# limit smem
smemBitsA = threadblockTile[0] * unroll * 2 * precisionBits
smemBitsB = threadblockTile[1] * unroll * 2 * precisionBits
smemKBytes = (smemBitsA + smemBitsB) / 8 / 1024
if smemKBytes > 48:
continue
tile = TileDescription(
[threadblockTile[0], threadblockTile[1], unroll],
1,
[
threadblockTile[0] // warpShape[0],
threadblockTile[1] // warpShape[1],
1,
],
math_inst,
min_cc,
max_cc,
)
def filter(t: TileDescription) -> bool:
nonlocal tile
return (
t.threadblock_shape[0] == tile.threadblock_shape[0]
and t.threadblock_shape[1] == tile.threadblock_shape[1]
and t.threadblock_shape[2] == tile.threadblock_shape[2]
and t.warp_count[0] == tile.warp_count[0]
and t.warp_count[1] == tile.warp_count[1]
and t.warp_count[2] == tile.warp_count[2]
and t.stages == tile.stages
)
if not any(t for t in tile_descriptions if filter(t)):
continue
for layout in layouts:
for dst_type, dst_layout in zip(dst_types, dst_layouts):
for alignment_src in alignment_constraints:
operations += GenerateConv2d(
ConvType.RegionRestrictedConvolution,
conv_kind,
[tile],
layout[0],
layout[1],
dst_layout,
dst_type,
min_cc,
alignment_src,
32,
32,
SpecialOptimizeDesc.NoneSpecialOpt,
ImplicitGemmMode.GemmNT
if conv_kind == ConvKind.Wgrad
else ImplicitGemmMode.GemmTN,
)
return operations
#
def GenerateDwconv2d_TensorOp_884(args, conv_kind):
layouts = [(LayoutType.TensorNCHW, LayoutType.TensorNCHW)]
......@@ -1644,6 +1856,14 @@ def GenerateDwconv2dWgradOperations(args):
return GenerateDwconv2d_TensorOp_884(args, ConvKind.Wgrad)
def GenerateRegionRestrictedconv2dWgradOperations(args):
assert args.type == "simt", (
"operation RegionRestrictedconv2d wgrad only support"
"simt. (got:{})".format(args.type)
)
return GenerateRegionRestrictedconv2d_Simt(args, ConvKind.Wgrad)
def GenerateGemmOperations(args):
if args.type == "tensorop884":
return GeneratesGemm_TensorOp_884(args)
......@@ -1698,6 +1918,8 @@ def ConcatFile(
sub_string_1 = sub_string_2 = "simt"
if "dwconv2d_" in operations:
filtered_operations = operations[:2] + operations[9:]
if "rrconv2d_" in operations:
filtered_operations = operations[:2] + operations[9:]
elif ("conv2d" in operations) or ("deconv" in operations):
filtered_operations = "cutlass"
else:
......@@ -1893,6 +2115,7 @@ if __name__ == "__main__":
"dwconv2d_fprop",
"dwconv2d_dgrad",
"dwconv2d_wgrad",
"rrconv2d_wgrad",
],
required=True,
help="Specifies the operation to generate (gemm, gemv, conv2d, deconv, dwconv2d_fprop, dwconv2d_dgrad, dwconv2d_wgrad)",
......@@ -1928,9 +2151,11 @@ if __name__ == "__main__":
operations = GenerateDwconv2dFpropOperations(args)
elif args.operations == "dwconv2d_dgrad":
operations = GenerateDwconv2dDgradOperations(args)
else:
assert args.operations == "dwconv2d_wgrad", "invalid operation"
elif args.operations == "dwconv2d_wgrad":
operations = GenerateDwconv2dWgradOperations(args)
else:
assert args.operations == "rrconv2d_wgrad", "invalid operation"
operations = GenerateRegionRestrictedconv2dWgradOperations(args)
if (
args.operations == "conv2d"
......@@ -1974,6 +2199,42 @@ if __name__ == "__main__":
required_cuda_ver_minor,
epilogue,
)
elif args.operations == "rrconv2d_wgrad":
for operation in operations:
with EmitRegionRestrictedConvSingleKernelWrapper(
args.output, operation, short_path
) as emitter:
emitter.emit()
head = EmitRegionRestrictedConvSingleKernelWrapper(
args.output, operations[0], short_path
).header_template
required_cuda_ver_major = operations[0].required_cuda_ver_major
required_cuda_ver_minor = operations[0].required_cuda_ver_minor
epilogue = EmitRegionRestrictedConvSingleKernelWrapper(
args.output, operations[0], short_path
).epilogue_template
if "tensorop" in args.type:
ConcatFile(
4,
args.output,
args.operations,
args.type,
head,
required_cuda_ver_major,
required_cuda_ver_minor,
epilogue,
)
else:
ConcatFile(
2,
args.output,
args.operations,
args.type,
head,
required_cuda_ver_major,
required_cuda_ver_minor,
epilogue,
)
elif args.operations == "gemm":
for operation in operations:
with EmitGemmSingleKernelWrapper(
......
......@@ -532,6 +532,7 @@ class ConvType(enum.Enum):
Local = enum_auto()
LocalShare = enum_auto()
DepthwiseConvolution = enum_auto()
RegionRestrictedConvolution = enum_auto()
ConvTypeTag = {
......@@ -540,6 +541,8 @@ ConvTypeTag = {
ConvType.Local: "cutlass::conv::ConvType::kLocal",
ConvType.LocalShare: "cutlass::conv::ConvType::kLocalShare",
ConvType.DepthwiseConvolution: "cutlass::conv::ConvType::kDepthwiseConvolution",
# RegionRestrictedConvolution using the same conv type with Depthwise
ConvType.RegionRestrictedConvolution: "cutlass::conv::ConvType::kDepthwiseConvolution",
}
#
......@@ -640,6 +643,8 @@ class MathInstruction:
element_accumulator,
opcode_class,
math_operation=MathOperation.multiply_add,
element_rin=DataType.s32,
element_rout=DataType.s32,
):
self.instruction_shape = instruction_shape
self.element_a = element_a
......@@ -647,6 +652,8 @@ class MathInstruction:
self.element_accumulator = element_accumulator
self.opcode_class = opcode_class
self.math_operation = math_operation
self.element_rin = element_rin
self.element_rout = element_rout
#
......
......@@ -85,4 +85,7 @@ cutlass_gen_list = [
"dwconv2d_wgrad_tensorop884_2.cu",
"dwconv2d_wgrad_tensorop884_3.cu",
"all_dwconv2d_wgrad_tensorop884_operations.cu",
"rrconv2d_wgrad_simt_0.cu",
"rrconv2d_wgrad_simt_1.cu",
"all_rrconv2d_wgrad_simt_operations.cu",
]
\ No newline at end of file
......@@ -188,6 +188,7 @@ if(MGE_WITH_CUDA)
gen_cutlass_kimpl(dwconv2d_dgrad tensorop884 CUTLASS_SOURCES)
gen_cutlass_kimpl(dwconv2d_wgrad simt CUTLASS_SOURCES)
gen_cutlass_kimpl(dwconv2d_wgrad tensorop884 CUTLASS_SOURCES)
gen_cutlass_kimpl(rrconv2d_wgrad simt CUTLASS_SOURCES)
list(PREPEND CUSOURCES ${CUTLASS_SOURCES})
# Compile the following file first, the priority_compile_opr.txt is generated by
......
......@@ -452,6 +452,86 @@ public:
///////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Operator_>
class RegionRestrictedConvolutionBackwardFilterOperation
: public ConvolutionBackwardFilterOperationBase<Operator_> {
public:
using Operator = Operator_;
using ElementSrc = typename Operator::ElementSrc;
using LayoutSrc = typename Operator::LayoutSrc;
using ElementDiff = typename Operator::ElementDiff;
using LayoutDiff = typename Operator::LayoutDiff;
using ElementGrad = typename Operator::ElementGrad;
using LayoutGrad = typename Operator::LayoutGrad;
using ElementAccumulator = typename Operator::ElementAccumulator;
using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute;
using OperatorArguments = typename Operator::Arguments;
using ElementRin = typename Operator::ElementMaskInput;
using LayoutRin = typename Operator::LayoutMaskInput;
using ElementRout = typename Operator::ElementMaskOutput;
using LayoutRout = typename Operator::LayoutMaskOutput;
RegionRestrictedConvolutionBackwardFilterOperation(
char const* name = "unknown_gemm")
: ConvolutionBackwardFilterOperationBase<Operator_>(name) {
/// rin in description -> rin in C++ template
this->m_description.rin = make_TensorDescription<ElementRin, LayoutRin>(
Operator::kAlignmentMaskInput);
/// rout in description -> rout in C++ template
this->m_description.rout = make_TensorDescription<ElementRout, LayoutRout>(
Operator::kAlignmentMaskOutput);
this->m_description.without_shared_load = false;
}
virtual Status run(
void const* arguments_ptr, void* device_workspace = nullptr,
cudaStream_t stream = nullptr) const {
cutlass::conv::Operator conv_op = this->m_description.conv_op;
ConvolutionArguments const* conv_args =
reinterpret_cast<ConvolutionArguments const*>(arguments_ptr);
const auto& ps = conv_args->problem_size;
OperatorArguments args;
args.problem_size = ps;
/// src in convolution arguments -> ref_src
args.ref_src = {
static_cast<ElementSrc*>(const_cast<void*>(conv_args->src)),
LayoutSrc::packed(implicit_gemm_tensor_b_extent(conv_op, ps))};
/// filter in convolution arguments -> ref_diff
args.ref_diff = {
static_cast<ElementDiff*>(const_cast<void*>(conv_args->filter)),
LayoutDiff::packed(implicit_gemm_tensor_a_extent(conv_op, ps))};
/// dst in convolution arguments -> ref_grad
args.ref_grad = {
static_cast<ElementGrad*>(conv_args->dst),
LayoutGrad::packed(implicit_gemm_tensor_c_extent(conv_op, ps))};
/// rin in convolution arguments -> ref_mask_input
args.ref_mask_input = {
static_cast<ElementRin*>(const_cast<void*>(conv_args->rin)),
LayoutRin::packed(implicit_gemm_tensor_rin_extent(conv_op, ps))};
/// rout in convolution arguments -> ref_mask_output
args.ref_mask_output = {
static_cast<ElementRout*>(const_cast<void*>(conv_args->rout)),
LayoutRout::packed(implicit_gemm_tensor_rout_extent(conv_op, ps))};
args.output_op = init_epilogue_param<typename Operator::EpilogueOutputOp>().get(
conv_args);
Operator op;
Status status = op.initialize(args, device_workspace);
if (status != Status::kSuccess) {
return status;
}
return op.run(stream);
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace library
} // namespace cutlass
......
......@@ -50,6 +50,7 @@ void initialize_all_deconv_simt_operations(Manifest& manifest);
void initialize_all_dwconv2d_fprop_simt_operations(Manifest& manifest);
void initialize_all_dwconv2d_dgrad_simt_operations(Manifest& manifest);
void initialize_all_dwconv2d_wgrad_simt_operations(Manifest& manifest);
void initialize_all_rrconv2d_wgrad_simt_operations(Manifest& manifest);
#if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) && CUTLASS_ARCH_MMA_SM70_SUPPORTED
void initialize_all_gemm_tensorop884_operations(Manifest& manifest);
void initialize_all_dwconv2d_fprop_tensorop884_operations(Manifest& manifest);
......@@ -70,6 +71,7 @@ void initialize_all(Manifest& manifest) {
initialize_all_dwconv2d_fprop_simt_operations(manifest);
initialize_all_dwconv2d_dgrad_simt_operations(manifest);
initialize_all_dwconv2d_wgrad_simt_operations(manifest);
initialize_all_rrconv2d_wgrad_simt_operations(manifest);
#if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) && CUTLASS_ARCH_MMA_SM70_SUPPORTED
initialize_all_gemm_tensorop884_operations(manifest);
initialize_all_dwconv2d_fprop_tensorop884_operations(manifest);
......
......@@ -471,6 +471,10 @@ struct ConvolutionDescription : public OperationDescription {
conv::SpecialOptimizeDesc special_optimization;
conv::ImplicitGemmMode gemm_mode;
bool without_shared_load;
// only used by rrconv
TensorDescription rin;
TensorDescription rout;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
......@@ -499,6 +503,10 @@ struct ConvolutionArguments {
/// Host pointer to extra param struct
void const* extra_param;
// only used by rrconv, default: nullptr
void const* rin = nullptr;
void const* rout = nullptr;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
......
......@@ -118,6 +118,11 @@ ConvolutionKey get_convolution_key_from_desc(const ConvolutionDescription& desc)
key.alignment_filter = desc.filter.alignment;
key.without_shared_load = desc.without_shared_load;
key.element_rin = desc.rin.element;
key.layout_rin = desc.rin.layout;
key.element_rout = desc.rout.element;
key.layout_rout = desc.rout.layout;
return key;
}
......
......@@ -201,6 +201,12 @@ struct ConvolutionKey {
bool without_shared_load;
// only used by rrconv
library::NumericTypeID element_rin = library::NumericTypeID::kInvalid;
library::LayoutTypeID layout_rin = library::LayoutTypeID::kInvalid;
library::NumericTypeID element_rout = library::NumericTypeID::kInvalid;
library::LayoutTypeID layout_rout = library::LayoutTypeID::kInvalid;
inline bool operator==(ConvolutionKey const& rhs) const {
return (conv_op == rhs.conv_op) && (element_src == rhs.element_src) &&
(layout_src == rhs.layout_src) &&
......@@ -223,7 +229,9 @@ struct ConvolutionKey {
(special_optimization == rhs.special_optimization) &&
(alignment_src == rhs.alignment_src) &&
(alignment_filter == rhs.alignment_filter) &&
(without_shared_load == rhs.without_shared_load);
(without_shared_load == rhs.without_shared_load) &&
(element_rin == rhs.element_rin) && (layout_rin == rhs.layout_rin) &&
(element_rout == rhs.element_rout) && (layout_rout == rhs.layout_rout);
}
inline bool operator!=(ConvolutionKey const& rhs) const { return !(*this == rhs); }
......@@ -260,7 +268,11 @@ struct ConvolutionKey {
"\n special_optimization: " + to_string(special_optimization) +
"\n alignment_src: " + std::to_string(alignment_src) +
"\n alignment_filter: " + std::to_string(alignment_filter) +
"\n without_shared_load: " + to_string(without_shared_load) + "\n}";
"\n without_shared_load: " + to_string(without_shared_load) +
"\n element_rin: " + to_string(element_rin) +
"\n layout_rin: " + to_string(layout_rin) +
"\n element_rout: " + to_string(element_rout) +
"\n layout_rout: " + to_string(layout_rout) + "\n}";
}
};
......@@ -293,6 +305,10 @@ struct ConvolutionKeyHasher {
.update(&key.alignment_src, sizeof(key.alignment_src))
.update(&key.alignment_filter, sizeof(key.alignment_filter))
.update(&key.without_shared_load, sizeof(key.without_shared_load))
.update(&key.element_rin, sizeof(key.element_rin))
.update(&key.layout_rin, sizeof(key.layout_rin))
.update(&key.element_rout, sizeof(key.element_rout))
.update(&key.layout_rout, sizeof(key.layout_rout))
.digest();
}
};
......
#include "src/cuda/region_restricted_convolution/opr_impl.h"
#include "src/cuda/cutlass/singleton.h"
#include "src/cuda/region_restricted_convolution/chanwise/depthwise_large_filter.cuh"
#include "src/cuda/region_restricted_convolution/chanwise/kern.cuh"
#include "src/cuda/utils.h"
......@@ -6,6 +7,7 @@
using namespace megdnn;
using namespace cuda;
using namespace region_restricted_convolution;
using namespace cutlass::library;
/* ============== RegionRestrictedConvolutionForwardImpl ============== */
void RegionRestrictedConvolutionForwardImpl::exec(
......@@ -113,7 +115,137 @@ size_t RegionRestrictedConvolutionBackwardFilterImpl::get_workspace_in_bytes(
void RegionRestrictedConvolutionBackwardFilterImpl::exec(
_megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_in rin,
_megdnn_tensor_in rout, _megdnn_tensor_out grad, _megdnn_workspace workspace) {
megdnn_throw("Region Restricted Conv BackwardFilter unimplemented");
auto fm = check_exec(
src.layout, diff.layout, rin.layout, rout.layout, grad.layout,
workspace.size);
megdnn_assert(
fm.group > 1 && src.layout.dtype.category() == DTypeCategory::FLOAT &&
param().compute_mode == Param::ComputeMode::DEFAULT &&
fm.spatial_ndim == 2 && fm.icpg == 1 && fm.ocpg == 1 &&
fm.dilation[0] == 1 && fm.dilation[1] == 1 && !fm.should_flip &&
param().stride_h == 1 && param().stride_w == 1);
int hi = src.layout.operator[](2), wi = src.layout.operator[](3);
int n = diff.layout.operator[](0), ho = diff.layout.operator[](2),
wo = diff.layout.operator[](3);
int co = fm.group, ci = co, groups = co;
int fh = fm.spatial[0], fw = fm.spatial[1];
int sh = fm.stride[0], sw = fm.stride[1];
int ph = fm.padding[0], pw = fm.padding[1];
int dh = 0, dw = 0;
// check if channelwise convolution
megdnn_assert(fm.icpg == 1 && fm.ocpg == 1);
auto stream = cuda_stream(handle());
float alpha = 1.f;
float beta = 0.f;
ConvolutionKey key;
int threadblock_shape_n = 128;
int warp_shape_m = 32;
int warp_shape_n = 64;
if (grad.layout.operator[](3) % 8 < 4) {
threadblock_shape_n = 64;
warp_shape_m = 64;
warp_shape_n = 32;
}
if (rin.layout.dtype == dtype::Int32() && rout.layout.dtype == dtype::Int32()) {
key = {
cutlass::conv::Operator::kWgrad,
NumericTypeID::kF32,
LayoutTypeID::kTensorNCHW,
NumericTypeID::kF32,
LayoutTypeID::kTensorNCHW,
NumericTypeID::kF32,
LayoutTypeID::kTensorNCHW,
NumericTypeID::kF32,
LayoutTypeID::kTensorNCHW,
NumericTypeID::kF32,
cutlass::conv::ConvType::kDepthwiseConvolution,
128,
threadblock_shape_n,
8,
warp_shape_m,
warp_shape_n,
8,
1,
1,
1,
cutlass::epilogue::EpilogueType::kLinearCombination,
1,
cutlass::conv::SpecialOptimizeDesc::NONE,
1,
1,
false,
NumericTypeID::kS32,
LayoutTypeID::kTensorNCHW,
NumericTypeID::kS32,
LayoutTypeID::kTensorNCHW,
};
} else if (
rin.layout.dtype == dtype::Uint8() && rout.layout.dtype == dtype::Uint8()) {
key = {
cutlass::conv::Operator::kWgrad,
NumericTypeID::kF32,
LayoutTypeID::kTensorNCHW,
NumericTypeID::kF32,
LayoutTypeID::kTensorNCHW,
NumericTypeID::kF32,
LayoutTypeID::kTensorNCHW,
NumericTypeID::kF32,
LayoutTypeID::kTensorNCHW,
NumericTypeID::kF32,
cutlass::conv::ConvType::kDepthwiseConvolution,
128,
threadblock_shape_n,
8,
warp_shape_m,
warp_shape_n,
8,
1,
1,
1,
cutlass::epilogue::EpilogueType::kLinearCombination,
1,
cutlass::conv::SpecialOptimizeDesc::NONE,
1,
1,
false,
NumericTypeID::kS8,
LayoutTypeID::kTensorNCHW,
NumericTypeID::kS8,
LayoutTypeID::kTensorNCHW,
};
} else {
megdnn_throw(ssprintf(
"don't support region restricted type rin: %s, rout: %s",
rin.layout.dtype.name(), rout.layout.dtype.name())
.c_str());
}
const Operation* op =
(const Operation*)Singleton::get().operation_table.find_op(key);
cutlass::conv::Conv2dProblemSize problem_size{
n, hi, wi, ci, co, fh, fw, ho,
wo, ph, pw, sh, sw, dh, dw, cutlass::conv::Mode::kCrossCorrelation,
1, // split k slices, always 1
groups, // groups
};
cutlass::library::ConvolutionArguments conv_args{
problem_size, src.raw_ptr(), diff.raw_ptr(), nullptr,
nullptr, grad.raw_ptr(), &alpha, &beta,
nullptr, nullptr, nullptr, nullptr,
nullptr, nullptr, rin.raw_ptr(), rout.raw_ptr()};
cutlass_check(op->run(&conv_args, nullptr, stream));
after_kernel_launch();
}
// vim: syntax=cpp.doxygen
......@@ -465,6 +465,206 @@ TEST_F(CUDA, BENCHMARK_REGION_RESTRICTED_CONV_FORWARD_LARGE_FILTER_UINT8) {
run_bench(64, 384, 32, 32, 31, 31, 1, 1, 10);
}
TEST_F(CUDA, BENCHMARK_REGION_RESTRICTED_CONV_BACKWARD_FILTER_FP32) {
require_compute_capability(7, 5);
Benchmarker<ConvolutionBackwardFilter> bencher(handle_cuda());
bencher.set_display(false);
bencher.set_before_exec_callback(AlgoChecker<ConvolutionBackwardFilter>(
"FLOAT32_NCHW_FMA_IMPLICIT_BATCHED_GEMM_128X128X8_32X64X8_2stage"));
Benchmarker<RegionRestrictedConvolutionBackwardFilter> rr_bencher(handle_cuda());
rr_bencher.set_display(false);
ConvolutionBackwardFilter::Param param;
param.format = ConvolutionBackwardFilter::Param::Format::NCHW;
param.sparse = ConvolutionBackwardFilter::Param::Sparse::GROUP;
RegionRestrictedConvolutionBackwardFilter::Param rr_param;
rr_param.format = RegionRestrictedConvolutionBackwardFilter::Param::Format::NCHW;
rr_param.sparse = RegionRestrictedConvolutionBackwardFilter::Param::Sparse::GROUP;
UniformIntRNG r_rng{1, 3};
auto run_bench = [&](size_t batch, size_t g, size_t hi, size_t wi, size_t fh,
size_t fw, size_t sh, size_t sw, size_t nr_times) {
param.pad_h = fh / 2;
param.pad_w = fw / 2;
param.stride_h = sh;
param.stride_w = sw;
rr_param.pad_h = fh / 2;
rr_param.pad_w = fw / 2;
rr_param.stride_h = sh;
rr_param.stride_w = sw;
bencher.set_param(param)
.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Float32())
.set_dtype(4, dtype::Float32());
bencher.proxy()->target_execution_policy = {};
bencher.set_times(nr_times);
rr_bencher.set_param(rr_param)
.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Int32())
.set_dtype(3, dtype::Int32());
rr_bencher.set_rng(2, &r_rng).set_rng(3, &r_rng);
rr_bencher.set_times(nr_times);
size_t ho = infer_conv_shape(hi, fh, sh, param.pad_h);
size_t wo = infer_conv_shape(wi, fw, sw, param.pad_w);
TensorShape src{batch, g, hi, wi}, diff{batch, g, ho, wo}, rin{batch, hi, wi},
rout{batch, ho, wo}, grad{g, 1, 1, fh, fw};
float bandwith = static_cast<float>(
src.total_nr_elems() + diff.total_nr_elems() +
grad.total_nr_elems()) /
(1024 * 1024 * 1024) * 1e3;
float rr_bandwith = static_cast<float>(
src.total_nr_elems() + diff.total_nr_elems() +
rin.total_nr_elems() + rout.total_nr_elems() +
grad.total_nr_elems()) /
(1024 * 1024 * 1024) * 1e3;
auto time_in_ms = bencher.execs({src, diff, grad}) / nr_times;
auto ops = 2.0 * batch * g * hi * wi * fh * fw / (time_in_ms * 1e-3) * 1e-12;
auto rr_time_in_ms = rr_bencher.execs({src, diff, rin, rout, grad}) / nr_times;
auto rr_ops =
2.0 * batch * g * hi * wi * fh * fw / (rr_time_in_ms * 1e-3) * 1e-12;
printf("[DGRAD]RegionRestrictedDepthwiseLargeFilter vs DepthwiseLargeFilter: "
"src=%s, "
"diff=%s, grad=%s\n"
"time: %.2f ms, time(rr): %.2f ms, perf: %.2fTops, perf(rr): %.2f Tops\n"
"bandwidth: %.2fGB/s, bandwidth(rr): %.2fGB/s, speedup: %.2f.\n",
src.to_string().c_str(), diff.to_string().c_str(),
grad.to_string().c_str(), time_in_ms, rr_time_in_ms, ops, rr_ops,
bandwith * 4 / time_in_ms, rr_bandwith * 4 / rr_time_in_ms,
time_in_ms / rr_time_in_ms);
};
run_bench(64, 384, 32, 32, 3, 3, 1, 1, 1000);
run_bench(64, 384, 32, 32, 5, 5, 1, 1, 1000);
run_bench(64, 384, 32, 32, 7, 7, 1, 1, 1000);
run_bench(64, 384, 32, 32, 9, 9, 1, 1, 1000);
run_bench(64, 384, 32, 32, 11, 11, 1, 1, 1000);
run_bench(64, 384, 32, 32, 13, 13, 1, 1, 1000);
run_bench(64, 384, 32, 32, 15, 15, 1, 1, 1000);
run_bench(64, 384, 32, 32, 17, 17, 1, 1, 1000);
run_bench(64, 384, 32, 32, 19, 19, 1, 1, 1000);
run_bench(64, 384, 32, 32, 21, 21, 1, 1, 1000);
run_bench(64, 384, 32, 32, 23, 23, 1, 1, 1000);
run_bench(64, 384, 32, 32, 25, 25, 1, 1, 1000);
run_bench(64, 384, 32, 32, 27, 27, 1, 1, 1000);
run_bench(64, 384, 32, 32, 29, 29, 1, 1, 1000);
run_bench(64, 384, 32, 32, 31, 31, 1, 1, 1000);
}
TEST_F(CUDA, BENCHMARK_REGION_RESTRICTED_CONV_BACKWARD_FILTER_FP32_RINT8) {
require_compute_capability(7, 5);
Benchmarker<ConvolutionBackwardFilter> bencher(handle_cuda());
bencher.set_display(false);
bencher.set_before_exec_callback(AlgoChecker<ConvolutionBackwardFilter>(
"FLOAT32_NCHW_FMA_IMPLICIT_BATCHED_GEMM_128X128X8_32X64X8_2stage"));
Benchmarker<RegionRestrictedConvolutionBackwardFilter> rr_bencher(handle_cuda());
rr_bencher.set_display(false);
ConvolutionBackwardFilter::Param param;
param.format = ConvolutionBackwardFilter::Param::Format::NCHW;
param.sparse = ConvolutionBackwardFilter::Param::Sparse::GROUP;
RegionRestrictedConvolutionBackwardFilter::Param rr_param;
rr_param.format = RegionRestrictedConvolutionBackwardFilter::Param::Format::NCHW;
rr_param.sparse = RegionRestrictedConvolutionBackwardFilter::Param::Sparse::GROUP;
UniformIntRNG r_rng{1, 3};
auto run_bench = [&](size_t batch, size_t g, size_t hi, size_t wi, size_t fh,
size_t fw, size_t sh, size_t sw, size_t nr_times) {
param.pad_h = fh / 2;
param.pad_w = fw / 2;
param.stride_h = sh;
param.stride_w = sw;
rr_param.pad_h = fh / 2;
rr_param.pad_w = fw / 2;
rr_param.stride_h = sh;
rr_param.stride_w = sw;
bencher.set_param(param)
.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Float32())
.set_dtype(4, dtype::Float32());
bencher.proxy()->target_execution_policy = {};
bencher.set_times(nr_times);
rr_bencher.set_param(rr_param)
.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Uint8())
.set_dtype(3, dtype::Uint8());
rr_bencher.set_rng(2, &r_rng).set_rng(3, &r_rng);
rr_bencher.set_times(nr_times);
size_t ho = infer_conv_shape(hi, fh, sh, param.pad_h);
size_t wo = infer_conv_shape(wi, fw, sw, param.pad_w);
TensorShape src{batch, g, hi, wi}, diff{batch, g, ho, wo}, rin{batch, hi, wi},
rout{batch, ho, wo}, grad{g, 1, 1, fh, fw};
float bandwith = static_cast<float>(
src.total_nr_elems() + diff.total_nr_elems() +
grad.total_nr_elems()) /
(1024 * 1024 * 1024) * 1e3;
float rr_bandwith = static_cast<float>(
src.total_nr_elems() + diff.total_nr_elems() +
rin.total_nr_elems() + rout.total_nr_elems() +
grad.total_nr_elems()) /
(1024 * 1024 * 1024) * 1e3;
auto time_in_ms = bencher.execs({src, diff, grad}) / nr_times;
auto ops = 2.0 * batch * g * hi * wi * fh * fw / (time_in_ms * 1e-3) * 1e-12;
auto rr_time_in_ms = rr_bencher.execs({src, diff, rin, rout, grad}) / nr_times;
auto rr_ops =
2.0 * batch * g * hi * wi * fh * fw / (rr_time_in_ms * 1e-3) * 1e-12;
printf("[DGRAD]RegionRestrictedDepthwiseLargeFilter vs DepthwiseLargeFilter: "
"src=%s, "
"diff=%s, grad=%s\n"
"time: %.2f ms, time(rr): %.2f ms, perf: %.2fTops, perf(rr): %.2f Tops\n"
"bandwidth: %.2fGB/s, bandwidth(rr): %.2fGB/s, speedup: %.2f.\n",
src.to_string().c_str(), diff.to_string().c_str(),
grad.to_string().c_str(), time_in_ms, rr_time_in_ms, ops, rr_ops,
bandwith * 4 / time_in_ms, rr_bandwith * 4 / rr_time_in_ms,
time_in_ms / rr_time_in_ms);
};
run_bench(64, 384, 32, 32, 3, 3, 1, 1, 1000);
run_bench(64, 384, 32, 32, 5, 5, 1, 1, 1000);
run_bench(64, 384, 32, 32, 7, 7, 1, 1, 1000);
run_bench(64, 384, 32, 32, 9, 9, 1, 1, 1000);
run_bench(64, 384, 32, 32, 11, 11, 1, 1, 1000);
run_bench(64, 384, 32, 32, 13, 13, 1, 1, 1000);
run_bench(64, 384, 32, 32, 15, 15, 1, 1, 1000);
run_bench(64, 384, 32, 32, 17, 17, 1, 1, 1000);
run_bench(64, 384, 32, 32, 19, 19, 1, 1, 1000);
run_bench(64, 384, 32, 32, 21, 21, 1, 1, 1000);
run_bench(64, 384, 32, 32, 23, 23, 1, 1, 1000);
run_bench(64, 384, 32, 32, 25, 25, 1, 1, 1000);
run_bench(64, 384, 32, 32, 27, 27, 1, 1, 1000);
run_bench(64, 384, 32, 32, 29, 29, 1, 1, 1000);
run_bench(64, 384, 32, 32, 31, 31, 1, 1, 1000);
}
#endif
TEST_F(CUDA, REGION_RESTRICTED_CONV_BWD_DATA_FP32) {
......@@ -585,6 +785,125 @@ TEST_F(CUDA, REGION_RESTRICTED_CONV_BWD_DATA_FP32_RIN_EQ_ROUT) {
}
}
TEST_F(CUDA, REGION_RESTRICTED_CONV_BWD_FILTER_FP32) {
Checker<RegionRestrictedConvolutionBackwardFilter> checker(handle_cuda());
for (auto dt : std::vector<DType>{dtype::Int32(), dtype::Uint8()}) {
auto run = [&checker, &dt](
size_t n, size_t g, size_t ih, size_t fh, size_t padding,
size_t stride) {
RegionRestrictedConvolutionBackwardFilter::Param cur_param;
cur_param.mode = RegionRestrictedConvolutionBackwardFilter::Param::Mode::
CROSS_CORRELATION;
cur_param.compute_mode = RegionRestrictedConvolutionBackwardFilter::Param::
ComputeMode::DEFAULT;
cur_param.sparse =
RegionRestrictedConvolutionBackwardFilter::Param::Sparse::GROUP;
checker.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dt)
.set_dtype(3, dt);
float scale = 64.f / sqrt(fh * fh);
UniformFloatRNG rng(scale, 2 * scale);
UniformIntRNG r_rng{1, 2};
checker.set_rng(0, &rng).set_rng(1, &rng).set_rng(2, &r_rng).set_rng(
3, &r_rng);
cur_param.pad_h = cur_param.pad_w = padding;
cur_param.stride_h = cur_param.stride_w = stride;
size_t oh = (ih + 2 * padding - fh + 1) / stride;
checker.set_param(cur_param).execs({
{n, g * 1, ih, ih}, // src
{n, g * 1, oh, oh}, // diff
{n, ih, ih}, // rin
{n, oh, oh}, // rout
{g, 1, 1, fh, fh} // grad
});
};
if (dt == dtype::Int32()) {
run(4, 8, 32, 5, 5 / 2, 1);
run(1, 2, 2, 2, 0, 1);
run(1, 2, 3, 3, 0, 1);
run(1, 2, 4, 4, 0, 1);
run(1, 2, 5, 5, 0, 1);
run(1, 2, 6, 6, 0, 1);
run(1, 2, 7, 7, 0, 1);
}
run(4, 8, 32, 7, 7 / 2, 1);
run(4, 8, 32, 9, 9 / 2, 1);
run(4, 8, 32, 11, 11 / 2, 1);
run(4, 8, 32, 13, 13 / 2, 1);
run(4, 8, 32, 15, 15 / 2, 1);
run(4, 8, 32, 17, 17 / 2, 1);
run(4, 8, 32, 19, 19 / 2, 1);
run(4, 8, 32, 21, 21 / 2, 1);
run(4, 8, 32, 23, 23 / 2, 1);
run(4, 8, 32, 25, 25 / 2, 1);
run(4, 8, 32, 27, 27 / 2, 1);
run(4, 8, 32, 29, 29 / 2, 1);
run(4, 8, 32, 31, 31 / 2, 1);
}
}
TEST_F(CUDA, REGION_RESTRICTED_CONV_BWD_FILTER_FP32_RIN_EQ_ROUT) {
Checker<RegionRestrictedConvolutionBackwardFilter> checker(handle_cuda());
for (auto dt : std::vector<DType>{dtype::Int32(), dtype::Uint8()}) {
auto run = [&checker, &dt](
size_t n, size_t g, size_t ih, size_t fh, size_t padding,
size_t stride) {
RegionRestrictedConvolutionBackwardFilter::Param cur_param;
cur_param.mode = RegionRestrictedConvolutionBackwardFilter::Param::Mode::
CROSS_CORRELATION;
cur_param.compute_mode = RegionRestrictedConvolutionBackwardFilter::Param::
ComputeMode::DEFAULT;
cur_param.sparse =
RegionRestrictedConvolutionBackwardFilter::Param::Sparse::GROUP;
checker.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dt)
.set_dtype(3, dt);
float scale = 64.f / sqrt(fh * fh);
UniformFloatRNG rng(scale, 2 * scale);
UniformIntRNG r_rng{1, 1};
checker.set_rng(0, &rng).set_rng(1, &rng).set_rng(2, &r_rng).set_rng(
3, &r_rng);
cur_param.pad_h = cur_param.pad_w = padding;
cur_param.stride_h = cur_param.stride_w = stride;
size_t oh = (ih + 2 * padding - fh + 1) / stride;
checker.set_param(cur_param).execs({
{n, g * 1, ih, ih}, // src
{n, g * 1, oh, oh}, // diff
{n, ih, ih}, // rin
{n, oh, oh}, // rout
{g, 1, 1, fh, fh} // grad
});
};
if (dt == dtype::Int32()) {
run(4, 8, 32, 5, 5 / 2, 1);
run(1, 2, 2, 2, 0, 1);
run(1, 2, 3, 3, 0, 1);
run(1, 2, 4, 4, 0, 1);
run(1, 2, 5, 5, 0, 1);
run(1, 2, 6, 6, 0, 1);
run(1, 2, 7, 7, 0, 1);
}
run(4, 8, 32, 7, 7 / 2, 1);
run(4, 8, 32, 9, 9 / 2, 1);
run(4, 8, 32, 11, 11 / 2, 1);
run(4, 8, 32, 13, 13 / 2, 1);
run(4, 8, 32, 15, 15 / 2, 1);
run(4, 8, 32, 17, 17 / 2, 1);
run(4, 8, 32, 19, 19 / 2, 1);
run(4, 8, 32, 21, 21 / 2, 1);
run(4, 8, 32, 23, 23 / 2, 1);
run(4, 8, 32, 25, 25 / 2, 1);
run(4, 8, 32, 27, 27 / 2, 1);
run(4, 8, 32, 29, 29 / 2, 1);
run(4, 8, 32, 31, 31 / 2, 1);
}
}
} // namespace test
} // namespace megdnn
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册