提交 39d77fb5 编写于 作者: M Megvii Engine Team

feat(arm): add arm rnn_cell/lstm_cell/lstm optimized kernel

GitOrigin-RevId: b9bb7352bc1921c7bbc4168fb3cdea6c6bb955c5
上级 3ddc32d3
......@@ -30,18 +30,18 @@ struct TanhOpBase : UnaryOpBase<src_ctype, dst_ctype> {
template <typename src_ctype, typename dst_type = src_ctype>
struct TanhOp;
#define OP(_ctype, _neon_type, _func_suffix, _simd_width) \
#define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \
template <> \
struct TanhOp<_ctype> : TanhOpBase<_ctype> { \
using TanhOpBase::TanhOpBase; \
using TanhOpBase::operator(); \
constexpr static size_t SIMD_WIDTH = _simd_width; \
void operator()(const _neon_type& src, _ctype* dst) const { \
void operator()(const _neon_type2& src, _ctype* dst) const { \
auto vitem = operator()(src); \
vst1q_##_func_suffix(dst, vitem.val[0]); \
vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \
} \
_neon_type operator()(const _neon_type& src) const { \
_neon_type2 operator()(const _neon_type2& src) const { \
auto one_val = vdupq_n_##_func_suffix(1.f); \
auto two_val = vdupq_n_##_func_suffix(2.f); \
auto val1 = src.val[0]; \
......@@ -62,10 +62,23 @@ struct TanhOp;
val2 = vsubq_##_func_suffix(one_val, val2); \
return {{val1, val2}}; \
} \
_neon_type operator()(const _neon_type& src) const { \
auto one_val = vdupq_n_##_func_suffix(1.f); \
auto two_val = vdupq_n_##_func_suffix(2.f); \
auto val1 = src; \
val1 = vmulq_##_func_suffix(two_val, val1); \
val1 = exp_ps_##_func_suffix(val1); \
val1 = vaddq_##_func_suffix(one_val, val1); \
auto rval1 = vrecpeq_##_func_suffix(val1); \
rval1 = vmulq_##_func_suffix(vrecpsq_##_func_suffix(val1, rval1), rval1); \
val1 = vmulq_##_func_suffix(two_val, rval1); \
val1 = vsubq_##_func_suffix(one_val, val1); \
return val1; \
} \
};
OP(dt_float32, float32x4x2_t, f32, 4)
OP(dt_float32, float32x4_t, float32x4x2_t, f32, 4)
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
OP(__fp16, float16x8x2_t, f16, 8)
OP(__fp16, float16x8_t, float16x8x2_t, f16, 8)
#endif
#undef OP
......
......@@ -19,9 +19,12 @@
#include "src/arm_common/elemwise/opr_impl.h"
#include "src/arm_common/elemwise_multi_type/opr_impl.h"
#include "src/arm_common/local/opr_impl.h"
#include "src/arm_common/lstm/opr_impl.h"
#include "src/arm_common/lstm_cell/opr_impl.h"
#include "src/arm_common/pooling/opr_impl.h"
#include "src/arm_common/reduce/opr_impl.h"
#include "src/arm_common/resize/opr_impl.h"
#include "src/arm_common/rnn_cell/opr_impl.h"
#include "src/arm_common/separable_conv/opr_impl.h"
#include "src/arm_common/separable_filter/opr_impl.h"
#include "src/arm_common/type_cvt/opr_impl.h"
......@@ -50,6 +53,9 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(TypeCvt)
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Reduce)
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvBias)
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionBackwardData)
MEGDNN_SPECIALIZE_CREATE_OPERATOR(RNNCell)
MEGDNN_SPECIALIZE_CREATE_OPERATOR(LSTMCell)
MEGDNN_SPECIALIZE_CREATE_OPERATOR(LSTM)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wpragmas"
......
/**
* \file dnn/src/arm_common/lstm/lstm_utils.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "./lstm_utils.h"
#include "src/arm_common/lstm/opr_impl.h"
#include "src/arm_common/lstm_cell/cell_kernel.h"
#include "src/arm_common/lstm_cell/opr_impl.h"
#include "src/naive/handle.h"
using namespace megdnn;
using namespace arm_common;
LstmCellWeight::LstmCellWeight(
RefPtr weight_ptr, size_t hidden_size, size_t input_size, bool has_bias,
DType dtype) {
// weight_ih: [gate_hidden_size, input_size]
// weight_hh: [gate_hidden_size, hidden_size]
// bias_ih: [gate_hidden_size]
// bias_hh: [gate_hidden_size]
size_t gate_hidden_size = 4 * hidden_size;
TensorLayout weight_ih_layout{{gate_hidden_size, input_size}, dtype};
TensorLayout weight_hh_layout{{gate_hidden_size, hidden_size}, dtype};
TensorLayout bias_layout{{gate_hidden_size}, dtype};
m_weight_size = 0;
m_weight_ih = TensorND(weight_ih_layout, weight_ptr);
m_weight_size += weight_ih_layout.span().dist_byte();
weight_ptr += weight_ih_layout.span().dist_byte();
m_weight_hh = TensorND(weight_hh_layout, weight_ptr);
m_weight_size += weight_hh_layout.span().dist_byte();
weight_ptr += weight_hh_layout.span().dist_byte();
if (has_bias) {
m_bias_ih = TensorND(bias_layout, weight_ptr);
m_weight_size += bias_layout.span().dist_byte();
weight_ptr += bias_layout.span().dist_byte();
m_bias_hh = TensorND(bias_layout, weight_ptr);
m_weight_size += bias_layout.span().dist_byte();
}
}
LstmStates::LstmStates(
const SmallVector<RefPtr> ptr, size_t hidden_size, size_t batch_size,
DType dtype) {
auto& h_ptr = ptr[0];
auto& c_ptr = ptr[1];
TensorLayout layout{{batch_size, hidden_size}, dtype};
m_h = TensorND(layout, h_ptr);
m_c = TensorND(layout, c_ptr);
m_memory_size = layout.span().dist_byte();
}
TensorNDArray megdnn::arm_common::split_tensor(
_megdnn_tensor_in tensor, size_t nr_tensor, const TensorLayout& layout) {
megdnn_assert(
tensor.layout.span().dist_byte() == nr_tensor * layout.span().dist_byte());
TensorNDArray tensors;
auto ptr = tensor.get_ref_ptr();
for (size_t i = 0; i < nr_tensor; i++) {
tensors.push_back(TensorND(layout, ptr));
ptr += layout.span().dist_byte();
}
return tensors;
}
namespace megdnn {
namespace arm_common {
template <>
void cell_opr_compute<LSTMCell, LstmStates>(
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih,
_megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_ih,
_megdnn_tensor_in bias_hh, const LstmStates& state_in, LstmStates& state_out,
Workspace cell_workspace, Handle* handle) {
auto opr = handle->create_operator<LSTMCellForward>();
TensorLayout gates, h_new, c_new;
opr->deduce_layout(
input.layout, weight_ih.layout, bias_ih.layout, state_in.m_h.layout,
weight_hh.layout, bias_hh.layout, state_in.m_c.layout, h_new, c_new, gates);
auto workspace_bundle = LstmCellCompute::get_workspace_bundle(
input.layout, weight_ih.layout, bias_ih.layout, state_in.m_h.layout,
weight_hh.layout, bias_hh.layout, state_in.m_c.layout, h_new, c_new, gates);
workspace_bundle.set(cell_workspace.raw_ptr);
TensorND gates_tensor{workspace_bundle.get(0), gates};
_megdnn_workspace new_workspace = {
static_cast<dt_byte*>(workspace_bundle.get(1)),
workspace_bundle.get_size(1)};
LstmCellCompute::run(
input, weight_ih, bias_ih, state_in.m_h, weight_hh, bias_hh, state_in.m_c,
state_out.m_h, state_out.m_c, gates_tensor, new_workspace, handle);
}
} // namespace arm_common
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/arm_common/lstm/lstm_utils.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/arm_common/lstm_cell/cell_kernel.h"
#include "src/common/opr_delegate.h"
#include "src/common/utils.h"
#include "src/naive/handle.h"
#include "src/naive/lstm/opr_impl.h"
namespace megdnn {
namespace arm_common {
template <class CellOp, class States>
void cell_opr_compute(
_megdnn_tensor_in step_input, _megdnn_tensor_in weight_ih,
_megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_ih,
_megdnn_tensor_in bias_hh, const States& state_in, States& state_out,
Workspace cell_workspace, Handle* handle);
struct LstmCellWeight {
size_t m_weight_size = 0;
TensorND m_weight_ih, m_weight_hh, m_bias_ih, m_bias_hh;
// if no bias, will create dummy bias tensor from workspace
LstmCellWeight(
RefPtr weight_ptr, size_t hidden_size, size_t input_size, bool has_bias,
DType dtype);
};
struct LstmStates {
static size_t nr_states() { return 2; }
size_t m_memory_size;
TensorND m_h, m_c;
LstmStates(
const SmallVector<RefPtr> ptr, size_t hidden_size, size_t batch_size,
DType dtype);
};
TensorNDArray split_tensor(
_megdnn_tensor_in tensor, size_t nr_tensor, const TensorLayout& layout);
template <class CellWeight>
SmallVector<CellWeight> get_all_cells(
size_t dir_size, size_t num_layers, size_t input_size, size_t hidden_size,
bool bias, _megdnn_tensor_in flatten_weights) {
SmallVector<CellWeight> cell_weights;
cell_weights.reserve(dir_size * num_layers);
auto weight_ptr = flatten_weights.get_ref_ptr();
for (size_t layer = 0; layer < num_layers; ++layer) {
for (size_t d = 0; d < dir_size; ++d) {
size_t cell_input_size = layer == 0 ? input_size : dir_size * hidden_size;
CellWeight cell_weight(
weight_ptr, hidden_size, cell_input_size, bias,
flatten_weights.layout.dtype);
weight_ptr += cell_weight.m_weight_size;
cell_weights.push_back(cell_weight);
}
}
return cell_weights;
}
template <class States>
SmallVector<States> get_all_status(
_megdnn_tensor_in hx, _megdnn_tensor_in cx, size_t hidden_size,
size_t batch_size, size_t num_layers, size_t dir_size, DType dtype) {
SmallVector<States> states;
auto hx_ptr = hx.get_ref_ptr();
auto cx_ptr = cx.get_ref_ptr();
for (size_t layer = 0; layer < num_layers * dir_size; ++layer) {
States state({hx_ptr, cx_ptr}, hidden_size, batch_size, dtype);
hx_ptr += state.m_memory_size;
cx_ptr += state.m_memory_size;
states.push_back(state);
}
return states;
}
template <class Cell, typename CellOpr, class States>
void exec_kernel(
SmallVector<Cell>& cells, const TensorNDArray& inputs,
const SmallVector<States>& states_in, SmallVector<States>& states_out,
TensorNDArray& outputs, size_t num_layers, size_t dir_size, Handle* handle,
WorkspaceBundle workspace_bundle) {
megdnn_assert(cells.size() == num_layers * dir_size);
megdnn_assert(
states_in.size() == states_out.size() &&
states_in.size() == num_layers * dir_size);
megdnn_assert(outputs.size() == inputs.size());
//! two tmp state workspace
megdnn_assert(workspace_bundle.nr_workspace() == 4 + States::nr_states());
size_t seq_len = inputs.size();
size_t batch_size = inputs[0].layout.shape[0];
size_t input_size = inputs[0].layout.shape[1];
size_t hidden_size = cells[0].m_weight_hh.layout.shape[1];
TensorLayout batch_output_layout{
{hidden_size}, outputs[0].layout.dtype}; // output hy
TensorLayout cell_output_layout{
{batch_size, hidden_size}, outputs[0].layout.dtype}; // output hy
TensorLayout seq_output_layout{
{batch_size, dir_size * hidden_size}, outputs[0].layout.dtype};
TensorLayout cell_first_input_layout{
{batch_size, input_size}, inputs[0].layout.dtype}; // input
TensorLayout cell_input_layout{
{batch_size, dir_size * hidden_size}, inputs[0].layout.dtype};
TensorLayout tmp_output_layout{
{seq_len, batch_size, dir_size * hidden_size}, outputs[0].layout.dtype};
//! workspace get
Workspace cell_workspace(
static_cast<dt_byte*>(workspace_bundle.get(0)),
workspace_bundle.get_size(0) + workspace_bundle.get_size(1));
auto&& tmp_inputs_1 = split_tensor(
TensorND{workspace_bundle.get(2), tmp_output_layout}, seq_len,
cell_input_layout);
auto&& tmp_outputs_1 = split_tensor(
TensorND{workspace_bundle.get(2), tmp_output_layout}, seq_len,
seq_output_layout);
auto&& tmp_inputs_2 = split_tensor(
TensorND{workspace_bundle.get(3), tmp_output_layout}, seq_len,
cell_input_layout);
auto&& tmp_outputs_2 = split_tensor(
TensorND{workspace_bundle.get(3), tmp_output_layout}, seq_len,
seq_output_layout);
using IoPair = std::pair<TensorNDArray, TensorNDArray>;
IoPair io_pair1 = {tmp_inputs_1, tmp_outputs_2};
IoPair io_pair2 = {tmp_inputs_2, tmp_outputs_1};
SmallVector<IoPair> io_pairs = {io_pair1, io_pair2};
SmallVector<RefPtr> ptr;
for (size_t index = 0; index < States::nr_states(); index++) {
ptr.push_back(workspace_bundle.get(4 + index));
}
auto&& tmp_state = States(ptr, hidden_size, batch_size, outputs[0].layout.dtype);
for (size_t layer = 0; layer < num_layers; layer++) {
auto layer_inputs = io_pairs[layer % 2].first;
auto layer_outputs = io_pairs[layer % 2].second;
//! if last layer, direct write to output tensors
if (num_layers - 1 == layer) {
layer_outputs = outputs;
}
if (0 == layer) {
layer_inputs = inputs;
}
for (size_t d = 0; d < dir_size; ++d) {
size_t cell_idx = layer * dir_size + d;
auto& cell = cells[cell_idx];
auto& state_in_origin = states_in[cell_idx];
auto& state_out_origin = states_out[cell_idx];
auto state_in = state_in_origin;
auto state_out = tmp_state;
for (size_t i = 0; i < seq_len; ++i) {
size_t step = d == 0 ? i : seq_len - 1 - i;
auto& step_input = layer_inputs[step];
auto& step_output = layer_outputs[step];
if (i == seq_len - 1) {
state_out = state_out_origin;
}
//! task 1
//! this CellOp will dispatch task inner, so here not dispatch task
cell_opr_compute<CellOpr, LstmStates>(
step_input, cell.m_weight_ih, cell.m_weight_hh, cell.m_bias_ih,
cell.m_bias_hh, state_in, state_out, cell_workspace, handle);
//! task 2
//! copy output to continue space
auto copy_to_output = [=]() {
//! if dir_size >1 and batch_size > 1, recorder to output
size_t stride = batch_output_layout.span().dist_byte();
if (dir_size > 1 && batch_size > 1) {
int8_t* source = static_cast<int8_t*>(state_out.m_h.raw_ptr());
int8_t* dst = static_cast<int8_t*>(step_output.raw_ptr()) +
d * stride;
for (size_t b = 0; b < batch_size; b++) {
memcpy(dst, source, stride);
source += stride;
dst += dir_size * stride;
}
} else {
void* source = state_out.m_h.raw_ptr();
int8_t* dst = static_cast<int8_t*>(step_output.raw_ptr()) +
d * stride;
memcpy(dst, source, state_out.m_h.layout.span().dist_byte());
}
};
MEGDNN_DISPATCH_CPU_KERN(
static_cast<naive::HandleImpl*>(handle), copy_to_output());
//! state_in and state_out are read and write inplace
if (0 == i) {
state_in = tmp_state;
}
}
}
}
}
template <typename CellOpr>
WorkspaceBundle get_workspace_bundle(
const TensorLayout& input, const TensorLayout& output,
const TensorLayout& flatten_weights, size_t hidden_size, size_t dir_size,
size_t states_size) {
size_t batch_size = input.shape[1];
size_t input_size = input.shape[2];
size_t gate_hidden_size = flatten_weights.shape[0];
// cell workspace
TensorLayout weight_ih{{gate_hidden_size, input_size}, flatten_weights.dtype};
TensorLayout weight_hh{
{gate_hidden_size, dir_size * hidden_size}, flatten_weights.dtype};
TensorLayout bias{{1, gate_hidden_size}, flatten_weights.dtype};
TensorLayout hx{{batch_size, dir_size * hidden_size}, input.dtype};
auto cell_opr = inplace_cpu_handle()->create_operator<CellOpr>();
TensorLayout h_new, c_new, gates;
cell_opr->deduce_layout(
input, weight_ih, bias, hx, weight_hh, bias, hx, h_new, c_new, gates);
SmallVector<size_t> workspaces;
//! the cell opr compute workspace
size_t cell_opr_workspace = cell_opr->get_workspace_in_bytes(
input, weight_ih, bias, hx, weight_hh, bias, hx, h_new, c_new, gates);
workspaces.push_back(gates.span().dist_byte());
workspaces.push_back(cell_opr_workspace);
//! double tmp output memory
size_t tmp_output_workspace = output.span().dist_byte();
workspaces.push_back(tmp_output_workspace);
workspaces.push_back(tmp_output_workspace);
//! tmp states memory
size_t tmp_state_workspace = hx.span().dist_byte();
for (size_t i = 0; i < states_size; i++) {
workspaces.push_back(tmp_state_workspace);
}
return {nullptr, workspaces};
}
} // namespace arm_common
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/arm_common/lstm/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/arm_common/lstm/opr_impl.h"
#include "./lstm_utils.h"
#include "src/arm_common/lstm_cell/opr_impl.h"
#include "src/naive/handle.h"
#include "midout.h"
MIDOUT_DECL(megdnn_arm_common_lstm)
using namespace megdnn;
using namespace arm_common;
void LSTMImpl::exec(
_megdnn_tensor_in input, _megdnn_tensor_in hx, _megdnn_tensor_in cx,
_megdnn_tensor_in flatten_weights, _megdnn_tensor_out output,
_megdnn_tensor_out hy, _megdnn_tensor_out cy, _megdnn_tensor_out,
_megdnn_workspace workspace) {
MIDOUT_BEGIN(megdnn_arm_common_lstm, midout_iv(0)) {
size_t dir_size = param().bidirectional ? 2 : 1;
size_t num_layers = param().num_layers;
size_t hidden_size = param().hidden_size;
size_t seq_len = input.layout.shape[0];
size_t batch_size = input.layout.shape[1];
size_t input_size = input.layout.shape[2];
//! in order to support input ptr change in record, so this task should be
//! dispatch to device
auto&& cell_weights = get_all_cells<LstmCellWeight>(
dir_size, num_layers, input_size, hidden_size, param().bias,
flatten_weights);
auto&& cell_states_in = get_all_status<LstmStates>(
hx, cx, hidden_size, batch_size, num_layers, dir_size, hx.layout.dtype);
auto&& cell_states_out = get_all_status<LstmStates>(
hy, cy, hidden_size, batch_size, num_layers, dir_size, hy.layout.dtype);
auto&& inputs = split_tensor(
input, seq_len,
TensorLayout{{batch_size, input_size}, input.layout.dtype});
auto&& outputs = split_tensor(
output, seq_len,
TensorLayout{
{batch_size, dir_size * hidden_size}, output.layout.dtype});
auto workspace_bundle = get_workspace_bundle<LSTMCell>(
input.layout, output.layout, flatten_weights.layout, hidden_size,
dir_size, LstmStates::nr_states());
workspace_bundle.set(workspace.raw_ptr);
exec_kernel<LstmCellWeight, LSTMCell, LstmStates>(
cell_weights, inputs, cell_states_in, cell_states_out, outputs,
num_layers, dir_size, handle(), workspace_bundle);
}
MIDOUT_END();
}
size_t LSTMImpl::get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout&, const TensorLayout&,
const TensorLayout& flatten_weights, const TensorLayout& output,
const TensorLayout&, const TensorLayout&, const TensorLayout&) {
MIDOUT_BEGIN(megdnn_arm_common_lstm, midout_iv(1)) {
size_t dir_size = param().bidirectional ? 2 : 1;
size_t hidden_size = param().hidden_size;
auto bundle = get_workspace_bundle<LSTMCell>(
input, output, flatten_weights, hidden_size, dir_size,
LstmStates::nr_states());
return bundle.total_size_in_bytes();
}
MIDOUT_END();
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/arm_common/lstm/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/common/utils.h"
#include "src/naive/lstm/opr_impl.h"
namespace megdnn {
namespace arm_common {
class LSTMImpl : public naive::LSTMImpl {
public:
using naive::LSTMImpl::LSTMImpl;
void exec(
_megdnn_tensor_in input, _megdnn_tensor_in hx, _megdnn_tensor_in cx,
_megdnn_tensor_in flatten_weights, _megdnn_tensor_out output,
_megdnn_tensor_out hy, _megdnn_tensor_out cy,
_megdnn_tensor_out reserve_space, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx,
const TensorLayout& flatten_weights, const TensorLayout& output,
const TensorLayout& hy, const TensorLayout& cy,
const TensorLayout& reserve_space) override;
//! in arm_common only store the output tensor, other tensor is only
//! used in computing grad, so arm ignore them
size_t get_reserve_size_in_bytes(const TensorLayout&) override { return 1; }
};
} // namespace arm_common
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/arm_common/lstm_cell/cell_kernel.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "./cell_kernel.h"
#include "src/arm_common/lstm_cell/opr_impl.h"
#include "src/common/lstm_cell.h"
#include "src/common/opr_delegate.h"
#include "src/naive/handle.h"
#include "src/arm_common/elemwise_helper/kimpl/sigmoid.h"
#include "src/arm_common/elemwise_helper/kimpl/tanh.h"
using namespace megdnn;
using namespace arm_common;
namespace {
template <class Op, bool bias>
struct ElemwiseCompute {
static Op op;
static inline float32x4x2_t compute_8(
float* dst, float* tmp, float* ih, float* hh) {
float32x4_t dst0 = vld1q_f32(dst);
float32x4_t dst1 = vld1q_f32(dst + 4);
float32x4_t tmp0 = vld1q_f32(tmp);
float32x4_t tmp1 = vld1q_f32(tmp + 4);
auto mid0 = vaddq_f32(dst0, tmp0);
auto mid1 = vaddq_f32(dst1, tmp1);
float32x4_t out0, out1;
if (bias) {
float32x4_t ih0 = vld1q_f32(ih);
float32x4_t ih1 = vld1q_f32(ih + 4);
float32x4_t hh0 = vld1q_f32(hh);
float32x4_t hh1 = vld1q_f32(hh + 4);
auto midd0 = vaddq_f32(ih0, hh0);
auto midd1 = vaddq_f32(ih1, hh1);
out0 = vaddq_f32(mid0, midd0);
out1 = vaddq_f32(mid1, midd1);
} else {
out0 = mid0;
out1 = mid1;
}
return {{op(out0), op(out1)}};
}
static inline float32x4_t compute_4(float* dst, float* tmp, float* ih, float* hh) {
float32x4_t dst0 = vld1q_f32(dst);
float32x4_t tmp0 = vld1q_f32(tmp);
auto mid0 = vaddq_f32(dst0, tmp0);
float32x4_t out0;
if (bias) {
float32x4_t ih0 = vld1q_f32(ih);
float32x4_t hh0 = vld1q_f32(hh);
auto midd0 = vaddq_f32(ih0, hh0);
out0 = vaddq_f32(mid0, midd0);
} else {
out0 = mid0;
}
return op(out0);
}
static inline float compute_1(float* dst, float* tmp, float* ih, float* hh) {
float out;
if (bias) {
out = dst[0] + tmp[0] + ih[0] + hh[0];
} else {
out = dst[0] + tmp[0];
}
return op(out);
}
};
template <class Op, bool bias>
Op ElemwiseCompute<Op, bias>::op = Op();
template <bool bias>
void rnn_cell_elemwise_compute(
_megdnn_tensor_out dst, _megdnn_tensor_in tmp, _megdnn_tensor_in bias_ih,
_megdnn_tensor_in bias_hh, _megdnn_tensor_in cx, _megdnn_tensor_out h_new,
_megdnn_tensor_out c_new) {
size_t batch = dst.layout[0];
size_t batch_length = dst.layout.total_nr_elems() / batch;
size_t base_length = batch_length / 4;
float *ih_ptr_ = nullptr, *hh_ptr_ = nullptr;
float* dst_ptr_ = dst.ptr<float>();
float* tmp_ptr_ = tmp.ptr<float>();
if (bias) {
ih_ptr_ = bias_ih.ptr<float>();
hh_ptr_ = bias_hh.ptr<float>();
}
float* cx_ptr_ = cx.ptr<float>();
float* h_new_ptr_ = h_new.ptr<float>();
float* c_new_ptr_ = c_new.ptr<float>();
ElemwiseCompute<SigmoidOp<dt_float32>, bias> sigmoid_compute;
ElemwiseCompute<TanhOp<dt_float32>, bias> tanh_compute;
TanhOp<dt_float32> tanh_op;
for (size_t b = 0; b < batch; b++) {
float* dst_ptr = dst_ptr_ + b * batch_length;
float* tmp_ptr = tmp_ptr_ + b * batch_length;
float* ih_ptr = ih_ptr_;
float* hh_ptr = hh_ptr_;
float* cx_ptr = cx_ptr_ + b * base_length;
float* h_new_ptr = h_new_ptr_ + b * base_length;
float* c_new_ptr = c_new_ptr_ + b * base_length;
size_t index = 0;
for (; index + 7 < base_length; index += 8) {
auto out_i = sigmoid_compute.compute_8(dst_ptr, tmp_ptr, ih_ptr, hh_ptr);
auto out_f = sigmoid_compute.compute_8(
dst_ptr + base_length, tmp_ptr + base_length, ih_ptr + base_length,
hh_ptr + base_length);
auto out_g = tanh_compute.compute_8(
dst_ptr + 2 * base_length, tmp_ptr + 2 * base_length,
ih_ptr + 2 * base_length, hh_ptr + 2 * base_length);
auto out_o = sigmoid_compute.compute_8(
dst_ptr + 3 * base_length, tmp_ptr + 3 * base_length,
ih_ptr + 3 * base_length, hh_ptr + 3 * base_length);
float32x4_t cx_0 = vld1q_f32(cx_ptr);
float32x4_t cx_1 = vld1q_f32(cx_ptr + 4);
//! f * cx + i * g
auto c_new_0 = vaddq_f32(
vmulq_f32(out_f.val[0], cx_0),
vmulq_f32(out_i.val[0], out_g.val[0]));
auto c_new_1 = vaddq_f32(
vmulq_f32(out_f.val[1], cx_1),
vmulq_f32(out_i.val[1], out_g.val[1]));
vst1q_f32(c_new_ptr, c_new_0);
vst1q_f32(c_new_ptr + 4, c_new_1);
auto h_new_0 = vmulq_f32(tanh_op(c_new_0), out_o.val[0]);
auto h_new_1 = vmulq_f32(tanh_op(c_new_1), out_o.val[1]);
vst1q_f32(h_new_ptr, h_new_0);
vst1q_f32(h_new_ptr + 4, h_new_1);
dst_ptr += 8;
tmp_ptr += 8;
ih_ptr += 8;
hh_ptr += 8;
cx_ptr += 8;
c_new_ptr += 8;
h_new_ptr += 8;
}
for (; index + 3 < base_length; index += 4) {
auto out_i = sigmoid_compute.compute_4(dst_ptr, tmp_ptr, ih_ptr, hh_ptr);
auto out_f = sigmoid_compute.compute_4(
dst_ptr + base_length, tmp_ptr + base_length, ih_ptr + base_length,
hh_ptr + base_length);
auto out_g = tanh_compute.compute_4(
dst_ptr + 2 * base_length, tmp_ptr + 2 * base_length,
ih_ptr + 2 * base_length, hh_ptr + 2 * base_length);
auto out_o = sigmoid_compute.compute_4(
dst_ptr + 3 * base_length, tmp_ptr + 3 * base_length,
ih_ptr + 3 * base_length, hh_ptr + 3 * base_length);
float32x4_t cx_v = vld1q_f32(cx_ptr);
//! f * cx + i * g
auto c_new = vaddq_f32(vmulq_f32(out_f, cx_v), vmulq_f32(out_i, out_g));
vst1q_f32(c_new_ptr, c_new);
auto h_new = vmulq_f32(tanh_op(c_new), out_o);
vst1q_f32(h_new_ptr, h_new);
dst_ptr += 4;
tmp_ptr += 4;
ih_ptr += 4;
hh_ptr += 4;
cx_ptr += 4;
c_new_ptr += 4;
h_new_ptr += 4;
}
for (; index < base_length; index++) {
auto out_i = sigmoid_compute.compute_1(dst_ptr, tmp_ptr, ih_ptr, hh_ptr);
auto out_f = sigmoid_compute.compute_1(
dst_ptr + base_length, tmp_ptr + base_length, ih_ptr + base_length,
hh_ptr + base_length);
auto out_g = tanh_compute.compute_1(
dst_ptr + 2 * base_length, tmp_ptr + 2 * base_length,
ih_ptr + 2 * base_length, hh_ptr + 2 * base_length);
auto out_o = sigmoid_compute.compute_1(
dst_ptr + 3 * base_length, tmp_ptr + 3 * base_length,
ih_ptr + 3 * base_length, hh_ptr + 3 * base_length);
c_new_ptr[0] = out_f * cx_ptr[0] + out_i * out_g;
h_new_ptr[0] = tanh_op(c_new_ptr[0]) * out_o;
dst_ptr += 1;
tmp_ptr += 1;
ih_ptr += 1;
hh_ptr += 1;
cx_ptr += 1;
c_new_ptr += 1;
h_new_ptr += 1;
}
}
}
} // namespace
void LstmCellCompute::run(
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih,
_megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh,
_megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new,
_megdnn_tensor_out gates, _megdnn_workspace workspace, Handle* handle) {
auto bundle = get_workspace_bundle(
input.layout, weight_ih.layout, bias_ih.layout, hx.layout, weight_hh.layout,
bias_hh.layout, cx.layout, h_new.layout, c_new.layout, gates.layout);
bundle.set(workspace.raw_ptr);
TensorND tmp{static_cast<void*>(bundle.get(0)), gates.layout};
auto matmul_workspace =
megdnn::Workspace{static_cast<dt_byte*>(bundle.get(1)), bundle.get_size(1)};
auto opr = handle->create_operator<MatrixMul>();
opr->param().transposeB = true;
//! the opr will dispatch compute task to device, so record mode
//! performance will not be effect
opr->exec(input, weight_ih, tmp, matmul_workspace);
opr->exec(hx, weight_hh, gates, matmul_workspace);
//! the optimized post compute, nonlinear(tmp + dst + bias_hx + bias_cx)
if (bias_ih.layout.ndim != 0 && bias_ih.layout.ndim != 0) {
MEGDNN_DISPATCH_CPU_KERN(
static_cast<naive::HandleImpl*>(handle),
rnn_cell_elemwise_compute<true>(
gates, tmp, bias_ih, bias_hh, cx, h_new, c_new));
} else {
megdnn_assert(bias_ih.layout.ndim == 0 && bias_ih.layout.ndim == 0);
MEGDNN_DISPATCH_CPU_KERN(
static_cast<naive::HandleImpl*>(handle),
rnn_cell_elemwise_compute<false>(
gates, tmp, bias_ih, bias_hh, cx, h_new, c_new));
}
}
WorkspaceBundle LstmCellCompute::get_workspace_bundle(
const TensorLayout& input, const TensorLayout& weight_ih, const TensorLayout&,
const TensorLayout& hx, const TensorLayout& weight_hh, const TensorLayout&,
const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout& gates) {
auto opr = inplace_cpu_handle()->create_operator<MatrixMul>();
opr->param().transposeB = true;
size_t matmul_workspace = std::max(
opr->get_workspace_in_bytes(input, weight_ih, gates),
opr->get_workspace_in_bytes(hx, weight_hh, gates));
return WorkspaceBundle{nullptr, {gates.span().dist_byte(), matmul_workspace}};
}
bool LstmCellCompute::is_optimized(
const TensorLayout& input, const TensorLayout&, const TensorLayout& bias_ih,
const TensorLayout&, const TensorLayout&, const TensorLayout& bias_hh,
const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout& gates) {
if (input.dtype.enumv() == DTypeEnum::Float32 && gates[1] == bias_ih[1] &&
bias_ih[0] == 1 && bias_ih.eq_layout(bias_hh)) {
return true;
} else {
return false;
}
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/arm_common/lstm_cell/cell_kernel.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/common/utils.h"
#include "src/naive/handle.h"
#include "src/naive/lstm_cell/opr_impl.h"
namespace megdnn {
namespace arm_common {
struct LstmCellCompute {
static WorkspaceBundle get_workspace_bundle(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& bias_ih, const TensorLayout& hx,
const TensorLayout& weight_hh, const TensorLayout& bias_hh,
const TensorLayout& cx, const TensorLayout& h_new,
const TensorLayout& c_new, const TensorLayout& gates);
static bool is_optimized(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& bias_ih, const TensorLayout& hx,
const TensorLayout& weight_hh, const TensorLayout& bias_hh,
const TensorLayout& cx, const TensorLayout& h_new,
const TensorLayout& c_new, const TensorLayout& gates);
static void run(
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih,
_megdnn_tensor_in bias_ih, _megdnn_tensor_in hx,
_megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh,
_megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new,
_megdnn_tensor_out gates, _megdnn_workspace workspace, Handle* handle);
};
} // namespace arm_common
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/arm_common/lstm_cell/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/arm_common/lstm_cell/opr_impl.h"
#include "src/common/lstm_cell.h"
#include "src/naive/handle.h"
#include "./cell_kernel.h"
#include "midout.h"
MIDOUT_DECL(megdnn_arm_common_lstm_cell)
using namespace megdnn;
using namespace arm_common;
void LSTMCellImpl::exec(
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih,
_megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh,
_megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new,
_megdnn_tensor_out gates, _megdnn_workspace workspace) {
//! only float32 and {1, xx} shape bias will be optimized
MIDOUT_BEGIN(megdnn_arm_common_lstm_cell, midout_iv(0)) {
if (!LstmCellCompute::is_optimized(
input.layout, weight_ih.layout, bias_ih.layout, hx.layout,
weight_hh.layout, bias_hh.layout, cx.layout, h_new.layout,
c_new.layout, gates.layout)) {
naive::LSTMCellImpl::exec(
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new,
gates, workspace);
} else {
LstmCellCompute::run(
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new,
gates, workspace, handle());
}
}
MIDOUT_END();
}
size_t LSTMCellImpl::get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& bias_ih, const TensorLayout& hx,
const TensorLayout& weight_hh, const TensorLayout& bias_hh,
const TensorLayout& cx, const TensorLayout& h_new, const TensorLayout& c_new,
const TensorLayout& gates) {
MIDOUT_BEGIN(megdnn_arm_common_lstm_cell, midout_iv(1)) {
if (!LstmCellCompute::is_optimized(
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new,
gates)) {
return naive::LSTMCellImpl::get_workspace_in_bytes(
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new,
gates);
} else {
return LstmCellCompute::get_workspace_bundle(
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new,
c_new, gates)
.total_size_in_bytes();
}
}
MIDOUT_END();
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/arm_common/lstm_cell/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/common/utils.h"
#include "src/naive/lstm_cell/opr_impl.h"
namespace megdnn {
namespace arm_common {
class LSTMCellImpl : public naive::LSTMCellImpl {
public:
using naive::LSTMCellImpl::LSTMCellImpl;
void exec(
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih,
_megdnn_tensor_in bias_ih, _megdnn_tensor_in hx,
_megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh,
_megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new,
_megdnn_tensor_out gates, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& bias_ih, const TensorLayout& hx,
const TensorLayout& weight_hh, const TensorLayout& bias_hh,
const TensorLayout& cx, const TensorLayout& h_new,
const TensorLayout& c_new, const TensorLayout& gates) override;
};
} // namespace arm_common
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/arm_common/rnn_cell/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/arm_common/rnn_cell/opr_impl.h"
#include "src/common/utils.h"
#include "src/naive/handle.h"
#include "src/arm_common/elemwise_helper/kimpl/none.h"
#include "src/arm_common/elemwise_helper/kimpl/relu.h"
#include "src/arm_common/elemwise_helper/kimpl/tanh.h"
#include "midout.h"
MIDOUT_DECL(megdnn_arm_common_rnn_cell)
using namespace megdnn;
using namespace arm_common;
namespace {
ElemwiseForward* get_elemwise_opr() {
static CpuOprDelegationStorage<1> storage;
return storage.get<ElemwiseForward>();
}
template <typename Op>
void elemwise_compute(
float* dst_ptr, float* tmp_ptr, float* ih_ptr, float* hh_ptr, size_t batch,
size_t length) {
const constexpr size_t SIMD_8 = 8;
const constexpr size_t SIMD_4 = 4;
Op op;
for (size_t b = 0; b < batch; b++) {
float* dst = dst_ptr + b * length;
float* tmp = tmp_ptr + b * length;
float* ih = ih_ptr;
float* hh = hh_ptr;
size_t index = 0;
for (; index + SIMD_8 - 1 < length; index += SIMD_8) {
float32x4_t dst0 = vld1q_f32(dst);
float32x4_t dst1 = vld1q_f32(dst + 4);
float32x4_t tmp0 = vld1q_f32(tmp);
float32x4_t tmp1 = vld1q_f32(tmp + 4);
float32x4_t ih0 = vld1q_f32(ih);
float32x4_t ih1 = vld1q_f32(ih + 4);
float32x4_t hh0 = vld1q_f32(hh);
float32x4_t hh1 = vld1q_f32(hh + 4);
auto mid0 = vaddq_f32(dst0, tmp0);
auto mid1 = vaddq_f32(dst1, tmp1);
auto midd0 = vaddq_f32(ih0, hh0);
auto midd1 = vaddq_f32(ih1, hh1);
auto out0 = vaddq_f32(mid0, midd0);
auto out1 = vaddq_f32(mid1, midd1);
vst1q_f32(dst, op(out0));
vst1q_f32(dst + 4, op(out1));
dst += SIMD_8;
tmp += SIMD_8;
ih += SIMD_8;
hh += SIMD_8;
}
for (; index + SIMD_4 - 1 < length; index += SIMD_4) {
float32x4_t dst0 = vld1q_f32(dst);
float32x4_t tmp0 = vld1q_f32(tmp);
float32x4_t ih0 = vld1q_f32(ih);
float32x4_t hh0 = vld1q_f32(hh);
auto mid0 = vaddq_f32(dst0, tmp0);
auto midd0 = vaddq_f32(ih0, hh0);
auto out0 = vaddq_f32(mid0, midd0);
vst1q_f32(dst, op(out0));
dst += SIMD_4;
tmp += SIMD_4;
ih += SIMD_4;
hh += SIMD_4;
}
for (; index < length; index++) {
auto out = dst[0] + tmp[0] + ih[0] + hh[0];
dst[0] = op(out);
dst++;
tmp++;
ih++;
hh++;
}
}
}
void rnn_cell_post_compute(
_megdnn_tensor_out dst, _megdnn_tensor_in tmp, _megdnn_tensor_in bias_ih,
_megdnn_tensor_in bias_hh, param::RNNCell::NonlineMode nonline_mode,
Handle* handle) {
using NonlineMode = param::RNNCell::NonlineMode;
megdnn_assert(
nonline_mode == NonlineMode::RELU || nonline_mode == NonlineMode::TANH ||
nonline_mode == NonlineMode::IDENTITY,
"Now arm only support nonlinear mode Relu, TANH, IDENTITY.");
if (dst.layout.dtype.enumv() == DTypeEnum::Float32 &&
dst.layout[1] == bias_ih.layout[1] && bias_ih.layout[0] == 1 &&
bias_ih.layout.eq_layout(bias_hh.layout)) {
auto run = [=]() {
size_t batch = dst.layout[0];
size_t length = bias_ih.layout.total_nr_elems();
float* dst_ptr = dst.ptr<float>();
float* tmp_ptr = tmp.ptr<float>();
float* ih_ptr = bias_ih.ptr<float>();
float* hh_ptr = bias_hh.ptr<float>();
if (nonline_mode == NonlineMode::RELU) {
elemwise_compute<ReluOp<dt_float32>>(
dst_ptr, tmp_ptr, ih_ptr, hh_ptr, batch, length);
} else if (nonline_mode == NonlineMode::TANH) {
elemwise_compute<TanhOp<dt_float32>>(
dst_ptr, tmp_ptr, ih_ptr, hh_ptr, batch, length);
} else {
elemwise_compute<NoneOp<dt_float32>>(
dst_ptr, tmp_ptr, ih_ptr, hh_ptr, batch, length);
}
};
MEGDNN_DISPATCH_CPU_KERN(static_cast<naive::HandleImpl*>(handle), run());
} else {
//! this opr must be created by inplace handle
auto elem_opr = get_elemwise_opr();
auto run = [=]() {
elem_opr->param().mode = Elemwise::Param::Mode::ADD;
elem_opr->exec({dst, tmp}, dst);
elem_opr->exec({dst, bias_ih}, dst);
elem_opr->exec({dst, bias_hh}, dst);
// activation
switch (nonline_mode) {
#define cb(_mode) \
case NonlineMode::_mode: { \
elem_opr->param().mode = Elemwise::Param::Mode::_mode; \
elem_opr->exec({dst}, dst); \
break; \
}
cb(RELU);
cb(TANH);
#undef cb
case NonlineMode::IDENTITY:
break;
default:
megdnn_throw("unsupport nonlinear mode.");
}
};
MEGDNN_DISPATCH_CPU_KERN(static_cast<naive::HandleImpl*>(handle), run());
}
}
} // namespace
WorkspaceBundle RNNCellImpl::get_workspace_bundle(
const TensorLayout& input, const TensorLayout& weight_ih, const TensorLayout&,
const TensorLayout& hx, const TensorLayout& weight_hh, const TensorLayout&,
const TensorLayout& dst) {
MIDOUT_BEGIN(megdnn_arm_common_rnn_cell, midout_iv(0)) {
auto opr = handle()->create_operator<MatrixMulForward>();
opr->param().transposeB = true;
auto matmul_workspace = std::max(
opr->get_workspace_in_bytes(input, weight_ih, dst),
opr->get_workspace_in_bytes(hx, weight_hh, dst));
auto tmp_workspace = dst.span().dist_byte();
return WorkspaceBundle{nullptr, {tmp_workspace, matmul_workspace}};
}
MIDOUT_END();
}
size_t RNNCellImpl::get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& bias_ih, const TensorLayout& hx,
const TensorLayout& weight_hh, const TensorLayout& bias_hh,
const TensorLayout& dst) {
return get_workspace_bundle(input, weight_ih, bias_ih, hx, weight_hh, bias_hh, dst)
.total_size_in_bytes();
}
void RNNCellImpl::exec(
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih,
_megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh,
_megdnn_tensor_out dst, _megdnn_workspace workspace) {
MIDOUT_BEGIN(megdnn_arm_common_rnn_cell, midout_iv(1)) {
auto bundle = get_workspace_bundle(
input.layout, weight_ih.layout, bias_ih.layout, hx.layout,
weight_hh.layout, bias_hh.layout, dst.layout);
bundle.set(workspace.raw_ptr);
auto nonline_mode = param().nonlineMode;
TensorND tmp{static_cast<void*>(bundle.get(0)), dst.layout};
auto new_workspace =
Workspace{static_cast<dt_byte*>(bundle.get(1)), bundle.get_size(1)};
//! this opr can't be created by inplace handle
auto opr = handle()->create_operator<MatrixMulForward>();
opr->param().transposeB = true;
//! the opr will dispatch compute task to device, so record mode
//! performance will not be effect
opr->exec(input, weight_ih, tmp, new_workspace);
opr->exec(hx, weight_hh, dst, new_workspace);
//! the optimized post compute, nonlinear(tmp + dst + bias_hx + bias_cx)
rnn_cell_post_compute(dst, tmp, bias_ih, bias_hh, nonline_mode, handle());
}
MIDOUT_END();
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/arm_common/rnn_cell/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/common/opr_delegate.h"
#include "src/naive/rnn_cell/opr_impl.h"
namespace megdnn {
namespace arm_common {
class RNNCellImpl : public naive::RNNCellImpl {
public:
using naive::RNNCellImpl::RNNCellImpl;
void exec(
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih,
_megdnn_tensor_in bias_ih, _megdnn_tensor_in hx,
_megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh,
_megdnn_tensor_out dst, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& bias_ih, const TensorLayout& hx,
const TensorLayout& weight_hh, const TensorLayout& bias_hh,
const TensorLayout& dst) override;
private:
WorkspaceBundle get_workspace_bundle(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& bias_ih, const TensorLayout& hx,
const TensorLayout& weight_hh, const TensorLayout& bias_hh,
const TensorLayout& dst);
};
} // namespace arm_common
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -31,4 +31,6 @@ public:
};
} // namespace naive
} // namespace megdnn
\ No newline at end of file
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/test/arm_common/lstm.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "test/arm_common/fixture.h"
#include "megdnn/oprs.h"
#include "megdnn/oprs/general.h"
#include "test/common/benchmarker.h"
#include "test/common/checker.h"
#include "test/common/task_record_check.h"
using namespace megdnn;
using namespace test;
namespace {
//! in arm_common the reserve tensor is not used
void output_canonizer(const CheckerHelper::TensorValueArray& arr) {
const TensorND& reserve = arr.back();
TensorND& modif_reserve = const_cast<TensorND&>(reserve);
modif_reserve.layout = TensorLayout();
}
} // namespace
TEST_F(ARM_COMMON, LSTMCell) {
Checker<LSTMCell> checker(handle());
checker.set_output_canonizer(output_canonizer);
checker.exec(
{{1, 10},
{40, 10},
{1, 40},
{1, 10},
{40, 10},
{1, 40},
{1, 10},
{},
{},
{}});
for (size_t batch : {2})
for (size_t n : {3, 4, 5, 23, 100})
for (size_t out : {3, 6, 25, 100}) {
checker.exec(
{{batch, n},
{out * 4, n},
{1, out * 4},
{batch, out},
{out * 4, out},
{1, out * 4},
{batch, out},
{},
{},
{}});
checker.exec(
{{batch, n},
{out * 4, n},
{batch, out * 4},
{batch, out},
{out * 4, out},
{batch, out * 4},
{batch, out},
{},
{},
{}});
}
}
TEST_F(ARM_COMMON, LSTMCellRecord) {
TaskRecordChecker<LSTMCell> checker(0);
checker.exec(
{{1, 10},
{40, 10},
{1, 40},
{1, 10},
{40, 10},
{1, 40},
{1, 10},
{},
{},
{}});
}
namespace {
void test_lstm(bool bias, bool direction, Handle* handle) {
Checker<LSTM> checker(handle, true);
//! because lstm has tanh, exp mathematical compute, after more iteration,
//! the error will more than 1e-3
checker.set_epsilon(1e-2);
checker.set_output_canonizer(output_canonizer);
for (size_t input_size : {2, 8, 13})
for (size_t hidden_size : {1, 4, 17}) {
size_t dir_size = direction == false ? 1 : 2;
LSTM::Param param;
param.bidirectional = direction;
size_t gate_hidden_size = 4 * hidden_size;
param.bias = bias;
param.hidden_size = hidden_size;
for (size_t seq_len : {1, 3, 5})
for (size_t batch_size : {1, 2, 4})
for (size_t number_layer : {1, 2, 4, 5, 8}) {
size_t flatten_size = 0;
for (size_t layer = 0; layer < number_layer; layer++) {
for (size_t dir = 0; dir < dir_size; dir++) {
flatten_size += layer == 0
? input_size
: dir_size * hidden_size; // ih
flatten_size += hidden_size; // hh
}
}
if (bias) {
flatten_size += 2 * dir_size * number_layer;
}
param.num_layers = number_layer;
checker.set_param(param).exec(
{{seq_len, batch_size, input_size}, // input
{number_layer * dir_size, batch_size,
hidden_size}, // hx
{number_layer * dir_size, batch_size,
hidden_size}, // hy
{gate_hidden_size, flatten_size}, // flat weight
{},
{},
{},
{}});
}
}
}
} // namespace
TEST_F(ARM_COMMON, LSTM_FORWARD_NO_BIAS_NO_DIRCTION) {
test_lstm(false, false, handle());
}
TEST_F(ARM_COMMON, LSTM_FORWARD_BIAS_NO_DIRCTION) {
test_lstm(true, false, handle());
}
TEST_F(ARM_COMMON, LSTM_FORWARD_DIRECTION_NO_BIAS) {
test_lstm(false, true, handle());
}
TEST_F(ARM_COMMON, LSTM_FORWARD_DIRECTION_BIAS) {
test_lstm(true, true, handle());
}
TEST_F(ARM_COMMON, LSTM_FORWARD_RECORD) {
TaskRecordChecker<LSTM> checker(0);
size_t input_size = 2;
size_t hidden_size = 2;
size_t gate_hidden_size = 4 * hidden_size;
LSTM::Param param;
param.bidirectional = false;
param.bias = false;
param.hidden_size = hidden_size;
// checker.set_output_canonizer(output_canonizer);
for (size_t seq_len : {1, 3, 5})
for (size_t batch_size : {1, 2, 4})
for (size_t number_layer : {1, 2, 4, 5, 8}) {
param.num_layers = number_layer;
checker.set_param(param).exec(
{{seq_len, batch_size, input_size}, // input
{number_layer, batch_size, hidden_size}, // hx
{number_layer, batch_size, hidden_size}, // hy
{number_layer, gate_hidden_size,
input_size + hidden_size}, // flat weight
{},
{},
{},
{}});
}
}
#if MEGDNN_WITH_BENCHMARK
TEST_F(ARM_COMMON, BENCHMARK_LSTM_FORWARD) {
Benchmarker<LSTM> optimized_bench(handle());
constexpr size_t RUNS = 20;
auto run = [&](size_t hidden_size, size_t input_size) {
optimized_bench.set_times(20).set_display(true);
size_t gate_hidden_size = 4 * hidden_size;
for (bool direction : {false, true}) {
LSTM::Param param;
param.hidden_size = hidden_size;
param.bidirectional = direction;
param.bias = false;
size_t dir_size = direction == false ? 1 : 2;
for (size_t seq_len : {1, 5, 8})
for (size_t batch_size : {1, 8, 16})
for (size_t number_layer : {1}) {
param.num_layers = number_layer;
size_t flatten_size = 0;
for (size_t layer = 0; layer < number_layer; layer++) {
for (size_t dir = 0; dir < dir_size; dir++) {
flatten_size += layer == 0
? input_size
: dir_size * hidden_size; // ih
flatten_size += hidden_size; // hh
}
}
optimized_bench.set_param(param).exec(
{{seq_len, batch_size, input_size}, // input
{number_layer * dir_size, batch_size,
hidden_size}, // hx
{number_layer * dir_size, batch_size,
hidden_size}, // hy
{gate_hidden_size, flatten_size}, // flat weight
{},
{},
{},
{}});
}
}
};
run(512, 256);
}
#endif
// vim: syntax=cpp.doxygen
/**
* \file dnn/test/arm_common/rnn.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "test/arm_common/fixture.h"
#include "megdnn/oprs.h"
#include "test/common/benchmarker.h"
#include "test/common/checker.h"
#include "test/common/task_record_check.h"
using namespace megdnn;
using namespace test;
TEST_F(ARM_COMMON, RNNCell) {
Checker<RNNCell> checker(handle());
using NonlineMode = param::RNNCell::NonlineMode;
param::RNNCell param;
for (auto mode : {NonlineMode::IDENTITY, NonlineMode::RELU, NonlineMode::TANH})
for (size_t batch : {1, 4})
for (size_t n : {3, 4, 5, 23, 100})
for (size_t h : {5, 23, 100})
for (size_t out : {3, 6, 25, 100}) {
param.nonlineMode = mode;
checker.set_param(param);
checker.exec(
{{batch, n},
{out, n},
{1, out},
{batch, h},
{out, h},
{1, out},
{}});
checker.exec(
{{batch, n},
{out, n},
{batch, out},
{batch, h},
{out, h},
{batch, out},
{}});
}
}
TEST_F(ARM_COMMON, RNNCellRecord) {
TaskRecordChecker<RNNCell> checker(0);
using NonlineMode = param::RNNCell::NonlineMode;
param::RNNCell param;
for (auto mode : {NonlineMode::IDENTITY, NonlineMode::RELU, NonlineMode::TANH}) {
param.nonlineMode = mode;
checker.set_param(param);
checker.exec({{1, 100}, {10, 100}, {1, 10}, {1, 100}, {10, 100}, {1, 10}, {}});
checker.exec({{1, 34}, {15, 34}, {1, 15}, {1, 34}, {15, 34}, {1, 15}, {}});
checker.exec({{1, 73}, {25, 73}, {1, 25}, {1, 73}, {25, 73}, {1, 25}, {}});
}
}
#if MEGDNN_WITH_BENCHMARK
#endif
// vim: syntax=cpp.doxygen
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册