提交 90a28f51 编写于 作者: X xiaogang 提交者: GitHub

lstm op (#3018)

* feat: add lstm op && kernel
      test=develop
上级 d3f627d2
......@@ -123,5 +123,6 @@ if (NOT HAS_ARM_MATH_LIB_DIR)
anchor_generator.cc
split_merge_lod_tenosr.cc
reduce_prod.cc
lstm.cc
DEPS ${lite_kernel_deps} context tensor)
endif()
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/arm/math/lstm.h"
#include "lite/backends/arm/math/funcs.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
void add_bias_rowwise(Tensor* input,
const Tensor* bias,
int start_w,
int end_w) {
auto in_dim = input->dims();
int width = input->numel() / in_dim[0];
int w_adds = width < end_w ? width : end_w;
float* i_data = input->mutable_data<float>();
const float* b_data = bias->data<float>();
for (int i = 0; i < in_dim[0]; ++i) {
for (int w = start_w; w < w_adds; ++w) {
i_data[w] += b_data[w];
}
}
}
void vector_dot(
float* out, const float* in, const float* v1, int size, const float* v2) {
int loop = size >> 2;
int remain = size & 3;
const float* in_ptr = in;
float* out_ptr = out;
const float* v1_ptr = v1;
const float* v2_ptr = v2;
for (int i = 0; i < loop; ++i) {
float32x4_t in = vld1q_f32(in_ptr);
float32x4_t data1 = vld1q_f32(v1_ptr);
if (!v2) {
// in_out * v1
float32x4_t out = vmulq_f32(in, data1);
vst1q_f32(out_ptr, out);
in_ptr += 4;
v1_ptr += 4;
out_ptr += 4;
} else {
// in_out + v1 * v2
float32x4_t data2 = vld1q_f32(v2_ptr);
float32x4_t out = vmlaq_f32(in, data1, data2);
vst1q_f32(out_ptr, out);
in_ptr += 4;
v1_ptr += 4;
out_ptr += 4;
v2_ptr += 4;
}
}
for (int i = 0; i < remain; ++i) {
if (!v2) {
out_ptr[i] = in_ptr[i] * v1_ptr[i];
++out_ptr;
++in_ptr;
++v1_ptr;
} else {
out_ptr[i] = in_ptr[i] + v1_ptr[i] * v2_ptr[i];
++out_ptr;
++in_ptr;
++v1_ptr;
++v2_ptr;
}
}
}
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <arm_neon.h>
#include <string>
#include "lite/backends/arm/math/activation.h"
#include "lite/core/tensor.h"
#include "lite/utils/logging.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
void add_bias_rowwise(Tensor* input,
const Tensor* bias,
int start_w,
int end_w);
inline float* row_offset(Tensor& input, int start) { // NOLINT
auto in_dim = input.dims();
int width = input.numel() / in_dim[0];
int offset = start < in_dim[0] ? start * width : input.numel();
return input.mutable_data<float>() + offset;
}
template <class T>
struct LstmMetaValue {
T* gate_value;
T* prev_state_value;
T* state_value;
T* state_active_value;
T* output_value;
T* check_ig;
T* check_fg;
T* check_og;
};
template <typename T>
void activation(
const T* din, T* dout, int size, std::string act_str, int threads) {
if (act_str == "sigmoid") {
act_sigmoid(din, dout, size, threads);
} else if (act_str == "tanh") {
act_tanh(din, dout, size, threads);
} else if (act_str == "relu") {
act_relu(din, dout, size, threads);
} else {
LOG(FATAL) << "unsupport activation " << act_str;
}
}
void vector_dot(float* out,
const float* in,
const float* v1,
int size,
const float* v2 = nullptr);
template <typename T>
struct LstmUnitFunctor {
static void compute(LstmMetaValue<T> value,
int frame_size,
int batch_size,
T cell_clip,
std::string gate_act,
std::string cell_act,
std::string cand_act,
int threads) {
for (int b = 0; b < batch_size; ++b) {
const int temp_len = frame_size;
float zero_ptr[temp_len]; // NOLINT
memset(zero_ptr, 0, sizeof(float) * temp_len);
T* value_in = value.gate_value;
T* value_ig = value_in + frame_size;
T* value_fg = value_ig + frame_size;
T* value_og = value_fg + frame_size;
T* state = value.state_value;
T* state_act = value.state_active_value;
T* output = value.output_value;
T* check_i = value.check_ig ? value.check_ig : zero_ptr;
T* check_f = value.check_fg ? value.check_fg : zero_ptr;
T* check_o = value.check_og ? value.check_og : zero_ptr;
T* prev_state =
value.prev_state_value ? value.prev_state_value : zero_ptr;
activation(value_in, value_in, frame_size, gate_act, threads);
vector_dot(value_ig, value_ig, prev_state, frame_size, check_i);
vector_dot(value_fg, value_fg, prev_state, frame_size, check_f);
activation(value_ig, value_ig, frame_size, cell_act, threads);
activation(value_fg, value_fg, frame_size, cell_act, threads);
vector_dot(state, value_in, value_ig, frame_size);
vector_dot(state, state, prev_state, frame_size, value_fg);
for (int i = 0; i < frame_size; ++i) {
if (cell_clip > 0.0) {
if (state[i] < -1.0 * cell_clip) {
state[i] = -1.0 * cell_clip;
}
if (state[i] > cell_clip) {
state[i] = cell_clip;
}
}
}
vector_dot(value_og, value_og, state, frame_size, check_o);
activation(value_og, value_og, frame_size, cell_act, threads);
activation(state, state_act, frame_size, cand_act, threads);
vector_dot(value.output_value, value_og, state_act, frame_size);
value.gate_value += frame_size * 4;
value.state_value += frame_size;
value.state_active_value += frame_size;
value.output_value += frame_size;
if (value.prev_state_value) {
value.prev_state_value += frame_size;
}
}
}
};
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
......@@ -101,6 +101,7 @@ add_kernel(beam_search_compute_arm ARM extra SRCS beam_search_compute.cc DEPS ${
add_kernel(fill_constant_compute_arm ARM basic SRCS fill_constant_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(lod_reset_compute_arm ARM extra SRCS lod_reset_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(is_empty_compute_arm ARM extra SRCS is_empty_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(lstm_arm ARM extra SRCS lstm_compute.cc DEPS ${lite_kernel_deps} math_arm)
lite_cc_test(test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_compute_arm)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/lstm_compute.h"
#include <string>
#include <vector>
#include "lite/backends/arm/math/lstm.h"
#include "lite/backends/arm/math/sequence2batch.h"
#include "lite/backends/arm/math/sgemm.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
template <typename T>
void LstmCompute<T>::Run() {
auto& param = this->Param<operators::LstmParam>();
auto input = param.Input;
auto weight = param.Weight;
auto bias = param.Bias;
auto hidden_t0 = param.H0;
auto cell_t0 = param.C0;
auto batch_gate = param.BatchGate;
auto hidden_out = param.Hidden;
auto cell_out = param.Cell;
auto batch_cell_pre_act = param.BatchCellPreAct;
batch_gate->template mutable_data<T>();
hidden_out->template mutable_data<T>();
cell_out->template mutable_data<T>();
bool is_reverse = param.is_reverse;
lite::arm::math::LoDTensor2BatchFunctor<T> to_batch;
to_batch(*input, batch_gate, true, is_reverse);
auto in_dims = input->dims();
int frame_size = static_cast<int>(in_dims[1] / 4);
DDimLite dims(std::vector<int64_t>{in_dims[0], frame_size});
if (bias) {
// checkpoint1
lite::arm::math::add_bias_rowwise(batch_gate, bias, 0, 4 * frame_size);
}
lite::arm::math::LstmMetaValue<T> lstm_value;
if (bias && param.use_peepholes) {
T* bias_data = const_cast<T*>(bias->template data<T>());
// the code style in LstmMetaValue will be updated later.
lstm_value.check_ig = bias_data + 4 * frame_size;
lstm_value.check_fg = lstm_value.check_ig + frame_size;
lstm_value.check_og = lstm_value.check_fg + frame_size;
} else {
lstm_value.check_ig = nullptr;
lstm_value.check_fg = nullptr;
lstm_value.check_og = nullptr;
}
lstm_value.prev_state_value = nullptr;
Tensor ordered_c0;
std::vector<uint64_t> order(batch_gate->lod()[2]);
if (cell_t0) {
// Since the batch computing for LSTM reorders the input sequence
// according to their length. The initialized cell state also needs
// to reorder.
lite::arm::math::ReorderInitState<T>(*cell_t0, order, &ordered_c0, true);
lstm_value.prev_state_value = ordered_c0.mutable_data<T>();
}
// Use the local variable as here.
Tensor batch_hidden, batch_cell;
batch_hidden.Resize(dims);
batch_cell.Resize(dims);
batch_cell_pre_act->Resize(dims);
batch_hidden.mutable_data<T>();
batch_cell.mutable_data<T>();
batch_cell_pre_act->template mutable_data<T>();
auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1;
std::string gate_act = param.gate_activation;
std::string cell_act = param.cell_activation;
std::string cand_act = param.candidate_activation;
int matrix_width = batch_gate->numel() / in_dims[0];
auto& ctx = this->ctx_->template As<ARMContext>();
for (size_t n = 0; n < num_batch; n++) {
int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]);
auto gate_t = lite::arm::math::row_offset(*batch_gate, bstart);
auto out_t = lite::arm::math::row_offset(batch_hidden, bstart);
auto cell_t = lite::arm::math::row_offset(batch_cell, bstart);
auto cell_pre_act_t =
lite::arm::math::row_offset(*batch_cell_pre_act, bstart);
int cur_batch_size = bend - bstart;
operators::ActivationParam act_param;
act_param.has_active = false;
if (n > 0) {
int pre_h_start = static_cast<int>(batch_starts[n - 1]);
int pre_h_end = pre_h_start + cur_batch_size;
auto pre_hidden_t =
lite::arm::math::row_offset(batch_hidden, pre_h_start);
int M = pre_h_end - pre_h_start;
int N = matrix_width;
int K = frame_size;
lite::arm::math::sgemm(false,
false,
M,
N,
K,
1,
pre_hidden_t,
K,
weight->template data<T>(),
N,
1,
gate_t,
N,
nullptr,
false,
act_param,
&ctx);
} else if (hidden_t0) {
// If n == 0 and there is no initialized hidden state, that is to say
// the H0 is zeros, the calculation W_h * H0 will be skiped.
// If n == 0 and there is initialized hidden state, calculate W_h * H0.
// Since the batch computing for LSTM reorders the input sequence
// according to their length. The initialized hidden state also needs
// to reorder.
Tensor ordered_h0;
lite::arm::math::ReorderInitState<T>(
*hidden_t0, order, &ordered_h0, true);
int M = ordered_h0.dims()[0];
int N = matrix_width;
int K = frame_size;
lite::arm::math::sgemm(false,
false,
M,
N,
K,
1,
ordered_h0.data<T>(),
K,
weight->template data<T>(),
N,
1,
gate_t,
N,
nullptr,
false,
act_param,
&ctx);
}
lstm_value.gate_value = gate_t;
lstm_value.output_value = out_t;
lstm_value.state_value = cell_t;
lstm_value.state_active_value = cell_pre_act_t;
T cell_clip = 0.0;
// checkpoint
lite::arm::math::LstmUnitFunctor<T>::compute(lstm_value,
frame_size,
cur_batch_size,
cell_clip,
cand_act,
gate_act,
cell_act,
ctx.threads());
lstm_value.prev_state_value = lstm_value.state_value;
}
lite::arm::math::Batch2LoDTensorFunctor<T> to_seq;
auto* lod_hidden = batch_hidden.mutable_lod();
*lod_hidden = batch_gate->lod();
to_seq(batch_hidden, hidden_out);
auto* lod_cell = batch_cell.mutable_lod();
*lod_cell = batch_gate->lod();
to_seq(batch_cell, cell_out);
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(lstm,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::LstmCompute<float>,
def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Weight", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Hidden", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Cell", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("BatchGate", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("BatchCellPreAct", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
template <typename T>
class LstmCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
void Run() override;
virtual ~LstmCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -134,6 +134,7 @@ add_operator(search_aligned_mat_mul_op extra SRCS search_aligned_mat_mul_op.cc D
add_operator(search_seq_fc_op extra SRCS search_seq_fc_op.cc DEPS ${op_DEPS})
add_operator(sequence_topk_avg_pooling_op basic SRCS sequence_topk_avg_pooling_op.cc DEPS ${op_DEPS})
add_operator(search_fc_op basic SRCS search_fc_op.cc DEPS ${op_DEPS})
add_operator(lstm_op extra SRCS lstm_op.cc DEPS ${op_DEPS})
if (NOT LITE_WITH_X86)
lite_cc_test(test_fc_op SRCS fc_op_test.cc
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/lstm_op.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool LstmOp::CheckShape() const {
CHECK_OR_FALSE(param_.Input);
CHECK_OR_FALSE(param_.Weight);
CHECK_OR_FALSE(param_.Bias);
return true;
}
bool LstmOp::InferShape() const {
auto in_dims = param_.Input->dims();
if (param_.H0) {
CHECK(param_.C0) << "lstm must has H0 and C0 in the same time";
auto h_dims = param_.H0->dims();
auto c_dims = param_.C0->dims();
CHECK_EQ(h_dims, c_dims) << "H0 and C0 dims must be same";
}
int frame_size = in_dims[1] / 4;
auto w_dims = param_.Weight->dims();
CHECK_EQ(w_dims.size(), 2) << "weight dims should be 2";
CHECK_EQ(w_dims[0], frame_size) << "weight first dims should be "
<< frame_size;
CHECK_EQ(w_dims[1], 4 * frame_size) << "weight dims should be 4 * "
<< frame_size;
auto b_dims = param_.Bias->dims();
CHECK_EQ(b_dims.size(), 2) << "Bias dims should be 2";
CHECK_EQ(b_dims[0], 1) << "Bias first dims should be 1";
if (param_.use_peepholes) {
CHECK_EQ(b_dims[1], 7 * frame_size) << "Bias second dim must be 7 * "
<< frame_size;
} else {
CHECK_EQ(b_dims[1], 4 * frame_size) << "Bias second dim must be 4 * "
<< frame_size;
}
DDimLite out_dims(std::vector<int64_t>{in_dims[0], frame_size});
param_.Hidden->Resize(out_dims);
param_.Cell->Resize(out_dims);
param_.BatchCellPreAct->Resize(out_dims);
param_.BatchGate->Resize(in_dims);
auto hidden_lod = param_.Hidden->mutable_lod();
*hidden_lod = param_.Input->lod();
auto cell_lod = param_.Cell->mutable_lod();
*cell_lod = param_.Input->lod();
return true;
}
bool LstmOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.Input =
scope->FindVar(opdesc.Input("Input").front())->GetMutable<lite::Tensor>();
param_.Weight = scope->FindVar(opdesc.Input("Weight").front())
->GetMutable<lite::Tensor>();
param_.Bias =
scope->FindVar(opdesc.Input("Bias").front())->GetMutable<lite::Tensor>();
param_.Hidden = scope->FindVar(opdesc.Output("Hidden").front())
->GetMutable<lite::Tensor>();
param_.Cell =
scope->FindVar(opdesc.Output("Cell").front())->GetMutable<lite::Tensor>();
param_.BatchGate = scope->FindVar(opdesc.Output("BatchGate").front())
->GetMutable<lite::Tensor>();
param_.BatchCellPreAct =
scope->FindVar(opdesc.Output("BatchCellPreAct").front())
->GetMutable<lite::Tensor>();
CHECK(param_.Input);
CHECK(param_.Weight);
CHECK(param_.Bias);
if (opdesc.Input("C0").size()) {
param_.C0 =
scope->FindVar(opdesc.Input("C0").front())->GetMutable<lite::Tensor>();
}
if (opdesc.Input("H0").size()) {
param_.H0 =
scope->FindVar(opdesc.Input("H0").front())->GetMutable<lite::Tensor>();
}
param_.use_peepholes = opdesc.GetAttr<bool>("use_peepholes");
param_.is_reverse = opdesc.GetAttr<bool>("is_reverse");
param_.gate_activation = opdesc.GetAttr<std::string>("gate_activation");
param_.cell_activation = opdesc.GetAttr<std::string>("cell_activation");
param_.candidate_activation =
opdesc.GetAttr<std::string>("candidate_activation");
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(lstm, paddle::lite::operators::LstmOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class LstmOp : public OpLite {
public:
LstmOp() {}
explicit LstmOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "lstm"; }
private:
mutable LstmParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -1134,6 +1134,22 @@ struct GridSamplerParam {
lite::Tensor* out{};
lite::Tensor* grid{};
};
struct LstmParam {
lite::Tensor* Input{};
lite::Tensor* Weight{};
lite::Tensor* Bias{};
lite::Tensor* Hidden{};
lite::Tensor* Cell{};
lite::Tensor* BatchGate{};
lite::Tensor* BatchCellPreAct{};
lite::Tensor* H0{nullptr};
lite::Tensor* C0{nullptr};
bool use_peepholes;
bool is_reverse;
std::string gate_activation;
std::string cell_activation;
std::string candidate_activation;
};
} // namespace operators
} // namespace lite
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册