diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/tanh.h b/dnn/src/arm_common/elemwise_helper/kimpl/tanh.h index 10792dcb28138b56281d3b9e4a193b368da6e82c..13d92becaff1050bf1698a94ad19887534660727 100644 --- a/dnn/src/arm_common/elemwise_helper/kimpl/tanh.h +++ b/dnn/src/arm_common/elemwise_helper/kimpl/tanh.h @@ -30,18 +30,18 @@ struct TanhOpBase : UnaryOpBase { template struct TanhOp; -#define OP(_ctype, _neon_type, _func_suffix, _simd_width) \ +#define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ template <> \ struct TanhOp<_ctype> : TanhOpBase<_ctype> { \ using TanhOpBase::TanhOpBase; \ using TanhOpBase::operator(); \ constexpr static size_t SIMD_WIDTH = _simd_width; \ - void operator()(const _neon_type& src, _ctype* dst) const { \ + void operator()(const _neon_type2& src, _ctype* dst) const { \ auto vitem = operator()(src); \ vst1q_##_func_suffix(dst, vitem.val[0]); \ vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ } \ - _neon_type operator()(const _neon_type& src) const { \ + _neon_type2 operator()(const _neon_type2& src) const { \ auto one_val = vdupq_n_##_func_suffix(1.f); \ auto two_val = vdupq_n_##_func_suffix(2.f); \ auto val1 = src.val[0]; \ @@ -62,10 +62,23 @@ struct TanhOp; val2 = vsubq_##_func_suffix(one_val, val2); \ return {{val1, val2}}; \ } \ + _neon_type operator()(const _neon_type& src) const { \ + auto one_val = vdupq_n_##_func_suffix(1.f); \ + auto two_val = vdupq_n_##_func_suffix(2.f); \ + auto val1 = src; \ + val1 = vmulq_##_func_suffix(two_val, val1); \ + val1 = exp_ps_##_func_suffix(val1); \ + val1 = vaddq_##_func_suffix(one_val, val1); \ + auto rval1 = vrecpeq_##_func_suffix(val1); \ + rval1 = vmulq_##_func_suffix(vrecpsq_##_func_suffix(val1, rval1), rval1); \ + val1 = vmulq_##_func_suffix(two_val, rval1); \ + val1 = vsubq_##_func_suffix(one_val, val1); \ + return val1; \ + } \ }; -OP(dt_float32, float32x4x2_t, f32, 4) +OP(dt_float32, float32x4_t, float32x4x2_t, f32, 4) #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -OP(__fp16, float16x8x2_t, f16, 8) +OP(__fp16, float16x8_t, float16x8x2_t, f16, 8) #endif #undef OP diff --git a/dnn/src/arm_common/handle.cpp b/dnn/src/arm_common/handle.cpp index 05de3a2d88c2f2eb6fac75788449b59d4f6e3f73..e3e4c9df0e31d783b97b4baf8e56a28121352188 100644 --- a/dnn/src/arm_common/handle.cpp +++ b/dnn/src/arm_common/handle.cpp @@ -19,9 +19,12 @@ #include "src/arm_common/elemwise/opr_impl.h" #include "src/arm_common/elemwise_multi_type/opr_impl.h" #include "src/arm_common/local/opr_impl.h" +#include "src/arm_common/lstm/opr_impl.h" +#include "src/arm_common/lstm_cell/opr_impl.h" #include "src/arm_common/pooling/opr_impl.h" #include "src/arm_common/reduce/opr_impl.h" #include "src/arm_common/resize/opr_impl.h" +#include "src/arm_common/rnn_cell/opr_impl.h" #include "src/arm_common/separable_conv/opr_impl.h" #include "src/arm_common/separable_filter/opr_impl.h" #include "src/arm_common/type_cvt/opr_impl.h" @@ -50,6 +53,9 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(TypeCvt) MEGDNN_SPECIALIZE_CREATE_OPERATOR(Reduce) MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvBias) MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionBackwardData) +MEGDNN_SPECIALIZE_CREATE_OPERATOR(RNNCell) +MEGDNN_SPECIALIZE_CREATE_OPERATOR(LSTMCell) +MEGDNN_SPECIALIZE_CREATE_OPERATOR(LSTM) #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wpragmas" diff --git a/dnn/src/arm_common/lstm/lstm_utils.cpp b/dnn/src/arm_common/lstm/lstm_utils.cpp new file mode 100644 index 0000000000000000000000000000000000000000..de47b56e219bc855aeabd36716a7035b9a159158 --- /dev/null +++ b/dnn/src/arm_common/lstm/lstm_utils.cpp @@ -0,0 +1,107 @@ +/** + * \file dnn/src/arm_common/lstm/lstm_utils.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "./lstm_utils.h" +#include "src/arm_common/lstm/opr_impl.h" +#include "src/arm_common/lstm_cell/cell_kernel.h" +#include "src/arm_common/lstm_cell/opr_impl.h" +#include "src/naive/handle.h" + +using namespace megdnn; +using namespace arm_common; + +LstmCellWeight::LstmCellWeight( + RefPtr weight_ptr, size_t hidden_size, size_t input_size, bool has_bias, + DType dtype) { + // weight_ih: [gate_hidden_size, input_size] + // weight_hh: [gate_hidden_size, hidden_size] + // bias_ih: [gate_hidden_size] + // bias_hh: [gate_hidden_size] + + size_t gate_hidden_size = 4 * hidden_size; + TensorLayout weight_ih_layout{{gate_hidden_size, input_size}, dtype}; + TensorLayout weight_hh_layout{{gate_hidden_size, hidden_size}, dtype}; + TensorLayout bias_layout{{gate_hidden_size}, dtype}; + m_weight_size = 0; + m_weight_ih = TensorND(weight_ih_layout, weight_ptr); + m_weight_size += weight_ih_layout.span().dist_byte(); + weight_ptr += weight_ih_layout.span().dist_byte(); + m_weight_hh = TensorND(weight_hh_layout, weight_ptr); + m_weight_size += weight_hh_layout.span().dist_byte(); + weight_ptr += weight_hh_layout.span().dist_byte(); + if (has_bias) { + m_bias_ih = TensorND(bias_layout, weight_ptr); + m_weight_size += bias_layout.span().dist_byte(); + weight_ptr += bias_layout.span().dist_byte(); + m_bias_hh = TensorND(bias_layout, weight_ptr); + m_weight_size += bias_layout.span().dist_byte(); + } +} + +LstmStates::LstmStates( + const SmallVector ptr, size_t hidden_size, size_t batch_size, + DType dtype) { + auto& h_ptr = ptr[0]; + auto& c_ptr = ptr[1]; + TensorLayout layout{{batch_size, hidden_size}, dtype}; + m_h = TensorND(layout, h_ptr); + m_c = TensorND(layout, c_ptr); + m_memory_size = layout.span().dist_byte(); +} + +TensorNDArray megdnn::arm_common::split_tensor( + _megdnn_tensor_in tensor, size_t nr_tensor, const TensorLayout& layout) { + megdnn_assert( + tensor.layout.span().dist_byte() == nr_tensor * layout.span().dist_byte()); + + TensorNDArray tensors; + auto ptr = tensor.get_ref_ptr(); + for (size_t i = 0; i < nr_tensor; i++) { + tensors.push_back(TensorND(layout, ptr)); + ptr += layout.span().dist_byte(); + } + return tensors; +} + +namespace megdnn { +namespace arm_common { + +template <> +void cell_opr_compute( + _megdnn_tensor_in input, _megdnn_tensor_in weight_ih, + _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_ih, + _megdnn_tensor_in bias_hh, const LstmStates& state_in, LstmStates& state_out, + Workspace cell_workspace, Handle* handle) { + auto opr = handle->create_operator(); + TensorLayout gates, h_new, c_new; + opr->deduce_layout( + input.layout, weight_ih.layout, bias_ih.layout, state_in.m_h.layout, + weight_hh.layout, bias_hh.layout, state_in.m_c.layout, h_new, c_new, gates); + + auto workspace_bundle = LstmCellCompute::get_workspace_bundle( + input.layout, weight_ih.layout, bias_ih.layout, state_in.m_h.layout, + weight_hh.layout, bias_hh.layout, state_in.m_c.layout, h_new, c_new, gates); + + workspace_bundle.set(cell_workspace.raw_ptr); + + TensorND gates_tensor{workspace_bundle.get(0), gates}; + _megdnn_workspace new_workspace = { + static_cast(workspace_bundle.get(1)), + workspace_bundle.get_size(1)}; + + LstmCellCompute::run( + input, weight_ih, bias_ih, state_in.m_h, weight_hh, bias_hh, state_in.m_c, + state_out.m_h, state_out.m_c, gates_tensor, new_workspace, handle); +} +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/lstm/lstm_utils.h b/dnn/src/arm_common/lstm/lstm_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..4eb3c1e0eea79f39c2426f1f6ad2b3f668d371b1 --- /dev/null +++ b/dnn/src/arm_common/lstm/lstm_utils.h @@ -0,0 +1,259 @@ +/** + * \file dnn/src/arm_common/lstm/lstm_utils.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "src/arm_common/lstm_cell/cell_kernel.h" +#include "src/common/opr_delegate.h" +#include "src/common/utils.h" +#include "src/naive/handle.h" +#include "src/naive/lstm/opr_impl.h" + +namespace megdnn { +namespace arm_common { + +template +void cell_opr_compute( + _megdnn_tensor_in step_input, _megdnn_tensor_in weight_ih, + _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_ih, + _megdnn_tensor_in bias_hh, const States& state_in, States& state_out, + Workspace cell_workspace, Handle* handle); + +struct LstmCellWeight { + size_t m_weight_size = 0; + + TensorND m_weight_ih, m_weight_hh, m_bias_ih, m_bias_hh; + // if no bias, will create dummy bias tensor from workspace + LstmCellWeight( + RefPtr weight_ptr, size_t hidden_size, size_t input_size, bool has_bias, + DType dtype); +}; + +struct LstmStates { + static size_t nr_states() { return 2; } + size_t m_memory_size; + TensorND m_h, m_c; + LstmStates( + const SmallVector ptr, size_t hidden_size, size_t batch_size, + DType dtype); +}; + +TensorNDArray split_tensor( + _megdnn_tensor_in tensor, size_t nr_tensor, const TensorLayout& layout); + +template +SmallVector get_all_cells( + size_t dir_size, size_t num_layers, size_t input_size, size_t hidden_size, + bool bias, _megdnn_tensor_in flatten_weights) { + SmallVector cell_weights; + cell_weights.reserve(dir_size * num_layers); + auto weight_ptr = flatten_weights.get_ref_ptr(); + for (size_t layer = 0; layer < num_layers; ++layer) { + for (size_t d = 0; d < dir_size; ++d) { + size_t cell_input_size = layer == 0 ? input_size : dir_size * hidden_size; + CellWeight cell_weight( + weight_ptr, hidden_size, cell_input_size, bias, + flatten_weights.layout.dtype); + weight_ptr += cell_weight.m_weight_size; + cell_weights.push_back(cell_weight); + } + } + return cell_weights; +} + +template +SmallVector get_all_status( + _megdnn_tensor_in hx, _megdnn_tensor_in cx, size_t hidden_size, + size_t batch_size, size_t num_layers, size_t dir_size, DType dtype) { + SmallVector states; + auto hx_ptr = hx.get_ref_ptr(); + auto cx_ptr = cx.get_ref_ptr(); + for (size_t layer = 0; layer < num_layers * dir_size; ++layer) { + States state({hx_ptr, cx_ptr}, hidden_size, batch_size, dtype); + hx_ptr += state.m_memory_size; + cx_ptr += state.m_memory_size; + states.push_back(state); + } + return states; +} + +template +void exec_kernel( + SmallVector& cells, const TensorNDArray& inputs, + const SmallVector& states_in, SmallVector& states_out, + TensorNDArray& outputs, size_t num_layers, size_t dir_size, Handle* handle, + WorkspaceBundle workspace_bundle) { + megdnn_assert(cells.size() == num_layers * dir_size); + megdnn_assert( + states_in.size() == states_out.size() && + states_in.size() == num_layers * dir_size); + megdnn_assert(outputs.size() == inputs.size()); + //! two tmp state workspace + megdnn_assert(workspace_bundle.nr_workspace() == 4 + States::nr_states()); + + size_t seq_len = inputs.size(); + size_t batch_size = inputs[0].layout.shape[0]; + size_t input_size = inputs[0].layout.shape[1]; + size_t hidden_size = cells[0].m_weight_hh.layout.shape[1]; + + TensorLayout batch_output_layout{ + {hidden_size}, outputs[0].layout.dtype}; // output hy + TensorLayout cell_output_layout{ + {batch_size, hidden_size}, outputs[0].layout.dtype}; // output hy + TensorLayout seq_output_layout{ + {batch_size, dir_size * hidden_size}, outputs[0].layout.dtype}; + TensorLayout cell_first_input_layout{ + {batch_size, input_size}, inputs[0].layout.dtype}; // input + TensorLayout cell_input_layout{ + {batch_size, dir_size * hidden_size}, inputs[0].layout.dtype}; + TensorLayout tmp_output_layout{ + {seq_len, batch_size, dir_size * hidden_size}, outputs[0].layout.dtype}; + + //! workspace get + Workspace cell_workspace( + static_cast(workspace_bundle.get(0)), + workspace_bundle.get_size(0) + workspace_bundle.get_size(1)); + auto&& tmp_inputs_1 = split_tensor( + TensorND{workspace_bundle.get(2), tmp_output_layout}, seq_len, + cell_input_layout); + auto&& tmp_outputs_1 = split_tensor( + TensorND{workspace_bundle.get(2), tmp_output_layout}, seq_len, + seq_output_layout); + + auto&& tmp_inputs_2 = split_tensor( + TensorND{workspace_bundle.get(3), tmp_output_layout}, seq_len, + cell_input_layout); + auto&& tmp_outputs_2 = split_tensor( + TensorND{workspace_bundle.get(3), tmp_output_layout}, seq_len, + seq_output_layout); + + using IoPair = std::pair; + IoPair io_pair1 = {tmp_inputs_1, tmp_outputs_2}; + IoPair io_pair2 = {tmp_inputs_2, tmp_outputs_1}; + SmallVector io_pairs = {io_pair1, io_pair2}; + + SmallVector ptr; + for (size_t index = 0; index < States::nr_states(); index++) { + ptr.push_back(workspace_bundle.get(4 + index)); + } + auto&& tmp_state = States(ptr, hidden_size, batch_size, outputs[0].layout.dtype); + + for (size_t layer = 0; layer < num_layers; layer++) { + auto layer_inputs = io_pairs[layer % 2].first; + auto layer_outputs = io_pairs[layer % 2].second; + + //! if last layer, direct write to output tensors + if (num_layers - 1 == layer) { + layer_outputs = outputs; + } + if (0 == layer) { + layer_inputs = inputs; + } + for (size_t d = 0; d < dir_size; ++d) { + size_t cell_idx = layer * dir_size + d; + auto& cell = cells[cell_idx]; + auto& state_in_origin = states_in[cell_idx]; + auto& state_out_origin = states_out[cell_idx]; + + auto state_in = state_in_origin; + auto state_out = tmp_state; + + for (size_t i = 0; i < seq_len; ++i) { + size_t step = d == 0 ? i : seq_len - 1 - i; + auto& step_input = layer_inputs[step]; + auto& step_output = layer_outputs[step]; + + if (i == seq_len - 1) { + state_out = state_out_origin; + } + //! task 1 + //! this CellOp will dispatch task inner, so here not dispatch task + cell_opr_compute( + step_input, cell.m_weight_ih, cell.m_weight_hh, cell.m_bias_ih, + cell.m_bias_hh, state_in, state_out, cell_workspace, handle); + //! task 2 + //! copy output to continue space + auto copy_to_output = [=]() { + //! if dir_size >1 and batch_size > 1, recorder to output + size_t stride = batch_output_layout.span().dist_byte(); + if (dir_size > 1 && batch_size > 1) { + int8_t* source = static_cast(state_out.m_h.raw_ptr()); + int8_t* dst = static_cast(step_output.raw_ptr()) + + d * stride; + for (size_t b = 0; b < batch_size; b++) { + memcpy(dst, source, stride); + source += stride; + dst += dir_size * stride; + } + } else { + void* source = state_out.m_h.raw_ptr(); + int8_t* dst = static_cast(step_output.raw_ptr()) + + d * stride; + memcpy(dst, source, state_out.m_h.layout.span().dist_byte()); + } + }; + MEGDNN_DISPATCH_CPU_KERN( + static_cast(handle), copy_to_output()); + + //! state_in and state_out are read and write inplace + if (0 == i) { + state_in = tmp_state; + } + } + } + } +} + +template +WorkspaceBundle get_workspace_bundle( + const TensorLayout& input, const TensorLayout& output, + const TensorLayout& flatten_weights, size_t hidden_size, size_t dir_size, + size_t states_size) { + size_t batch_size = input.shape[1]; + size_t input_size = input.shape[2]; + size_t gate_hidden_size = flatten_weights.shape[0]; + + // cell workspace + TensorLayout weight_ih{{gate_hidden_size, input_size}, flatten_weights.dtype}; + TensorLayout weight_hh{ + {gate_hidden_size, dir_size * hidden_size}, flatten_weights.dtype}; + TensorLayout bias{{1, gate_hidden_size}, flatten_weights.dtype}; + TensorLayout hx{{batch_size, dir_size * hidden_size}, input.dtype}; + + auto cell_opr = inplace_cpu_handle()->create_operator(); + + TensorLayout h_new, c_new, gates; + cell_opr->deduce_layout( + input, weight_ih, bias, hx, weight_hh, bias, hx, h_new, c_new, gates); + + SmallVector workspaces; + //! the cell opr compute workspace + size_t cell_opr_workspace = cell_opr->get_workspace_in_bytes( + input, weight_ih, bias, hx, weight_hh, bias, hx, h_new, c_new, gates); + workspaces.push_back(gates.span().dist_byte()); + workspaces.push_back(cell_opr_workspace); + //! double tmp output memory + size_t tmp_output_workspace = output.span().dist_byte(); + workspaces.push_back(tmp_output_workspace); + workspaces.push_back(tmp_output_workspace); + + //! tmp states memory + size_t tmp_state_workspace = hx.span().dist_byte(); + for (size_t i = 0; i < states_size; i++) { + workspaces.push_back(tmp_state_workspace); + } + return {nullptr, workspaces}; +} + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/lstm/opr_impl.cpp b/dnn/src/arm_common/lstm/opr_impl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4bd87381f8ab49a264d1a39869868a810c5d6ae2 --- /dev/null +++ b/dnn/src/arm_common/lstm/opr_impl.cpp @@ -0,0 +1,83 @@ +/** + * \file dnn/src/arm_common/lstm/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/arm_common/lstm/opr_impl.h" +#include "./lstm_utils.h" +#include "src/arm_common/lstm_cell/opr_impl.h" +#include "src/naive/handle.h" + +#include "midout.h" + +MIDOUT_DECL(megdnn_arm_common_lstm) + +using namespace megdnn; +using namespace arm_common; + +void LSTMImpl::exec( + _megdnn_tensor_in input, _megdnn_tensor_in hx, _megdnn_tensor_in cx, + _megdnn_tensor_in flatten_weights, _megdnn_tensor_out output, + _megdnn_tensor_out hy, _megdnn_tensor_out cy, _megdnn_tensor_out, + _megdnn_workspace workspace) { + MIDOUT_BEGIN(megdnn_arm_common_lstm, midout_iv(0)) { + size_t dir_size = param().bidirectional ? 2 : 1; + size_t num_layers = param().num_layers; + size_t hidden_size = param().hidden_size; + + size_t seq_len = input.layout.shape[0]; + size_t batch_size = input.layout.shape[1]; + size_t input_size = input.layout.shape[2]; + + //! in order to support input ptr change in record, so this task should be + //! dispatch to device + auto&& cell_weights = get_all_cells( + dir_size, num_layers, input_size, hidden_size, param().bias, + flatten_weights); + auto&& cell_states_in = get_all_status( + hx, cx, hidden_size, batch_size, num_layers, dir_size, hx.layout.dtype); + auto&& cell_states_out = get_all_status( + hy, cy, hidden_size, batch_size, num_layers, dir_size, hy.layout.dtype); + auto&& inputs = split_tensor( + input, seq_len, + TensorLayout{{batch_size, input_size}, input.layout.dtype}); + auto&& outputs = split_tensor( + output, seq_len, + TensorLayout{ + {batch_size, dir_size * hidden_size}, output.layout.dtype}); + + auto workspace_bundle = get_workspace_bundle( + input.layout, output.layout, flatten_weights.layout, hidden_size, + dir_size, LstmStates::nr_states()); + + workspace_bundle.set(workspace.raw_ptr); + exec_kernel( + cell_weights, inputs, cell_states_in, cell_states_out, outputs, + num_layers, dir_size, handle(), workspace_bundle); + } + MIDOUT_END(); +} + +size_t LSTMImpl::get_workspace_in_bytes( + const TensorLayout& input, const TensorLayout&, const TensorLayout&, + const TensorLayout& flatten_weights, const TensorLayout& output, + const TensorLayout&, const TensorLayout&, const TensorLayout&) { + MIDOUT_BEGIN(megdnn_arm_common_lstm, midout_iv(1)) { + size_t dir_size = param().bidirectional ? 2 : 1; + size_t hidden_size = param().hidden_size; + + auto bundle = get_workspace_bundle( + input, output, flatten_weights, hidden_size, dir_size, + LstmStates::nr_states()); + return bundle.total_size_in_bytes(); + } + MIDOUT_END(); +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/lstm/opr_impl.h b/dnn/src/arm_common/lstm/opr_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..33489934b90d5abe74c4ce3c407a7b88ed0dd48f --- /dev/null +++ b/dnn/src/arm_common/lstm/opr_impl.h @@ -0,0 +1,43 @@ +/** + * \file dnn/src/arm_common/lstm/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "src/common/utils.h" +#include "src/naive/lstm/opr_impl.h" + +namespace megdnn { +namespace arm_common { + +class LSTMImpl : public naive::LSTMImpl { +public: + using naive::LSTMImpl::LSTMImpl; + void exec( + _megdnn_tensor_in input, _megdnn_tensor_in hx, _megdnn_tensor_in cx, + _megdnn_tensor_in flatten_weights, _megdnn_tensor_out output, + _megdnn_tensor_out hy, _megdnn_tensor_out cy, + _megdnn_tensor_out reserve_space, _megdnn_workspace workspace) override; + + size_t get_workspace_in_bytes( + const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx, + const TensorLayout& flatten_weights, const TensorLayout& output, + const TensorLayout& hy, const TensorLayout& cy, + const TensorLayout& reserve_space) override; + + //! in arm_common only store the output tensor, other tensor is only + //! used in computing grad, so arm ignore them + size_t get_reserve_size_in_bytes(const TensorLayout&) override { return 1; } +}; + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/lstm_cell/cell_kernel.cpp b/dnn/src/arm_common/lstm_cell/cell_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a177e98f3ad81db819b97160c60575442e61855f --- /dev/null +++ b/dnn/src/arm_common/lstm_cell/cell_kernel.cpp @@ -0,0 +1,273 @@ +/** + * \file dnn/src/arm_common/lstm_cell/cell_kernel.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "./cell_kernel.h" +#include "src/arm_common/lstm_cell/opr_impl.h" +#include "src/common/lstm_cell.h" +#include "src/common/opr_delegate.h" +#include "src/naive/handle.h" + +#include "src/arm_common/elemwise_helper/kimpl/sigmoid.h" +#include "src/arm_common/elemwise_helper/kimpl/tanh.h" + +using namespace megdnn; +using namespace arm_common; + +namespace { +template +struct ElemwiseCompute { + static Op op; + + static inline float32x4x2_t compute_8( + float* dst, float* tmp, float* ih, float* hh) { + float32x4_t dst0 = vld1q_f32(dst); + float32x4_t dst1 = vld1q_f32(dst + 4); + float32x4_t tmp0 = vld1q_f32(tmp); + float32x4_t tmp1 = vld1q_f32(tmp + 4); + + auto mid0 = vaddq_f32(dst0, tmp0); + auto mid1 = vaddq_f32(dst1, tmp1); + float32x4_t out0, out1; + if (bias) { + float32x4_t ih0 = vld1q_f32(ih); + float32x4_t ih1 = vld1q_f32(ih + 4); + float32x4_t hh0 = vld1q_f32(hh); + float32x4_t hh1 = vld1q_f32(hh + 4); + auto midd0 = vaddq_f32(ih0, hh0); + auto midd1 = vaddq_f32(ih1, hh1); + + out0 = vaddq_f32(mid0, midd0); + out1 = vaddq_f32(mid1, midd1); + } else { + out0 = mid0; + out1 = mid1; + } + return {{op(out0), op(out1)}}; + } + + static inline float32x4_t compute_4(float* dst, float* tmp, float* ih, float* hh) { + float32x4_t dst0 = vld1q_f32(dst); + float32x4_t tmp0 = vld1q_f32(tmp); + + auto mid0 = vaddq_f32(dst0, tmp0); + float32x4_t out0; + if (bias) { + float32x4_t ih0 = vld1q_f32(ih); + float32x4_t hh0 = vld1q_f32(hh); + auto midd0 = vaddq_f32(ih0, hh0); + + out0 = vaddq_f32(mid0, midd0); + } else { + out0 = mid0; + } + return op(out0); + } + + static inline float compute_1(float* dst, float* tmp, float* ih, float* hh) { + float out; + if (bias) { + out = dst[0] + tmp[0] + ih[0] + hh[0]; + } else { + out = dst[0] + tmp[0]; + } + return op(out); + } +}; + +template +Op ElemwiseCompute::op = Op(); + +template +void rnn_cell_elemwise_compute( + _megdnn_tensor_out dst, _megdnn_tensor_in tmp, _megdnn_tensor_in bias_ih, + _megdnn_tensor_in bias_hh, _megdnn_tensor_in cx, _megdnn_tensor_out h_new, + _megdnn_tensor_out c_new) { + size_t batch = dst.layout[0]; + size_t batch_length = dst.layout.total_nr_elems() / batch; + size_t base_length = batch_length / 4; + float *ih_ptr_ = nullptr, *hh_ptr_ = nullptr; + float* dst_ptr_ = dst.ptr(); + float* tmp_ptr_ = tmp.ptr(); + if (bias) { + ih_ptr_ = bias_ih.ptr(); + hh_ptr_ = bias_hh.ptr(); + } + float* cx_ptr_ = cx.ptr(); + float* h_new_ptr_ = h_new.ptr(); + float* c_new_ptr_ = c_new.ptr(); + + ElemwiseCompute, bias> sigmoid_compute; + ElemwiseCompute, bias> tanh_compute; + TanhOp tanh_op; + + for (size_t b = 0; b < batch; b++) { + float* dst_ptr = dst_ptr_ + b * batch_length; + float* tmp_ptr = tmp_ptr_ + b * batch_length; + float* ih_ptr = ih_ptr_; + float* hh_ptr = hh_ptr_; + float* cx_ptr = cx_ptr_ + b * base_length; + float* h_new_ptr = h_new_ptr_ + b * base_length; + float* c_new_ptr = c_new_ptr_ + b * base_length; + size_t index = 0; + for (; index + 7 < base_length; index += 8) { + auto out_i = sigmoid_compute.compute_8(dst_ptr, tmp_ptr, ih_ptr, hh_ptr); + auto out_f = sigmoid_compute.compute_8( + dst_ptr + base_length, tmp_ptr + base_length, ih_ptr + base_length, + hh_ptr + base_length); + auto out_g = tanh_compute.compute_8( + dst_ptr + 2 * base_length, tmp_ptr + 2 * base_length, + ih_ptr + 2 * base_length, hh_ptr + 2 * base_length); + auto out_o = sigmoid_compute.compute_8( + dst_ptr + 3 * base_length, tmp_ptr + 3 * base_length, + ih_ptr + 3 * base_length, hh_ptr + 3 * base_length); + float32x4_t cx_0 = vld1q_f32(cx_ptr); + float32x4_t cx_1 = vld1q_f32(cx_ptr + 4); + + //! f * cx + i * g + auto c_new_0 = vaddq_f32( + vmulq_f32(out_f.val[0], cx_0), + vmulq_f32(out_i.val[0], out_g.val[0])); + auto c_new_1 = vaddq_f32( + vmulq_f32(out_f.val[1], cx_1), + vmulq_f32(out_i.val[1], out_g.val[1])); + vst1q_f32(c_new_ptr, c_new_0); + vst1q_f32(c_new_ptr + 4, c_new_1); + + auto h_new_0 = vmulq_f32(tanh_op(c_new_0), out_o.val[0]); + auto h_new_1 = vmulq_f32(tanh_op(c_new_1), out_o.val[1]); + + vst1q_f32(h_new_ptr, h_new_0); + vst1q_f32(h_new_ptr + 4, h_new_1); + + dst_ptr += 8; + tmp_ptr += 8; + ih_ptr += 8; + hh_ptr += 8; + cx_ptr += 8; + c_new_ptr += 8; + h_new_ptr += 8; + } + for (; index + 3 < base_length; index += 4) { + auto out_i = sigmoid_compute.compute_4(dst_ptr, tmp_ptr, ih_ptr, hh_ptr); + auto out_f = sigmoid_compute.compute_4( + dst_ptr + base_length, tmp_ptr + base_length, ih_ptr + base_length, + hh_ptr + base_length); + auto out_g = tanh_compute.compute_4( + dst_ptr + 2 * base_length, tmp_ptr + 2 * base_length, + ih_ptr + 2 * base_length, hh_ptr + 2 * base_length); + auto out_o = sigmoid_compute.compute_4( + dst_ptr + 3 * base_length, tmp_ptr + 3 * base_length, + ih_ptr + 3 * base_length, hh_ptr + 3 * base_length); + float32x4_t cx_v = vld1q_f32(cx_ptr); + + //! f * cx + i * g + auto c_new = vaddq_f32(vmulq_f32(out_f, cx_v), vmulq_f32(out_i, out_g)); + vst1q_f32(c_new_ptr, c_new); + + auto h_new = vmulq_f32(tanh_op(c_new), out_o); + + vst1q_f32(h_new_ptr, h_new); + + dst_ptr += 4; + tmp_ptr += 4; + ih_ptr += 4; + hh_ptr += 4; + cx_ptr += 4; + c_new_ptr += 4; + h_new_ptr += 4; + } + for (; index < base_length; index++) { + auto out_i = sigmoid_compute.compute_1(dst_ptr, tmp_ptr, ih_ptr, hh_ptr); + auto out_f = sigmoid_compute.compute_1( + dst_ptr + base_length, tmp_ptr + base_length, ih_ptr + base_length, + hh_ptr + base_length); + auto out_g = tanh_compute.compute_1( + dst_ptr + 2 * base_length, tmp_ptr + 2 * base_length, + ih_ptr + 2 * base_length, hh_ptr + 2 * base_length); + auto out_o = sigmoid_compute.compute_1( + dst_ptr + 3 * base_length, tmp_ptr + 3 * base_length, + ih_ptr + 3 * base_length, hh_ptr + 3 * base_length); + c_new_ptr[0] = out_f * cx_ptr[0] + out_i * out_g; + h_new_ptr[0] = tanh_op(c_new_ptr[0]) * out_o; + + dst_ptr += 1; + tmp_ptr += 1; + ih_ptr += 1; + hh_ptr += 1; + cx_ptr += 1; + c_new_ptr += 1; + h_new_ptr += 1; + } + } +} +} // namespace + +void LstmCellCompute::run( + _megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih, + _megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, + _megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new, + _megdnn_tensor_out gates, _megdnn_workspace workspace, Handle* handle) { + auto bundle = get_workspace_bundle( + input.layout, weight_ih.layout, bias_ih.layout, hx.layout, weight_hh.layout, + bias_hh.layout, cx.layout, h_new.layout, c_new.layout, gates.layout); + bundle.set(workspace.raw_ptr); + TensorND tmp{static_cast(bundle.get(0)), gates.layout}; + auto matmul_workspace = + megdnn::Workspace{static_cast(bundle.get(1)), bundle.get_size(1)}; + auto opr = handle->create_operator(); + opr->param().transposeB = true; + //! the opr will dispatch compute task to device, so record mode + //! performance will not be effect + opr->exec(input, weight_ih, tmp, matmul_workspace); + opr->exec(hx, weight_hh, gates, matmul_workspace); + + //! the optimized post compute, nonlinear(tmp + dst + bias_hx + bias_cx) + if (bias_ih.layout.ndim != 0 && bias_ih.layout.ndim != 0) { + MEGDNN_DISPATCH_CPU_KERN( + static_cast(handle), + rnn_cell_elemwise_compute( + gates, tmp, bias_ih, bias_hh, cx, h_new, c_new)); + } else { + megdnn_assert(bias_ih.layout.ndim == 0 && bias_ih.layout.ndim == 0); + MEGDNN_DISPATCH_CPU_KERN( + static_cast(handle), + rnn_cell_elemwise_compute( + gates, tmp, bias_ih, bias_hh, cx, h_new, c_new)); + } +} + +WorkspaceBundle LstmCellCompute::get_workspace_bundle( + const TensorLayout& input, const TensorLayout& weight_ih, const TensorLayout&, + const TensorLayout& hx, const TensorLayout& weight_hh, const TensorLayout&, + const TensorLayout&, const TensorLayout&, const TensorLayout&, + const TensorLayout& gates) { + auto opr = inplace_cpu_handle()->create_operator(); + opr->param().transposeB = true; + size_t matmul_workspace = std::max( + opr->get_workspace_in_bytes(input, weight_ih, gates), + opr->get_workspace_in_bytes(hx, weight_hh, gates)); + return WorkspaceBundle{nullptr, {gates.span().dist_byte(), matmul_workspace}}; +} + +bool LstmCellCompute::is_optimized( + const TensorLayout& input, const TensorLayout&, const TensorLayout& bias_ih, + const TensorLayout&, const TensorLayout&, const TensorLayout& bias_hh, + const TensorLayout&, const TensorLayout&, const TensorLayout&, + const TensorLayout& gates) { + if (input.dtype.enumv() == DTypeEnum::Float32 && gates[1] == bias_ih[1] && + bias_ih[0] == 1 && bias_ih.eq_layout(bias_hh)) { + return true; + } else { + return false; + } +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/lstm_cell/cell_kernel.h b/dnn/src/arm_common/lstm_cell/cell_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..ea2dca29ecc61405df4991b1cce3f603e11d79ed --- /dev/null +++ b/dnn/src/arm_common/lstm_cell/cell_kernel.h @@ -0,0 +1,47 @@ +/** + * \file dnn/src/arm_common/lstm_cell/cell_kernel.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "src/common/utils.h" +#include "src/naive/handle.h" +#include "src/naive/lstm_cell/opr_impl.h" + +namespace megdnn { +namespace arm_common { + +struct LstmCellCompute { + static WorkspaceBundle get_workspace_bundle( + const TensorLayout& input, const TensorLayout& weight_ih, + const TensorLayout& bias_ih, const TensorLayout& hx, + const TensorLayout& weight_hh, const TensorLayout& bias_hh, + const TensorLayout& cx, const TensorLayout& h_new, + const TensorLayout& c_new, const TensorLayout& gates); + + static bool is_optimized( + const TensorLayout& input, const TensorLayout& weight_ih, + const TensorLayout& bias_ih, const TensorLayout& hx, + const TensorLayout& weight_hh, const TensorLayout& bias_hh, + const TensorLayout& cx, const TensorLayout& h_new, + const TensorLayout& c_new, const TensorLayout& gates); + + static void run( + _megdnn_tensor_in input, _megdnn_tensor_in weight_ih, + _megdnn_tensor_in bias_ih, _megdnn_tensor_in hx, + _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, + _megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new, + _megdnn_tensor_out gates, _megdnn_workspace workspace, Handle* handle); +}; + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/lstm_cell/opr_impl.cpp b/dnn/src/arm_common/lstm_cell/opr_impl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..29bcdc68bab0c8d76f67fa4457575327bbe3be96 --- /dev/null +++ b/dnn/src/arm_common/lstm_cell/opr_impl.cpp @@ -0,0 +1,71 @@ +/** + * \file dnn/src/arm_common/lstm_cell/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/arm_common/lstm_cell/opr_impl.h" +#include "src/common/lstm_cell.h" +#include "src/naive/handle.h" + +#include "./cell_kernel.h" + +#include "midout.h" + +MIDOUT_DECL(megdnn_arm_common_lstm_cell) + +using namespace megdnn; +using namespace arm_common; + +void LSTMCellImpl::exec( + _megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih, + _megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, + _megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new, + _megdnn_tensor_out gates, _megdnn_workspace workspace) { + //! only float32 and {1, xx} shape bias will be optimized + MIDOUT_BEGIN(megdnn_arm_common_lstm_cell, midout_iv(0)) { + if (!LstmCellCompute::is_optimized( + input.layout, weight_ih.layout, bias_ih.layout, hx.layout, + weight_hh.layout, bias_hh.layout, cx.layout, h_new.layout, + c_new.layout, gates.layout)) { + naive::LSTMCellImpl::exec( + input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, + gates, workspace); + } else { + LstmCellCompute::run( + input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, + gates, workspace, handle()); + } + } + MIDOUT_END(); +} + +size_t LSTMCellImpl::get_workspace_in_bytes( + const TensorLayout& input, const TensorLayout& weight_ih, + const TensorLayout& bias_ih, const TensorLayout& hx, + const TensorLayout& weight_hh, const TensorLayout& bias_hh, + const TensorLayout& cx, const TensorLayout& h_new, const TensorLayout& c_new, + const TensorLayout& gates) { + MIDOUT_BEGIN(megdnn_arm_common_lstm_cell, midout_iv(1)) { + if (!LstmCellCompute::is_optimized( + input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, + gates)) { + return naive::LSTMCellImpl::get_workspace_in_bytes( + input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, + gates); + } else { + return LstmCellCompute::get_workspace_bundle( + input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, + c_new, gates) + .total_size_in_bytes(); + } + } + MIDOUT_END(); +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/lstm_cell/opr_impl.h b/dnn/src/arm_common/lstm_cell/opr_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..fbc3d693160074c0ec6c24d70fa35ecf4cbd4b6d --- /dev/null +++ b/dnn/src/arm_common/lstm_cell/opr_impl.h @@ -0,0 +1,41 @@ +/** + * \file dnn/src/arm_common/lstm_cell/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "src/common/utils.h" +#include "src/naive/lstm_cell/opr_impl.h" + +namespace megdnn { +namespace arm_common { + +class LSTMCellImpl : public naive::LSTMCellImpl { +public: + using naive::LSTMCellImpl::LSTMCellImpl; + void exec( + _megdnn_tensor_in input, _megdnn_tensor_in weight_ih, + _megdnn_tensor_in bias_ih, _megdnn_tensor_in hx, + _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, + _megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new, + _megdnn_tensor_out gates, _megdnn_workspace workspace) override; + + size_t get_workspace_in_bytes( + const TensorLayout& input, const TensorLayout& weight_ih, + const TensorLayout& bias_ih, const TensorLayout& hx, + const TensorLayout& weight_hh, const TensorLayout& bias_hh, + const TensorLayout& cx, const TensorLayout& h_new, + const TensorLayout& c_new, const TensorLayout& gates) override; +}; + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/rnn_cell/opr_impl.cpp b/dnn/src/arm_common/rnn_cell/opr_impl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..74c4ea849c7fb554b88e9b9a7d9421c538aaab10 --- /dev/null +++ b/dnn/src/arm_common/rnn_cell/opr_impl.cpp @@ -0,0 +1,218 @@ +/** + * \file dnn/src/arm_common/rnn_cell/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/arm_common/rnn_cell/opr_impl.h" +#include "src/common/utils.h" +#include "src/naive/handle.h" + +#include "src/arm_common/elemwise_helper/kimpl/none.h" +#include "src/arm_common/elemwise_helper/kimpl/relu.h" +#include "src/arm_common/elemwise_helper/kimpl/tanh.h" + +#include "midout.h" + +MIDOUT_DECL(megdnn_arm_common_rnn_cell) + +using namespace megdnn; +using namespace arm_common; + +namespace { +ElemwiseForward* get_elemwise_opr() { + static CpuOprDelegationStorage<1> storage; + return storage.get(); +} + +template +void elemwise_compute( + float* dst_ptr, float* tmp_ptr, float* ih_ptr, float* hh_ptr, size_t batch, + size_t length) { + const constexpr size_t SIMD_8 = 8; + const constexpr size_t SIMD_4 = 4; + Op op; + for (size_t b = 0; b < batch; b++) { + float* dst = dst_ptr + b * length; + float* tmp = tmp_ptr + b * length; + float* ih = ih_ptr; + float* hh = hh_ptr; + size_t index = 0; + for (; index + SIMD_8 - 1 < length; index += SIMD_8) { + float32x4_t dst0 = vld1q_f32(dst); + float32x4_t dst1 = vld1q_f32(dst + 4); + + float32x4_t tmp0 = vld1q_f32(tmp); + float32x4_t tmp1 = vld1q_f32(tmp + 4); + + float32x4_t ih0 = vld1q_f32(ih); + float32x4_t ih1 = vld1q_f32(ih + 4); + + float32x4_t hh0 = vld1q_f32(hh); + float32x4_t hh1 = vld1q_f32(hh + 4); + + auto mid0 = vaddq_f32(dst0, tmp0); + auto mid1 = vaddq_f32(dst1, tmp1); + auto midd0 = vaddq_f32(ih0, hh0); + auto midd1 = vaddq_f32(ih1, hh1); + auto out0 = vaddq_f32(mid0, midd0); + auto out1 = vaddq_f32(mid1, midd1); + + vst1q_f32(dst, op(out0)); + vst1q_f32(dst + 4, op(out1)); + + dst += SIMD_8; + tmp += SIMD_8; + ih += SIMD_8; + hh += SIMD_8; + } + for (; index + SIMD_4 - 1 < length; index += SIMD_4) { + float32x4_t dst0 = vld1q_f32(dst); + float32x4_t tmp0 = vld1q_f32(tmp); + float32x4_t ih0 = vld1q_f32(ih); + float32x4_t hh0 = vld1q_f32(hh); + + auto mid0 = vaddq_f32(dst0, tmp0); + auto midd0 = vaddq_f32(ih0, hh0); + auto out0 = vaddq_f32(mid0, midd0); + + vst1q_f32(dst, op(out0)); + + dst += SIMD_4; + tmp += SIMD_4; + ih += SIMD_4; + hh += SIMD_4; + } + for (; index < length; index++) { + auto out = dst[0] + tmp[0] + ih[0] + hh[0]; + dst[0] = op(out); + dst++; + tmp++; + ih++; + hh++; + } + } +} + +void rnn_cell_post_compute( + _megdnn_tensor_out dst, _megdnn_tensor_in tmp, _megdnn_tensor_in bias_ih, + _megdnn_tensor_in bias_hh, param::RNNCell::NonlineMode nonline_mode, + Handle* handle) { + using NonlineMode = param::RNNCell::NonlineMode; + megdnn_assert( + nonline_mode == NonlineMode::RELU || nonline_mode == NonlineMode::TANH || + nonline_mode == NonlineMode::IDENTITY, + "Now arm only support nonlinear mode Relu, TANH, IDENTITY."); + if (dst.layout.dtype.enumv() == DTypeEnum::Float32 && + dst.layout[1] == bias_ih.layout[1] && bias_ih.layout[0] == 1 && + bias_ih.layout.eq_layout(bias_hh.layout)) { + auto run = [=]() { + size_t batch = dst.layout[0]; + size_t length = bias_ih.layout.total_nr_elems(); + float* dst_ptr = dst.ptr(); + float* tmp_ptr = tmp.ptr(); + float* ih_ptr = bias_ih.ptr(); + float* hh_ptr = bias_hh.ptr(); + if (nonline_mode == NonlineMode::RELU) { + elemwise_compute>( + dst_ptr, tmp_ptr, ih_ptr, hh_ptr, batch, length); + } else if (nonline_mode == NonlineMode::TANH) { + elemwise_compute>( + dst_ptr, tmp_ptr, ih_ptr, hh_ptr, batch, length); + } else { + elemwise_compute>( + dst_ptr, tmp_ptr, ih_ptr, hh_ptr, batch, length); + } + }; + MEGDNN_DISPATCH_CPU_KERN(static_cast(handle), run()); + } else { + //! this opr must be created by inplace handle + auto elem_opr = get_elemwise_opr(); + auto run = [=]() { + elem_opr->param().mode = Elemwise::Param::Mode::ADD; + elem_opr->exec({dst, tmp}, dst); + elem_opr->exec({dst, bias_ih}, dst); + elem_opr->exec({dst, bias_hh}, dst); + // activation + switch (nonline_mode) { +#define cb(_mode) \ + case NonlineMode::_mode: { \ + elem_opr->param().mode = Elemwise::Param::Mode::_mode; \ + elem_opr->exec({dst}, dst); \ + break; \ + } + cb(RELU); + cb(TANH); +#undef cb + case NonlineMode::IDENTITY: + break; + default: + megdnn_throw("unsupport nonlinear mode."); + } + }; + MEGDNN_DISPATCH_CPU_KERN(static_cast(handle), run()); + } +} +} // namespace + +WorkspaceBundle RNNCellImpl::get_workspace_bundle( + const TensorLayout& input, const TensorLayout& weight_ih, const TensorLayout&, + const TensorLayout& hx, const TensorLayout& weight_hh, const TensorLayout&, + const TensorLayout& dst) { + MIDOUT_BEGIN(megdnn_arm_common_rnn_cell, midout_iv(0)) { + auto opr = handle()->create_operator(); + opr->param().transposeB = true; + auto matmul_workspace = std::max( + opr->get_workspace_in_bytes(input, weight_ih, dst), + opr->get_workspace_in_bytes(hx, weight_hh, dst)); + auto tmp_workspace = dst.span().dist_byte(); + return WorkspaceBundle{nullptr, {tmp_workspace, matmul_workspace}}; + } + MIDOUT_END(); +} + +size_t RNNCellImpl::get_workspace_in_bytes( + const TensorLayout& input, const TensorLayout& weight_ih, + const TensorLayout& bias_ih, const TensorLayout& hx, + const TensorLayout& weight_hh, const TensorLayout& bias_hh, + const TensorLayout& dst) { + return get_workspace_bundle(input, weight_ih, bias_ih, hx, weight_hh, bias_hh, dst) + .total_size_in_bytes(); +} + +void RNNCellImpl::exec( + _megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih, + _megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, + _megdnn_tensor_out dst, _megdnn_workspace workspace) { + MIDOUT_BEGIN(megdnn_arm_common_rnn_cell, midout_iv(1)) { + auto bundle = get_workspace_bundle( + input.layout, weight_ih.layout, bias_ih.layout, hx.layout, + weight_hh.layout, bias_hh.layout, dst.layout); + bundle.set(workspace.raw_ptr); + auto nonline_mode = param().nonlineMode; + + TensorND tmp{static_cast(bundle.get(0)), dst.layout}; + auto new_workspace = + Workspace{static_cast(bundle.get(1)), bundle.get_size(1)}; + //! this opr can't be created by inplace handle + auto opr = handle()->create_operator(); + + opr->param().transposeB = true; + //! the opr will dispatch compute task to device, so record mode + //! performance will not be effect + opr->exec(input, weight_ih, tmp, new_workspace); + opr->exec(hx, weight_hh, dst, new_workspace); + + //! the optimized post compute, nonlinear(tmp + dst + bias_hx + bias_cx) + + rnn_cell_post_compute(dst, tmp, bias_ih, bias_hh, nonline_mode, handle()); + } + MIDOUT_END(); +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/rnn_cell/opr_impl.h b/dnn/src/arm_common/rnn_cell/opr_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..4fd147ffb847fa04e13e16afe1a5eebbd6f0e3d4 --- /dev/null +++ b/dnn/src/arm_common/rnn_cell/opr_impl.h @@ -0,0 +1,43 @@ +/** + * \file dnn/src/arm_common/rnn_cell/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "src/common/opr_delegate.h" +#include "src/naive/rnn_cell/opr_impl.h" + +namespace megdnn { +namespace arm_common { + +class RNNCellImpl : public naive::RNNCellImpl { +public: + using naive::RNNCellImpl::RNNCellImpl; + void exec( + _megdnn_tensor_in input, _megdnn_tensor_in weight_ih, + _megdnn_tensor_in bias_ih, _megdnn_tensor_in hx, + _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, + _megdnn_tensor_out dst, _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout& input, const TensorLayout& weight_ih, + const TensorLayout& bias_ih, const TensorLayout& hx, + const TensorLayout& weight_hh, const TensorLayout& bias_hh, + const TensorLayout& dst) override; + +private: + WorkspaceBundle get_workspace_bundle( + const TensorLayout& input, const TensorLayout& weight_ih, + const TensorLayout& bias_ih, const TensorLayout& hx, + const TensorLayout& weight_hh, const TensorLayout& bias_hh, + const TensorLayout& dst); +}; + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/naive/rnn_cell/opr_impl.h b/dnn/src/naive/rnn_cell/opr_impl.h index 59e34cc341ffe0ea6cbc745603a39074025ec349..f79e8ef31cc5c9d6f6abc3cb51ad5fd9a3b09e9d 100644 --- a/dnn/src/naive/rnn_cell/opr_impl.h +++ b/dnn/src/naive/rnn_cell/opr_impl.h @@ -31,4 +31,6 @@ public: }; } // namespace naive -} // namespace megdnn \ No newline at end of file +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/test/arm_common/lstm.cpp b/dnn/test/arm_common/lstm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ddc45a104746434d466fe8328c4d535aab0a4d72 --- /dev/null +++ b/dnn/test/arm_common/lstm.cpp @@ -0,0 +1,225 @@ +/** + * \file dnn/test/arm_common/lstm.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "test/arm_common/fixture.h" + +#include "megdnn/oprs.h" +#include "megdnn/oprs/general.h" +#include "test/common/benchmarker.h" +#include "test/common/checker.h" +#include "test/common/task_record_check.h" + +using namespace megdnn; +using namespace test; + +namespace { +//! in arm_common the reserve tensor is not used +void output_canonizer(const CheckerHelper::TensorValueArray& arr) { + const TensorND& reserve = arr.back(); + TensorND& modif_reserve = const_cast(reserve); + modif_reserve.layout = TensorLayout(); +} +} // namespace + +TEST_F(ARM_COMMON, LSTMCell) { + Checker checker(handle()); + checker.set_output_canonizer(output_canonizer); + checker.exec( + {{1, 10}, + {40, 10}, + {1, 40}, + {1, 10}, + {40, 10}, + {1, 40}, + {1, 10}, + {}, + {}, + {}}); + for (size_t batch : {2}) + for (size_t n : {3, 4, 5, 23, 100}) + for (size_t out : {3, 6, 25, 100}) { + checker.exec( + {{batch, n}, + {out * 4, n}, + {1, out * 4}, + {batch, out}, + {out * 4, out}, + {1, out * 4}, + {batch, out}, + {}, + {}, + {}}); + checker.exec( + {{batch, n}, + {out * 4, n}, + {batch, out * 4}, + {batch, out}, + {out * 4, out}, + {batch, out * 4}, + {batch, out}, + {}, + {}, + {}}); + } +} + +TEST_F(ARM_COMMON, LSTMCellRecord) { + TaskRecordChecker checker(0); + checker.exec( + {{1, 10}, + {40, 10}, + {1, 40}, + {1, 10}, + {40, 10}, + {1, 40}, + {1, 10}, + {}, + {}, + {}}); +} + +namespace { +void test_lstm(bool bias, bool direction, Handle* handle) { + Checker checker(handle, true); + //! because lstm has tanh, exp mathematical compute, after more iteration, + //! the error will more than 1e-3 + checker.set_epsilon(1e-2); + checker.set_output_canonizer(output_canonizer); + for (size_t input_size : {2, 8, 13}) + for (size_t hidden_size : {1, 4, 17}) { + size_t dir_size = direction == false ? 1 : 2; + LSTM::Param param; + param.bidirectional = direction; + size_t gate_hidden_size = 4 * hidden_size; + param.bias = bias; + param.hidden_size = hidden_size; + for (size_t seq_len : {1, 3, 5}) + for (size_t batch_size : {1, 2, 4}) + for (size_t number_layer : {1, 2, 4, 5, 8}) { + size_t flatten_size = 0; + for (size_t layer = 0; layer < number_layer; layer++) { + for (size_t dir = 0; dir < dir_size; dir++) { + flatten_size += layer == 0 + ? input_size + : dir_size * hidden_size; // ih + flatten_size += hidden_size; // hh + } + } + if (bias) { + flatten_size += 2 * dir_size * number_layer; + } + param.num_layers = number_layer; + checker.set_param(param).exec( + {{seq_len, batch_size, input_size}, // input + {number_layer * dir_size, batch_size, + hidden_size}, // hx + {number_layer * dir_size, batch_size, + hidden_size}, // hy + {gate_hidden_size, flatten_size}, // flat weight + {}, + {}, + {}, + {}}); + } + } +} + +} // namespace + +TEST_F(ARM_COMMON, LSTM_FORWARD_NO_BIAS_NO_DIRCTION) { + test_lstm(false, false, handle()); +} + +TEST_F(ARM_COMMON, LSTM_FORWARD_BIAS_NO_DIRCTION) { + test_lstm(true, false, handle()); +} + +TEST_F(ARM_COMMON, LSTM_FORWARD_DIRECTION_NO_BIAS) { + test_lstm(false, true, handle()); +} + +TEST_F(ARM_COMMON, LSTM_FORWARD_DIRECTION_BIAS) { + test_lstm(true, true, handle()); +} + +TEST_F(ARM_COMMON, LSTM_FORWARD_RECORD) { + TaskRecordChecker checker(0); + size_t input_size = 2; + size_t hidden_size = 2; + size_t gate_hidden_size = 4 * hidden_size; + LSTM::Param param; + param.bidirectional = false; + param.bias = false; + param.hidden_size = hidden_size; + + // checker.set_output_canonizer(output_canonizer); + for (size_t seq_len : {1, 3, 5}) + for (size_t batch_size : {1, 2, 4}) + for (size_t number_layer : {1, 2, 4, 5, 8}) { + param.num_layers = number_layer; + checker.set_param(param).exec( + {{seq_len, batch_size, input_size}, // input + {number_layer, batch_size, hidden_size}, // hx + {number_layer, batch_size, hidden_size}, // hy + {number_layer, gate_hidden_size, + input_size + hidden_size}, // flat weight + {}, + {}, + {}, + {}}); + } +} + +#if MEGDNN_WITH_BENCHMARK + +TEST_F(ARM_COMMON, BENCHMARK_LSTM_FORWARD) { + Benchmarker optimized_bench(handle()); + constexpr size_t RUNS = 20; + auto run = [&](size_t hidden_size, size_t input_size) { + optimized_bench.set_times(20).set_display(true); + size_t gate_hidden_size = 4 * hidden_size; + for (bool direction : {false, true}) { + LSTM::Param param; + param.hidden_size = hidden_size; + param.bidirectional = direction; + param.bias = false; + size_t dir_size = direction == false ? 1 : 2; + for (size_t seq_len : {1, 5, 8}) + for (size_t batch_size : {1, 8, 16}) + for (size_t number_layer : {1}) { + param.num_layers = number_layer; + size_t flatten_size = 0; + for (size_t layer = 0; layer < number_layer; layer++) { + for (size_t dir = 0; dir < dir_size; dir++) { + flatten_size += layer == 0 + ? input_size + : dir_size * hidden_size; // ih + flatten_size += hidden_size; // hh + } + } + optimized_bench.set_param(param).exec( + {{seq_len, batch_size, input_size}, // input + {number_layer * dir_size, batch_size, + hidden_size}, // hx + {number_layer * dir_size, batch_size, + hidden_size}, // hy + {gate_hidden_size, flatten_size}, // flat weight + {}, + {}, + {}, + {}}); + } + } + }; + run(512, 256); +} + +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/test/arm_common/rnn.cpp b/dnn/test/arm_common/rnn.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ee935ca9fea94ecce2a3b5b543be8ccc08ddf779 --- /dev/null +++ b/dnn/test/arm_common/rnn.cpp @@ -0,0 +1,67 @@ +/** + * \file dnn/test/arm_common/rnn.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "test/arm_common/fixture.h" + +#include "megdnn/oprs.h" +#include "test/common/benchmarker.h" +#include "test/common/checker.h" +#include "test/common/task_record_check.h" + +using namespace megdnn; +using namespace test; + +TEST_F(ARM_COMMON, RNNCell) { + Checker checker(handle()); + using NonlineMode = param::RNNCell::NonlineMode; + param::RNNCell param; + for (auto mode : {NonlineMode::IDENTITY, NonlineMode::RELU, NonlineMode::TANH}) + for (size_t batch : {1, 4}) + for (size_t n : {3, 4, 5, 23, 100}) + for (size_t h : {5, 23, 100}) + for (size_t out : {3, 6, 25, 100}) { + param.nonlineMode = mode; + checker.set_param(param); + checker.exec( + {{batch, n}, + {out, n}, + {1, out}, + {batch, h}, + {out, h}, + {1, out}, + {}}); + checker.exec( + {{batch, n}, + {out, n}, + {batch, out}, + {batch, h}, + {out, h}, + {batch, out}, + {}}); + } +} + +TEST_F(ARM_COMMON, RNNCellRecord) { + TaskRecordChecker checker(0); + using NonlineMode = param::RNNCell::NonlineMode; + param::RNNCell param; + for (auto mode : {NonlineMode::IDENTITY, NonlineMode::RELU, NonlineMode::TANH}) { + param.nonlineMode = mode; + checker.set_param(param); + checker.exec({{1, 100}, {10, 100}, {1, 10}, {1, 100}, {10, 100}, {1, 10}, {}}); + checker.exec({{1, 34}, {15, 34}, {1, 15}, {1, 34}, {15, 34}, {1, 15}, {}}); + checker.exec({{1, 73}, {25, 73}, {1, 25}, {1, 73}, {25, 73}, {1, 25}, {}}); + } +} + +#if MEGDNN_WITH_BENCHMARK + +#endif +// vim: syntax=cpp.doxygen