diff --git a/src/common/types.cpp b/src/common/types.cpp index cd5e66517f159f5f9db118313b78ccd2a8c216a8..d1a1a55a89f69a8d6f195e548b864af8d5bd4e64 100644 --- a/src/common/types.cpp +++ b/src/common/types.cpp @@ -66,6 +66,7 @@ const char *G_OP_TYPE_CONV_TRANSPOSE = "conv2d_transpose"; const char *G_OP_TYPE_PRELU = "prelu"; const char *G_OP_TYPE_LOOKUP_TABLE = "lookup_table"; const char *G_OP_TYPE_GRU = "gru"; +const char *G_OP_TYPE_GRU_UNIT = "gru_unit"; const char *G_OP_TYPE_CRF = "crf_decoding"; const char *G_OP_TYPE_BILINEAR_INTERP = "bilinear_interp"; const char *G_OP_TYPE_FLATTEN = "flatten"; @@ -149,6 +150,9 @@ std::unordered_map< {G_OP_TYPE_GRU, {{"Input", "H0", "Weight", "Bias"}, {"BatchGate", "BatchResetHiddenPrev", "BatchHidden", "Hidden"}}}, + {G_OP_TYPE_GRU_UNIT, + {{"Input", "HiddenPrev", "Weight", "Bias"}, + {"Gate", "ResetHiddenPrev", "Hidden"}}}, {G_OP_TYPE_CRF, {{"Emission", "Transition", "Label"}, {"ViterbiPath"}}}, {G_OP_TYPE_BILINEAR_INTERP, {{"OutSize", "X"}, {"Out"}}}, {G_OP_TYPE_FLATTEN, {{"X"}, {"Out"}}}, diff --git a/src/operators/gru_unit_op.cpp b/src/operators/gru_unit_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..056026c5b9fb359de378ae2d63d13c0fac16e367 --- /dev/null +++ b/src/operators/gru_unit_op.cpp @@ -0,0 +1,70 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef GRU_UNIT_OP + +#include "operators/gru_unit_op.h" + +namespace paddle_mobile { +namespace operators { + +template +void GruUnitOp::InferShape() const { + auto input_dims = this->param_.InputInput()->dims(); + auto hidden_prev_dims = this->param_.InputHiddenPrev()->dims(); + auto weight_dims = this->param_.InputWeight()->dims(); + int batch_size = input_dims[0]; + int input_size = input_dims[1]; + int frame_size = hidden_prev_dims[1]; + int weight_height = weight_dims[0]; + int weight_width = weight_dims[1]; + PADDLE_MOBILE_ENFORCE( + (input_size == frame_size * 3), + "The input_size must be 3 times of frame_size in GRUUnitOp."); + PADDLE_MOBILE_ENFORCE( + (weight_height == frame_size), + "The shape of Weight matrix must be [frame_size, frame_size * 3]."); + PADDLE_MOBILE_ENFORCE( + (weight_width == frame_size * 3), + "The shape of Weight matrix must be [frame_size, frame_size * 3]."); + if (this->param_.InputBias()) { + auto bias_dims = this->param_.InputBias()->dims(); + int bias_height = bias_dims[0]; + int bias_width = bias_dims[1]; + PADDLE_MOBILE_ENFORCE((bias_height == 1), + "The shape of Bias must be [1, frame_size * 3]."); + PADDLE_MOBILE_ENFORCE((bias_width == frame_size * 3), + "The shape of Bias must be [1, frame_size * 3]."); + } + this->param_.OutGate()->Resize({batch_size, frame_size * 3}); + this->param_.OutResetHiddenPrev()->Resize({batch_size, frame_size}); + this->param_.OutHidden()->Resize({batch_size, frame_size}); +} + +} // namespace operators +} // namespace paddle_mobile + +namespace ops = paddle_mobile::operators; +#ifdef PADDLE_MOBILE_CPU +REGISTER_OPERATOR_CPU(gru_unit, ops::GruUnitOp); +#endif +#ifdef PADDLE_MOBILE_MALI_GPU +#endif +#ifdef PADDLE_MOBILE_FPGA +#endif + +#ifdef PADDLE_MOBILE_CL +#endif + +#endif \ No newline at end of file diff --git a/src/operators/gru_unit_op.h b/src/operators/gru_unit_op.h new file mode 100644 index 0000000000000000000000000000000000000000..4188662d05e79a97fa2f0dba62303391ae8e0d70 --- /dev/null +++ b/src/operators/gru_unit_op.h @@ -0,0 +1,43 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef GRU_UNIT_OP + +#pragma once + +#include "framework/operator.h" +#include "operators/kernel/gru_unit_kernel.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +template +class GruUnitOp : public framework::OperatorWithKernel< + DeviceType, GruUnitParam, + operators::GruUnitKernel> { + public: + GruUnitOp(const std::string &type, const VariableNameMap &inputs, + const VariableNameMap &outputs, const AttributeMap &attrs, + std::shared_ptr scope) + : framework::OperatorWithKernel, + operators::GruUnitKernel>( + type, inputs, outputs, attrs, scope){}; + void InferShape() const override; +}; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/arm/gru_unit_kernel.cpp b/src/operators/kernel/arm/gru_unit_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..28b6da0c4d1ad1e5c0cf39e043abf7852a3cc5cf --- /dev/null +++ b/src/operators/kernel/arm/gru_unit_kernel.cpp @@ -0,0 +1,38 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef GRU_UNIT_OP + +#include "operators/kernel/gru_unit_kernel.h" +#include "operators/kernel/central-arm-func/gru_unit_arm_func.h" + +namespace paddle_mobile { +namespace operators { + +template <> +bool GruUnitKernel::Init(GruUnitParam *param) { + return true; +} + +template <> +void GruUnitKernel::Compute(const GruUnitParam ¶m) { + GruUnitCompute(param); +} + +template class GruUnitKernel; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/central-arm-func/gru_unit_arm_func.h b/src/operators/kernel/central-arm-func/gru_unit_arm_func.h new file mode 100644 index 0000000000000000000000000000000000000000..599b9b46ceb50c3d3d4bd3e9be666823882efb5b --- /dev/null +++ b/src/operators/kernel/central-arm-func/gru_unit_arm_func.h @@ -0,0 +1,62 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef GRU_UNIT_OP + +#pragma once + +#include +#include "operators/math/math_function.h" +#include "operators/kernel/activation_kernel.h" +#include "operators/math/gemm.h" +#include "operators/op_param.h" +namespace paddle_mobile { +namespace operators { + +template +void GruUnitCompute(const GruUnitParam& param) { + auto* input = param.InputInput(); + auto* hidden_prev = param.InputHiddenPrev(); + auto* weight = param.InputWeight(); + auto* bias = param.InputBias(); + auto* gate = param.OutGate(); + auto* reset_hidden_prev = param.OutResetHiddenPrev(); + auto* hidden = param.OutHidden(); + + if (bias) { + math::RowwiseAdd add_bias; + add_bias(*gate, *bias, gate); + } + + int batch_size = input->dims()[0]; + int frame_size = hidden_prev->dims()[1]; + const P* weight_data = weight->data

(); + math::GRUMetaValue

gru_value; + gru_value.gate_weight = const_cast(weight_data); + gru_value.state_weight = + const_cast(weight_data + 2 * frame_size * frame_size); + gru_value.output_value = hidden->data

(); + gru_value.prev_out_value = gru_value.output_value; + gru_value.gate_value = gate->data

(); + gru_value.reset_output_value = reset_hidden_prev->data

(); + auto active_node = math::GetActivationType(param.Activation()); + auto active_gate = math::GetActivationType(param.GateActivation()); + math::GRUUnitFunctor::compute(gru_value, frame_size, batch_size, + active_node, active_gate); +} + +} // namespace operators +} // namespace paddle_mobile + +#endif \ No newline at end of file diff --git a/src/operators/kernel/gru_unit_kernel.h b/src/operators/kernel/gru_unit_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..53dfa234a79e4fe53e22c034c509ddc59130ce98 --- /dev/null +++ b/src/operators/kernel/gru_unit_kernel.h @@ -0,0 +1,35 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef GRU_UNIT_OP + +#pragma once + +#include "framework/operator.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +template +class GruUnitKernel + : public framework::OpKernelBase> { + public: + void Compute(const GruUnitParam& param); + bool Init(GruUnitParam* param); +}; +} // namespace operators +} // namespace paddle_mobile + +#endif \ No newline at end of file diff --git a/src/operators/math/activation.h b/src/operators/math/activation.h index 08ba4a8f2a7442860a0516f0f6e2726f5d09ec6d..90b9ab4c3a558a994370ea80693e1d31687bb44e 100644 --- a/src/operators/math/activation.h +++ b/src/operators/math/activation.h @@ -45,6 +45,19 @@ inline ActivationType GetActivationType(const std::string &type) { PADDLE_MOBILE_THROW_EXCEPTION("Not support activation type."); } +inline ActivationType GetActivationType(const int type) { + if (type == 0) { + return ActivationType::IDENTITY; + } else if (type == 1) { + return ActivationType::SIGMOID; + } else if (type == 2) { + return ActivationType::TANH; + } else if (type == 3) { + return ActivationType::RELU; + } + PADDLE_MOBILE_THROW_EXCEPTION("Not support activation type."); +} + #if defined(__ARM_NEON__) || defined(__ARM_NEON) template inline float32x4_t vActiveq_f32(const float32x4_t &x) { diff --git a/src/operators/op_param.h b/src/operators/op_param.h index d20075f89195bbbd3a35577b19f84b9bb91b2a1c..ed036ab7c2a191a0924e3c2d6c6ad61e6de79bd4 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -74,6 +74,13 @@ class OpParam { static T *InputH0From(const VariableNameMap &inputs, const Scope &scope) { return GetVarValue("H0", inputs, scope); } + + template + static T *InputHiddenPrevFrom(const VariableNameMap &inputs, + const Scope &scope) { + return GetVarValue("HiddenPrev", inputs, scope); + } + template static T *InputAlphaFrom(const VariableNameMap &inputs, const Scope &scope) { return GetVarValue("Alpha", inputs, scope); @@ -214,6 +221,11 @@ class OpParam { return GetVarValue("BatchGate", outputs, scope); } + template + static T *OutputGateFrom(const VariableNameMap &outputs, const Scope &scope) { + return GetVarValue("Gate", outputs, scope); + } + template static T *OutputViterbiPathFrom(const VariableNameMap &outputs, const Scope &scope) { @@ -225,6 +237,12 @@ class OpParam { return GetVarValue("BatchResetHiddenPrev", outputs, scope); } + template + static T *OutputResetHiddenPrevFrom(const VariableNameMap &outputs, + const Scope &scope) { + return GetVarValue("ResetHiddenPrev", outputs, scope); + } + template static T *OutputBatchHiddenFrom(const VariableNameMap &outputs, const Scope &scope) { @@ -2444,6 +2462,51 @@ class GruParam : public OpParam { }; #endif +#ifdef GRU_UNIT_OP +template +class GruUnitParam : public OpParam { + typedef typename DtypeTensorTrait::gtype GType; + + public: + GruUnitParam(const VariableNameMap &inputs, const VariableNameMap &outputs, + const AttributeMap &attrs, const Scope &scope) { + input_input_ = InputFrom(inputs, scope); + input_hidden_prev_ = InputHiddenPrevFrom(inputs, scope); + input_bias_ = InputBiasFrom(inputs, scope); + input_weight_ = InputWeightFrom(inputs, scope); + + output_gate_ = OutputGateFrom(outputs, scope); + output_reset_hidden_prev_ = + OutputResetHiddenPrevFrom(outputs, scope); + output_hidden_ = OutputHiddenFrom(outputs, scope); + activation_ = GetAttr("activation", attrs); + gate_activation_ = GetAttr("gate_activation", attrs); + } + const GType *InputInput() const { return input_input_; } + const GType *InputWeight() const { return input_weight_; } + const GType *InputHiddenPrev() const { return input_hidden_prev_; } + const GType *InputBias() const { return input_bias_; } + const int &Activation() const { return activation_; } + const int &GateActivation() const { return gate_activation_; } + + GType *OutGate() const { return output_gate_; } + GType *OutResetHiddenPrev() const { return output_reset_hidden_prev_; } + GType *OutHidden() const { return output_hidden_; } + + private: + GType *input_input_; + GType *input_hidden_prev_; + GType *input_bias_; + GType *input_weight_; + + GType *output_gate_; + GType *output_reset_hidden_prev_; + GType *output_hidden_; + int activation_; + int gate_activation_; +}; +#endif + #ifdef FLATTEN_OP template class FlattenParam : public OpParam { diff --git a/tools/op.cmake b/tools/op.cmake index c29d6eb0f4d1324be75a69b0c29aac56b76ed421..b14dfdacf6051d6edc57934cc25841f346e9d0df 100644 --- a/tools/op.cmake +++ b/tools/op.cmake @@ -256,6 +256,7 @@ if(NOT FOUND_MATCH) set(IM2SEQUENCE_OP ON) set(LOOKUP_OP ON) set(GRU_OP ON) + set(GRU_UNIT_OP ON) set(CRF_OP ON) set(BILINEAR_INTERP_OP ON) set(SPLIT_OP ON) @@ -450,6 +451,10 @@ if (GRU_OP) add_definitions(-DGRU_OP) endif() +if (GRU_UNIT_OP) + add_definitions(-DGRU_UNIT_OP) +endif() + if (CRF_OP) add_definitions(-DCRF_OP) endif()