提交 0180d460 编写于 作者: Z zhaojiaying01

add gru_unit op

上级 d460e482
......@@ -66,6 +66,7 @@ const char *G_OP_TYPE_CONV_TRANSPOSE = "conv2d_transpose";
const char *G_OP_TYPE_PRELU = "prelu";
const char *G_OP_TYPE_LOOKUP_TABLE = "lookup_table";
const char *G_OP_TYPE_GRU = "gru";
const char *G_OP_TYPE_GRU_UNIT = "gru_unit";
const char *G_OP_TYPE_CRF = "crf_decoding";
const char *G_OP_TYPE_BILINEAR_INTERP = "bilinear_interp";
const char *G_OP_TYPE_FLATTEN = "flatten";
......@@ -149,6 +150,9 @@ std::unordered_map<
{G_OP_TYPE_GRU,
{{"Input", "H0", "Weight", "Bias"},
{"BatchGate", "BatchResetHiddenPrev", "BatchHidden", "Hidden"}}},
{G_OP_TYPE_GRU_UNIT,
{{"Input", "HiddenPrev", "Weight", "Bias"},
{"Gate", "ResetHiddenPrev", "Hidden"}}},
{G_OP_TYPE_CRF, {{"Emission", "Transition", "Label"}, {"ViterbiPath"}}},
{G_OP_TYPE_BILINEAR_INTERP, {{"OutSize", "X"}, {"Out"}}},
{G_OP_TYPE_FLATTEN, {{"X"}, {"Out"}}},
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef GRU_UNIT_OP
#include "operators/gru_unit_op.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
void GruUnitOp<DeviceType, T>::InferShape() const {
auto input_dims = this->param_.InputInput()->dims();
auto hidden_prev_dims = this->param_.InputHiddenPrev()->dims();
auto weight_dims = this->param_.InputWeight()->dims();
int batch_size = input_dims[0];
int input_size = input_dims[1];
int frame_size = hidden_prev_dims[1];
int weight_height = weight_dims[0];
int weight_width = weight_dims[1];
PADDLE_MOBILE_ENFORCE(
(input_size == frame_size * 3),
"The input_size must be 3 times of frame_size in GRUUnitOp.");
PADDLE_MOBILE_ENFORCE(
(weight_height == frame_size),
"The shape of Weight matrix must be [frame_size, frame_size * 3].");
PADDLE_MOBILE_ENFORCE(
(weight_width == frame_size * 3),
"The shape of Weight matrix must be [frame_size, frame_size * 3].");
if (this->param_.InputBias()) {
auto bias_dims = this->param_.InputBias()->dims();
int bias_height = bias_dims[0];
int bias_width = bias_dims[1];
PADDLE_MOBILE_ENFORCE((bias_height == 1),
"The shape of Bias must be [1, frame_size * 3].");
PADDLE_MOBILE_ENFORCE((bias_width == frame_size * 3),
"The shape of Bias must be [1, frame_size * 3].");
}
this->param_.OutGate()->Resize({batch_size, frame_size * 3});
this->param_.OutResetHiddenPrev()->Resize({batch_size, frame_size});
this->param_.OutHidden()->Resize({batch_size, frame_size});
}
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(gru_unit, ops::GruUnitOp);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
#endif
#ifdef PADDLE_MOBILE_FPGA
#endif
#ifdef PADDLE_MOBILE_CL
#endif
#endif
\ No newline at end of file
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef GRU_UNIT_OP
#pragma once
#include "framework/operator.h"
#include "operators/kernel/gru_unit_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class GruUnitOp : public framework::OperatorWithKernel<
DeviceType, GruUnitParam<DeviceType>,
operators::GruUnitKernel<DeviceType, T>> {
public:
GruUnitOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs,
std::shared_ptr<Scope> scope)
: framework::OperatorWithKernel<DeviceType, GruUnitParam<DeviceType>,
operators::GruUnitKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope){};
void InferShape() const override;
};
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef GRU_UNIT_OP
#include "operators/kernel/gru_unit_kernel.h"
#include "operators/kernel/central-arm-func/gru_unit_arm_func.h"
namespace paddle_mobile {
namespace operators {
template <>
bool GruUnitKernel<CPU, float>::Init(GruUnitParam<CPU> *param) {
return true;
}
template <>
void GruUnitKernel<CPU, float>::Compute(const GruUnitParam<CPU> &param) {
GruUnitCompute(param);
}
template class GruUnitKernel<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef GRU_UNIT_OP
#pragma once
#include <operators/math/gru_compute.h>
#include "operators/math/math_function.h"
#include "operators/kernel/activation_kernel.h"
#include "operators/math/gemm.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename P>
void GruUnitCompute(const GruUnitParam<CPU>& param) {
auto* input = param.InputInput();
auto* hidden_prev = param.InputHiddenPrev();
auto* weight = param.InputWeight();
auto* bias = param.InputBias();
auto* gate = param.OutGate();
auto* reset_hidden_prev = param.OutResetHiddenPrev();
auto* hidden = param.OutHidden();
if (bias) {
math::RowwiseAdd<CPU, float> add_bias;
add_bias(*gate, *bias, gate);
}
int batch_size = input->dims()[0];
int frame_size = hidden_prev->dims()[1];
const P* weight_data = weight->data<P>();
math::GRUMetaValue<P> gru_value;
gru_value.gate_weight = const_cast<P*>(weight_data);
gru_value.state_weight =
const_cast<P*>(weight_data + 2 * frame_size * frame_size);
gru_value.output_value = hidden->data<P>();
gru_value.prev_out_value = gru_value.output_value;
gru_value.gate_value = gate->data<P>();
gru_value.reset_output_value = reset_hidden_prev->data<P>();
auto active_node = math::GetActivationType(param.Activation());
auto active_gate = math::GetActivationType(param.GateActivation());
math::GRUUnitFunctor<CPU, float>::compute(gru_value, frame_size, batch_size,
active_node, active_gate);
}
} // namespace operators
} // namespace paddle_mobile
#endif
\ No newline at end of file
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef GRU_UNIT_OP
#pragma once
#include "framework/operator.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class GruUnitKernel
: public framework::OpKernelBase<DeviceType, GruUnitParam<DeviceType>> {
public:
void Compute(const GruUnitParam<DeviceType>& param);
bool Init(GruUnitParam<DeviceType>* param);
};
} // namespace operators
} // namespace paddle_mobile
#endif
\ No newline at end of file
......@@ -45,6 +45,19 @@ inline ActivationType GetActivationType(const std::string &type) {
PADDLE_MOBILE_THROW_EXCEPTION("Not support activation type.");
}
inline ActivationType GetActivationType(const int type) {
if (type == 0) {
return ActivationType::IDENTITY;
} else if (type == 1) {
return ActivationType::SIGMOID;
} else if (type == 2) {
return ActivationType::TANH;
} else if (type == 3) {
return ActivationType::RELU;
}
PADDLE_MOBILE_THROW_EXCEPTION("Not support activation type.");
}
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
template <ActivationType Act = IDENTITY>
inline float32x4_t vActiveq_f32(const float32x4_t &x) {
......
......@@ -74,6 +74,13 @@ class OpParam {
static T *InputH0From(const VariableNameMap &inputs, const Scope &scope) {
return GetVarValue<T>("H0", inputs, scope);
}
template <typename T>
static T *InputHiddenPrevFrom(const VariableNameMap &inputs,
const Scope &scope) {
return GetVarValue<T>("HiddenPrev", inputs, scope);
}
template <typename T>
static T *InputAlphaFrom(const VariableNameMap &inputs, const Scope &scope) {
return GetVarValue<T>("Alpha", inputs, scope);
......@@ -214,6 +221,11 @@ class OpParam {
return GetVarValue<T>("BatchGate", outputs, scope);
}
template <typename T>
static T *OutputGateFrom(const VariableNameMap &outputs, const Scope &scope) {
return GetVarValue<T>("Gate", outputs, scope);
}
template <typename T>
static T *OutputViterbiPathFrom(const VariableNameMap &outputs,
const Scope &scope) {
......@@ -225,6 +237,12 @@ class OpParam {
return GetVarValue<T>("BatchResetHiddenPrev", outputs, scope);
}
template <typename T>
static T *OutputResetHiddenPrevFrom(const VariableNameMap &outputs,
const Scope &scope) {
return GetVarValue<T>("ResetHiddenPrev", outputs, scope);
}
template <typename T>
static T *OutputBatchHiddenFrom(const VariableNameMap &outputs,
const Scope &scope) {
......@@ -2444,6 +2462,51 @@ class GruParam : public OpParam {
};
#endif
#ifdef GRU_UNIT_OP
template <typename Dtype>
class GruUnitParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
public:
GruUnitParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
input_input_ = InputFrom<GType>(inputs, scope);
input_hidden_prev_ = InputHiddenPrevFrom<GType>(inputs, scope);
input_bias_ = InputBiasFrom<GType>(inputs, scope);
input_weight_ = InputWeightFrom<GType>(inputs, scope);
output_gate_ = OutputGateFrom<GType>(outputs, scope);
output_reset_hidden_prev_ =
OutputResetHiddenPrevFrom<GType>(outputs, scope);
output_hidden_ = OutputHiddenFrom<GType>(outputs, scope);
activation_ = GetAttr<int>("activation", attrs);
gate_activation_ = GetAttr<int>("gate_activation", attrs);
}
const GType *InputInput() const { return input_input_; }
const GType *InputWeight() const { return input_weight_; }
const GType *InputHiddenPrev() const { return input_hidden_prev_; }
const GType *InputBias() const { return input_bias_; }
const int &Activation() const { return activation_; }
const int &GateActivation() const { return gate_activation_; }
GType *OutGate() const { return output_gate_; }
GType *OutResetHiddenPrev() const { return output_reset_hidden_prev_; }
GType *OutHidden() const { return output_hidden_; }
private:
GType *input_input_;
GType *input_hidden_prev_;
GType *input_bias_;
GType *input_weight_;
GType *output_gate_;
GType *output_reset_hidden_prev_;
GType *output_hidden_;
int activation_;
int gate_activation_;
};
#endif
#ifdef FLATTEN_OP
template <typename Dtype>
class FlattenParam : public OpParam {
......
......@@ -256,6 +256,7 @@ if(NOT FOUND_MATCH)
set(IM2SEQUENCE_OP ON)
set(LOOKUP_OP ON)
set(GRU_OP ON)
set(GRU_UNIT_OP ON)
set(CRF_OP ON)
set(BILINEAR_INTERP_OP ON)
set(SPLIT_OP ON)
......@@ -450,6 +451,10 @@ if (GRU_OP)
add_definitions(-DGRU_OP)
endif()
if (GRU_UNIT_OP)
add_definitions(-DGRU_UNIT_OP)
endif()
if (CRF_OP)
add_definitions(-DCRF_OP)
endif()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册