提交 55042195 编写于 作者: M Megvii Engine Team

chore(winograd): add Convolutionv2 param

GitOrigin-RevId: 1a9e2ea340f6eb37b6a03db53038ccc053e77635
上级 7191c4bd
......@@ -53,7 +53,7 @@ pdef('Axis').add_fields('int32', 'axis', 0)
'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.'))
)
(pdef('Convolution', version=1).
(pdef('Convolution', version=1, is_legacy=True).
add_enum_alias('Mode', 'ConvolutionV0').
add_fields(
'uint32',
......@@ -78,6 +78,39 @@ pdef('Axis').add_fields('int32', 'axis', 0)
name_field='compute_mode')
)
(pdef('Convolution', version=2).
add_enum_alias('Mode', 'ConvolutionV0').
add_fields(
'uint32',
Doc('pad_h', 'padding on one side on the first dimension'), 0,
Doc('pad_w', 'padding on one side on the second dimension'), 0,
Doc('stride_h', 'kernel stride on the first dimension'), 1,
Doc('stride_w', 'kernel stride on the second dimension'), 1,
Doc('dilate_h', 'dilation (i.e. size of each zero-padded kernel block) '
'on the second dimension'), 1,
Doc('dilate_w', 'dilation (i.e. size of each zero-padded kernel block) '
'on the second dimension'), 1
).
add_enum_alias('Sparse', 'ConvolutionV0').
add_enum(Doc('Format', 'convolution data/filter/output format; see '
':class:`RelayoutFormat` for more details'),
'NCHW', 'NHWC', 'NHWCD4', 'NCHW4', 'NCHW8', 'NCHW32', 'NCHW88',
'NCHW44','NCHW44_DOT',
Doc('NCHW4_NCHW32', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'),
Doc('NCHW32_NCHW4', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'),
Doc('NCHW4_NCHW', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'),
Doc('NHWC_NCHW', 'NHWC_NCHW means input tensors are nhwc layout, '
'output tensor is nchw layout'),
Doc('NHWC_NCHW4_IC_SMALL', 'NHWC_NCHW4_IC_SMALL means input tensors are nhwc(c < 4) layout, '
'output tensor is nchw4 layout, padding c=4'),
Doc('NCHW_NCHW4_IC_SMALL', 'NCHW_NCHW4_IC_SMALL means input tensors are nchw(c < 4) layout, '
'output tensor is nchw4 layout, padding c=4'),
Doc('CHWN4', 'CHWN4 is currently only used on Nvidia platform for fast implementation '
'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.')).
add_enum_alias('ComputeMode', 'ConvolutionV1',name_field='compute_mode')
)
(pdef('MaskPropagate').
add_fields(
'uint32',
......@@ -137,10 +170,10 @@ pdef('Axis').add_fields('int32', 'axis', 0)
'on the second dimension'), 1,
Doc('dilate_w', 'dilation (i.e. size of each zero-padded kernel block) '
'on the second dimension'), 1).
add_enum_alias('ComputeMode', 'Convolution', name_field='compute_mode')
add_enum_alias('ComputeMode', 'ConvolutionV1', name_field='compute_mode')
)
(pdef('ConvBias', 'active(conv(x, w) + bias)', version=3).
(pdef('ConvBias', 'active(conv(x, w) + bias)', version=3, is_legacy=True).
add_enum_alias('NonlineMode', 'ConvBiasV0').
add_enum_alias('Mode', 'ConvolutionV0').
add_enum_alias('Sparse', 'ConvolutionV0').
......@@ -156,9 +189,26 @@ pdef('Axis').add_fields('int32', 'axis', 0)
Doc('dilate_w', 'dilation (i.e. size of each zero-padded kernel block) '
'on the second dimension'), 1,
Doc('output_block_size', 'detail meaning \see winograd in conv bias'), 0).
add_enum_alias('ComputeMode', 'Convolution', name_field='compute_mode')
add_enum_alias('ComputeMode', 'ConvolutionV1', name_field='compute_mode')
)
(pdef('ConvBias', 'active(conv(x, w) + bias)', version=4).
add_enum_alias('NonlineMode', 'ConvBiasV0').
add_enum_alias('Mode', 'ConvolutionV0').
add_enum_alias('Sparse', 'ConvolutionV0').
add_enum_alias('Format', 'Convolution').
add_fields(
'uint32',
Doc('pad_h', 'padding on one side on the first dimension'), 0,
Doc('pad_w', 'padding on one side on the second dimension'), 0,
Doc('stride_h', 'kernel stride on the first dimension'), 1,
Doc('stride_w', 'kernel stride on the second dimension'), 1,
Doc('dilate_h', 'dilation (i.e. size of each zero-padded kernel block) '
'on the second dimension'), 1,
Doc('dilate_w', 'dilation (i.e. size of each zero-padded kernel block) '
'on the second dimension'), 1).
add_enum_alias('ComputeMode', 'ConvolutionV1', name_field='compute_mode')
)
(pdef('SeparableConv').
add_enum_alias('Mode', 'ConvolutionV0').
add_enum('BorderMode', 'BORDER_REPLICATE', 'BORDER_REFLECT',
......@@ -172,7 +222,7 @@ pdef('Axis').add_fields('int32', 'axis', 0)
add_fields('uint32', 'pad_h', 0, 'pad_w', 0, 'stride_h', 1, 'stride_w', 1,
'window_h', 3, 'window_w', 3))
(pdef('Pooling').
(pdef('Pooling', version=0, is_legacy=True).
add_enum(
'Mode',
Doc('MAX', 'maximum value inside pooling window'),
......@@ -188,11 +238,23 @@ pdef('Axis').add_fields('int32', 'axis', 0)
add_enum_alias('Format', 'ConvolutionV0')
)
(pdef('AdaptivePooling').
add_enum_alias('Mode', 'Pooling').
(pdef('Pooling', version=1).
add_enum_alias('Mode','PoolingV0').
add_fields('uint32', 'pad_h', 0, 'pad_w', 0, 'stride_h', 2, 'stride_w', 2,
'window_h', 2, 'window_w', 2).
add_enum_alias('Format', 'Convolution')
)
(pdef('AdaptivePooling', version=0,is_legacy=True).
add_enum_alias('Mode', 'PoolingV0').
add_enum_alias('Format', 'ConvolutionV0')
)
(pdef('AdaptivePooling', version=1).
add_enum_alias('Mode', 'PoolingV0').
add_enum_alias('Format', 'Convolution')
)
(pdef('LRN',
'see ImageNet Classification with Deep Convolutional Neural Networks for'
' meaning of the fields').
......@@ -239,7 +301,7 @@ BORDER_MODES = [Doc('REPLICATE', 'aaaaaa|abcdefgh|hhhhhhh'),
Doc('CONSTANT', 'iiiiii|abcdefgh|iiiiiii'),
Doc('TRANSPARENT', ''),
Doc('ISOLATED', '')]
(pdef('WarpPerspective', version=1).
(pdef('WarpPerspective', version=1, is_legacy=True).
add_enum('InterpolationMode', *INTERP_MODES,
name_field='imode', default=1,
member_alias=[(i, 'INTER_{}'.format(i)) for i in INTERP_MODES]
......@@ -251,6 +313,13 @@ BORDER_MODES = [Doc('REPLICATE', 'aaaaaa|abcdefgh|hhhhhhh'),
add_enum_alias('Format', 'ConvolutionV0').
add_fields('float32', Doc('border_val', 'used for CONSTANT bmode'), '.0f'))
(pdef('WarpPerspective', version=2).
add_enum_alias('InterpolationMode','WarpPerspectiveV1',name_field="imode").
add_enum_alias('BorderMode','WarpPerspectiveV1',name_field="bmode").
add_enum_alias('Format', 'Convolution').
add_fields('float32', Doc('border_val', 'used for CONSTANT bmode'), '.0f'))
pdef('SpatialTfGridGenerator').add_enum('Mode', 'AFFINE')
pdef('SpatialTfSampler').add_enum('Mode', 'BILINEAR')
......@@ -420,9 +489,12 @@ pdef('ElemwiseMultiType').add_enum(
pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0)
(pdef('DctChannelSelect', '2d discrete cosine transform').add_enum_alias('Format', 'ConvolutionV0').
(pdef('DctChannelSelect', '2d discrete cosine transform', version=0, is_legacy=True).add_enum_alias('Format', 'ConvolutionV0').
add_enum('FastImpl', 'NONE', 'FIX_32_MASK').add_fields('int32', 'dct_block_size', 8))
(pdef('DctChannelSelect', '2d discrete cosine transform', version=1).add_enum_alias('Format', 'Convolution').
add_enum_alias('FastImpl', 'DctChannelSelectV0').add_fields('int32', 'dct_block_size', 8))
(pdef('MatrixMul', version=0, is_legacy=True).
add_fields('bool', 'transposeA', 'false', 'transposeB', 'false').
add_enum('DataType',
......@@ -695,34 +767,51 @@ pdef('UniformRNG').add_fields('uint64', 'seed', 0)
name_field = 'mode'))
(pdef('WarpAffine', version=0, is_legacy=True)
.add_enum_alias('InterpolationMode', 'WarpPerspective', name_field='imode')
.add_enum_alias('BorderMode', 'WarpPerspective', name_field='border_mode')
.add_enum_alias('InterpolationMode', 'WarpPerspectiveV1', name_field='imode')
.add_enum_alias('BorderMode', 'WarpPerspectiveV1', name_field='border_mode')
.add_fields('float32', Doc('border_val', 'used for CONSTANT bmode'), '.0f'))
(pdef('WarpAffine', version=1)
.add_enum_alias('InterpolationMode', 'WarpPerspective', name_field='imode')
.add_enum_alias('BorderMode', 'WarpPerspective', name_field='border_mode')
(pdef('WarpAffine', version=1, is_legacy=True)
.add_enum_alias('InterpolationMode', 'WarpPerspectiveV1', name_field='imode')
.add_enum_alias('BorderMode', 'WarpPerspectiveV1', name_field='border_mode')
.add_fields('float32', Doc('border_val', 'used for CONSTANT bmode'), '.0f')
.add_enum_alias('Format', 'ConvolutionV0', default=1))
(pdef('WarpAffine', version=2)
.add_enum_alias('InterpolationMode', 'WarpPerspectiveV1', name_field='imode')
.add_enum_alias('BorderMode', 'WarpPerspectiveV1', name_field='border_mode')
.add_fields('float32', Doc('border_val', 'used for CONSTANT bmode'), '.0f')
.add_enum_alias('Format', 'Convolution', default=1))
(pdef('GaussianBlur')
.add_enum_alias('BorderMode', 'WarpPerspective', name_field='border_mode')
.add_enum_alias('BorderMode', 'WarpPerspectiveV1', name_field='border_mode')
.add_fields('uint32', 'kernel_height', 0, 'kernel_width', 0)
.add_fields('float32','sigma_x', '0.f', 'sigma_y', '0.f'))
(pdef('Resize', version=0, is_legacy=True)
.add_enum_alias('InterpolationMode', 'WarpPerspective', name_field='imode'))
.add_enum_alias('InterpolationMode', 'WarpPerspectiveV1', name_field='imode'))
(pdef('Resize', version=1)
.add_enum_alias('InterpolationMode', 'WarpPerspective', name_field='imode')
(pdef('Resize', version=1, is_legacy=True)
.add_enum_alias('InterpolationMode', 'WarpPerspectiveV1', name_field='imode')
.add_enum_alias('Format', 'ConvolutionV0', default=1))
(pdef('Remap', version=0)
.add_enum_alias('InterpolationMode', 'WarpPerspective', name_field='imode')
.add_enum_alias('BorderMode', 'WarpPerspective', name_field='border_type')
(pdef('Resize', version=2)
.add_enum_alias('InterpolationMode', 'WarpPerspectiveV1', name_field='imode')
.add_enum_alias('Format', 'Convolution', default=1))
(pdef('Remap', version=0,is_legacy=True)
.add_enum_alias('InterpolationMode', 'WarpPerspectiveV1', name_field='imode')
.add_enum_alias('BorderMode', 'WarpPerspectiveV1', name_field='border_type')
.add_enum_alias('Format', 'ConvolutionV0', default=1)
.add_fields('float32', 'scalar', '0.f'))
(pdef('Remap', version=1)
.add_enum_alias('InterpolationMode', 'WarpPerspectiveV1', name_field='imode')
.add_enum_alias('BorderMode', 'WarpPerspectiveV1', name_field='border_type')
.add_enum_alias('Format', 'Convolution', default=1)
.add_fields('float32', 'scalar', '0.f'))
(pdef('Convolution3D').
add_enum('Mode', 'CROSS_CORRELATION', 'CONVOLUTION').
add_fields(
......@@ -879,13 +968,19 @@ when the ``I`` suffix is present.
)
(pdef('SeparableFilter').
(pdef('SeparableFilter', version=0, is_legacy=True).
add_enum_alias('Format', 'ConvolutionV0').
add_enum_alias('BorderMode', 'WarpPerspective').
add_enum_alias('BorderMode', 'WarpPerspectiveV1').
add_fields('bool', 'is_symm_kernel', 'true').
add_fields('uint32', 'ksize_h', 3, 'ksize_w', 3, 'anchor_h', 1, 'anchor_w', 1))
(pdef('SeparableFilter', version=1).
add_enum_alias('Format', 'Convolution').
add_enum_alias('BorderMode', 'WarpPerspectiveV1').
add_fields('bool', 'is_symm_kernel', 'true').
add_fields('uint32', 'ksize_h', 3, 'ksize_w', 3, 'anchor_h', 1, 'anchor_w', 1))
(pdef('LocalShare', 'Local share convolution').
(pdef('LocalShare', 'Local share convolution',version=0, is_legacy=True).
add_enum_alias('Mode', 'ConvolutionV0').
add_fields(
'uint32',
......@@ -902,10 +997,31 @@ when the ``I`` suffix is present.
).
add_enum_alias('Sparse', 'ConvolutionV0').
add_enum_alias('Format', 'ConvolutionV0').
add_enum_alias('ComputeMode', 'Convolution')
add_enum_alias('ComputeMode', 'ConvolutionV1')
)
(pdef('ROIAlign').
(pdef('LocalShare', 'Local share convolution', version=1).
add_enum_alias('Mode', 'ConvolutionV0').
add_fields(
'uint32',
Doc('pad_h', 'padding on one side on the first dimension'), 0,
Doc('pad_w', 'padding on one side on the second dimension'), 0,
Doc('stride_h', 'kernel stride on the first dimension'), 1,
Doc('stride_w', 'kernel stride on the second dimension'), 1,
Doc('dilate_h', 'dilation (i.e. size of each zero-padded kernel block) '
'on the second dimension'), 1,
Doc('dilate_w', 'dilation (i.e. size of each zero-padded kernel block) '
'on the second dimension'), 1,
Doc('spatial_groups_h', 'spatial groups on the first dimension'), 1,
Doc('spatial_groups_w', 'spatial groups on the second dimension'), 1
).
add_enum_alias('Sparse', 'ConvolutionV0').
add_enum_alias('Format', 'Convolution').
add_enum_alias('ComputeMode', 'ConvolutionV1')
)
(pdef('ROIAlign',version=0,is_legacy=True).
add_enum('Mode', 'MAX', 'AVERAGE', name_field='mode').
add_enum_alias('Format', 'ConvolutionV0').
add_fields('float32', 'spatial_scale', '1.0').
......@@ -916,6 +1032,19 @@ when the ``I`` suffix is present.
'sample_height', '2',
'sample_width', '2')
)
(pdef('ROIAlign', version=1).
add_enum_alias('Mode', 'ROIAlignV0', name_field='mode').
add_enum_alias('Format', 'Convolution').
add_fields('float32', 'spatial_scale', '1.0').
add_fields('float32', 'offset', '0.0').
add_fields('uint32',
'pooled_height', '1',
'pooled_width', '1',
'sample_height', '2',
'sample_width', '2')
)
(pdef('DeformablePSROIPooling').
add_fields('bool', 'no_trans', 'true').
add_fields('float32', 'spatial_scale', 1,
......@@ -926,7 +1055,7 @@ when the ``I`` suffix is present.
Doc('part_size', 'size of each deformable part'), 1,
Doc('sample_per_part', 'sample count of each bbox'), 1))
(pdef('BatchConvBias', 'Batch convolution (unshare weights on the batch dimension)').
(pdef('BatchConvBias', 'Batch convolution (unshare weights on the batch dimension)',version=0,is_legacy=True).
add_enum_alias('NonlineMode', 'ConvBiasV0').
add_enum_alias('Mode', 'ConvolutionV0').
add_fields(
......@@ -942,8 +1071,28 @@ when the ``I`` suffix is present.
).
add_enum_alias('Sparse', 'ConvolutionV0').
add_enum_alias('Format', 'ConvolutionV0').
add_enum_alias('ComputeMode', 'Convolution', name_field="compute_mode")
add_enum_alias('ComputeMode', 'ConvolutionV1', name_field="compute_mode")
)
(pdef('BatchConvBias', 'Batch convolution (unshare weights on the batch dimension)',version=1).
add_enum_alias('NonlineMode', 'ConvBiasV0').
add_enum_alias('Mode', 'ConvolutionV0').
add_fields(
'uint32',
Doc('pad_h', 'padding on one side on the first dimension'), 0,
Doc('pad_w', 'padding on one side on the second dimension'), 0,
Doc('stride_h', 'kernel stride on the first dimension'), 1,
Doc('stride_w', 'kernel stride on the second dimension'), 1,
Doc('dilate_h', 'dilation (i.e. size of each zero-padded kernel block) '
'on the second dimension'), 1,
Doc('dilate_w', 'dilation (i.e. size of each zero-padded kernel block) '
'on the second dimension'), 1,
).
add_enum_alias('Sparse', 'ConvolutionV0').
add_enum_alias('Format', 'Convolution').
add_enum_alias('ComputeMode', 'ConvolutionV1', name_field="compute_mode")
)
(pdef('FakeQuant').
add_fields('int32','qmin','-2147483648').
add_fields('int32','qmax','2147483647')
......
......@@ -68,7 +68,6 @@ ConvolutionForwardImpl::conv_bias_extra_data(const TensorLayout& src,
conv_param.stride_w,
conv_param.dilate_h,
conv_param.dilate_w,
0,
conv_param.compute_mode};
ret.convbias_opr->execution_policy() = {this->execution_policy().algo};
return ret;
......
......@@ -173,7 +173,6 @@ SmallVector<AlgoCategory> ConvBiasImpl::suggest_algo_category_order(
auto FH = param.filter_meta.spatial[0];
auto FW = param.filter_meta.spatial[1];
//! TODO: now winograd only support fast-run
//! nchw88 use mkl-dnn which algo is direct
if (param.filter_meta.format == param::ConvBias::Format::NCHW88) {
return {AlgoCategory::DIRECT, AlgoCategory::IM2COL};
......
......@@ -35,7 +35,7 @@ TEST(TestImperative, APlusB) {
}
TEST(TestImperative, Convolution) {
auto op = OprAttr::make("ConvolutionV1");
auto op = OprAttr::make("ConvolutionV2");
auto&& attr = op->cast_final_safe<OprAttr>();
using Param = opr::Convolution::Param;
using Policy = opr::Convolution::ExecutionPolicy;
......
......@@ -1752,7 +1752,6 @@ void FuseConvBiasNonlinPass::apply(OptState& state) const {
param.stride_w,
param.dilate_h,
param.dilate_w,
0,
param.compute_mode};
};
......
......@@ -1945,7 +1945,7 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
megdnn::param::ConvBias::Format conv_bias_format =
megdnn::param::ConvBias::Format::NCHW88;
megdnn::param::Convolution::Format conv_format =
megdnn::param::ConvolutionV0::Format::NCHW88;
megdnn::param::Convolution::Format::NCHW88;
megdnn::param::Pooling::Format pooling_format =
megdnn::param::Pooling::Format::NCHW88;
std::string convter_pass_name = "conv_format_nchw88";
......@@ -1958,7 +1958,7 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
src_to_nchwxx_mode = RelayoutMode::NCHW_TO_NCHW4;
src_to_nchw_mode = RelayoutMode::NCHW4_TO_NCHW;
conv_bias_format = megdnn::param::ConvBias::Format::NCHW44;
conv_format = megdnn::param::ConvolutionV0::Format::NCHW44;
conv_format = megdnn::param::Convolution::Format::NCHW44;
pooling_format = megdnn::param::Pooling::Format::NCHW44;
convter_pass_name = "conv_format_nchw44";
}
......@@ -2360,7 +2360,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
struct TestTransResult {
TransType trans_type;
RelayoutMode relayout_mod;
megdnn::param::ConvolutionV0::Format conv_format;
megdnn::param::Convolution::Format conv_format;
};
constexpr size_t pack_c_size = 4_z;
auto test_trans_nchw44_dot =
......
......@@ -18,7 +18,7 @@ decl_opr('Convolution',
params=[('param', 'Convolution'),
('execution_polity', 'ExecutionPolicy')],
desc='batched convolution on channeled 2D images',
version=1, has_out_dtype=True)
version=2, has_out_dtype=True)
decl_opr('ConvolutionBackwardData',
pyname='deconvolution_v0',
......@@ -51,7 +51,7 @@ decl_opr('ConvolutionBackwardData',
],
desc='batched deconvolution on channeled 2D images; the underlying '
'computation is in fact gradient of convolution w.r.t. data',
version=1)
version=2)
decl_opr('MaskConvolution',
inputs=[Doc('src',
......@@ -138,14 +138,14 @@ decl_opr('LRN',
decl_opr('Pooling',
inputs=['src'],
params='Pooling')
params='Pooling',version=1)
decl_opr('AdaptivePooling',
inputs=[Doc('src', 'input image, shape (n, c, ih, iw)'),
Doc('out_shape', 'output image shape, containing two elements specifying output height and width.')],
params='AdaptivePooling',
desc='Adaptive Pooling.'
'The output shape is (n, c, oh, ow), where (oh, ow) is given by *out_shape*.')
'The output shape is (n, c, oh, ow), where (oh, ow) is given by *out_shape*.',version=1)
decl_opr('ROIPooling', outputs=[0],
inputs=[Doc('src', 'input image, shape (n, c, ih, iw)'),
......@@ -215,7 +215,7 @@ decl_opr('ConvBiasForward',
('execution_policy', 'ExecutionPolicy')],
desc=('activation(convolution(src, filter) + bias) with specified '
'dtype'),
version=3, has_out_dtype=True)
version=4, has_out_dtype=True)
decl_opr('BatchNorm',
pyname='batch_norm',
......@@ -255,7 +255,7 @@ r"""
& iw=-pad_w+ow \\times stride_w \\\\
& grp_h = oh / (OH / spatial_groups_h) \\\\
& grp_w = ow / (OW / spatial_groups_w)
"""),
"""), version=1,
has_out_dtype=True)
decl_opr('ROIAlign', outputs=[0],
......@@ -270,7 +270,7 @@ decl_opr('ROIAlign', outputs=[0],
desc='ROI Align, see '
'Mask-RCNN: https://arxiv.org/pdf/1703.06870.pdf, '
'The output shape is (m, c, pooled_height, pooled_width), where (pooled_height, pooled_width) is given by '
'*Param*.')
'*Param*.',version=1)
decl_opr('DeformableConvForward',
pyname='deformable_conv',
......@@ -312,7 +312,7 @@ r"""
* filter_{n, oc, ic, kh, kw} \\\\
\\text{where} & ih=-pad_h+oh \\times stride_h \\\\
& iw=-pad_w+ow \\times stride_w
"""),
"""), version=1,
has_out_dtype=True)
decl_opr('FakeQuant',
......
......@@ -22,10 +22,84 @@
#include "megbrain/opr/dnn/tqt.h"
#include "megbrain/serialization/sereg.h"
#include "megdnn/opr_param_defs.h"
#include "megdnn/oprs/nn.h"
namespace mgb {
namespace serialization {
template <class MegDNNPooling = megdnn::Pooling>
struct MakePoolingCaller1 {
template <typename Opr>
static VarNode* make(const cg::VarNodeArray& inputs,
const typename MegDNNPooling::Param& param,
const OperatorNodeConfig& config) {
if (inputs.size() == 1) {
return Opr::make(inputs[0], param, config).node();
}
return nullptr;
}
};
template <class MegDNNROIALIGN = megdnn::ROIAlign>
struct MakeROIAlignCaller1 {
template <typename Opr>
static VarNode* make(const cg::VarNodeArray& inputs,
const typename MegDNNROIALIGN::Param& param,
const OperatorNodeConfig& config) {
if (inputs.size() == 2) {
return Opr::make(inputs[0],inputs[1], param, config).node();
} else {
return nullptr;
}
}
};
template <class MegDNNROIALIGN = megdnn::ROIAlignBackward>
struct MakeROIAlignCaller4 {
template <typename Opr>
static VarNode* make(const cg::VarNodeArray& inputs,
const typename MegDNNROIALIGN::Param& param,
const OperatorNodeConfig& config) {
if (inputs.size() == 4) {
return Opr::make(inputs[0], inputs[1], inputs[2], inputs[3],
param, config)
.node();
} else {
return nullptr;
}
}
};
template <class MegDNNPooling = megdnn::PoolingBackward>
struct MakePoolingBackwardCaller3 {
template <typename Opr>
static VarNode* make(const cg::VarNodeArray& inputs,
const typename MegDNNPooling::Param& param,
const OperatorNodeConfig& config) {
if (inputs.size() == 3) {
return Opr::make(inputs[0], inputs[1], inputs[2], param, config)
.node();
}
return nullptr;
}
};
template <class MegDNNPooling = megdnn::AdaptivePoolingBackward>
struct MakeAdaptivePoolingBackwardCaller3 {
template <typename Opr>
static VarNode* make(const cg::VarNodeArray& inputs,
const typename MegDNNPooling::Param& param,
const OperatorNodeConfig& config) {
if (inputs.size() == 4) {
return Opr::make(inputs[0], inputs[1], inputs[2], inputs[3],
param, config)
.node();
}
return nullptr;
}
};
template<class MegDNNConv = megdnn::Convolution>
struct MakeConvCaller2 {
template<typename Opr>
......@@ -41,6 +115,7 @@ namespace serialization {
return nullptr;
}
};
template<class MegDNNConv = megdnn::Convolution>
struct MakeConvCaller3 {
template<typename Opr>
......@@ -56,6 +131,7 @@ namespace serialization {
return nullptr;
}
};
template<class MegDNNConv = megdnn::Convolution>
struct MakeConvCaller4 {
template<typename Opr>
......@@ -71,6 +147,7 @@ namespace serialization {
return nullptr;
}
};
template<class MegDNNConv = megdnn::Convolution>
struct MakeConvCaller5 {
template <typename Opr>
......@@ -141,6 +218,75 @@ namespace serialization {
}
};
template <class Opr, class Maker0,
typename PoolingParam = megdnn::param::Pooling>
struct PoolingLoadDumpImpl {
static void dump(OprDumpContext& ctx,
const cg::OperatorNodeBase& opr_) {
auto&& opr = opr_.cast_final_safe<Opr>();
ctx.write_param<PoolingParam>(opr.param());
}
static VarNode* make(
const cg::VarNodeArray& inputs, const PoolingParam& param,
const OperatorNodeConfig& config) {
VarNode* ret = Maker0::template make<Opr>(inputs, param,
config);
mgb_assert(ret);
return ret;
}
static cg::OperatorNodeBase* load(OprLoadContext& ctx,
const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
auto param = ctx.read_param<PoolingParam>();
return make(inputs, param, config)->owner_opr();
}
};
template<>
struct OprLoadDumpImpl<opr::AdaptivePoolingBackward, 0>:
public PoolingLoadDumpImpl<opr::AdaptivePoolingBackward,
MakeAdaptivePoolingBackwardCaller3<megdnn::AdaptivePoolingBackward>,
megdnn::param::AdaptivePooling>
{};
template<>
struct OprLoadDumpImpl<opr::AdaptivePooling, 0>:
public PoolingLoadDumpImpl<opr::AdaptivePooling,
MakeROIAlignCaller1<megdnn::AdaptivePooling>,
megdnn::param::AdaptivePooling>
{};
template<>
struct OprLoadDumpImpl<opr::ROIAlign, 0>:
public PoolingLoadDumpImpl<opr::ROIAlign,
MakeROIAlignCaller1<megdnn::ROIAlign>,
megdnn::param::ROIAlign>
{};
template<>
struct OprLoadDumpImpl<opr::ROIAlignBackward, 0>:
public PoolingLoadDumpImpl<opr::ROIAlignBackward,
MakeROIAlignCaller4<megdnn::ROIAlignBackward>,
megdnn::param::ROIAlign>
{};
template<>
struct OprLoadDumpImpl<opr::Pooling, 0>:
public PoolingLoadDumpImpl<opr::Pooling,
MakePoolingCaller1<megdnn::Pooling>,
megdnn::param::Pooling>
{};
template<>
struct OprLoadDumpImpl<opr::PoolingBackward, 0>:
public PoolingLoadDumpImpl<opr::PoolingBackward,
MakePoolingBackwardCaller3<megdnn::PoolingBackward>,
megdnn::param::Pooling>
{};
template<>
struct OprLoadDumpImpl<opr::Convolution, 0>:
public ConvLoadDumpImpl<opr::Convolution,
......@@ -374,12 +520,12 @@ namespace serialization {
namespace opr {
using ConvolutionV1 = Convolution;
using ConvolutionBackwardDataV1 = ConvolutionBackwardData;
using ConvolutionBackwardFilterV1 = ConvolutionBackwardFilter;
MGB_SEREG_OPR(ConvolutionV1, 0);
MGB_SEREG_OPR(ConvolutionBackwardDataV1, 0);
MGB_SEREG_OPR(ConvolutionBackwardFilterV1, 0);
using ConvolutionV2 = Convolution;
using ConvolutionBackwardDataV2 = ConvolutionBackwardData;
using ConvolutionBackwardFilterV2 = ConvolutionBackwardFilter;
MGB_SEREG_OPR(ConvolutionV2, 0);
MGB_SEREG_OPR(ConvolutionBackwardDataV2, 0);
MGB_SEREG_OPR(ConvolutionBackwardFilterV2, 0);
MGB_SEREG_OPR(Images2Neibs, 1);
MGB_SEREG_OPR(Images2NeibsBackward, 2);
......@@ -400,12 +546,14 @@ namespace opr {
MGB_SEREG_OPR(LRN, 1);
MGB_SEREG_OPR(LRNBackward, 3);
MGB_SEREG_OPR(Pooling, 1);
MGB_SEREG_OPR(PoolingBackward, 3);
MGB_SEREG_OPR(AdaptivePooling, 2);
MGB_SEREG_OPR(AdaptivePoolingBackward, 4);
using PoolingV1 = Pooling;
using PoolingBackwardV1 = PoolingBackward;
MGB_SEREG_OPR(PoolingV1, 1);
MGB_SEREG_OPR(PoolingBackwardV1, 3);
using AdaptivePoolingV1 = AdaptivePooling;
using AdaptivePoolingBackwardV1 = AdaptivePoolingBackward;
MGB_SEREG_OPR(AdaptivePoolingV1, 2);
MGB_SEREG_OPR(AdaptivePoolingBackwardV1, 4);
MGB_SEREG_OPR(ROIPooling, 3);
MGB_SEREG_OPR(ROIPoolingBackward, 4);
......@@ -418,18 +566,23 @@ namespace opr {
MGB_SEREG_OPR(Convolution3DBackwardData, 0);
MGB_SEREG_OPR(Convolution3DBackwardFilter, 0);
using ConvBiasForwardV3 = ConvBiasForward;
MGB_SEREG_OPR(ConvBiasForwardV3, 0);
using ConvBiasForwardV4 = ConvBiasForward;
MGB_SEREG_OPR(ConvBiasForwardV4, 0);
MGB_SEREG_OPR(BatchNorm, 0);
MGB_SEREG_OPR(BatchNormBackward, 5);
MGB_SEREG_OPR(LocalShareForward, 0);
MGB_SEREG_OPR(LocalShareBackwardData, 0);
MGB_SEREG_OPR(LocalShareBackwardFilter, 0);
MGB_SEREG_OPR(ROIAlign, 2);
MGB_SEREG_OPR(ROIAlignBackward, 4);
using LocalShareForwardV1 = LocalShareForward;
using LocalShareBackwardDataV1 = LocalShareBackwardData;
using LocalShareBackwardFilterV1 = LocalShareBackwardFilter;
MGB_SEREG_OPR(LocalShareForwardV1, 0);
MGB_SEREG_OPR(LocalShareBackwardDataV1, 0);
MGB_SEREG_OPR(LocalShareBackwardFilterV1, 0);
using ROIAlignV1=ROIAlign;
using ROIAlignBackwardV1=ROIAlignBackward;
MGB_SEREG_OPR(ROIAlignV1, 2);
MGB_SEREG_OPR(ROIAlignBackwardV1, 4);
MGB_SEREG_OPR(DeformableConvForward, 0);
MGB_SEREG_OPR(DeformableConvBackwardData, 0);
MGB_SEREG_OPR(DeformableConvBackwardFilter, 0);
......@@ -437,7 +590,9 @@ namespace opr {
MGB_SEREG_OPR(DeformablePSROIPoolingForward, 3);
MGB_SEREG_OPR(DeformablePSROIPoolingBackward, 5);
MGB_SEREG_OPR(BatchConvBiasForward, 0);
using BatchConvBiasForwardV1 = BatchConvBiasForward;
MGB_SEREG_OPR(BatchConvBiasForwardV1, 0);
MGB_SEREG_OPR(FakeQuant, 3);
MGB_SEREG_OPR(FakeQuantBackward, 4);
MGB_SEREG_OPR(TQT, 2);
......
......@@ -12,7 +12,8 @@ decl_opr(
params='WarpPerspective',
desc='Apply perspective transformation to batched 2D images; '
'see http://docs.opencv.org/2.4/modules/imgproc/doc/geometric_transformations.html '
'for details on perspective transformations.')
'for details on perspective transformations.',
version=2)
decl_opr(
'WarpPerspective',
......@@ -62,7 +63,7 @@ decl_opr('Resize',
desc='Resize an image. '
'see http://docs.opencv.org/2.4/modules/imgproc/doc/geometric_transformations.html?highlight=resize#cv2.resize'
' for details.',
version=1)
version=2)
decl_opr(
'WarpAffine',
......@@ -77,7 +78,7 @@ decl_opr(
desc='Apply affine transformation to batched 2D images; '
'see http://docs.opencv.org/2.4/modules/imgproc/doc/geometric_transformations.html '
'for details on affine transformations.',
version=1)
version=2)
decl_opr(
'Remap',
......@@ -89,7 +90,8 @@ decl_opr(
params='Remap',
desc='Remap transformation to batched 2D images; '
'see https://docs.opencv.org/2.4/modules/imgproc/doc/geometric_transformations.html?highlight=remap'
'for details on remap transformations.')
'for details on remap transformations.',
version=1)
decl_raw_opr(
'dct_channel_select',
......
......@@ -9,8 +9,10 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include <type_traits>
#include "megbrain/opr/imgproc.h"
#include "megbrain/serialization/sereg.h"
#include "megdnn/opr_param_defs.h"
namespace mgb {
namespace serialization {
......@@ -38,6 +40,63 @@ namespace serialization {
}
};
template <>
struct OprMaker<opr::Remap, 0> {
using Opr = opr::Remap;
using Param = Opr::Param;
static cg::OperatorNodeBase* make(const Param& param,
const cg::VarNodeArray& inputs,
ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
if (inputs.size() == 2) {
return Opr::make(inputs[0], inputs[1], param, config)
.node()
->owner_opr();
} else {
return nullptr;
}
}
};
template<>
struct OprMaker<opr::RemapBackwardMat, 0> {
using Opr = opr::RemapBackwardMat;
using Param = Opr::Param;
static cg::OperatorNodeBase* make(const Param& param,
const cg::VarNodeArray& inputs,
ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
if (inputs.size() == 3) {
return Opr::make(inputs[0], inputs[1], inputs[2], param, config)
.node()
->owner_opr();
} else {
return nullptr;
}
}
};
template<>
struct OprMaker<opr::RemapBackwardData, 0> {
using Opr = opr::RemapBackwardData;
using Param = Opr::Param;
static cg::OperatorNodeBase* make(const Param& param,
const cg::VarNodeArray& inputs,
ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
if (inputs.size() == 3) {
return Opr::make(inputs[0], inputs[1], inputs[2], param, config)
.node()
->owner_opr();
} else {
return nullptr;
}
}
};
template <>
struct OprMaker<opr::DctChannelSelectForward, 0> {
using Opr = opr::DctChannelSelectForward;
......@@ -106,29 +165,35 @@ namespace serialization {
} // namespace serialization
namespace opr {
MGB_SEREG_OPR(WarpPerspective, 0);
MGB_SEREG_OPR(WarpPerspectiveBackwardData, 0);
MGB_SEREG_OPR(WarpPerspectiveBackwardMat, 0);
using WarpPerspectiveV2=WarpPerspective;
using WarpPerspectiveBackwardDataV2=WarpPerspectiveBackwardData;
using WarpPerspectiveBackwardMatV2=WarpPerspectiveBackwardMat;
MGB_SEREG_OPR(WarpPerspectiveV2, 0);
MGB_SEREG_OPR(WarpPerspectiveBackwardDataV2, 0);
MGB_SEREG_OPR(WarpPerspectiveBackwardMatV2, 0);
MGB_SEREG_OPR(Rotate, 1);
MGB_SEREG_OPR(CvtColor, 1);
MGB_SEREG_OPR(GaussianBlur, 1);
MGB_SEREG_OPR(ResizeBackward, 2);
MGB_SEREG_OPR(Remap, 2);
MGB_SEREG_OPR(RemapBackwardData, 3);
MGB_SEREG_OPR(RemapBackwardMat, 3);
using RemapV1=Remap;
using RemapBackwardDataV1=RemapBackwardData;
using RemapBackwardMatV1=RemapBackwardMat;
MGB_SEREG_OPR(RemapV1, 2);
MGB_SEREG_OPR(RemapBackwardDataV1, 3);
MGB_SEREG_OPR(RemapBackwardMatV1, 3);
//! current warp affine version
using WarpAffineV1 = opr::WarpAffine;
MGB_SEREG_OPR(WarpAffineV1, 3);
using WarpAffineV2 = opr::WarpAffine;
MGB_SEREG_OPR(WarpAffineV2, 3);
//! current resize version
using ResizeV1 = opr::Resize;
MGB_SEREG_OPR(ResizeV1, 2);
using ResizeV2 = opr::Resize;
MGB_SEREG_OPR(ResizeV2, 2);
MGB_SEREG_OPR(DctChannelSelect, 0);
using DctChannelSelectV1 = opr::DctChannelSelect;
MGB_SEREG_OPR(DctChannelSelectV1, 0);
} // namespace opr
......
......@@ -71,7 +71,6 @@ namespace {
void OprRegistry::add(const OprRegistry& record) {
auto&& sd = static_data();
auto persist_id = record.persist_type_id;
auto registry_ins = sd.id2reg.emplace(persist_id, record);
mgb_assert(registry_ins.second ||
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册