/** * \file dnn/src/common/lstm_cell.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "src/common/lstm_cell.h" #include "megdnn/oprs.h" #include "src/common/utils.h" namespace megdnn { void LSTMCell::deduce_layout( const TensorLayout& input, const TensorLayout& weight_ih, const TensorLayout& bias_ih, const TensorLayout& hx, const TensorLayout& weight_hh, const TensorLayout& bias_hh, const TensorLayout& cx, TensorLayout& h_new, TensorLayout& c_new, TensorLayout& gates) { h_new = TensorLayout(hx, hx.dtype); c_new = TensorLayout(cx, cx.dtype); auto opr = handle()->create_operator(); opr->param().nonlineMode = param::RNNCell::NonlineMode::IDENTITY; opr->deduce_layout(input, weight_ih, bias_ih, hx, weight_hh, bias_hh, gates); } void LSTMCell::check_exec( const TensorLayout& input, const TensorLayout& weight_ih, const TensorLayout& bias_ih, const TensorLayout& hx, const TensorLayout& weight_hh, const TensorLayout& bias_hh, const TensorLayout& cx, const TensorLayout& h_new, const TensorLayout& c_new, const TensorLayout& gates, size_t workspace_in_bytes) { TensorLayout h_new_expected, c_new_expected, gates_expected; auto errmsg = [&]() { std::string msg; msg.append("input="); msg.append(input.to_string()); msg.append(", weight_ih="); msg.append(weight_ih.to_string()); msg.append(", bias_ih="); msg.append(bias_ih.to_string()); msg.append(", hx="); msg.append(hx.to_string()); msg.append(", weight_hh="); msg.append(weight_hh.to_string()); msg.append(", bias_hh="); msg.append(bias_hh.to_string()); msg.append(", cx="); msg.append(cx.to_string()); return msg; }; #define ASSERT_BRIEF(_content) megdnn_assert(_content, "%s", errmsg().c_str()); ASSERT_BRIEF(input.ndim == 2) ASSERT_BRIEF(input.shape[1] == weight_ih.shape[1]) ASSERT_BRIEF(weight_ih.shape[0] == weight_hh.shape[0]) ASSERT_BRIEF(weight_hh.shape[0] == 4 * weight_hh.shape[1]) ASSERT_BRIEF(bias_ih.shape[0] == bias_hh.shape[0]) ASSERT_BRIEF(hx.ndim == 2) ASSERT_BRIEF(hx.shape[0] == input.shape[0]) ASSERT_BRIEF(hx.shape[1] == cx.shape[1]) // hidden_size ASSERT_BRIEF(cx.ndim == 2) ASSERT_BRIEF(cx.shape[0] == input.shape[0]) ASSERT_BRIEF(cx.shape[1] == weight_hh.shape[1]) #undef ASSERT_BRIEF deduce_layout( input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new_expected, c_new_expected, gates_expected); megdnn_assert_eq_layout(h_new_expected, h_new); megdnn_assert_eq_layout(c_new_expected, c_new); megdnn_assert_eq_layout(gates_expected, gates); auto required_workspace_in_bytes = get_workspace_in_bytes( input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, gates); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } } // namespace megdnn namespace megdnn { namespace lstm_cell { size_t get_workspace_in_bytes( const TensorLayout& input, const TensorLayout& weight_ih, const TensorLayout& bias_ih, const TensorLayout& hx, const TensorLayout& weight_hh, const TensorLayout& bias_hh, const TensorLayout& /*cx*/, const TensorLayout& /*h_new*/, const TensorLayout& /*c_new*/, const TensorLayout& gates, Handle* handle) { TensorLayout tmp_layout; auto opr = handle->create_operator(); opr->param().nonlineMode = param::RNNCell::NonlineMode::IDENTITY; opr->deduce_layout(input, weight_ih, bias_ih, hx, weight_hh, bias_hh, tmp_layout); size_t rnn_cell_need = opr->get_workspace_in_bytes( input, weight_ih, bias_ih, hx, weight_hh, bias_hh, gates); size_t lstm_cell_need = 2 * tmp_layout.span().dist_byte(); return rnn_cell_need > lstm_cell_need ? rnn_cell_need : lstm_cell_need; } void exec( _megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih, _megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, _megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new, _megdnn_tensor_out gates, _megdnn_workspace workspace, Handle* handle) { auto opr = handle->create_operator(); opr->param().nonlineMode = param::RNNCell::NonlineMode::IDENTITY; opr->exec(input, weight_ih, bias_ih, hx, weight_hh, bias_hh, gates, workspace); // activation size_t batch_size = hx.layout.shape[0]; size_t hidden_size = hx.layout.shape[1]; auto copy_opr = handle->create_operator(); TensorND copy_gates{static_cast(workspace.raw_ptr), gates.layout}; TensorLayout hidden_layout{TensorShape{hidden_size}, hx.layout.dtype}; TensorLayout gateinfo_layout{TensorShape{batch_size, hidden_size}, hx.layout.dtype}; for (size_t i = 0; i < batch_size; i++) { for (size_t j = 0; j < 4; j++) { TensorND half_step_states{ // output static_cast(gates.raw_ptr()) + (4 * i + j) * hidden_layout.span().dist_byte(), hidden_layout}; TensorND half_step_output{ static_cast(copy_gates.raw_ptr()) + j * gateinfo_layout.span().dist_byte() + i * hidden_layout.span().dist_byte(), hidden_layout}; copy_opr->exec(half_step_states, half_step_output); } } void* workspace_ptr = workspace.raw_ptr + copy_gates.layout.span().dist_byte(); copy_opr->exec(copy_gates, gates); // sigmoid: i f TensorND tmp{static_cast(workspace_ptr), copy_gates.layout}; TensorLayout gates_ifo_layout{ TensorShape({batch_size, hidden_size * 2}), copy_gates.layout.dtype}; TensorND gates_ifo_origin{copy_gates.raw_ptr(), gates_ifo_layout}; TensorND gates_ifo{tmp.raw_ptr(), gates_ifo_layout}; auto sigmoid = handle->create_operator(); sigmoid->param().mode = Elemwise::Param::Mode::SIGMOID; sigmoid->exec({gates_ifo_origin}, gates_ifo); // tanh: g TensorLayout g_layout{ TensorShape({batch_size, hidden_size}), copy_gates.layout.dtype}; TensorND g_origin{ static_cast(copy_gates.raw_ptr()) + gates_ifo_layout.span().dist_byte(), g_layout}; TensorND g{ static_cast(tmp.raw_ptr()) + gates_ifo_layout.span().dist_byte(), g_layout}; auto tanh = handle->create_operator(); tanh->param().mode = Elemwise::Param::Mode::TANH; tanh->exec({g_origin}, g); // sigmoid: o TensorLayout three_gates_ifo_layout{ TensorShape({batch_size, hidden_size * 3}), copy_gates.layout.dtype}; TensorLayout o_layout{ TensorShape({batch_size, hidden_size}), copy_gates.layout.dtype}; TensorND o_origin{ static_cast(copy_gates.raw_ptr()) + three_gates_ifo_layout.span().dist_byte(), o_layout}; TensorND o{ static_cast(tmp.raw_ptr()) + three_gates_ifo_layout.span().dist_byte(), o_layout}; sigmoid->exec({o_origin}, o); // extract i f o TensorND i{static_cast(tmp.raw_ptr()), g_layout}; TensorND f{ static_cast(tmp.raw_ptr()) + g_layout.span().dist_byte(), g_layout}; // calculate new cell state auto elewise_mul_add = handle->create_operator(); elewise_mul_add->param().mode = Elemwise::Param::Mode::FUSE_MUL_ADD4; elewise_mul_add->exec({f, cx, i, g}, c_new); // calculate new hidden state tanh->exec({c_new}, h_new); auto elewise_mul = handle->create_operator(); elewise_mul->param().mode = Elemwise::Param::Mode::MUL; elewise_mul->exec({o, h_new}, h_new); } } // namespace lstm_cell } // namespace megdnn