提交 20b42a8c 编写于 作者: M Megvii Engine Team

fix(dnn): add naive lstm kernel

GitOrigin-RevId: f08ef810cf936768a022c10f226d80e499355659
上级 2faa6ea5
......@@ -2059,7 +2059,7 @@ public:
_megdnn_tensor_in bias_ih, _megdnn_tensor_in hx,
_megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh,
_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
static void deduce_layout(
void deduce_layout(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& bias_ih, const TensorLayout& hx,
const TensorLayout& weight_hh, const TensorLayout& bias_hh,
......
......@@ -36,18 +36,13 @@ pdef('Axis').add_fields('int32', 'axis', 0)
add_enum(Doc('Format', 'convolution data/filter/output format; see '
':class:`RelayoutFormat` for more details'),
'NCHW = 0', 'NHWC = 1', 'NHWCD4 = 2', 'NCHW4 = 3', 'NCHW8 = 4', 'NCHW32 = 5', 'NCHW88 = 6',
'NCHW44 = 7', 'NCHW44_DOT = 8',
'NCHW44 = 7','NCHW44_DOT = 8',
Doc('NCHW_WINOGRAD = 9', 'NCHW layout with weights tranformed by winograd'),
Doc('NCHW88_WINOGRAD = 10',
'NCHW88 layout with weights tranformed by winograd'),
Doc('NCHW44_WINOGRAD = 11',
'NCHW44 layout with weights tranformed by winograd'),
Doc('NCHW4_NCHW32 = 12',
'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'),
Doc('NCHW32_NCHW4 = 13',
'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'),
Doc('NCHW4_NCHW = 14',
'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'),
Doc('NCHW88_WINOGRAD = 10', 'NCHW88 layout with weights tranformed by winograd'),
Doc('NCHW44_WINOGRAD = 11', 'NCHW44 layout with weights tranformed by winograd'),
Doc('NCHW4_NCHW32 = 12', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'),
Doc('NCHW32_NCHW4 = 13', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'),
Doc('NCHW4_NCHW = 14', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'),
Doc('NHWC_NCHW = 15', 'NHWC_NCHW means input tensors are nhwc layout, '
'output tensor is nchw layout'),
Doc('NHWC_NCHW4_IC_SMALL = 16', 'NHWC_NCHW4_IC_SMALL means input tensors are nhwc(c < 4) layout, '
......@@ -101,13 +96,10 @@ pdef('Axis').add_fields('int32', 'axis', 0)
add_enum(Doc('Format', 'convolution data/filter/output format; see '
':class:`RelayoutFormat` for more details'),
'NCHW = 0', 'NHWC = 1', 'NHWCD4 = 2', 'NCHW4 = 3', 'NCHW8 = 4', 'NCHW32 = 5', 'NCHW88 = 6',
'NCHW44 = 7', 'NCHW44_DOT = 8',
Doc('NCHW4_NCHW32 = 9',
'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'),
Doc('NCHW32_NCHW4 = 10',
'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'),
Doc('NCHW4_NCHW = 11',
'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'),
'NCHW44 = 7','NCHW44_DOT = 8',
Doc('NCHW4_NCHW32 = 9', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'),
Doc('NCHW32_NCHW4 = 10', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'),
Doc('NCHW4_NCHW = 11', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'),
Doc('NHWC_NCHW = 12', 'NHWC_NCHW means input tensors are nhwc layout, '
'output tensor is nchw layout'),
Doc('NHWC_NCHW4_IC_SMALL = 13', 'NHWC_NCHW4_IC_SMALL means input tensors are nhwc(c < 4) layout, '
......@@ -115,11 +107,11 @@ pdef('Axis').add_fields('int32', 'axis', 0)
Doc('NCHW_NCHW4_IC_SMALL = 14', 'NCHW_NCHW4_IC_SMALL means input tensors are nchw(c < 4) layout, '
'output tensor is nchw4 layout, padding c=4'),
Doc('CHWN4 = 15', 'CHWN4 is currently only used on Nvidia platform for fast implementation '
'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.'),
'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.'),
Doc('NCHW64 = 16', 'NCHW64 is designed for convolution implementation to utilizing TensorCore '
'instructions for 4-bit integers on Nvidia platforms'),
'instructions for 4-bit integers on Nvidia platforms'),
Doc('NCHW4_NHWC = 17', 'NCHW4_NHWC means input tensors are nchw4 layout, output tensor is nhwc layout')).
add_enum_alias('ComputeMode', 'ConvolutionV1', name_field='compute_mode')
add_enum_alias('ComputeMode', 'ConvolutionV1',name_field='compute_mode')
)
......@@ -141,7 +133,7 @@ pdef('Axis').add_fields('int32', 'axis', 0)
add_enum_alias('ConvMode', 'ConvolutionV0', 'Mode').
add_enum('PoolMode', 'AVERAGE = 0', 'MAX = 1').
add_enum('NonlineMode', 'IDENTITY = 0', 'RELU = 1', 'SIGMOID = 2').
add_fields('uint32', 'pool_shape_h', 1, 'pool_shape_w', 1, 'pool_stride_h', 1, 'pool_stride_w', 1,
add_fields('uint32', 'pool_shape_h', 1, 'pool_shape_w', 1, 'pool_stride_h', 1, 'pool_stride_w', 1, \
'pool_pad_h', 0, 'pool_pad_w', 0, 'conv_stride_h', 1, 'conv_stride_w', 1, 'conv_pad_h', 0, 'conv_pad_w', 0))
(pdef('ConvBias', 'legacy conv_bias', version=0, is_legacy=True).
......@@ -224,8 +216,8 @@ pdef('Axis').add_fields('int32', 'axis', 0)
(pdef('SeparableConv').
add_enum_alias('Mode', 'ConvolutionV0').
add_enum('BorderMode', 'BORDER_REPLICATE = 0', 'BORDER_REFLECT = 1',
'BORDER_REFLECT_101 = 2', 'BORDER_WRAP = 3',
'BORDER_CONSTANT = 4', 'BORDER_TRANSPARENT = 5', 'BORDER_ISOLATED = 6').
'BORDER_REFLECT_101 = 2','BORDER_WRAP = 3',
'BORDER_CONSTANT = 4', 'BORDER_TRANSPARENT = 5','BORDER_ISOLATED = 6').
add_fields('bool', 'is_symm_kernel', 'true').
add_fields('uint32', 'pad_h', 0, 'pad_w', 0, 'stride_h', 1, 'stride_w', 1,
'ksize_h', 3, 'ksize_w', 3, 'anchor_h', 1, 'anchor_w', 1))
......@@ -255,7 +247,7 @@ pdef('Axis').add_fields('int32', 'axis', 0)
)
(pdef('Pooling', version=1).
add_enum_alias('Mode', 'PoolingV0').
add_enum_alias('Mode','PoolingV0').
add_fields('uint32', 'pad_h', 0, 'pad_w', 0, 'stride_h', 2, 'stride_w', 2,
'window_h', 2, 'window_w', 2).
add_enum_alias('Format', 'Convolution')
......@@ -310,8 +302,7 @@ pdef('Axis').add_fields('int32', 'axis', 0)
).
add_fields('float32', 'scale', '1.f'))
INTERP_MODES = ['NEAREST = 0', 'LINEAR = 1',
'AREA = 2', 'CUBIC = 3', 'LANCZOS4 = 4']
INTERP_MODES = ['NEAREST = 0', 'LINEAR = 1', 'AREA = 2', 'CUBIC = 3', 'LANCZOS4 = 4']
BORDER_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'),
Doc('REFLECT = 1', 'fedcba|abcdefgh|hgfedcb'),
Doc('REFLECT_101 = 2', 'gfedcb|abcdefgh|gfedcba'),
......@@ -332,8 +323,8 @@ BORDER_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'),
add_fields('float32', Doc('border_val', 'used for CONSTANT bmode'), '.0f'))
(pdef('WarpPerspective', version=2).
add_enum_alias('InterpolationMode', 'WarpPerspectiveV1', name_field="imode").
add_enum_alias('BorderMode', 'WarpPerspectiveV1', name_field="bmode").
add_enum_alias('InterpolationMode','WarpPerspectiveV1',name_field="imode").
add_enum_alias('BorderMode','WarpPerspectiveV1',name_field="bmode").
add_enum_alias('Format', 'Convolution').
add_fields('float32', Doc('border_val', 'used for CONSTANT bmode'), '.0f'))
......@@ -408,7 +399,7 @@ pdef('Elemwise').add_enum(
Doc('RMULH = 43', 'binary: rounded higher l bits of x * y, where l is the bit '
'length of x.'),
Doc('ATAN2 = 44', 'binary: atan2(y,x)'),
Doc('ATAN2 = 44','binary: atan2(y,x)'),
Doc('ERF = 45', 'unary: erf(x)'),
Doc('ERFINV = 46', 'unary: inverse function of erf(x)'),
Doc('ERFC = 47', 'unary: erfc(x)'),
......@@ -643,7 +634,7 @@ Currently, ```DEFAULT``` mode means:
Doc('axis',
'axis along which reduction is performed; if INT_MAX is given, '
'reduce to given target shape (only used in megbrain)'),
(1 << 31)-1).
(1<<31)-1).
add_enum('DataType',
Doc('DEFAULT = 0',
'''
......@@ -698,7 +689,7 @@ Currently, ```DEFAULT``` mode means:
add_fields('int32',
Doc('axis',
'axis along which cumsum is performed, default with INT_MAX'),
(1 << 31)-1).
(1<<31)-1).
add_fields('bool',
Doc('exclusive',
'whether the current element is taken into account'),
......@@ -770,8 +761,7 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0)
(pdef('UniformRNG', version=1).
add_fields('uint64', 'seed', 0).
add_fields(
'dtype', Doc(
'dtype', 'The dtype of output Tensor. Only support Float32.'),
'dtype', Doc('dtype', 'The dtype of output Tensor. Only support Float32.'),
'DTypeEnum::Float32'))
(pdef('GaussianRNG', version=0, is_legacy=True).
......@@ -782,8 +772,7 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0)
add_fields('uint64', 'seed', 0).
add_fields('float32', 'mean', 0, 'std', 1).
add_fields(
'dtype', Doc(
'dtype', 'The dtype of output Tensor. Only support Float32.'),
'dtype', Doc('dtype', 'The dtype of output Tensor. Only support Float32.'),
'DTypeEnum::Float32'))
(pdef('GammaRNG').
......@@ -830,7 +819,7 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0)
('YUV2GRAY_NV12', 'BT601_YUV2GRAY_NV12'),
('YUV2GRAY_YV12', 'BT601_YUV2GRAY_YV12'),
('YUV2GRAY_YU12', 'BT601_YUV2GRAY_YU12')],
name_field='mode'))
name_field = 'mode'))
(pdef('WarpAffine', version=0, is_legacy=True)
.add_enum_alias('InterpolationMode', 'WarpPerspectiveV1', name_field='imode')
......@@ -853,7 +842,7 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0)
(pdef('GaussianBlur')
.add_enum_alias('BorderMode', 'WarpPerspectiveV1', name_field='border_mode')
.add_fields('uint32', 'kernel_height', 0, 'kernel_width', 0)
.add_fields('float32', 'sigma_x', '0.f', 'sigma_y', '0.f'))
.add_fields('float32','sigma_x', '0.f', 'sigma_y', '0.f'))
(pdef('Resize', version=0, is_legacy=True)
.add_enum_alias('InterpolationMode', 'WarpPerspectiveV1', name_field='imode'))
......@@ -866,7 +855,7 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0)
.add_enum_alias('InterpolationMode', 'WarpPerspectiveV1', name_field='imode')
.add_enum_alias('Format', 'Convolution', default=1))
(pdef('Remap', version=0, is_legacy=True)
(pdef('Remap', version=0,is_legacy=True)
.add_enum_alias('InterpolationMode', 'WarpPerspectiveV1', name_field='imode')
.add_enum_alias('BorderMode', 'WarpPerspectiveV1', name_field='border_type')
.add_enum_alias('Format', 'ConvolutionV0', default=1)
......@@ -920,8 +909,8 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0)
(pdef('SeparableConv3D').
add_enum_alias('Mode', 'Convolution3D').
add_enum('BorderMode', 'BORDER_REPLICATE = 0', 'BORDER_REFLECT = 1',
'BORDER_REFLECT_101 = 2', 'BORDER_WRAP = 3',
'BORDER_CONSTANT = 4', 'BORDER_TRANSPARENT = 5', 'BORDER_ISOLATED = 6').
'BORDER_REFLECT_101 = 2','BORDER_WRAP = 3',
'BORDER_CONSTANT = 4', 'BORDER_TRANSPARENT = 5','BORDER_ISOLATED = 6').
add_fields('bool', 'is_symm_kernel', 'true').
add_fields('uint32', 'pad_d', 0, 'pad_h', 0, 'pad_w', 0,
'stride_d', 0, 'stride_h', 1, 'stride_w', 1,
......@@ -1034,10 +1023,10 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o
'NCHW_NCHW4 = 24',
'NCHW4_NCHW = 25',
'NCHW_NCHW4_WEIGHT = 26',
'NCHW_NCHW64 = 27',
'NCHW64_NCHW = 28',
'NCHW_NHWC = 29',
'NHWC_NCHW = 30',
'NCHW_NCHW64 = 27',
'NCHW64_NCHW = 28',
'NCHW_NHWC = 29',
'NHWC_NCHW = 30',
)
)
......@@ -1059,7 +1048,7 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o
add_fields('bool', 'is_symm_kernel', 'true').
add_fields('uint32', 'ksize_h', 3, 'ksize_w', 3, 'anchor_h', 1, 'anchor_w', 1))
(pdef('LocalShare', 'Local share convolution', version=0, is_legacy=True).
(pdef('LocalShare', 'Local share convolution',version=0, is_legacy=True).
add_enum_alias('Mode', 'ConvolutionV0').
add_fields(
'uint32',
......@@ -1100,7 +1089,7 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o
)
(pdef('ROIAlign', version=0, is_legacy=True).
(pdef('ROIAlign',version=0,is_legacy=True).
add_enum('Mode', 'MAX = 0', 'AVERAGE = 1', name_field='mode').
add_enum_alias('Format', 'ConvolutionV0').
add_fields('float32', 'spatial_scale', '1.0').
......@@ -1144,7 +1133,7 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o
Doc('part_size', 'size of each deformable part'), 1,
Doc('sample_per_part', 'sample count of each bbox'), 1))
(pdef('BatchConvBias', 'Batch convolution (unshare weights on the batch dimension)', version=0, is_legacy=True).
(pdef('BatchConvBias', 'Batch convolution (unshare weights on the batch dimension)',version=0,is_legacy=True).
add_enum_alias('NonlineMode', 'ConvBiasV0').
add_enum_alias('Mode', 'ConvolutionV0').
add_fields(
......@@ -1163,7 +1152,7 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o
add_enum_alias('ComputeMode', 'ConvolutionV1', name_field="compute_mode")
)
(pdef('BatchConvBias', 'Batch convolution (unshare weights on the batch dimension)', version=1).
(pdef('BatchConvBias', 'Batch convolution (unshare weights on the batch dimension)',version=1).
add_enum_alias('NonlineMode', 'ConvBiasV0').
add_enum_alias('Mode', 'ConvolutionV0').
add_fields(
......@@ -1183,8 +1172,8 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o
)
(pdef('FakeQuant').
add_fields('int32', 'qmin', '-2147483648').
add_fields('int32', 'qmax', '2147483647')
add_fields('int32','qmin','-2147483648').
add_fields('int32','qmax','2147483647')
)
(pdef('TQT').
add_fields('int32', 'qmin', '-2147483648').
......@@ -1203,13 +1192,13 @@ PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'),
Doc('REFLECT = 1', 'fedcba|abcdefgh|hgfedcb'),
Doc('CONSTANT = 2', 'iiiiii|abcdefgh|iiiiiii')]
(pdef('Padding').
add_fields('uint32', Doc('front_offset_dim0', 'offset in dim 0'), 0).
add_fields('uint32', Doc('front_offset_dim1', 'offset in dim 1'), 0).
add_fields('uint32', Doc('front_offset_dim2', 'offset in dim 2'), 0).
add_fields('uint32', Doc('front_offset_dim3', 'offset in dim 3'), 0).
add_fields('uint32', Doc('front_offset_dim4', 'offset in dim 4'), 0).
add_fields('uint32', Doc('front_offset_dim5', 'offset in dim 5'), 0).
add_fields('uint32', Doc('front_offset_dim6', 'offset in dim 6'), 0).
add_fields('uint32', Doc('front_offset_dim0','offset in dim 0'), 0).
add_fields('uint32', Doc('front_offset_dim1','offset in dim 1'), 0).
add_fields('uint32', Doc('front_offset_dim2','offset in dim 2'), 0).
add_fields('uint32', Doc('front_offset_dim3','offset in dim 3'), 0).
add_fields('uint32', Doc('front_offset_dim4','offset in dim 4'), 0).
add_fields('uint32', Doc('front_offset_dim5','offset in dim 5'), 0).
add_fields('uint32', Doc('front_offset_dim6','offset in dim 6'), 0).
add_fields('uint32', Doc('back_offset_dim0', 'back offset in dim0'), 0).
add_fields('uint32', Doc('back_offset_dim1', 'back offset in dim1'), 0).
add_fields('uint32', Doc('back_offset_dim2', 'back offset in dim2'), 0).
......@@ -1217,7 +1206,7 @@ PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'),
add_fields('uint32', Doc('back_offset_dim4', 'back offset in dim4'), 0).
add_fields('uint32', Doc('back_offset_dim5', 'back offset in dim5'), 0).
add_fields('uint32', Doc('back_offset_dim6', 'back offset in dim6'), 0).
add_fields('float32', Doc('padding_val', 'param of padding opr'), 0).
add_fields('float32', Doc('padding_val','param of padding opr'), 0).
add_enum('PaddingMode', *PADDING_MODES,
name_field='padding_mode', default=2,
member_alias=[(i, 'PADDING_{}'.format(i)) for i in PADDING_MODES]
......@@ -1241,22 +1230,21 @@ PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'),
)
(pdef('RNN').
add_fields('uint32', 'num_layers', '1').
add_fields('bool', 'bidirectional', 'false').
add_fields('bool', 'bias', 'true').
add_fields('uint32', 'hidden_size', '128').
add_fields('uint32', 'proj_size', '0').
add_fields('float32', 'dropout', '0.f').
add_fields('uint32', Doc('num_layers', 'Number of recurrent layers'), '1').
add_fields('bool', Doc('bidirectional', 'If becomes a bidirectional RNN'), 'false').
add_fields('bool', Doc('bias', 'If the layer use bias weights b_ih and b_hh'), 'true').
add_fields('uint32', Doc('hidden_size', 'The number of features in the hidden state'), '128').
add_fields('float32', Doc('dropout', 'If introduce a Dropout layer on the outputs of each RNN layer'), '0.f').
add_enum_alias('NonlineMode', 'RNNCell').
add_enum_alias('FwdMode', 'BN', name_field='fwd_mode')
)
(pdef('LSTM').
add_fields('uint32', 'num_layers', '1').
add_fields('bool', 'bidirectional', 'false').
add_fields('bool', 'bias', 'true').
add_fields('uint32', 'hidden_size', '128').
add_fields('uint32', 'proj_size', '0').
add_fields('float32', 'dropout', '0.f').
add_fields('uint32', Doc('num_layers', 'Number of recurrent layers'), '1').
add_fields('bool', Doc('bidirectional', 'If becomes a bidirectional LSTM'), 'false').
add_fields('bool', Doc('bias', 'If the layer use bias weights b_ih and b_hh'), 'true').
add_fields('uint32', Doc('hidden_size', 'The number of features in the hidden state'), '128').
add_fields('uint32', Doc('proj_size', 'If use LSTM with projections of corresponding size'), '0').
add_fields('float32', Doc('dropout', 'If introduce a Dropout layer on the outputs of each LSTM layer'), '0.f').
add_enum_alias('FwdMode', 'BN', name_field='fwd_mode')
)
......@@ -224,5 +224,6 @@ std::unique_ptr<Opr> Handle::create_operator() {
#define INST(opr) template std::unique_ptr<opr> Handle::create_operator();
MEGDNN_FOREACH_OPR_CLASS(INST)
#undef INST
// vim: syntax=cpp.doxygen
......@@ -10,19 +10,12 @@
*/
#include "megdnn/oprs.h"
#include "src/common/utils.h"
// #include "src/cuda/lstm/utils.h"
namespace megdnn {
/*size_t get_reserve_size(Handle* handle, megdnn::LSTMForward::Param& param, const
TensorLayout& input) { #if CUDNN_MAJOR >= 6 auto holder =
megdnn::cuda::lstm::get_RNNDescHolder_v6(handle, param, input); return
holder.reserveSpace_size; # else return 0; #endif
}*/
void LSTM::deduce_layout(
const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx,
const TensorLayout& flatten_weights, TensorLayout& output, TensorLayout& hy,
const TensorLayout& /*flatten_weights*/, TensorLayout& output, TensorLayout& hy,
TensorLayout& cy, TensorLayout& reserve_space) {
// input: [seq_len, batch_size, input_size]
// hx: [D * num_layers, batch_size, hidden_size]
......@@ -34,24 +27,30 @@ void LSTM::deduce_layout(
TensorShape{seq_len, batch_size, D * hidden_size}, input.dtype);
hy = TensorLayout(hx);
cy = TensorLayout(cx);
// reserve_space = {{get_reserve_size(this->handle(), param(), input)},
// dtype::Byte()};
reserve_space = {{get_reserve_size_in_bytes(input)}, dtype::Byte()};
reserve_space = {{get_reserve_size_in_bytes(input)}, input.dtype};
}
void LSTM::check_exec(
const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx,
const TensorLayout& flatten_weights, const TensorLayout& output,
const TensorLayout& hy, const TensorLayout& cy,
const TensorLayout& reserve_space, size_t workspace_in_bytes) {
const TensorLayout& /*reserve_space*/, size_t /*workspace_in_bytes*/) {
auto errmsg = [&]() {
std::string msg;
msg.append("input=");
msg.append(input.to_string());
msg.append(", output=");
msg.append(output.to_string());
msg.append(", hx=");
msg.append(hx.to_string());
msg.append(", cx=");
msg.append(cx.to_string());
msg.append(", hy=");
msg.append(hy.to_string());
msg.append(", cy=");
msg.append(cy.to_string());
msg.append(", flatten_weights=");
msg.append(flatten_weights.to_string());
msg.append(", hidden_size=");
msg.append(std::to_string(param().hidden_size));
msg.append(", num_layers=");
......@@ -61,9 +60,29 @@ void LSTM::check_exec(
return msg;
};
size_t D = param().bidirectional ? 2 : 1;
size_t b = param().bias ? 1 : 0;
size_t num_layers = param().num_layers;
size_t input_size = input.shape[2];
size_t gate_hidden_size = 4 * param().hidden_size;
// first layer{ weight_ih_l[k][_reverse].shape = (4*hidden_size, input_size)
// weight_hh_l[k][_reverse].shape = (4*hidden_size, hidden_size)}
// other layers{ weight_ih_l[k][_reverse].shape = (4*hidden_size, num_directions *
// hidden_size)
// weight_hh_l[k][_reverse].shape = (4*hidden_size, hidden_size)}
// bias: 2 * num_directions * num_layers
// size_dim1 = D * first layer + (layer -1) * other layer + bias
size_t size_dim1 = D * (input_size + param().hidden_size) +
(num_layers - 1) * D * ((D + 1) * param().hidden_size) +
b * 2 * D * num_layers;
#define ASSERT_BRIEF(_content) megdnn_assert(_content, "%s", errmsg().c_str());
ASSERT_BRIEF(input.ndim == 3)
ASSERT_BRIEF(output.ndim == 3)
ASSERT_BRIEF(flatten_weights.shape[0] == gate_hidden_size)
ASSERT_BRIEF(flatten_weights.shape[0] == size_dim1)
ASSERT_BRIEF(output.shape[0] == input.shape[0])
ASSERT_BRIEF(output.shape[1] == input.shape[1])
ASSERT_BRIEF(output.shape[2] == D * param().hidden_size)
ASSERT_BRIEF(hx.ndim == 3)
ASSERT_BRIEF(hx.shape[0] == D * num_layers)
ASSERT_BRIEF(hx.shape[1] == input.shape[1]) // batch_size
......@@ -72,14 +91,22 @@ void LSTM::check_exec(
ASSERT_BRIEF(cx.shape[0] == D * num_layers)
ASSERT_BRIEF(cx.shape[1] == input.shape[1]) // batch_size
ASSERT_BRIEF(cx.shape[2] == param().hidden_size)
ASSERT_BRIEF(hy.ndim == 3)
ASSERT_BRIEF(hy.shape[0] == D * num_layers)
ASSERT_BRIEF(hy.shape[1] == input.shape[1]) // batch_size
ASSERT_BRIEF(hy.shape[2] == param().hidden_size)
ASSERT_BRIEF(cy.ndim == 3)
ASSERT_BRIEF(cy.shape[0] == D * num_layers)
ASSERT_BRIEF(cy.shape[1] == input.shape[1]) // batch_size
ASSERT_BRIEF(cy.shape[2] == param().hidden_size)
#undef ASSERT_BRIEF
}
void LSTMBackward::deduce_layout(
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy,
const TensorLayout& dcy, const TensorLayout& flatten_weights,
const TensorLayout& reserve_space, TensorLayout& dx, TensorLayout& dhx,
const TensorLayout& x, const TensorLayout& /*y*/, const TensorLayout& hx,
const TensorLayout& cx, const TensorLayout& /*dy*/, const TensorLayout& /*dhy*/,
const TensorLayout& /*dcy*/, const TensorLayout& flatten_weights,
const TensorLayout& /*reserve_space*/, TensorLayout& dx, TensorLayout& dhx,
TensorLayout& dcx, TensorLayout& dw) {
dx = x;
dhx = hx;
......@@ -87,12 +114,14 @@ void LSTMBackward::deduce_layout(
dw = flatten_weights;
}
// TODO: add shape check of BWD
void LSTMBackward::check_exec(
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy,
const TensorLayout& dcy, const TensorLayout& flatten_weights,
const TensorLayout& reserve_space, const TensorLayout& dx,
const TensorLayout& dhx, const TensorLayout& dcx, const TensorLayout& dw,
size_t workspace_in_bytes) {}
const TensorLayout& /*x*/, const TensorLayout& /*y*/,
const TensorLayout& /*hx*/, const TensorLayout& /*cx*/,
const TensorLayout& /*dy*/, const TensorLayout& /*dhy*/,
const TensorLayout& /*dcy*/, const TensorLayout& /*flatten_weights*/,
const TensorLayout& /*reserve_space*/, const TensorLayout& /*dx*/,
const TensorLayout& /*dhx*/, const TensorLayout& /*dcx*/,
const TensorLayout& /*dw*/, size_t /*workspace_in_bytes*/) {}
} // namespace megdnn
\ No newline at end of file
......@@ -20,8 +20,6 @@ void LSTMCell::deduce_layout(
const TensorLayout& weight_hh, const TensorLayout& bias_hh,
const TensorLayout& cx, TensorLayout& h_new, TensorLayout& c_new,
TensorLayout& gates) {
// size_t batch_size = hx.shape[0];
// size_t hidden_size = hx.shape[1];
h_new = TensorLayout(hx, hx.dtype);
c_new = TensorLayout(cx, cx.dtype);
auto opr = handle()->create_operator<RNNCellForward>();
......@@ -36,6 +34,39 @@ void LSTMCell::check_exec(
const TensorLayout& cx, const TensorLayout& h_new, const TensorLayout& c_new,
const TensorLayout& gates, size_t workspace_in_bytes) {
TensorLayout h_new_expected, c_new_expected, gates_expected;
auto errmsg = [&]() {
std::string msg;
msg.append("input=");
msg.append(input.to_string());
msg.append(", weight_ih=");
msg.append(weight_ih.to_string());
msg.append(", bias_ih=");
msg.append(bias_ih.to_string());
msg.append(", hx=");
msg.append(hx.to_string());
msg.append(", weight_hh=");
msg.append(weight_hh.to_string());
msg.append(", bias_hh=");
msg.append(bias_hh.to_string());
msg.append(", cx=");
msg.append(cx.to_string());
return msg;
};
#define ASSERT_BRIEF(_content) megdnn_assert(_content, "%s", errmsg().c_str());
ASSERT_BRIEF(input.ndim == 2)
ASSERT_BRIEF(input.shape[1] == weight_ih.shape[1])
ASSERT_BRIEF(weight_ih.shape[0] == weight_hh.shape[0])
ASSERT_BRIEF(weight_hh.shape[0] == 4 * weight_hh.shape[1])
ASSERT_BRIEF(bias_ih.shape[0] == bias_hh.shape[0])
ASSERT_BRIEF(hx.ndim == 2)
ASSERT_BRIEF(hx.shape[0] == input.shape[0])
ASSERT_BRIEF(hx.shape[1] == cx.shape[1]) // hidden_size
ASSERT_BRIEF(cx.ndim == 2)
ASSERT_BRIEF(cx.shape[0] == input.shape[0])
ASSERT_BRIEF(cx.shape[1] == weight_hh.shape[1])
#undef ASSERT_BRIEF
deduce_layout(
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new_expected,
c_new_expected, gates_expected);
......@@ -57,15 +88,15 @@ size_t get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& bias_ih, const TensorLayout& hx,
const TensorLayout& weight_hh, const TensorLayout& bias_hh,
const TensorLayout& cx, const TensorLayout& h_new, const TensorLayout& c_new,
const TensorLayout& gates, Handle* handle) {
const TensorLayout& /*cx*/, const TensorLayout& /*h_new*/,
const TensorLayout& /*c_new*/, const TensorLayout& gates, Handle* handle) {
TensorLayout tmp_layout;
auto opr = handle->create_operator<RNNCellForward>();
opr->param().nonlineMode = param::RNNCell::NonlineMode::IDENTITY;
opr->deduce_layout(input, weight_ih, bias_ih, hx, weight_hh, bias_hh, tmp_layout);
size_t rnn_cell_need = opr->get_workspace_in_bytes(
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, gates);
size_t lstm_cell_need = tmp_layout.span().dist_byte();
size_t lstm_cell_need = 2 * tmp_layout.span().dist_byte();
return rnn_cell_need > lstm_cell_need ? rnn_cell_need : lstm_cell_need;
}
......@@ -76,37 +107,48 @@ void exec(
_megdnn_tensor_out gates, _megdnn_workspace workspace, Handle* handle) {
auto opr = handle->create_operator<RNNCellForward>();
opr->param().nonlineMode = param::RNNCell::NonlineMode::IDENTITY;
/*TensorLayout tmp_layout;
opr->deduce_layout(input.layout, weight_ih.layout,
hx.layout, weight_hh.layout,
bias.layout, tmp_layout);
auto workspace_ptr = workspace.raw_ptr;
// TensorND tmp{static_cast<void*>(workspace.raw_ptr), tmp_layout};
TensorND tmp{workspace_ptr, tmp_layout};
auto new_workspace = Workspace{workspace_ptr + tmp.layout.span().dist_byte(),
workspace.size -
tmp.layout.span().dist_byte()};*/
// opr->exec(input, weight_ih, hx, weight_hh, bias, tmp, new_workspace);
opr->exec(input, weight_ih, bias_ih, hx, weight_hh, bias_hh, gates, workspace);
// activation
// size_t batch_size = tmp.layout.shape[0];
size_t batch_size = hx.layout.shape[0];
size_t hidden_size = hx.layout.shape[1];
// sigmoid: i f o
// TensorLayout gates_ifo_layout{TensorShape({batch_size, hidden_size * 3}),
// tmp.layout.dtype};
TensorND tmp{static_cast<void*>(workspace.raw_ptr), gates.layout};
auto copy_opr = handle->create_operator<TypeCvtForward>();
TensorND copy_gates{static_cast<void*>(workspace.raw_ptr), gates.layout};
TensorLayout hidden_layout{TensorShape{hidden_size}, hx.layout.dtype};
TensorLayout gateinfo_layout{TensorShape{batch_size, hidden_size}, hx.layout.dtype};
for (size_t i = 0; i < batch_size; i++) {
for (size_t j = 0; j < 4; j++) {
TensorND half_step_states{
// output
static_cast<uint8_t*>(gates.raw_ptr()) +
(4 * i + j) * hidden_layout.span().dist_byte(),
hidden_layout};
TensorND half_step_output{
static_cast<uint8_t*>(copy_gates.raw_ptr()) +
j * gateinfo_layout.span().dist_byte() +
i * hidden_layout.span().dist_byte(),
hidden_layout};
copy_opr->exec(half_step_states, half_step_output);
}
}
void* workspace_ptr = workspace.raw_ptr + copy_gates.layout.span().dist_byte();
copy_opr->exec(copy_gates, gates);
// sigmoid: i f
TensorND tmp{static_cast<void*>(workspace_ptr), copy_gates.layout};
TensorLayout gates_ifo_layout{
TensorShape({batch_size, hidden_size * 3}), gates.layout.dtype};
TensorND gates_ifo_origin{gates.raw_ptr(), gates_ifo_layout};
TensorShape({batch_size, hidden_size * 2}), copy_gates.layout.dtype};
TensorND gates_ifo_origin{copy_gates.raw_ptr(), gates_ifo_layout};
TensorND gates_ifo{tmp.raw_ptr(), gates_ifo_layout};
auto sigmoid = handle->create_operator<ElemwiseForward>();
sigmoid->param().mode = Elemwise::Param::Mode::SIGMOID;
sigmoid->exec({gates_ifo_origin}, gates_ifo);
// tanh: g
TensorLayout g_layout{TensorShape({batch_size, hidden_size}), gates.layout.dtype};
TensorLayout g_layout{
TensorShape({batch_size, hidden_size}), copy_gates.layout.dtype};
TensorND g_origin{
static_cast<char*>(gates.raw_ptr()) + gates_ifo_layout.span().dist_byte(),
static_cast<char*>(copy_gates.raw_ptr()) +
gates_ifo_layout.span().dist_byte(),
g_layout};
TensorND g{
static_cast<char*>(tmp.raw_ptr()) + gates_ifo_layout.span().dist_byte(),
......@@ -114,13 +156,24 @@ void exec(
auto tanh = handle->create_operator<ElemwiseForward>();
tanh->param().mode = Elemwise::Param::Mode::TANH;
tanh->exec({g_origin}, g);
// sigmoid: o
TensorLayout three_gates_ifo_layout{
TensorShape({batch_size, hidden_size * 3}), copy_gates.layout.dtype};
TensorLayout o_layout{
TensorShape({batch_size, hidden_size}), copy_gates.layout.dtype};
TensorND o_origin{
static_cast<char*>(copy_gates.raw_ptr()) +
three_gates_ifo_layout.span().dist_byte(),
o_layout};
TensorND o{
static_cast<char*>(tmp.raw_ptr()) +
three_gates_ifo_layout.span().dist_byte(),
o_layout};
sigmoid->exec({o_origin}, o);
// extract i f o
TensorND i{static_cast<char*>(tmp.raw_ptr()), g_layout};
TensorND f{
static_cast<char*>(tmp.raw_ptr()) + g_layout.span().dist_byte(), g_layout};
TensorND o{
static_cast<char*>(tmp.raw_ptr()) + g_layout.span().dist_byte() * 2,
g_layout};
// calculate new cell state
auto elewise_mul_add = handle->create_operator<ElemwiseForward>();
elewise_mul_add->param().mode = Elemwise::Param::Mode::FUSE_MUL_ADD4;
......
......@@ -139,8 +139,12 @@ DEF(LayerNormForward, 6, true, true);
DEF(LayerNormBackward, 8, true, true);
DEF(DropoutForward, 3, true, true);
DEF(DropoutBackward, 3, true, true);
DEF(RNNCellForward, 6, true, true);
DEF(RNNCellForward, 7, true, true);
DEF(RNNForward, 6, true, true);
DEF(RNNBackward, 10, true, true);
DEF(LSTMCellForward, 10, true, true);
DEF(LSTMForward, 8, true, true);
DEF(LSTMBackward, 13, true, true);
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -16,10 +16,8 @@ namespace megdnn {
void RNN::deduce_layout(
const TensorLayout& input, const TensorLayout& hx,
const TensorLayout& flatten_weights, TensorLayout& output, TensorLayout& hy,
const TensorLayout& /*flatten_weights*/, TensorLayout& output, TensorLayout& hy,
TensorLayout& reserve_space) {
// input: [seq_len, batch_size, input_size]
// hx: [D * num_layers, batch_size, hidden_size]
size_t seq_len = input.shape[0];
size_t batch_size = input.shape[1];
size_t D = param().bidirectional ? 2 : 1;
......@@ -27,22 +25,26 @@ void RNN::deduce_layout(
output = TensorLayout(
TensorShape{seq_len, batch_size, D * hidden_size}, input.dtype);
hy = TensorLayout(hx);
// reserve_space = {{get_reserve_size(this->handle(), param(), input)},
// dtype::Byte()};
reserve_space = {{get_reserve_size_in_bytes(input)}, dtype::Byte()};
reserve_space = {{get_reserve_size_in_bytes(input)}, input.dtype};
}
void RNN::check_exec(
const TensorLayout& input, const TensorLayout& hx,
const TensorLayout& flatten_weights, const TensorLayout& output,
const TensorLayout& hy, const TensorLayout& reserve_space,
size_t workspace_in_bytes) {
const TensorLayout& hy, const TensorLayout& /*reserve_space*/,
size_t /*workspace_in_bytes*/) {
auto errmsg = [&]() {
std::string msg;
msg.append("input=");
msg.append(input.to_string());
msg.append(", output=");
msg.append(output.to_string());
msg.append(", hx=");
msg.append(hx.to_string());
msg.append(", flatten_weights=");
msg.append(flatten_weights.to_string());
msg.append(", hy=");
msg.append(hy.to_string());
msg.append(", hidden_size=");
msg.append(std::to_string(param().hidden_size));
msg.append(", num_layers=");
......@@ -52,20 +54,38 @@ void RNN::check_exec(
return msg;
};
size_t D = param().bidirectional ? 2 : 1;
size_t b = param().bias ? 1 : 0;
size_t num_layers = param().num_layers;
size_t input_size = input.shape[2];
size_t gate_hidden_size = param().hidden_size;
// calculate size_dim1 the same as lstm
size_t size_dim1 = D * (input_size + param().hidden_size) +
(num_layers - 1) * D * ((D + 1) * param().hidden_size) +
b * 2 * D * num_layers;
#define ASSERT_BRIEF(_content) megdnn_assert(_content, "%s", errmsg().c_str());
ASSERT_BRIEF(hx.ndim == 3)
ASSERT_BRIEF(input.ndim == 3)
ASSERT_BRIEF(output.ndim == 3)
ASSERT_BRIEF(hy.ndim == 3)
ASSERT_BRIEF(flatten_weights.shape[0] == gate_hidden_size)
ASSERT_BRIEF(flatten_weights.shape[0] == size_dim1)
ASSERT_BRIEF(hx.shape[0] == D * num_layers)
ASSERT_BRIEF(hx.shape[1] == input.shape[1]) // batch_size
ASSERT_BRIEF(hx.shape[2] == param().hidden_size)
ASSERT_BRIEF(output.shape[0] == input.shape[0])
ASSERT_BRIEF(output.shape[1] == input.shape[1])
ASSERT_BRIEF(output.shape[2] == D * param().hidden_size)
ASSERT_BRIEF(hy.shape[0] == hx.shape[0])
ASSERT_BRIEF(hy.shape[1] == hx.shape[1])
ASSERT_BRIEF(hy.shape[2] == hx.shape[2])
#undef ASSERT_BRIEF
}
void RNNBackward::deduce_layout(
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
const TensorLayout& dy, const TensorLayout& dhy,
const TensorLayout& flatten_weights, const TensorLayout& reserve_space,
const TensorLayout& x, const TensorLayout& /*y*/, const TensorLayout& hx,
const TensorLayout& /*dy*/, const TensorLayout& /*dhy*/,
const TensorLayout& flatten_weights, const TensorLayout& /*reserve_space*/,
TensorLayout& dx, TensorLayout& dhx, TensorLayout& dw) {
dx = x;
dhx = hx;
......@@ -73,10 +93,11 @@ void RNNBackward::deduce_layout(
}
void RNNBackward::check_exec(
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
const TensorLayout& dy, const TensorLayout& dhy,
const TensorLayout& flatten_weights, const TensorLayout& reserve_space,
const TensorLayout& dx, const TensorLayout& dhx, const TensorLayout& dw,
size_t workspace_in_bytes) {}
const TensorLayout& /*x*/, const TensorLayout& /*y*/,
const TensorLayout& /*hx*/, const TensorLayout& /*dy*/,
const TensorLayout& /*dhy*/, const TensorLayout& /*flatten_weights*/,
const TensorLayout& /*reserve_space*/, const TensorLayout& /*dx*/,
const TensorLayout& /*dhx*/, const TensorLayout& /*dw*/,
size_t /*workspace_in_bytes*/) {}
} // namespace megdnn
......@@ -16,16 +16,11 @@ namespace megdnn {
void RNNCell::deduce_layout(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& bias_ih, const TensorLayout& hx,
const TensorLayout& weight_hh, const TensorLayout& bias_hh, TensorLayout& dst) {
// megdnn_assert(hx.ndim == 2);
const TensorLayout& /*bias_ih*/, const TensorLayout& hx,
const TensorLayout& /*weight_hh*/, const TensorLayout& /*bias_hh*/,
TensorLayout& dst) {
size_t batch_size = hx.shape[0];
// size_t hidden_size = weight_hh.shape[1];
size_t gate_hidden_size = weight_ih.shape[0];
// size_t input_size = weight_ih.shape[1];
// megdnn_assert(input.shape[1] == input_size);
// megdnn_assert(hx.shape[1] == hidden_size);
// megdnn_assert_eq_dtype(input, hx);
dst = TensorLayout(TensorShape({batch_size, gate_hidden_size}), input.dtype);
}
......@@ -36,6 +31,37 @@ void RNNCell::check_exec(
const TensorLayout& weight_hh, const TensorLayout& bias_hh,
const TensorLayout& dst, size_t workspace_in_bytes) {
TensorLayout dst_expected;
auto errmsg = [&]() {
std::string msg;
msg.append("input=");
msg.append(input.to_string());
msg.append(", weight_ih=");
msg.append(weight_ih.to_string());
msg.append(", bias_ih=");
msg.append(bias_ih.to_string());
msg.append(", hx=");
msg.append(hx.to_string());
msg.append(", weight_hh=");
msg.append(weight_hh.to_string());
msg.append(", bias_hh=");
msg.append(bias_hh.to_string());
msg.append(", dst=");
msg.append(dst.to_string());
return msg;
};
#define ASSERT_BRIEF(_content) megdnn_assert(_content, "%s", errmsg().c_str());
ASSERT_BRIEF(input.ndim == 2)
ASSERT_BRIEF(hx.ndim == 2)
ASSERT_BRIEF(hx.shape[0] == input.shape[0]) // batch
ASSERT_BRIEF(input.shape[1] == weight_ih.shape[1])
ASSERT_BRIEF(hx.shape[0] == dst.shape[0]) // batch
ASSERT_BRIEF(hx.shape[1] == dst.shape[1])
ASSERT_BRIEF(hx.shape[1] == weight_ih.shape[0]) // hidden_size
ASSERT_BRIEF(weight_ih.shape[0] == weight_hh.shape[0])
ASSERT_BRIEF(weight_hh.shape[0] == weight_hh.shape[1])
ASSERT_BRIEF(bias_ih.shape[0] == bias_hh.shape[0])
#undef ASSERT_BRIEF
megdnn_assert_eq_dtype(input, dst);
megdnn_assert_eq_dtype(hx, dst);
deduce_layout(input, weight_ih, bias_ih, hx, weight_hh, bias_hh, dst_expected);
......@@ -53,12 +79,15 @@ namespace rnn_cell {
size_t get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& bias_ih, const TensorLayout& hx,
const TensorLayout& weight_hh, const TensorLayout& bias_hh,
const TensorLayout& /*bias_ih*/, const TensorLayout& hx,
const TensorLayout& weight_hh, const TensorLayout& /*bias_hh*/,
const TensorLayout& dst, Handle* handle) {
auto opr = handle->create_operator<MatrixMulForward>();
opr->param().transposeB = true;
return dst.span().dist_byte() + opr->get_workspace_in_bytes(hx, weight_hh, dst);
return dst.span().dist_byte() +
std::max(
opr->get_workspace_in_bytes(hx, weight_hh, dst),
opr->get_workspace_in_bytes(input, weight_ih, dst));
}
void exec(
......@@ -74,14 +103,11 @@ void exec(
opr->param().transposeB = true;
opr->exec(input, weight_ih, tmp, new_workspace);
opr->exec(hx, weight_hh, dst, new_workspace);
// if (this->param().bias) add_bias(dst, tmp, bias, dst);
// if (this->param().bias) {
auto add_opr = handle->create_operator<ElemwiseForward>();
add_opr->param().mode = Elemwise::Param::Mode::ADD;
add_opr->exec({dst, tmp}, dst);
add_opr->exec({dst, bias_ih}, dst);
add_opr->exec({dst, bias_hh}, dst);
// }
// activation
using NonlineMode = param::RNNCell::NonlineMode;
......
......@@ -160,29 +160,6 @@ void TensorDesc::set(
}
}
void TensorDesc::set_nd(const TensorLayout& layout, int pad) {
int nbDims = layout.ndim < pad ? pad : layout.ndim;
int dimA[nbDims], strideA[nbDims];
for (size_t i = 0; i < layout.ndim; ++i) {
dimA[i] = layout.shape[i];
// strideA[i] = layout.stride[i];
}
for (size_t i = layout.ndim; i < nbDims; ++i) {
dimA[i] = 1; // unused
// strideA[i] = 1;
}
// stride
for (size_t i = 0; i < nbDims; ++i) {
strideA[i] = 1;
for (size_t j = i + 1; j < nbDims; ++j) {
strideA[i] *= dimA[j];
}
}
cudnn_check(cudnnSetTensorNdDescriptor(
desc, to_cudnn_dtype(layout.dtype), nbDims, dimA, strideA));
}
std::string TensorDesc::to_string() {
cudnnDataType_t data_type;
int n;
......@@ -456,97 +433,6 @@ void Conv3DDesc::set(const param::Convolution3D& param, const size_t nr_group) {
desc, 3, padA, filterStrideA, dilationA, mode, CUDNN_DATA_FLOAT));
}
DropoutDesc::DropoutDesc() {
cudnn_check(cudnnCreateDropoutDescriptor(&desc));
}
DropoutDesc::~DropoutDesc() {
cudnn_check(cudnnDestroyDropoutDescriptor(desc));
}
void DropoutDesc::set(float dropout, Handle* handle, TensorND& state) {
cudnn_check(cudnnSetDropoutDescriptor(
desc, cudnn_handle(handle), dropout, state.raw_ptr(),
state.layout.span().dist_byte(), 0 // seed
));
}
void DropoutDesc::set_no_dropout(Handle* handle) {
cudnn_check(
cudnnSetDropoutDescriptor(desc, cudnn_handle(handle), 0, nullptr, 0, 0));
}
RNNDesc::RNNDesc() {
cudnn_check(cudnnCreateRNNDescriptor(&desc));
}
RNNDesc::~RNNDesc() {
cudnn_check(cudnnDestroyRNNDescriptor(desc));
}
void RNNDesc::set(
size_t input_size, size_t hidden_size, size_t proj_size, size_t num_layers,
bool bidirectional, bool bias, const megdnn::DType dtype, cudnnRNNMode_t mode,
DropoutDesc& dropout_desc, Handle* handle) {
cudnnRNNMode_t rnn_mode = mode;
cudnnRNNBiasMode_t bias_mode = bias ? CUDNN_RNN_DOUBLE_BIAS : CUDNN_RNN_NO_BIAS;
cudnnDirectionMode_t dir_mode =
bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL;
cudnnDataType_t math_prec;
// math precision
if (dtype.enumv() == DTypeEnum::Float16)
math_prec = CUDNN_DATA_HALF;
else
math_prec = CUDNN_DATA_FLOAT;
#if false // CUDNN_MAJOR >= 8
cudnn_check(cudnnSetRNNDescriptor_v8(
desc, CUDNN_RNN_ALGO_STANDARD, mode, bias_mode, dir_mode,
CUDNN_LINEAR_INPUT, to_cudnn_dtype(dtype), math_prec, CUDNN_DEFAULT_MATH,
input_size, hidden_size, proj_size, num_layers, dropout_desc.desc,
CUDNN_RNN_PADDED_IO_DISABLED));
#else
cudnn_check(cudnnSetRNNDescriptor_v6(
cudnn_handle(handle), desc, hidden_size, num_layers, dropout_desc.desc,
CUDNN_LINEAR_INPUT, dir_mode, mode, CUDNN_RNN_ALGO_STANDARD, math_prec));
#endif
}
RNNDataDesc::RNNDataDesc() {
cudnn_check(cudnnCreateRNNDataDescriptor(&desc));
}
RNNDataDesc::~RNNDataDesc() {
cudnn_check(cudnnDestroyRNNDataDescriptor(desc));
}
void RNNDataDesc::set(
int batchSize, int vectorSize, int maxSeqLength, const int* devSeqLengths,
DType dtype) {
// for now, all tensor are padded in python
// int seqLengthArray[batchSize];
// for (int i = 0; i < batchSize; ++i) seqLengthArray[i] = maxSeqLength;
cudnn_check(cudnnSetRNNDataDescriptor(
desc, to_cudnn_dtype(dtype), CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED,
maxSeqLength, batchSize, vectorSize, devSeqLengths, nullptr));
}
RNNWeightFilterDesc::RNNWeightFilterDesc() {
cudnn_check(cudnnCreateFilterDescriptor(&desc));
}
RNNWeightFilterDesc::~RNNWeightFilterDesc() {
cudnn_check(cudnnDestroyFilterDescriptor(desc));
}
void RNNWeightFilterDesc::set(const TensorLayout& flatten_weights) {
int weight_elem_num = flatten_weights.total_nr_elems();
int dimW[] = {weight_elem_num, 1, 1};
cudnn_check(cudnnSetFilterNdDescriptor(
desc, to_cudnn_dtype(flatten_weights.dtype), CUDNN_TENSOR_NCHW, 3, dimW));
}
////////////////////////// CudnnAlgoPack //////////////////////////
#define V1(v) #v
......
......@@ -30,7 +30,6 @@ public:
void set(
const TensorLayout& layout,
const param::Convolution::Format = param::Convolution::Format::NCHW);
void set_nd(const TensorLayout& layout, int pad = 3); // at least 3 dimensions
std::string to_string();
~TensorDesc();
cudnnTensorDescriptor_t desc;
......@@ -122,44 +121,6 @@ public:
static const std::unordered_map<cudnnConvolutionFwdAlgo_t, Attr> conv3d_fwd_algos();
};
class DropoutDesc {
public:
DropoutDesc();
void set(float dropout, Handle* handle, TensorND& state);
void set_no_dropout(Handle* handle);
~DropoutDesc();
cudnnDropoutDescriptor_t desc;
};
class RNNDesc {
public:
RNNDesc();
void set(
size_t input_size, size_t hidden_size, size_t proj_size, size_t num_layers,
bool bidirectional, bool bias, const megdnn::DType dtype,
cudnnRNNMode_t mode, DropoutDesc& dropout_desc, Handle* handle);
~RNNDesc();
cudnnRNNDescriptor_t desc;
};
class RNNDataDesc {
public:
RNNDataDesc();
void set(
int batchSize, int vectorSize, int maxSeqLength, const int* devSeqLengths,
DType dtype);
~RNNDataDesc();
cudnnRNNDataDescriptor_t desc;
};
class RNNWeightFilterDesc {
public:
RNNWeightFilterDesc();
void set(const TensorLayout& flatten_weights);
~RNNWeightFilterDesc();
cudnnFilterDescriptor_t desc;
};
} // namespace cuda
} // namespace megdnn
......
......@@ -10,7 +10,7 @@
* implied.
*/
#include "src/common/handle_impl.h"
// #include "src/common/handle_impl.h"
#include "src/cuda/adaptive_pooling/opr_impl.h"
#include "src/cuda/add_update/opr_impl.h"
......@@ -52,8 +52,6 @@
#include "src/cuda/local_share/opr_impl.h"
#include "src/cuda/lrn/opr_impl.h"
#include "src/cuda/lsq/opr_impl.h"
#include "src/cuda/lstm/opr_impl.h"
#include "src/cuda/lstm_cell/opr_impl.h"
#include "src/cuda/mask_conv/opr_impl.h"
#include "src/cuda/matrix_inverse/opr_impl.h"
#include "src/cuda/matrix_mul/opr_impl.h"
......@@ -70,8 +68,6 @@
#include "src/cuda/repeat/opr_impl.h"
#include "src/cuda/resize/opr_impl.h"
#include "src/cuda/rng/opr_impl.h"
#include "src/cuda/rnn/opr_impl.h"
#include "src/cuda/rnn_cell/opr_impl.h"
#include "src/cuda/roi_align/opr_impl.h"
#include "src/cuda/roi_copy/opr_impl.h"
#include "src/cuda/roi_pooling/opr_impl.h"
......@@ -94,6 +90,7 @@
namespace megdnn {
namespace cuda {
// After Adding CUDA LSTM, the declaration of CUDA Backend should be restored
// MEGDNN_FOREACH_OPR_CLASS(MEGDNN_SPECIALIZE_CREATE_OPERATOR)
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionBackwardData);
......@@ -222,6 +219,8 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(PaddingForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(PaddingBackward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(LayerNormForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(LayerNormBackward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(DropoutForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(DropoutBackward);
template <typename Opr>
std::unique_ptr<Opr> HandleImpl::create_operator() {
......@@ -232,9 +231,11 @@ std::unique_ptr<Opr> HandleImpl::create_operator() {
#define MEGDNN_INST_CREATE_OPERATOR(opr) \
template std::unique_ptr<megdnn::opr> HandleImpl::create_operator();
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wpragmas"
#pragma GCC diagnostic ignored "-Winstantiation-after-specialization"
MEGDNN_FOREACH_OPR_CLASS(MEGDNN_INST_CREATE_OPERATOR)
#pragma GCC diagnostic pop
} // namespace cuda
} // namespace megdnn
......
/**
* \file dnn/src/cuda/lstm/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/cuda/lstm/opr_impl.h"
#include "src/cuda/lstm/utils.h"
#include "src/cuda/utils.h"
#include <cudnn.h>
namespace megdnn {
namespace cuda {
void LSTMImpl::exec(
_megdnn_tensor_in input, _megdnn_tensor_in hx, _megdnn_tensor_in cx,
_megdnn_tensor_in flatten_weights, _megdnn_tensor_out output,
_megdnn_tensor_out hy, _megdnn_tensor_out cy, _megdnn_tensor_out reserve_space,
_megdnn_workspace workspace) {
Handle* handle = this->handle();
rnn::RNNForwardDescHolder_v6 desc_holder =
lstm::get_RNNDescHolder_v6(this->handle(), param(), input.layout);
auto x_desc_arr = rnn::get_descs(desc_holder.x_descs);
auto y_desc_arr = rnn::get_descs(desc_holder.y_descs);
RNNWeightFilterDesc w_desc;
w_desc.set(flatten_weights.layout);
if (param().fwd_mode == param::LSTM::FwdMode::TRAINING) {
cudnn_check(cudnnRNNForwardTraining(
cudnn_handle(handle), desc_holder.rnn_desc.desc, desc_holder.seq_len,
x_desc_arr.data(), input.raw_ptr(), desc_holder.hx_desc.desc,
hx.raw_ptr(), desc_holder.cx_desc.desc, cx.raw_ptr(), w_desc.desc,
flatten_weights.raw_ptr(), y_desc_arr.data(), output.raw_ptr(),
desc_holder.hy_desc.desc, hy.raw_ptr(), desc_holder.cy_desc.desc,
cy.raw_ptr(), workspace.raw_ptr, desc_holder.workspace_size,
reserve_space.raw_ptr(), desc_holder.reserveSpace_size));
} else {
cudnn_check(cudnnRNNForwardInference(
cudnn_handle(handle), desc_holder.rnn_desc.desc, desc_holder.seq_len,
x_desc_arr.data(), input.raw_ptr(), desc_holder.hx_desc.desc,
hx.raw_ptr(), desc_holder.cx_desc.desc, nullptr, w_desc.desc,
flatten_weights.raw_ptr(), y_desc_arr.data(), output.raw_ptr(),
desc_holder.hy_desc.desc, hy.raw_ptr(), desc_holder.cy_desc.desc,
cy.raw_ptr(), workspace.raw_ptr, desc_holder.workspace_size));
}
}
size_t LSTMImpl::get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx,
const TensorLayout& flatten_weights, const TensorLayout& output,
const TensorLayout& hy, const TensorLayout& cy,
const TensorLayout& reserve_space) {
rnn::RNNForwardDescHolder_v6 desc_holder =
lstm::get_RNNDescHolder_v6(this->handle(), param(), input);
return desc_holder.workspace_size;
}
size_t LSTMImpl::get_reserve_size_in_bytes(const TensorLayout& input) {
rnn::RNNForwardDescHolder_v6 desc_holder =
lstm::get_RNNDescHolder_v6(this->handle(), param(), input);
return desc_holder.reserveSpace_size;
}
void LSTMBackwardImpl::exec(
_megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx,
_megdnn_tensor_in cx, _megdnn_tensor_in dy, _megdnn_tensor_in dhy,
_megdnn_tensor_in dcy, _megdnn_tensor_in flatten_weights,
_megdnn_tensor_in reserve_space, _megdnn_tensor_out dx, _megdnn_tensor_out dhx,
_megdnn_tensor_out dcx, _megdnn_tensor_out dw, _megdnn_workspace workspace) {
Handle* handle = this->handle();
size_t seq_len = x.layout.shape[0];
auto desc_holder = lstm::get_RNNDescHolder_v6(handle, param(), x.layout);
auto x_desc_arr_ptr = rnn::get_descs(desc_holder.x_descs).data();
auto y_desc_arr_ptr = rnn::get_descs(desc_holder.y_descs).data();
RNNWeightFilterDesc w_desc;
w_desc.set(flatten_weights.layout);
cudnn_check(cudnnRNNBackwardData(
cudnn_handle(handle), desc_holder.rnn_desc.desc, seq_len, y_desc_arr_ptr,
y.raw_ptr(), y_desc_arr_ptr, dy.raw_ptr(), desc_holder.hy_desc.desc,
dhy.raw_ptr(), desc_holder.cy_desc.desc, dcy.raw_ptr(), w_desc.desc,
flatten_weights.raw_ptr(), desc_holder.hx_desc.desc, hx.raw_ptr(),
desc_holder.cx_desc.desc, cx.raw_ptr(), x_desc_arr_ptr, dx.raw_ptr(),
desc_holder.hx_desc.desc, dhx.raw_ptr(), desc_holder.cx_desc.desc,
dcx.raw_ptr(), workspace.raw_ptr, desc_holder.workspace_size,
reserve_space.raw_ptr(), desc_holder.reserveSpace_size));
cudnn_check(cudnnRNNBackwardWeights(
cudnn_handle(handle), desc_holder.rnn_desc.desc, seq_len, x_desc_arr_ptr,
x.raw_ptr(), desc_holder.hx_desc.desc, hx.raw_ptr(), y_desc_arr_ptr,
y.raw_ptr(), workspace.raw_ptr, desc_holder.workspace_size, w_desc.desc,
dw.raw_ptr(), reserve_space.raw_ptr(), desc_holder.reserveSpace_size));
}
size_t LSTMBackwardImpl::get_workspace_in_bytes(
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy,
const TensorLayout& dcy, const TensorLayout& flatten_weights,
const TensorLayout& reserve_space, const TensorLayout& dx,
const TensorLayout& dhx, const TensorLayout& dcx, const TensorLayout& dw) {
auto desc_holder = lstm::get_RNNDescHolder_v6(this->handle(), param(), x);
return desc_holder.workspace_size;
}
} // namespace cuda
} // namespace megdnn
/**
* \file dnn/src/cuda/lstm/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megdnn/oprs.h"
namespace megdnn {
namespace cuda {
class LSTMImpl : public LSTM {
public:
using LSTM::LSTM;
void exec(
_megdnn_tensor_in input, _megdnn_tensor_in hx, _megdnn_tensor_in cx,
_megdnn_tensor_in flatten_weights, _megdnn_tensor_out output,
_megdnn_tensor_out hy, _megdnn_tensor_out cy,
_megdnn_tensor_out reserve_space, _megdnn_workspace workspace);
size_t get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx,
const TensorLayout& flatten_weights, const TensorLayout& output,
const TensorLayout& hy, const TensorLayout& cy,
const TensorLayout& reserve_space);
size_t get_reserve_size_in_bytes(const TensorLayout& input);
};
class LSTMBackwardImpl : public LSTMBackward {
public:
using LSTMBackward::LSTMBackward;
virtual void exec(
_megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx,
_megdnn_tensor_in cx, _megdnn_tensor_in dy, _megdnn_tensor_in dhy,
_megdnn_tensor_in dcy, _megdnn_tensor_in flatten_weights,
_megdnn_tensor_in reserve_space, _megdnn_tensor_out dx,
_megdnn_tensor_out dhx, _megdnn_tensor_out dcx, _megdnn_tensor_out dw,
_megdnn_workspace workspace);
virtual size_t get_workspace_in_bytes(
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy,
const TensorLayout& dcy, const TensorLayout& flatten_weights,
const TensorLayout& reserve_space, const TensorLayout& dx,
const TensorLayout& dhx, const TensorLayout& dcx, const TensorLayout& dw);
};
} // namespace cuda
} // namespace megdnn
\ No newline at end of file
/**
* \file dnn/src/cuda/lstm/utils.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/cuda/lstm/utils.h"
#include "src/cuda/utils.h"
#include <cudnn.h>
namespace megdnn {
namespace cuda {
namespace lstm {
RNNForwardDescHolder_v6 get_RNNDescHolder_v6(
Handle* handle, megdnn::LSTMForward::Param& _param, const TensorLayout& input) {
size_t seq_len = input.shape[0];
size_t batch_size = input.shape[1];
size_t input_size = input.shape[2];
cudnnRNNMode_t mode = CUDNN_LSTM;
using FwdMode = param::LSTM::FwdMode;
RNNForwardDescHolder_v6 desc_holder(
handle, seq_len, batch_size, _param.hidden_size, input_size,
_param.proj_size, _param.num_layers, _param.bidirectional, _param.bias,
input.dtype, mode);
return desc_holder;
}
} // namespace lstm
} // namespace cuda
} // namespace megdnn
\ No newline at end of file
/**
* \file dnn/src/cuda/lstm/utils.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/cuda/cudnn_wrapper.h"
#include "src/cuda/rnn/utils.h"
namespace megdnn {
namespace cuda {
namespace lstm {
using megdnn::cuda::rnn::RNNForwardDescHolder_v6;
RNNForwardDescHolder_v6 get_RNNDescHolder_v6(
Handle* handle, megdnn::LSTMForward::Param& _param, const TensorLayout& input);
} // namespace lstm
} // namespace cuda
} // namespace megdnn
\ No newline at end of file
/**
* \file dnn/src/cuda/lstm_cell/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/cuda/lstm_cell/opr_impl.h"
#include "megdnn/dtype.h"
#include "megdnn/oprs/base.h"
#include "src/common/lstm_cell.h"
#include "src/common/opr_delegate.h"
#include "src/common/utils.h"
namespace megdnn {
namespace cuda {
size_t LSTMCellImpl::get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& bias_ih, const TensorLayout& hx,
const TensorLayout& weight_hh, const TensorLayout& bias_hh,
const TensorLayout& cx, const TensorLayout& h_new, const TensorLayout& c_new,
const TensorLayout& gates) {
return megdnn::lstm_cell::get_workspace_in_bytes(
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, gates,
handle());
}
void LSTMCellImpl::exec(
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih,
_megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh,
_megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new,
_megdnn_tensor_out gates, _megdnn_workspace workspace) {
megdnn::lstm_cell::exec(
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, gates,
workspace, handle());
}
} // namespace cuda
} // namespace megdnn
\ No newline at end of file
/**
* \file dnn/src/cuda/lstm_cell/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megdnn/oprs.h"
#include "src/cuda/rnn_cell/opr_impl.h"
namespace megdnn {
namespace cuda {
class LSTMCellImpl : public LSTMCell {
public:
using LSTMCell::LSTMCell;
void exec(
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih,
_megdnn_tensor_in bias_ih, _megdnn_tensor_in hx,
_megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh,
_megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new,
_megdnn_tensor_out gates, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& bias_ih, const TensorLayout& hx,
const TensorLayout& weight_hh, const TensorLayout& bias_hh,
const TensorLayout& cx, const TensorLayout& h_new,
const TensorLayout& c_new, const TensorLayout& gates) override;
};
} // namespace cuda
} // namespace megdnn
\ No newline at end of file
/**
* \file dnn/src/cuda/rnn/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/cuda/rnn/opr_impl.h"
#include "src/common/rnn.h"
#include "src/cuda/utils.h"
//#include <cstring>
#include <cudnn.h>
#include <cstdlib>
#include <iostream>
namespace megdnn {
namespace cuda {
using namespace std;
void RNNImpl::exec(
_megdnn_tensor_in input, _megdnn_tensor_in hx,
_megdnn_tensor_in flatten_weights, _megdnn_tensor_out output,
_megdnn_tensor_out hy, _megdnn_tensor_out reserve_space,
_megdnn_workspace workspace) {
Handle* handle = this->handle();
#if false // CUDNN_MAJOR >= 8
rnn::RNNForwardDescHolder desc_holder = this->get_desc_holder(input.layout);
void* workspace_ptr = workspace.raw_ptr;
void* reserveSpace_ptr = static_cast<uint8_t*>(workspace_ptr) + desc_holder.workspace_size;
cudnn_check(cudnnRNNForward(
cudnn_handle(handle), desc_holder.rnn_desc.desc, desc_holder.fwdMode, desc_holder.devSeqLengths,
desc_holder.x_desc.desc, input.raw_ptr(), desc_holder.y_desc.desc, output.raw_ptr(),
desc_holder.h_desc.desc, hx.raw_ptr(), hy.raw_ptr(),
desc_holder.h_desc.desc, nullptr, nullptr,
desc_holder.weight_size, flatten_weights.raw_ptr(), desc_holder.workspace_size, workspace_ptr,
desc_holder.reserveSpace_size, reserveSpace_ptr
));
#else
rnn::RNNForwardDescHolder_v6 desc_holder =
rnn::get_RNNDescHolder_v6(this->handle(), param(), input.layout);
auto x_desc_arr = rnn::get_descs(desc_holder.x_descs);
auto y_desc_arr = rnn::get_descs(desc_holder.y_descs);
RNNWeightFilterDesc w_desc;
w_desc.set(flatten_weights.layout);
if (param().fwd_mode == param::RNN::FwdMode::TRAINING) {
cudnn_check(cudnnRNNForwardTraining(
cudnn_handle(handle), desc_holder.rnn_desc.desc, desc_holder.seq_len,
x_desc_arr.data(), input.raw_ptr(), desc_holder.hx_desc.desc,
hx.raw_ptr(), desc_holder.cx_desc.desc, NULL, w_desc.desc,
flatten_weights.raw_ptr(), y_desc_arr.data(), output.raw_ptr(),
desc_holder.hy_desc.desc, hy.raw_ptr(), desc_holder.cy_desc.desc, NULL,
workspace.raw_ptr, desc_holder.workspace_size, reserve_space.raw_ptr(),
desc_holder.reserveSpace_size));
} else {
cudnn_check(cudnnRNNForwardInference(
cudnn_handle(handle), desc_holder.rnn_desc.desc, desc_holder.seq_len,
x_desc_arr.data(), input.raw_ptr(), desc_holder.hx_desc.desc,
hx.raw_ptr(), desc_holder.cx_desc.desc, nullptr, w_desc.desc,
flatten_weights.raw_ptr(), y_desc_arr.data(), output.raw_ptr(),
desc_holder.hy_desc.desc, hy.raw_ptr(), desc_holder.cy_desc.desc,
nullptr, workspace.raw_ptr, desc_holder.workspace_size));
}
#endif
}
size_t RNNImpl::get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& hx,
const TensorLayout& flatten_weights, const TensorLayout& output,
const TensorLayout& hy, const TensorLayout& reserve_space) {
#if false // CUDNN_MAJOR >= 8
rnn::RNNForwardDescHolder desc_holder = this->get_desc_holder(input);
#else
rnn::RNNForwardDescHolder_v6 desc_holder =
rnn::get_RNNDescHolder_v6(this->handle(), param(), input);
#endif
return desc_holder.workspace_size;
}
size_t RNNImpl::get_reserve_size_in_bytes(const TensorLayout& input) {
rnn::RNNForwardDescHolder_v6 desc_holder =
rnn::get_RNNDescHolder_v6(this->handle(), param(), input);
return desc_holder.reserveSpace_size;
}
/*rnn::RNNForwardDescHolder RNNImpl::get_desc_holder(const TensorLayout& input) {
Handle* handle = this->handle();
size_t seq_len = input.shape[0];
size_t batch_size = input.shape[1];
size_t input_size = input.shape[2];
auto _param = param();
cudnnRNNMode_t mode;
using NonlineMode = param::RNN::NonlineMode;
switch (_param.nonlineMode) {
case NonlineMode::RELU:
mode = CUDNN_RNN_RELU;
break;
case NonlineMode::TANH:
mode = CUDNN_RNN_TANH;
break;
}
cudnnForwardMode_t fwdMode = CUDNN_FWD_MODE_TRAINING;
using FwdMode = param::RNN::FwdMode;
switch (_param.fwd_mode) {
case FwdMode::TRAINING:
fwdMode = CUDNN_FWD_MODE_TRAINING;
break;
case FwdMode::INFERENCE:
fwdMode = CUDNN_FWD_MODE_INFERENCE;
break;
}
rnn::RNNForwardDescHolder desc_holder(
handle, seq_len, batch_size, _param.hidden_size, input_size,
_param.proj_size, _param.num_layers, _param.bidirectional, _param.bias,
input.dtype, mode, fwdMode);
return desc_holder;
}*/
void RNNBackwardImpl::exec(
_megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx,
_megdnn_tensor_in dy, _megdnn_tensor_in dhy, _megdnn_tensor_in flatten_weights,
_megdnn_tensor_in reserve_space, _megdnn_tensor_out dx, _megdnn_tensor_out dhx,
_megdnn_tensor_out dw, _megdnn_workspace workspace) {
Handle* handle = this->handle();
size_t seq_len = x.layout.shape[0];
auto desc_holder = rnn::get_RNNDescHolder_v6(handle, param(), x.layout);
auto x_desc_arr_ptr = rnn::get_descs(desc_holder.x_descs).data();
auto y_desc_arr_ptr = rnn::get_descs(desc_holder.y_descs).data();
RNNWeightFilterDesc w_desc;
w_desc.set(flatten_weights.layout);
cudnn_check(cudnnRNNBackwardData(
cudnn_handle(handle), desc_holder.rnn_desc.desc, seq_len, y_desc_arr_ptr,
y.raw_ptr(), y_desc_arr_ptr, dy.raw_ptr(), desc_holder.hy_desc.desc,
dhy.raw_ptr(), desc_holder.cy_desc.desc, NULL, w_desc.desc,
flatten_weights.raw_ptr(), desc_holder.hx_desc.desc, hx.raw_ptr(),
desc_holder.cx_desc.desc, NULL, x_desc_arr_ptr, dx.raw_ptr(),
desc_holder.hx_desc.desc, dhx.raw_ptr(), desc_holder.cx_desc.desc, NULL,
workspace.raw_ptr, desc_holder.workspace_size, reserve_space.raw_ptr(),
desc_holder.reserveSpace_size));
cudnn_check(cudnnRNNBackwardWeights(
cudnn_handle(handle), desc_holder.rnn_desc.desc, seq_len, x_desc_arr_ptr,
x.raw_ptr(), desc_holder.hx_desc.desc, hx.raw_ptr(), y_desc_arr_ptr,
y.raw_ptr(), workspace.raw_ptr, desc_holder.workspace_size, w_desc.desc,
dw.raw_ptr(), reserve_space.raw_ptr(), desc_holder.reserveSpace_size));
}
size_t RNNBackwardImpl::get_workspace_in_bytes(
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
const TensorLayout& dy, const TensorLayout& dhy,
const TensorLayout& flatten_weights, const TensorLayout& reserve_space,
const TensorLayout& dx, const TensorLayout& dhx, const TensorLayout& dw) {
auto desc_holder = rnn::get_RNNDescHolder_v6(this->handle(), param(), x);
return desc_holder.workspace_size;
}
} // namespace cuda
} // namespace megdnn
/**
* \file dnn/src/cuda/rnn/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megdnn/oprs.h"
#include "src/cuda/cudnn_wrapper.h"
#include "src/cuda/rnn/utils.h"
namespace megdnn {
namespace cuda {
class RNNImpl : public RNN {
public:
using RNN::RNN;
void exec(
_megdnn_tensor_in input, _megdnn_tensor_in hx,
_megdnn_tensor_in flatten_weights, _megdnn_tensor_out output,
_megdnn_tensor_out hy, _megdnn_tensor_out reserve_space,
_megdnn_workspace workspace);
size_t get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& hx,
const TensorLayout& flatten_weights, const TensorLayout& output,
const TensorLayout& hy, const TensorLayout& reserve_space);
size_t get_reserve_size_in_bytes(const TensorLayout& input);
// private:
// rnn::RNNForwardDescHolder get_desc_holder(const TensorLayout& input);
};
class RNNBackwardImpl : public RNNBackward {
public:
using RNNBackward::RNNBackward;
virtual void exec(
_megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx,
_megdnn_tensor_in dy, _megdnn_tensor_in dhy,
_megdnn_tensor_in flatten_weights, _megdnn_tensor_in reserve_space,
_megdnn_tensor_out dx, _megdnn_tensor_out dhx, _megdnn_tensor_out dw,
_megdnn_workspace workspace);
virtual size_t get_workspace_in_bytes(
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
const TensorLayout& dy, const TensorLayout& dhy,
const TensorLayout& flatten_weights, const TensorLayout& reserve_space,
const TensorLayout& dx, const TensorLayout& dhx, const TensorLayout& dw);
};
} // namespace cuda
} // namespace megdnn
/**
* \file dnn/src/cuda/rnn/utils.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/cuda/rnn/utils.h"
#include "src/cuda/utils.h"
#include <cudnn.h>
namespace megdnn {
namespace cuda {
namespace rnn {
/*RNNForwardDescHolder::RNNForwardDescHolder(Handle* handle, size_t seq_len, size_t
batch_size, size_t hidden_size, size_t input_size, size_t proj_size, size_t num_layers,
bool bidirectional, bool bias, DType dtype, cudnnRNNMode_t _mode, cudnnForwardMode_t
_fwdMode) : mode(_mode), fwdMode(_fwdMode)
{
size_t D = bidirectional ? 2 : 1;
// TODO: set dropout to 0 in inference mode
dropout_desc.set_no_dropout(handle);
// seq len is unified (not packed)
// cuda_check(cudaMalloc((void**)&devSeqLengths, sizeof(int32_t) * batch_size));
devSeqLengths = (int32_t*)malloc(sizeof(int32_t) * batch_size);
for (size_t i = 0; i < batch_size; ++i) devSeqLengths[i] = seq_len;
// proj size should be smaller than hidden size according to cudnn api
// otherwise it is disabled
proj_size = (proj_size > hidden_size || proj_size == 0) ? hidden_size :
proj_size; rnn_desc.set( input_size, hidden_size, proj_size, num_layers, bidirectional,
bias, dtype, mode, dropout_desc, handle
);
x_desc.set(batch_size, input_size, seq_len, devSeqLengths, dtype);
y_desc.set(batch_size, D * proj_size, seq_len,
devSeqLengths, dtype);
h_desc.set_nd(TensorLayout(TensorShape{D * num_layers, batch_size, proj_size},
dtype));
cudnn_check(cudnnGetRNNWeightSpaceSize(cudnn_handle(handle), rnn_desc.desc,
&weight_size));
cudnn_check(cudnnGetRNNTempSpaceSizes(
cudnn_handle(handle), rnn_desc.desc, fwdMode, x_desc.desc,
&workspace_size, &reserveSpace_size
));
}
RNNForwardDescHolder::~RNNForwardDescHolder() {
// cuda_check(cudaFree(devSeqLengths));
free(devSeqLengths);
}*/
RNNForwardDescHolder_v6::RNNForwardDescHolder_v6(
Handle* handle, size_t seq_len, size_t batch_size, size_t hidden_size,
size_t input_size, size_t proj_size, size_t num_layers, bool bidirectional,
bool bias, DType dtype, cudnnRNNMode_t _mode)
: mode(_mode), seq_len(seq_len) {
size_t D = bidirectional ? 2 : 1;
// TODO: set dropout to 0 in inference mode
dropout_desc.set_no_dropout(handle);
proj_size = (proj_size > hidden_size || proj_size == 0) ? hidden_size : proj_size;
rnn_desc.set(
input_size, hidden_size, proj_size, num_layers, bidirectional, bias, dtype,
mode, dropout_desc, handle);
x_descs.resize(seq_len);
y_descs.resize(seq_len);
for (size_t i = 0; i < seq_len; ++i) {
x_descs[i].set_nd(TensorLayout(TensorShape{batch_size, input_size}, dtype), 3);
y_descs[i].set_nd(
TensorLayout(TensorShape{batch_size, D * hidden_size}, dtype), 3);
}
#define SET_H(_var) \
_var.set_nd(TensorLayout( \
TensorShape{D * num_layers, batch_size, hidden_size}, dtype));
SET_H(hx_desc)
SET_H(cx_desc)
SET_H(hy_desc)
SET_H(cy_desc)
#undef SET_H
std::vector<cudnnTensorDescriptor_t> x_desc_arr = get_descs(x_descs);
cudnn_check(cudnnGetRNNWorkspaceSize(
cudnn_handle(handle), rnn_desc.desc, seq_len, x_desc_arr.data(),
&workspace_size));
cudnn_check(cudnnGetRNNTrainingReserveSize(
cudnn_handle(handle), rnn_desc.desc, seq_len, x_desc_arr.data(),
&reserveSpace_size));
}
RNNForwardDescHolder_v6 get_RNNDescHolder_v6(
Handle* handle, megdnn::RNNForward::Param& _param, const TensorLayout& input) {
size_t seq_len = input.shape[0];
size_t batch_size = input.shape[1];
size_t input_size = input.shape[2];
cudnnRNNMode_t mode;
using NonlineMode = param::RNN::NonlineMode;
switch (_param.nonlineMode) {
case NonlineMode::RELU:
mode = CUDNN_RNN_RELU;
break;
case NonlineMode::TANH:
mode = CUDNN_RNN_TANH;
break;
}
RNNForwardDescHolder_v6 desc_holder(
handle, seq_len, batch_size, _param.hidden_size, input_size,
_param.proj_size, _param.num_layers, _param.bidirectional, _param.bias,
input.dtype, mode);
return desc_holder;
}
std::vector<cudnnTensorDescriptor_t> get_descs(const std::vector<TensorDesc>& descs) {
std::vector<cudnnTensorDescriptor_t> r;
r.reserve(descs.size());
for (auto& desc : descs) {
r.emplace_back(desc.desc);
}
return r;
}
} // namespace rnn
} // namespace cuda
} // namespace megdnn
\ No newline at end of file
/**
* \file dnn/src/cuda/rnn/utils.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/cuda/cudnn_wrapper.h"
namespace megdnn {
namespace cuda {
namespace rnn {
// v8, not for now
/*struct RNNForwardDescHolder {
int32_t* devSeqLengths;
cudnnRNNMode_t mode;
cudnnForwardMode_t fwdMode;
RNNDesc rnn_desc;
DropoutDesc dropout_desc;
RNNDataDesc x_desc, y_desc;
TensorDesc h_desc;
size_t weight_size, workspace_size, reserveSpace_size;
RNNForwardDescHolder(Handle* handle, size_t seq_len, size_t batch_size, size_t
hidden_size, size_t input_size, size_t proj_size, size_t num_layers, bool bidirectional,
bool bias, DType dtype,
cudnnRNNMode_t _mode, cudnnForwardMode_t _fwdMode); ~RNNForwardDescHolder();
};*/
struct RNNForwardDescHolder_v6 {
cudnnRNNMode_t mode;
RNNDesc rnn_desc;
int seq_len;
DropoutDesc dropout_desc;
std::vector<TensorDesc> x_descs, y_descs;
TensorDesc hx_desc, cx_desc, hy_desc, cy_desc;
size_t workspace_size, reserveSpace_size;
RNNForwardDescHolder_v6(
Handle* handle, size_t seq_len, size_t batch_size, size_t hidden_size,
size_t input_size, size_t proj_size, size_t num_layers, bool bidirectional,
bool bias, DType dtype, cudnnRNNMode_t _mode);
};
RNNForwardDescHolder_v6 get_RNNDescHolder_v6(
Handle* handle, megdnn::RNNForward::Param& _param, const TensorLayout& input);
std::vector<cudnnTensorDescriptor_t> get_descs(const std::vector<TensorDesc>& descs);
} // namespace rnn
} // namespace cuda
} // namespace megdnn
\ No newline at end of file
/**
* \file dnn/src/cuda/rnn_cell/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/cuda/rnn_cell/opr_impl.h"
#include "src/common/rnn_cell.h"
namespace megdnn {
namespace cuda {
size_t RNNCellImpl::get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& bias_ih, const TensorLayout& hx,
const TensorLayout& weight_hh, const TensorLayout& bias_hh,
const TensorLayout& dst) {
return megdnn::rnn_cell::get_workspace_in_bytes(
input, weight_ih, bias_hh, hx, weight_hh, bias_hh, dst, handle());
}
void RNNCellImpl::exec(
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih,
_megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh,
_megdnn_tensor_out dst, _megdnn_workspace workspace) {
megdnn::rnn_cell::exec(
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, dst, workspace,
param().nonlineMode, handle());
}
} // namespace cuda
} // namespace megdnn
\ No newline at end of file
/**
* \file dnn/src/cuda/rnn_cell/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megdnn/oprs.h"
namespace megdnn {
namespace cuda {
class RNNCellImpl : public RNNCell {
public:
using RNNCell::RNNCell;
void exec(
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih,
_megdnn_tensor_in bias_ih, _megdnn_tensor_in hx,
_megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh,
_megdnn_tensor_out dst, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& bias_ih, const TensorLayout& hx,
const TensorLayout& weight_hh, const TensorLayout& bias_hh,
const TensorLayout& dst) override;
/*
private:
void add_bias(_megdnn_tensor_in A,
_megdnn_tensor_in B,
_megdnn_tensor_in bias,
_megdnn_tensor_out C);
*/
};
} // namespace cuda
} // namespace megdnn
\ No newline at end of file
......@@ -9,7 +9,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/naive/handle.h"
#include "src/common/handle_impl.h"
......@@ -140,4 +139,5 @@ MEGDNN_FOREACH_OPR_CLASS(MEGDNN_SPECIALIZE_CREATE_OPERATOR)
} // namespace naive
} // namespace megdnn
// vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen
\ No newline at end of file
......@@ -12,6 +12,9 @@
#include "src/naive/rnn/funcs.h"
#include "src/naive/rnn/rnn.h"
#include "midout.h"
MIDOUT_DECL(megdnn_naive_lstm_fwd)
namespace megdnn {
namespace naive {
using rnn::LSTMCellWeightWrapper;
......@@ -21,29 +24,32 @@ void LSTMImpl::exec(
_megdnn_tensor_in flatten_weights, _megdnn_tensor_out output,
_megdnn_tensor_out hy, _megdnn_tensor_out cy, _megdnn_tensor_out reserve_space,
_megdnn_workspace workspace) {
auto _param = param();
size_t D = _param.bidirectional ? 2 : 1;
size_t num_layers = _param.num_layers;
size_t input_size = input.layout.shape[2];
std::vector<LSTMCellWeightWrapper> cells;
size_t used_workspace_size = rnn::get_cells<LSTMCellWeightWrapper>(
D, num_layers, input_size, _param.hidden_size, _param.bias, cells,
flatten_weights, workspace);
MIDOUT_BEGIN(megdnn_naive_lstm_fwd) {
auto _param = param();
size_t D = _param.bidirectional ? 2 : 1;
size_t num_layers = _param.num_layers;
size_t input_size = input.layout.shape[2];
std::vector<LSTMCellWeightWrapper> cells;
size_t used_workspace_size = rnn::get_cells<LSTMCellWeightWrapper>(
D, num_layers, input_size, _param.hidden_size, _param.bias, cells,
flatten_weights, workspace);
Workspace new_workspace(
workspace.raw_ptr + used_workspace_size,
workspace.size - used_workspace_size);
TensorNDArray states = {hx, cx}, states_new = {hy, cy};
rnn::exec_internal<LSTMCellWeightWrapper, LSTMCellForward>(
cells, input, states, states_new, output, reserve_space, num_layers, D,
this->handle(), new_workspace);
Workspace new_workspace(
workspace.raw_ptr + used_workspace_size,
workspace.size - used_workspace_size);
TensorNDArray states = {hx, cx}, states_new = {hy, cy};
rnn::exec_internal<LSTMCellWeightWrapper, LSTMCellForward>(
cells, input, states, states_new, output, reserve_space, num_layers, D,
param::RNNCell::NonlineMode::IDENTITY, this->handle(), new_workspace);
}
MIDOUT_END();
}
size_t LSTMImpl::get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx,
const TensorLayout& flatten_weights, const TensorLayout& output,
const TensorLayout& hy, const TensorLayout& cy,
const TensorLayout& reserve_space) {
const TensorLayout& input, const TensorLayout& /*hx*/,
const TensorLayout& /*cx*/, const TensorLayout& flatten_weights,
const TensorLayout& output, const TensorLayout& /*hy*/,
const TensorLayout& /*cy*/, const TensorLayout& /*reserve_space*/) {
size_t workspace_size = rnn::get_workspace_in_bytes<LSTMCellForward>(
input, flatten_weights, param().hidden_size, param().bidirectional ? 2 : 1,
this->handle());
......@@ -77,6 +83,7 @@ void LSTMBackwardImpl::exec(
size_t num_layers = param().num_layers;
size_t D = param().bidirectional ? 2 : 1;
size_t input_size = x.layout.shape[2];
size_t batch_size = x.layout.shape[1];
size_t hidden_size = param().hidden_size;
size_t used_workspace_size = 0;
......@@ -90,10 +97,27 @@ void LSTMBackwardImpl::exec(
Workspace new_workspace = Workspace(
workspace.raw_ptr + used_workspace_size,
workspace.size - used_workspace_size);
TensorNDArray states = {hx, cx};
std::vector<TensorNDArray> hx_param;
TensorLayout unfold_hx_layout{
TensorShape{batch_size, hidden_size}, hx.layout.dtype};
for (size_t layer = 0; layer < num_layers; ++layer) {
for (size_t d = 0; d < D; ++d) {
TensorNDArray unfold_hx;
size_t idx = layer * D + d;
size_t states_offset = idx * unfold_hx_layout.span().dist_byte();
for (size_t i = 0; i < states.size(); ++i) {
unfold_hx.push_back(TensorND{
static_cast<uint8_t*>(states[i].raw_ptr()) + states_offset,
unfold_hx_layout});
}
hx_param.push_back(unfold_hx);
}
}
used_workspace_size += rnn::get_inputs_for_exec<LSTMCellWeightWrapper>(
x, y, reserve_space, num_layers, D, hidden_size, cells, layer_inputs,
layer_outputs, cell_seq_states, param::RNNCell::NonlineMode::IDENTITY,
new_workspace);
x, y, hx_param, reserve_space, num_layers, D, hidden_size, cells,
layer_inputs, layer_outputs, cell_seq_states,
param::RNNCell::NonlineMode::IDENTITY, new_workspace);
// dhy arr, dhx arr
TensorNDArray dhy_arr = {dhy, dcy}, dhx_arr = {dhx, dcx};
......@@ -110,11 +134,12 @@ void LSTMBackwardImpl::exec(
}
size_t LSTMBackwardImpl::get_workspace_in_bytes(
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy,
const TensorLayout& dcy, const TensorLayout& flatten_weights,
const TensorLayout& reserve_space, const TensorLayout& dx,
const TensorLayout& dhx, const TensorLayout& dcx, const TensorLayout& dw) {
const TensorLayout& x, const TensorLayout& y, const TensorLayout& /*hx*/,
const TensorLayout& /*cx*/, const TensorLayout& /*dy*/,
const TensorLayout& /*dhy*/, const TensorLayout& /*dcy*/,
const TensorLayout& flatten_weights, const TensorLayout& /*reserve_space*/,
const TensorLayout& /*dx*/, const TensorLayout& /*dhx*/,
const TensorLayout& /*dcx*/, const TensorLayout& /*dw*/) {
size_t D = param().bidirectional ? 2 : 1;
size_t num_layers = param().num_layers;
size_t hidden_size = param().hidden_size;
......@@ -142,5 +167,6 @@ size_t LSTMBackwardImpl::get_workspace_in_bytes(
return workspace_size;
}
} // namespace naive
} // namespace megdnn
} // namespace megdnn
\ No newline at end of file
// vim: syntax=cpp.doxygen
\ No newline at end of file
......@@ -22,14 +22,16 @@ public:
_megdnn_tensor_in input, _megdnn_tensor_in hx, _megdnn_tensor_in cx,
_megdnn_tensor_in flatten_weights, _megdnn_tensor_out output,
_megdnn_tensor_out hy, _megdnn_tensor_out cy,
_megdnn_tensor_out reserve_space, _megdnn_workspace workspace);
_megdnn_tensor_out reserve_space, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx,
const TensorLayout& flatten_weights, const TensorLayout& output,
const TensorLayout& hy, const TensorLayout& cy,
const TensorLayout& reserve_space);
size_t get_reserve_size_in_bytes(const TensorLayout& input);
const TensorLayout& reserve_space) override;
size_t get_reserve_size_in_bytes(const TensorLayout& input) override;
bool is_thread_safe() const override { return true; }
};
class LSTMBackwardImpl : public LSTMBackward {
......@@ -42,14 +44,17 @@ public:
_megdnn_tensor_in dcy, _megdnn_tensor_in flatten_weights,
_megdnn_tensor_in reserve_space, _megdnn_tensor_out dx,
_megdnn_tensor_out dhx, _megdnn_tensor_out dcx, _megdnn_tensor_out dw,
_megdnn_workspace workspace);
_megdnn_workspace workspace) override;
bool is_thread_safe() const override { return true; }
virtual size_t get_workspace_in_bytes(
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy,
const TensorLayout& dcy, const TensorLayout& flatten_weights,
const TensorLayout& reserve_space, const TensorLayout& dx,
const TensorLayout& dhx, const TensorLayout& dcx, const TensorLayout& dw);
const TensorLayout& dhx, const TensorLayout& dcx,
const TensorLayout& dw) override;
};
} // namespace naive
......
......@@ -19,7 +19,8 @@ void cell_opr_exec<LSTMCellForward>(
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih,
_megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_ih,
_megdnn_tensor_in bias_hh, const TensorNDArray& states,
TensorNDArray& states_new, _megdnn_workspace workspace, Handle* handle) {
TensorNDArray& states_new, _megdnn_workspace workspace,
param::RNNCell::NonlineMode /*nonline_mode*/, Handle* handle) {
auto opr = handle->create_operator<LSTMCellForward>();
TensorLayout gates, h_new, c_new;
opr->deduce_layout(
......
......@@ -11,6 +11,9 @@
#include "src/naive/lstm_cell/opr_impl.h"
#include "src/common/lstm_cell.h"
#include "midout.h"
MIDOUT_DECL(megdnn_naive_lstmcell_fwd)
namespace megdnn {
namespace naive {
size_t LSTMCellImpl::get_workspace_in_bytes(
......@@ -29,9 +32,12 @@ void LSTMCellImpl::exec(
_megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh,
_megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new,
_megdnn_tensor_out gates, _megdnn_workspace workspace) {
megdnn::lstm_cell::exec(
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, gates,
workspace, handle());
MIDOUT_BEGIN(megdnn_naive_lstmcell_fwd) {
megdnn::lstm_cell::exec(
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new,
gates, workspace, handle());
}
MIDOUT_END();
}
} // namespace naive
......
......@@ -30,6 +30,8 @@ public:
const TensorLayout& weight_hh, const TensorLayout& bias_hh,
const TensorLayout& cx, const TensorLayout& h_new,
const TensorLayout& c_new, const TensorLayout& gates) override;
bool is_thread_safe() const override { return true; }
};
} // namespace naive
......
此差异已折叠。
此差异已折叠。
......@@ -22,6 +22,9 @@
#include <cstring>
#include "midout.h"
MIDOUT_DECL(megdnn_naive_rnn_fwd)
namespace megdnn {
namespace naive {
......@@ -32,33 +35,40 @@ void RNNImpl::exec(
_megdnn_tensor_in flatten_weights, _megdnn_tensor_out output,
_megdnn_tensor_out hy, _megdnn_tensor_out reserve_space,
_megdnn_workspace workspace) {
auto _param = param();
size_t D = _param.bidirectional ? 2 : 1;
size_t num_layers = _param.num_layers;
size_t input_size = input.layout.shape[2];
std::vector<RNNCellWeightWrapper> cells;
size_t used_workspace_size = rnn::get_cells<RNNCellWeightWrapper>(
D, num_layers, input_size, _param.hidden_size, _param.bias, cells,
flatten_weights, workspace);
Workspace new_workspace(
workspace.raw_ptr + used_workspace_size,
workspace.size - used_workspace_size);
TensorNDArray states, states_new;
states.push_back(hx);
states_new.push_back(hy);
rnn::exec_internal<RNNCellWeightWrapper, RNNCellForward>(
cells, input, states, states_new, output, reserve_space, num_layers, D,
this->handle(), new_workspace);
MIDOUT_BEGIN(megdnn_naive_rnn_fwd) {
auto _param = param();
size_t D = _param.bidirectional ? 2 : 1;
size_t num_layers = _param.num_layers;
size_t input_size = input.layout.shape[2];
std::vector<RNNCellWeightWrapper> cells;
size_t used_workspace_size = rnn::get_cells<RNNCellWeightWrapper>(
D, num_layers, input_size, _param.hidden_size, _param.bias, cells,
flatten_weights, workspace);
Workspace new_workspace(
workspace.raw_ptr + used_workspace_size,
workspace.size - used_workspace_size);
TensorNDArray states, states_new;
states.push_back(hx);
states_new.push_back(hy);
rnn::exec_internal<RNNCellWeightWrapper, RNNCellForward>(
cells, input, states, states_new, output, reserve_space, num_layers, D,
_param.nonlineMode, this->handle(), new_workspace);
}
MIDOUT_END();
}
size_t RNNImpl::get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& hx,
const TensorLayout& flatten_weights, const TensorLayout& output,
const TensorLayout& hy, const TensorLayout& reserve_space) {
const TensorLayout& /*hy*/, const TensorLayout& /*reserve_space*/) {
auto _param = param();
size_t D = _param.bidirectional ? 2 : 1;
size_t last_dim = std::max(input.shape[2], D * hx.shape[1]);
TensorLayout last_input = {{input.shape[0], input.shape[1], last_dim}, input.dtype};
size_t workspace_size = rnn::get_workspace_in_bytes<RNNCellForward>(
input, flatten_weights, param().hidden_size, param().bidirectional ? 2 : 1,
this->handle());
last_input, flatten_weights, param().hidden_size,
param().bidirectional ? 2 : 1, this->handle());
if (!param().bias) { // use fake bias (all 0)
TensorLayout bias_layout = {{param().hidden_size}, flatten_weights.dtype};
workspace_size += bias_layout.span().dist_byte();
......@@ -82,50 +92,23 @@ void RNNBackwardImpl::exec(
_megdnn_tensor_in reserve_space, _megdnn_tensor_out dx, _megdnn_tensor_out dhx,
_megdnn_tensor_out dw, _megdnn_workspace workspace) {
TensorNDArray layer_inputs;
// layer_inputs.push_back(x);
TensorNDArray layer_outputs;
std::vector<std::vector<TensorNDArray>> cell_seq_states;
size_t num_layers = param().num_layers;
size_t D = param().bidirectional ? 2 : 1;
// size_t seq_len = x.layout.shape[0];
// size_t batch_size = x.layout.shape[1];
size_t input_size = x.layout.shape[2];
size_t batch_size = x.layout.shape[1];
size_t hidden_size = param().hidden_size;
size_t used_workspace_size = 0;
// get cells
std::vector<RNNCellWeightWrapper> cells;
// workspace_ptr = static_cast<uint8_t*>(workspace_ptr) +
used_workspace_size += rnn::get_cells(
D, num_layers, input_size, hidden_size, param().bias, cells,
flatten_weights, workspace);
// extract intermedia states from reserve space
/*for (int layer = 0; layer < num_layers; ++layer) {
TensorND layer_output{workspace_ptr, y.layout};
workspace_ptr = static_cast<uint8_t*>(workspace_ptr) +
layer_output.layout.span().dist_byte(); for (int d = 0; d < D; ++d) {
cell_seq_states.push_back(std::vector<TensorNDArray>());
// reverse direction is stored with reversed order of sequence order
for (int i = 0; i < seq_len; ++i) {
size_t step = i;
if (d == 1) step = seq_len - i - 1;
size_t offset = ((layer * D + d) * seq_len + step) *
cell_output_layout.span().dist_byte(); TensorND
hy{static_cast<uint8_t*>(reserve_space.raw_ptr) + offset, cell_output_layout};
// states
cell_seq_states[cell_seq_states.size() - 1].push_back({hy});
// output
offset = i * D * cell_output_layout.span().dist_byte();
memcpy(static_cast<uint8_t*>(layer_output.raw_ptr) + offset,
hy.raw_ptr, hy.layout.span().dist_byte());
}
}
cell_seq_outputs.push_back(layer_output);
if (layer != num_layers - 1) layer_inputs.push_back(layer_output);
}*/
// nonlinear mode
param::RNNCell::NonlineMode nonlineMode;
param::RNNCell::NonlineMode nonlineMode = param::RNNCell::NonlineMode::TANH;
using ModeRNN = param::RNN::NonlineMode;
using ModeRNNCell = param::RNNCell::NonlineMode;
switch (param().nonlineMode) {
......@@ -135,22 +118,34 @@ void RNNBackwardImpl::exec(
case ModeRNN::TANH:
nonlineMode = ModeRNNCell::TANH;
break;
case ModeRNN::IDENTITY:
break;
}
// get formatted inputs
Workspace new_workspace = Workspace(
workspace.raw_ptr + used_workspace_size,
workspace.size - used_workspace_size);
TensorLayout unfold_hx_layout{
TensorShape{batch_size, hidden_size}, hx.layout.dtype};
std::vector<TensorNDArray> hx_param;
for (size_t layer = 0; layer < num_layers; ++layer) {
for (size_t d = 0; d < D; ++d) {
TensorNDArray unfold_hx;
size_t idx = layer * D + d;
size_t states_offset = idx * unfold_hx_layout.span().dist_byte();
unfold_hx.push_back(TensorND{
static_cast<uint8_t*>(hx.raw_ptr()) + states_offset,
unfold_hx_layout});
hx_param.push_back(unfold_hx);
}
}
used_workspace_size += rnn::get_inputs_for_exec<RNNCellWeightWrapper>(
x, y, reserve_space, num_layers, D, hidden_size, cells, layer_inputs,
layer_outputs, cell_seq_states, nonlineMode, new_workspace);
x, y, hx_param, reserve_space, num_layers, D, hidden_size, cells,
layer_inputs, layer_outputs, cell_seq_states, nonlineMode, new_workspace);
// dhy arr, dhx arr
TensorNDArray dhy_arr = {dhy}, dhx_arr = {dhx};
// exec
/*size_t used_workspace_size = static_cast<uint8_t*>(workspace_ptr) -
static_cast<uint8_t*>((void*)workspace.raw_ptr);*/
new_workspace = Workspace(
workspace.raw_ptr + used_workspace_size,
workspace.size - used_workspace_size);
......@@ -161,10 +156,11 @@ void RNNBackwardImpl::exec(
}
size_t RNNBackwardImpl::get_workspace_in_bytes(
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
const TensorLayout& dy, const TensorLayout& dhy,
const TensorLayout& flatten_weights, const TensorLayout& reserve_space,
const TensorLayout& dx, const TensorLayout& dhx, const TensorLayout& dw) {
const TensorLayout& x, const TensorLayout& y, const TensorLayout& /*hx*/,
const TensorLayout& /*dy*/, const TensorLayout& /*dhy*/,
const TensorLayout& flatten_weights, const TensorLayout& /*reserve_space*/,
const TensorLayout& /*dx*/, const TensorLayout& /*dhx*/,
const TensorLayout& /*dw*/) {
size_t D = param().bidirectional ? 2 : 1;
size_t num_layers = param().num_layers;
size_t hidden_size = param().hidden_size;
......
......@@ -22,13 +22,14 @@ public:
_megdnn_tensor_in input, _megdnn_tensor_in hx,
_megdnn_tensor_in flatten_weights, _megdnn_tensor_out output,
_megdnn_tensor_out hy, _megdnn_tensor_out reserve_space,
_megdnn_workspace workspace);
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& hx,
const TensorLayout& flatten_weights, const TensorLayout& output,
const TensorLayout& hy, const TensorLayout& reserve_space);
size_t get_reserve_size_in_bytes(const TensorLayout& input);
const TensorLayout& hy, const TensorLayout& reserve_space) override;
size_t get_reserve_size_in_bytes(const TensorLayout& input) override;
bool is_thread_safe() const override { return true; }
};
class RNNBackwardImpl : public RNNBackward {
......@@ -40,13 +41,15 @@ public:
_megdnn_tensor_in dy, _megdnn_tensor_in dhy,
_megdnn_tensor_in flatten_weights, _megdnn_tensor_in reserve_space,
_megdnn_tensor_out dx, _megdnn_tensor_out dhx, _megdnn_tensor_out dw,
_megdnn_workspace workspace);
_megdnn_workspace workspace) override;
virtual size_t get_workspace_in_bytes(
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
const TensorLayout& dy, const TensorLayout& dhy,
const TensorLayout& flatten_weights, const TensorLayout& reserve_space,
const TensorLayout& dx, const TensorLayout& dhx, const TensorLayout& dw);
const TensorLayout& dx, const TensorLayout& dhx,
const TensorLayout& dw) override;
bool is_thread_safe() const override { return true; }
};
} // namespace naive
......
......@@ -218,7 +218,11 @@ void LSTMCellWeightWrapper::backward(
x, weight_ih, bias_ih, states[0], weight_hh, bias_hh, states[1], dstates[0],
dstates[1], gates_tensor,
new_workspace); // no information left in the workspace
// i, f, o, g
// BUG: The order of gate_grad if i_g f_g o_g g_g , but it should be i_g f_g g_g o_g
// The returned gradient includes both horizontal and vertical gradients,
// horizontal grad = douts[1] vertical gradients = douts[1]
// Here the variable is confusing !!!
TensorLayout single_gate = {{gates.shape[0], gates.shape[1] / 4}, gates.dtype};
TensorND i, f, o, g, i_grad, f_grad, o_grad,
g_grad; // grad refers to the grad of gates before activation
......@@ -239,8 +243,8 @@ void LSTMCellWeightWrapper::backward(
g_grad = {
static_cast<uint8_t*>(o_grad.raw_ptr()) + single_gate.span().dist_byte(),
single_gate};
// activation
auto elem_opr = handle->create_operator<ElemwiseForward>();
elem_opr->param().mode = Elemwise::Mode::SIGMOID;
elem_opr->exec({i}, i);
elem_opr->exec({f}, f);
......@@ -254,8 +258,8 @@ void LSTMCellWeightWrapper::backward(
mul_opr->exec({douts[0], tanh_cy}, dstates[0]);
elem_opr->param().mode = Elemwise::Mode::SIGMOID_GRAD;
elem_opr->exec({o, dstates[0]}, o_grad); // grad of gate o
// use dstates[0] as tmp tensor to store dhy * o
mul_opr->exec({douts[0], o}, dstates[0]);
elem_opr->param().mode = Elemwise::Mode::TANH_GRAD;
elem_opr->exec({tanh_cy, dstates[0]}, dstates[1]); // grad of cy from hy
elem_opr->param().mode = Elemwise::Mode::ADD;
......
......@@ -38,6 +38,7 @@ public:
_megdnn_tensor_out dwi, _megdnn_tensor_out dwh, _megdnn_tensor_out dbias,
_megdnn_workspace workspace) const;
virtual size_t num_states() const;
virtual ~CellWeightsWrapperBase() {}
};
class RNNCellWeightWrapper : public CellWeightsWrapperBase {
......
......@@ -19,8 +19,10 @@ void cell_opr_exec<RNNCellForward>(
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih,
_megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_ih,
_megdnn_tensor_in bias_hh, const TensorNDArray& states,
TensorNDArray& states_new, _megdnn_workspace workspace, Handle* handle) {
TensorNDArray& states_new, _megdnn_workspace workspace,
param::RNNCell::NonlineMode nonline_mode, Handle* handle) {
auto opr = handle->create_operator<RNNCellForward>();
opr->param().nonlineMode = nonline_mode;
opr->exec(
input, weight_ih, bias_ih, states[0], weight_hh, bias_hh, states_new[0],
workspace);
......
......@@ -11,6 +11,9 @@
#include "src/naive/rnn_cell/opr_impl.h"
#include "src/common/rnn_cell.h"
#include "midout.h"
MIDOUT_DECL(megdnn_naive_rnncell_fwd)
namespace megdnn {
namespace naive {
size_t RNNCellImpl::get_workspace_in_bytes(
......@@ -26,9 +29,12 @@ void RNNCellImpl::exec(
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih,
_megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh,
_megdnn_tensor_out dst, _megdnn_workspace workspace) {
megdnn::rnn_cell::exec(
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, dst, workspace,
param().nonlineMode, handle());
MIDOUT_BEGIN(megdnn_naive_rnncell_fwd) {
megdnn::rnn_cell::exec(
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, dst, workspace,
param().nonlineMode, handle());
}
MIDOUT_END();
}
} // namespace naive
} // namespace megdnn
\ No newline at end of file
......@@ -27,6 +27,7 @@ public:
const TensorLayout& bias_ih, const TensorLayout& hx,
const TensorLayout& weight_hh, const TensorLayout& bias_hh,
const TensorLayout& dst) override;
bool is_thread_safe() const override { return true; }
};
} // namespace naive
......
......@@ -77,17 +77,18 @@ struct DeduceLayoutProxy<Opr, 6, false> {
};
template <typename Opr>
struct DeduceLayoutProxy<Opr, 6, true> {
static void deduce_layout(Opr* opr, TensorLayoutArray& layouts) {
megdnn_assert(layouts.size() == 6);
opr->deduce_layout(
layouts[0], layouts[1], layouts[2], layouts[3], layouts[4], layouts[5]);
}
struct DeduceLayoutProxy<Opr, 7, false> {
static void deduce_layout(Opr*, TensorLayoutArray&) {}
};
template <typename Opr>
struct DeduceLayoutProxy<Opr, 7, false> {
static void deduce_layout(Opr*, TensorLayoutArray&) {}
struct DeduceLayoutProxy<Opr, 7, true> {
static void deduce_layout(Opr* opr, TensorLayoutArray& layouts) {
megdnn_assert(layouts.size() == 7);
opr->deduce_layout(
layouts[0], layouts[1], layouts[2], layouts[3], layouts[4], layouts[5],
layouts[6]);
}
};
template <typename Opr>
......@@ -109,6 +110,38 @@ struct DeduceLayoutProxy<Opr, 9, true> {
layouts[6], layouts[7], layouts[8]);
}
};
template <typename Opr>
struct DeduceLayoutProxy<Opr, 10, true> {
static void deduce_layout(Opr* opr, TensorLayoutArray& layouts) {
megdnn_assert(layouts.size() == 10);
opr->deduce_layout(
layouts[0], layouts[1], layouts[2], layouts[3], layouts[4], layouts[5],
layouts[6], layouts[7], layouts[8], layouts[9]);
}
};
template <typename Opr>
struct DeduceLayoutProxy<Opr, 10, false> {
static void deduce_layout(Opr*, TensorLayoutArray&) {}
};
template <typename Opr>
struct DeduceLayoutProxy<Opr, 13, true> {
static void deduce_layout(Opr* opr, TensorLayoutArray& layouts) {
megdnn_assert(layouts.size() == 13);
opr->deduce_layout(
layouts[0], layouts[1], layouts[2], layouts[3], layouts[4], layouts[5],
layouts[6], layouts[7], layouts[8], layouts[9], layouts[10],
layouts[11], layouts[12]);
}
};
template <typename Opr>
struct DeduceLayoutProxy<Opr, 13, false> {
static void deduce_layout(Opr*, TensorLayoutArray&) {}
};
} // namespace test
} // namespace megdnn
......
......@@ -22,6 +22,44 @@ namespace test {
template <typename Opr, size_t Arity, bool has_workspace>
struct ExecProxy;
template <typename Opr>
struct ExecProxy<Opr, 13, true> {
WorkspaceWrapper W;
void exec(Opr* opr, const TensorNDArray& tensors) {
if (!W.valid()) {
W = WorkspaceWrapper(opr->handle(), 0);
}
W.update(opr->get_workspace_in_bytes(
tensors[0].layout, tensors[1].layout, tensors[2].layout,
tensors[3].layout, tensors[4].layout, tensors[5].layout,
tensors[6].layout, tensors[7].layout, tensors[8].layout,
tensors[9].layout, tensors[10].layout, tensors[11].layout,
tensors[12].layout));
opr->exec(
tensors[0], tensors[1], tensors[2], tensors[3], tensors[4], tensors[5],
tensors[6], tensors[7], tensors[8], tensors[9], tensors[10],
tensors[11], tensors[12], W.workspace());
}
};
template <typename Opr>
struct ExecProxy<Opr, 10, true> {
WorkspaceWrapper W;
void exec(Opr* opr, const TensorNDArray& tensors) {
if (!W.valid()) {
W = WorkspaceWrapper(opr->handle(), 0);
}
W.update(opr->get_workspace_in_bytes(
tensors[0].layout, tensors[1].layout, tensors[2].layout,
tensors[3].layout, tensors[4].layout, tensors[5].layout,
tensors[6].layout, tensors[7].layout, tensors[8].layout,
tensors[9].layout));
opr->exec(
tensors[0], tensors[1], tensors[2], tensors[3], tensors[4], tensors[5],
tensors[6], tensors[7], tensors[8], tensors[9], W.workspace());
}
};
template <typename Opr>
struct ExecProxy<Opr, 9, true> {
WorkspaceWrapper W;
......
/**
* \file dnn/test/common/rnn.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <vector>
#include "megdnn/basic_types.h"
#include "megdnn/opr_param_defs.h"
namespace megdnn {
namespace test {
namespace rnn {
struct TestArg {
param::RNN param;
TensorShape input, hx, flatten_weights;
TestArg(param::RNN param, TensorShape input, TensorShape hx,
TensorShape flatten_weights)
: param(param), input(input), hx(hx), flatten_weights(flatten_weights) {}
};
inline std::vector<TestArg> get_args() {
std::vector<TestArg> args;
size_t batch_size = 2;
size_t input_size = 3;
size_t hidden_size = 2;
size_t seq_len = 2;
size_t gate_hidden_size = hidden_size;
param::RNN param;
param.num_layers = 1;
param.bidirectional = false;
param.bias = false;
param.hidden_size = hidden_size;
param.nonlineMode = param::RNN::NonlineMode::RELU;
args.emplace_back(
param, TensorShape{seq_len, batch_size, input_size},
TensorShape{batch_size, hidden_size},
TensorShape{gate_hidden_size, input_size + hidden_size});
return args;
}
} // namespace rnn
} // namespace test
} // namespace megdnn
\ No newline at end of file
/**
* \file dnn/test/naive/lstm.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// #include "test/common/lstm.h"
#include "megdnn/dtype.h"
#include "megdnn/oprs.h"
#include "test/common/checker.h"
#include "test/naive/fixture.h"
namespace megdnn {
namespace test {
TEST_F(NAIVE, LSTM_FORWARD) {
Checker<LSTM> checker(handle(), true);
size_t batch_size = 2;
size_t input_size = 3;
size_t hidden_size = 2;
size_t seq_len = 2;
size_t gate_hidden_size = 4 * hidden_size;
LSTM::Param param;
param.num_layers = 1;
param.bidirectional = false;
param.bias = false;
param.hidden_size = hidden_size;
checker.set_param(param).exect(
Testcase{
TensorValue(
{seq_len, batch_size, input_size}, dtype::Float32(),
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}), // input
TensorValue(
{batch_size, hidden_size}, dtype::Float32(),
{1, 2, 3, 4}), // hx
TensorValue(
{batch_size, hidden_size}, dtype::Float32(),
{2, 3, 4, 5}), // cx
TensorValue(
{gate_hidden_size, input_size + hidden_size},
dtype::Float32(),
{3, 6, 1, 3, 2, 7, 2, 1, 3, 2, 1, 1, 3, 6,
1, 3, 2, 7, 2, 1, 3, 2, 1, 1, 9, 3, 5, 1,
9, 3, 5, 1, 9, 3, 5, 1, 9, 3, 5, 1}), // flattern weights
{},
{},
{},
{}},
Testcase{
{},
{},
{},
{},
TensorValue(
{seq_len, batch_size, hidden_size}, dtype::Float32(),
{0.9951, 0.9993, 0.9999, 1.0000, 0.9993, 0.9999, 1.0000,
1.0000}), // output
TensorValue(
{batch_size, hidden_size}, dtype::Float32(),
{0.9993, 0.9999, 1.0000, 1.0000}), // hy
TensorValue(
{batch_size, hidden_size}, dtype::Float32(),
{4.0000, 5.0000, 6.0000, 7.0000}), // cy
TensorValue(
{2, 2, 2, 2}, dtype::Float32(),
{0.995054, 0.999328, 0.99990, 0.999987, 3., 4., 5., 6.,
0.999329, 0.999328, 0.99990, 1., 4., 5., 6.,
7.}) // reserve space
});
param.bidirectional = true;
checker.set_param(param).exect(
Testcase{
TensorValue(
{seq_len, batch_size, input_size}, dtype::Float32(),
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}), // input
TensorValue(
{2, batch_size, hidden_size}, dtype::Float32(),
{1, 2, 3, 4, 5, 6, 7, 8}), // hx
TensorValue(
{2, batch_size, hidden_size}, dtype::Float32(),
{2, 3, 4, 5, 6, 7, 8, 9}), // cx
TensorValue(
{gate_hidden_size, 2 * (input_size + hidden_size)},
dtype::Float32(),
{3, 6, 1, 3, 2, 7, 2, 1, 3, 2, 1, 1, 3, 6, 1, 3, 2,
7, 2, 1, 3, 2, 1, 1, 9, 3, 5, 1, 9, 3, 5, 1, 9, 3,
5, 1, 9, 3, 5, 1, 3, 6, 1, 3, 2, 7, 2, 1, 3, 2, 1,
1, 3, 6, 1, 3, 2, 7, 2, 1, 3, 2, 1, 1, 9, 3, 5, 1,
9, 3, 5, 1, 9, 3, 5, 1, 9, 3, 5, 1}), // flattern weights
{},
{},
{},
{}},
Testcase{
{},
{},
{},
{},
TensorValue(
{seq_len, batch_size, 2 * hidden_size}, dtype::Float32(),
{0.9951, 0.9993, 1.0000, 1.0000, 0.9999, 1.0000, 1.0000,
1.0000, 0.9993, 0.9999, 1.0000, 1.0000, 1.0000, 1.0000,
1.0000, 1.0000}), // output
TensorValue(
{2, batch_size, hidden_size}, dtype::Float32(),
{0.9993, 0.9999, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
1.0000}), // hy
TensorValue(
{2, batch_size, hidden_size}, dtype::Float32(),
{4.0000, 5.0000, 6.0000, 7.0000, 8.0000, 9.0000, 10.0000,
11.0000}), // cy
TensorValue(
{4, 2, 2, 2}, dtype::Float32(),
{0.995054, 0.999328, 0.99990, 0.999987, 3., 4.,
5., 6., 0.999329, 0.999328, 0.99990, 1.,
4., 5., 6., 7., 1., 0.999328,
0.99990, 0.999987, 7., 8., 9., 10.,
0.999329, 0.999328, 0.99990, 1., 8., 9.,
10., 11.}) // reserve space
});
param.num_layers = 2;
checker.set_param(param).exect(
Testcase{
TensorValue(
{seq_len, batch_size, input_size}, dtype::Float32(),
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}), // input
TensorValue(
{4, batch_size, hidden_size}, dtype::Float32(),
{1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8}), // hx
TensorValue(
{4, batch_size, hidden_size}, dtype::Float32(),
{2, 3, 4, 5, 6, 7, 8, 9, 2, 3, 4, 5, 6, 7, 8, 9}), // cx
TensorValue(
{8, 22}, dtype::Float32(),
{
3, 6, 1, 3, 2, 7, 2, 1, 3, 2, 1, 1, 3, 6, 1, 3,
2, 7, 2, 1, 3, 2, 1, 1, 9, 3, 5, 1, 9, 3, 5, 1,
9, 3, 5, 1, 9, 3, 5, 1, 3, 6, 1, 3, 2, 7, 2, 1,
3, 2, 1, 1, 3, 6, 1, 3, 2, 7, 2, 1, 3, 2, 1, 1,
9, 3, 5, 1, 9, 3, 5, 1, 9, 3, 5, 1, 9, 3, 5, 1,
3, 6, 1, 3, 2, 7, 2, 1, 3, 2, 1, 1, 2, 7, 2, 1,
3, 6, 1, 3, 2, 7, 2, 1, 3, 2, 1, 1, 2, 7, 2, 1,
3, 6, 1, 3, 2, 7, 2, 1, 3, 2, 1, 1, 2, 7, 2, 1,
3, 6, 1, 3, 2, 7, 2, 1, 3, 2, 1, 1, 2, 7, 2, 1,
9, 3, 5, 1, 9, 3, 5, 1, 9, 3, 5, 1, 9, 3, 5, 1,
9, 3, 5, 1, 9, 3, 5, 1, 9, 3, 5, 1, 9, 3, 5, 1,
}), // flattern weights
{},
{},
{},
{}},
Testcase{
{},
{},
{},
{},
TensorValue(
{seq_len, batch_size, 2 * hidden_size}, dtype::Float32(),
{0.9951, 0.9993, 1.0000, 1.0000, 0.9999, 1.0000, 1.0000,
1.0000, 0.9993, 0.9999, 1.0000, 1.0000, 1.0000, 1.0000,
1.0000, 1.0000}), // output
TensorValue(
{4, batch_size, hidden_size}, dtype::Float32(),
{0.9993, 0.9999, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
1.0000, 0.9993, 0.9999, 1.0000, 1.0000, 1.0000, 1.0000,
1.0000, 1.0000}), // hy
TensorValue(
{4, batch_size, hidden_size}, dtype::Float32(),
{4.0000, 5.0000, 6.0000, 7.0000, 8.0000, 9.0000, 10.0000,
11.0000, 4.0000, 5.0000, 6.0000, 7.0000, 8.0000, 9.0000,
10.0000, 11.0000}), // cy
TensorValue(
{8, 2, 2, 2}, dtype::Float32(),
{
0.995054, 0.999328, 0.99990, 0.999987, 3.,
4., 5., 6., 0.999329, 0.999328,
0.99990, 1., 4., 5., 6.,
7., 1., 0.999328, 0.99990, 0.999987,
7., 8., 9., 10., 0.999329,
0.999328, 0.99990, 1., 8., 9.,
10., 11., 0.995054, 0.999328, 0.99990,
0.999987, 3., 4., 5., 6.,
0.999329, 0.999328, 0.99990, 1., 4.,
5., 6., 7., 1., 0.999328,
0.99990, 0.999987, 7., 8., 9.,
10., 0.999329, 0.999328, 0.99990, 1.,
8., 9., 10., 11.,
}) // reserve space
});
}
} // namespace test
} // namespace megdnn
\ No newline at end of file
/**
* \file dnn/test/naive/lstmcell.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megdnn/dtype.h"
#include "megdnn/oprs.h"
#include "test/common/checker.h"
#include "test/naive/fixture.h"
namespace megdnn {
namespace test {
TEST_F(NAIVE, LSTMCELL) {
Checker<LSTMCell> checker(handle(), true);
for (size_t batch : {1, 4})
for (size_t n : {3, 4, 5, 23, 100})
for (size_t out : {3, 6, 25, 100}) {
checker.exec(
{{batch, n},
{out * 4, n},
{1, out * 4},
{batch, out},
{out * 4, out},
{1, out * 4},
{batch, out},
{},
{},
{}});
}
size_t batch_size = 2;
size_t input_size = 3;
size_t hidden_size = 2;
checker.exect(
Testcase{
TensorValue(
{batch_size, input_size}, dtype::Float32(),
{1, 2, 3, 4, 5, 6}), // input
TensorValue(
{4 * hidden_size, input_size}, dtype::Float32(),
{
0.3535, 0.3535, 0.3535, 0.3535, 0.3535, 0.3535,
0.3535, 0.3535, 0.3535, 0.3535, 0.3535, 0.3535,
0.3535, 0.3535, 0.3535, 0.3535, 0.3535, 0.3535,
0.3535, 0.3535, 0.3535, 0.3535, 0.3535, 0.3535,
}), // weight_ih
TensorValue(
{4 * hidden_size}, dtype::Float32(),
{0, 0, 0, 0, 0, 0, 0, 0}), // bias_ih
TensorValue(
{batch_size, hidden_size}, dtype::Float32(),
{1, 2, 3, 4}), // hx
TensorValue(
{4 * hidden_size, hidden_size}, dtype::Float32(),
{0.3535, 0.3535, 0.3535, 0.3535, 0.3535, 0.3535, 0.3535,
0.3535, 0.3535, 0.3535, 0.3535, 0.3535, 0.3535, 0.3535,
0.3535, 0.3535}), // weight_hh
TensorValue(
{4 * hidden_size}, dtype::Float32(),
{0, 0, 0, 0, 0, 0, 0, 0}), // bias_hh
TensorValue(
{batch_size, hidden_size}, dtype::Float32(),
{2, 3, 4, 5}), // cx
{},
{},
{}},
Testcase{
{},
{},
{},
{},
{},
{},
{},
TensorValue(
{batch_size, hidden_size}, dtype::Float32(),
{0.9541, 0.9593, 0.9995, 0.9996}), // hy
TensorValue(
{batch_size, hidden_size}, dtype::Float32(),
{2.8771, 3.8373, 4.9979, 5.9975}), // cy
TensorValue(
{batch_size, 4 * hidden_size}, dtype::Float32(),
{3.18198, 3.18198, 7.7781, 7.7781, 3.18198, 3.18198,
7.77817, 7.77817, 3.18198, 3.18198, 7.77817, 7.77817,
3.18198, 3.18198, 7.77817, 7.77817}), // cy
});
batch_size = 2;
input_size = 2;
hidden_size = 1;
checker.exect(
Testcase{
TensorValue(
{batch_size, input_size}, dtype::Float32(),
{1, 2, 3, 4}), // input
TensorValue(
{4 * hidden_size, input_size}, dtype::Float32(),
{0.3535, 0.3535, 0.3535, 0.3535, 0.3535, 0.3535, 0.3535,
0.3535}), // weight_ih
TensorValue(
{4 * hidden_size}, dtype::Float32(),
{0.3535, 0.3535, 0.3535, 0.3535}), // bias_ih
TensorValue(
{batch_size, hidden_size}, dtype::Float32(), {1, 2}), // hx
TensorValue(
{4 * hidden_size, hidden_size}, dtype::Float32(),
{0.3535, 0.3535, 0.3535, 0.3535}), // weight_hh
TensorValue(
{4 * hidden_size}, dtype::Float32(),
{0.3535, 0.3535, 0.3535, 0.3535}), // bias_hh
TensorValue(
{batch_size, hidden_size}, dtype::Float32(), {4, 5}), // cx
{},
{},
{}},
Testcase{
{},
{},
{},
{},
{},
{},
{},
TensorValue(
{batch_size, hidden_size}, dtype::Float32(),
{0.8927, 0.9799}), // hy
TensorValue(
{batch_size, hidden_size}, dtype::Float32(),
{4.4393, 5.8788}), // cy
TensorValue(
{batch_size, 4 * hidden_size}, dtype::Float32(),
{2.1210, 3.8885, 2.1210, 3.8885, 2.1210, 3.8885, 2.1210,
3.8885}), // gates
});
}
} // namespace test
} // namespace megdnn
\ No newline at end of file
......@@ -8,7 +8,6 @@
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "test/common/rnn.h"
#include "megdnn/dtype.h"
#include "megdnn/oprs.h"
#include "test/common/checker.h"
......@@ -17,22 +16,7 @@
namespace megdnn {
namespace test {
/*TEST_F(NAIVE, RNN) {
std::vector<rnn::TestArg> args = rnn::get_args();
Checker<RNN> checker(handle());
for (auto&& arg : args) {
checker.set_param(arg.param)
.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Float32())
.set_dtype(3, dtype::Float32())
.set_dtype(4, dtype::Float32())
.set_dtype(5, dtype::Float32())
.execs({arg.input, arg.hx, arg.flatten_weights, {}, {}, {}});
}
}*/
TEST_F(NAIVE, RNN_HAND_MADE) {
TEST_F(NAIVE, RNN_FORWARD) {
Checker<RNN> checker(handle(), false);
size_t batch_size = 2;
size_t input_size = 3;
......@@ -49,14 +33,17 @@ TEST_F(NAIVE, RNN_HAND_MADE) {
Testcase{
TensorValue(
{seq_len, batch_size, input_size}, dtype::Float32(),
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}), // input
{-0.66536, 0.08049, 0.12008, 0.63423, 1.37801, 0.02591,
0.09153, 0.82866, -1.70429, -1.26624, -0.06421,
0.35816}), // input
TensorValue(
{batch_size, hidden_size}, dtype::Float32(),
{2, 1, 3, 5}), // hx
{-3.19544, -1.24232, 1.99512, -0.25692}), // hx
TensorValue(
{gate_hidden_size, input_size + hidden_size},
dtype::Float32(),
{3, 6, 1, 3, 2, 7, 9, 3, 5, 1}), // weights
{0.35355, 0.35355, 0.35355, 0.35355, 0.35355, 0.35355,
0.35355, 0.35355, 0.35355, 0.35355}), // flattern weights
{},
{},
{}},
......@@ -66,13 +53,54 @@ TEST_F(NAIVE, RNN_HAND_MADE) {
{},
TensorValue(
{seq_len, batch_size, hidden_size}, dtype::Float32(),
{39, 39, 90, 84, 300, 216, 546, 366}), // output
{0.0, 0.0, 1.3351, 1.3351, 0.0, 0.0, 0.6003,
0.6003}), // output
TensorValue(
{batch_size, hidden_size}, dtype::Float32(),
{21, 11, 42, 20}), // hy
{0.0, 0.0, 0.6003, 0.6003}), // hy
TensorValue(
{1, 2, 2, 2}, dtype::Float32(),
{2, 1, 3, 5, 21, 11, 42, 20}) // reserve space
{0.0, 0.0, 1.33512, 1.33512, 0.0, 0.0, 0.60031,
0.60031}) // reserve space
});
param.num_layers = 2;
checker.set_param(param).exect(
Testcase{
TensorValue(
{seq_len, batch_size, input_size}, dtype::Float32(),
{-0.66536, 0.08049, 0.12008, 0.63423, 1.37801, 0.02591,
0.09153, 0.82866, -1.70429, -1.26624, -0.06421,
0.35816}), // input
TensorValue(
{2, batch_size, hidden_size}, dtype::Float32(),
{-3.19544, -1.24232, 1.99512, -0.25692, -3.19544, -1.24232,
1.99512, -0.25692}), // hx
TensorValue(
{2, 9}, dtype::Float32(),
{0.35355, 0.35355, 0.35355, 0.35355, 0.35355, 0.35355,
0.35355, 0.35355, 0.35355, 0.35355, 0.35355, 0.35355,
0.35355, 0.35355, 0.35355, 0.35355, 0.35355,
0.35355}), // weights
{},
{},
{}},
Testcase{
{},
{},
{},
TensorValue(
{seq_len, batch_size, hidden_size}, dtype::Float32(),
{0.0, 0.0, 1.5586, 1.5586, 0.0, 0.0, 1.5266,
1.5266}), // output
TensorValue(
{2, batch_size, hidden_size}, dtype::Float32(),
{0.0, 0.0, 0.6003, 0.6003, 0.0, 0.0, 1.5266,
1.5266}), // hy
TensorValue(
{2, 2, 2, 2}, dtype::Float32(),
{0.0, 0.0, 1.33512, 1.33512, 0.0, 0.0, 0.60031, 0.60031,
0.0, 0.0, 1.55861, 1.55861, 0.0, 0.0, 1.52658,
1.52658}) // reserve space
});
}
......
/**
* \file dnn/test/naive/rnncell.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megdnn/dtype.h"
#include "megdnn/oprs.h"
#include "test/common/checker.h"
#include "test/naive/fixture.h"
namespace megdnn {
namespace test {
TEST_F(NAIVE, RNNCELL) {
Checker<RNNCell> checker(handle(), false);
for (size_t batch : {1, 4})
for (size_t inp : {3, 4, 5, 23, 100})
for (size_t hidden : {3, 6, 25, 100}) {
checker.exec(
{{batch, inp},
{hidden, inp},
{1, hidden},
{batch, hidden},
{hidden, hidden},
{1, hidden},
{}});
}
size_t batch_size = 2;
size_t input_size = 3;
size_t hidden_size = 2;
RNNCell::Param param;
param.nonlineMode = param::RNNCell::NonlineMode::TANH;
checker.set_param(param).exect(
Testcase{
TensorValue(
{batch_size, input_size}, dtype::Float32(),
{1, 2, 3, 4, 5, 6}), // input
TensorValue(
{hidden_size, input_size}, dtype::Float32(),
{0.3535, 0.3535, 0.3535, 0.3535, 0.3535,
0.3535}), // weight_ih
TensorValue({1, hidden_size}, dtype::Float32(), {0, 0}), // bias_ih
TensorValue(
{batch_size, hidden_size}, dtype::Float32(),
{1, 2, 3, 4}), // hx
TensorValue(
{hidden_size, hidden_size}, dtype::Float32(),
{0.3535, 0.3535, 0.3535, 0.3535}), // weight_hh
TensorValue({1, hidden_size}, dtype::Float32(), {0, 0}), // bias_hh
{}},
Testcase{
{},
{},
{},
{},
{},
{},
TensorValue(
{batch_size, hidden_size}, dtype::Float32(),
{0.9966, 0.9966, 1.0, 1.0}), // dst
});
batch_size = 2;
input_size = 2;
hidden_size = 1;
param.nonlineMode = param::RNNCell::NonlineMode::RELU;
checker.set_param(param).exect(
Testcase{
TensorValue(
{batch_size, input_size}, dtype::Float32(),
{1, 2, 3, 4}), // input
TensorValue(
{hidden_size, input_size}, dtype::Float32(),
{0.3535, 0.3535}), // weight_ih
TensorValue(
{1, hidden_size}, dtype::Float32(), {0.3535}), // bias_ih
TensorValue(
{batch_size, hidden_size}, dtype::Float32(),
{-1, -2}), // hx
TensorValue(
{hidden_size, hidden_size}, dtype::Float32(),
{0.3535}), // weight_hh
TensorValue(
{1, hidden_size}, dtype::Float32(), {0.3535}), // bias_hh
{}},
Testcase{
{},
{},
{},
{},
{},
{},
TensorValue(
{batch_size, hidden_size}, dtype::Float32(),
{1.414, 2.4745}), // hy
});
}
} // namespace test
} // namespace megdnn
\ No newline at end of file
......@@ -11,6 +11,7 @@ import pytest
import megengine as mge
import megengine.functional as F
from megengine.device import get_device_count
from megengine.module import LSTM, RNN, LSTMCell, RNNCell
......@@ -20,6 +21,7 @@ def assert_tuple_equal(src, ref):
assert i == j
@pytest.mark.skipif(get_device_count("gpu") > 0, reason="no algorithm on cuda")
@pytest.mark.parametrize(
"batch_size, input_size, hidden_size, init_hidden",
[(3, 10, 20, True), (3, 10, 20, False), (1, 10, 20, False)],
......@@ -35,7 +37,7 @@ def test_rnn_cell(batch_size, input_size, hidden_size, init_hidden):
assert_tuple_equal(h_new.shape, (batch_size, hidden_size))
# is batch_size == 0 tolerated ? it will cause error in slice operation xx[:, ...]
@pytest.mark.skipif(get_device_count("gpu") > 0, reason="no algorithm on cuda")
@pytest.mark.parametrize(
"batch_size, input_size, hidden_size, init_hidden",
[(3, 10, 20, True), (3, 10, 20, False), (1, 10, 20, False)],
......@@ -53,6 +55,7 @@ def test_lstm_cell(batch_size, input_size, hidden_size, init_hidden):
assert_tuple_equal(c_new.shape, (batch_size, hidden_size))
@pytest.mark.skipif(get_device_count("gpu") > 0, reason="no algorithm on cuda")
@pytest.mark.parametrize(
"batch_size, seq_len, input_size, hidden_size, num_layers, bidirectional, init_hidden, batch_first",
[
......@@ -70,7 +73,6 @@ def test_lstm_cell(batch_size, input_size, hidden_size, init_hidden):
),
],
)
# (0, 1, 1, 1, 1, False, True, False)])
def test_rnn(
batch_size,
seq_len,
......@@ -113,6 +115,7 @@ def test_rnn(
)
@pytest.mark.skipif(get_device_count("gpu") > 0, reason="no algorithm on cuda")
@pytest.mark.parametrize(
"batch_size, seq_len, input_size, hidden_size, num_layers, bidirectional, init_hidden, batch_first",
[
......@@ -130,7 +133,6 @@ def test_rnn(
),
],
)
# (0, 1, 1, 1, 1, False, True, False)])
def test_lstm(
batch_size,
seq_len,
......@@ -175,7 +177,3 @@ def test_lstm(
assert_tuple_equal(
h_n[1].shape, (num_directions * num_layers, batch_size, hidden_size)
)
if __name__ == "__main__":
test_lstm(5, 10, 10, 20, 1, False, False, True)
......@@ -123,7 +123,7 @@ MGB_IMPL_OPR_GRAD(LSTMCell) {
SymbolVar input(opr.input(0)), weight_ih(opr.input(1)), hx(opr.input(3)),
weight_hh(opr.input(4)), cx(opr.input(6));
SymbolVar h_out(opr.output(0)), c_out(opr.output(1)), gates(opr.output(2)),
h_og{out_grad.at(0)}, c_og{out_grad.at(1)}, tmp;
h_og{out_grad.at(0)}, c_og{out_grad.at(1)};
size_t ghs = gates.shape()[1] / 4; // gate_hidden_size
SymbolVarArray gates_array = Split::make(
gates, Split::Options::make_partition(gates, 1, {ghs, ghs, ghs, ghs}));
......@@ -141,7 +141,7 @@ MGB_IMPL_OPR_GRAD(LSTMCell) {
f_grad = Elemwise::make({f, c_og * cx}, Mode::SIGMOID_GRAD);
i_grad = Elemwise::make({i, c_og * g}, Mode::SIGMOID_GRAD);
g_grad = Elemwise::make({g, c_og * i}, Mode::TANH_GRAD);
SymbolVar rnn_cell_grad = Concat::make({i_grad, f_grad, o_grad, g_grad}, {-1});
SymbolVar rnn_cell_grad = Concat::make({i_grad, f_grad, o_grad, g_grad}, -1);
SymbolVar result;
if (wrt_idx < 6) {
......@@ -258,7 +258,6 @@ MGB_IMPL_OPR_GRAD(LSTM) {
SymbolVarArray grads = LSTMBackward::make(
opr.input(0), opr.output(0), opr.input(1), opr.input(2), out_grad.at(0),
out_grad.at(1), out_grad.at(2), opr.input(3), opr.output(3), opr.param());
SymbolVar res;
return grads.at(wrt_idx).node(); // input, hx, cx, weights
}
#endif
......
......@@ -25,11 +25,11 @@ MGB_DEFINE_OPR_CLASS(
public:
using NonlineMode = Param::NonlineMode;
RNNCellForward(
MGE_WIN_DECLSPEC_FUC RNNCellForward(
VarNode* input, VarNode* weight_ih, VarNode* bias_ih, VarNode* hx,
VarNode* weight_hh, VarNode* bias_hh, const Param& param,
const OperatorNodeConfig& config);
static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar input, SymbolVar weight_ih, SymbolVar bias_ih, SymbolVar hx,
SymbolVar weight_hh, SymbolVar bias_hh, const Param& param = {},
const OperatorNodeConfig& config = {});
......@@ -39,11 +39,11 @@ using RNNCell = RNNCellForward;
MGB_DEFINE_OPR_CLASS(
LSTMCellForward, intl::MegDNNOprWrapperFwd<megdnn::LSTMCellForward>) // {
public:
LSTMCellForward(
MGE_WIN_DECLSPEC_FUC LSTMCellForward(
VarNode* input, VarNode* weight_ih, VarNode* bias_ih, VarNode* hx,
VarNode* weight_hh, VarNode* bias_hh, VarNode* cx, const Param& param,
const OperatorNodeConfig& config);
static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar input, SymbolVar weight_ih, SymbolVar bias_ih, SymbolVar hx,
SymbolVar weight_hh, SymbolVar bias_hh, SymbolVar cx,
const Param& param = {}, const OperatorNodeConfig& config = {});
......@@ -51,17 +51,11 @@ public:
using LSTMCell = LSTMCellForward;
MGB_DEFINE_OPR_CLASS(RNNForward, intl::MegDNNOprWrapperFwd<megdnn::RNNForward>) // {
/*private:
SymbolVarArray weight_ih_arr; // 1d, idx: direction * num_layers + layer
SymbolVarArray weight_hh_arr;
SymbolVarArray bias_arr;
*/
public:
RNNForward(
MGE_WIN_DECLSPEC_FUC RNNForward(
VarNode* input, VarNode* hx, VarNode* flatten_weights, const Param& param,
const OperatorNodeConfig& config);
static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar input, SymbolVar hx, SymbolVar flatten_weights,
const Param& param = {}, const OperatorNodeConfig& config = {});
};
......@@ -70,11 +64,11 @@ using RNN = RNNForward;
MGB_DEFINE_OPR_CLASS(
RNNBackward, intl::MegDNNOprWrapperBwd<megdnn::RNNBackward>) // {
public:
RNNBackward(
MGE_WIN_DECLSPEC_FUC RNNBackward(
VarNode* x, VarNode* y, VarNode* hx, VarNode* dy, VarNode* dhy,
VarNode* flatten_weights, VarNode* reserve_space, const Param& param,
const OperatorNodeConfig& config);
static SymbolVarArray make(
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make(
SymbolVar x, SymbolVar y, SymbolVar hx, SymbolVar dy, SymbolVar dhy,
SymbolVar flatten_weights, SymbolVar reserve_space, const Param& param = {},
const OperatorNodeConfig& config = {});
......@@ -88,10 +82,10 @@ private:
MGB_DEFINE_OPR_CLASS(
LSTMForward, intl::MegDNNOprWrapperFwd<megdnn::LSTMForward>) // {
public:
LSTMForward(
MGE_WIN_DECLSPEC_FUC LSTMForward(
VarNode* input, VarNode* hx, VarNode* cx, VarNode* flatten_weights,
const Param& param, const OperatorNodeConfig& config);
static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar input, SymbolVar hx, SymbolVar cx, SymbolVar flatten_weights,
const Param& param = {}, const OperatorNodeConfig& config = {});
};
......@@ -100,11 +94,11 @@ using LSTM = LSTMForward;
MGB_DEFINE_OPR_CLASS(
LSTMBackward, intl::MegDNNOprWrapperBwd<megdnn::LSTMBackward>) // {
public:
LSTMBackward(
MGE_WIN_DECLSPEC_FUC LSTMBackward(
VarNode* x, VarNode* y, VarNode* hx, VarNode* cx, VarNode* dy, VarNode* dhy,
VarNode* dcy, VarNode* flatten_weights, VarNode* reserve_space,
const Param& param, const OperatorNodeConfig& config);
static SymbolVarArray make(
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make(
SymbolVar x, SymbolVar y, SymbolVar hx, SymbolVar cx, SymbolVar dy,
SymbolVar dhy, SymbolVar dcy, SymbolVar flatten_weights,
SymbolVar reserve_space, const Param& param = {},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册