diff --git a/src/common/types.cpp b/src/common/types.cpp index 420c789e3f2136e6b71c977c8a4e2bbf745d9143..c25c5db30c7183b6685db03386ca9a9355ca6958 100644 --- a/src/common/types.cpp +++ b/src/common/types.cpp @@ -89,6 +89,10 @@ const char *G_OP_TYPE_FUSION_DECONV_RELU = "fusion_deconv_relu"; const char *G_OP_TYPE_FUSION_DECONV_ADD = "fusion_deconv_add"; const char *G_OP_TYPE_FUSION_DECONV_ADD_RELU = "fusion_deconv_add_relu"; +const char *G_OP_TYPE_SEQUENCE_EXPAND = "sequence_expand"; +const char *G_OP_TYPE_SEQUENCE_POOL = "sequence_pool"; +const char *G_OP_TYPE_SEQUENCE_SOFTMAX = "sequence_softmax"; + std::unordered_map< std::string, std::pair, std::vector>> op_input_output_key = { @@ -162,5 +166,8 @@ std::unordered_map< {G_OP_TYPE_TANH, {{"X"}, {"Out"}}}, {G_OP_TYPE_FUSION_DECONV_RELU, {{"Input"}, {"Out"}}}, {G_OP_TYPE_FUSION_DECONV_ADD, {{"Input"}, {"Out"}}}, - {G_OP_TYPE_FUSION_DECONV_ADD_RELU, {{"Input"}, {"Out"}}}}; + {G_OP_TYPE_FUSION_DECONV_ADD_RELU, {{"Input"}, {"Out"}}}, + {G_OP_TYPE_SEQUENCE_EXPAND, {{"X", "Y"}, {"Out"}}}, + {G_OP_TYPE_SEQUENCE_POOL, {{"X"}, {"Out"}}}, + {G_OP_TYPE_SEQUENCE_SOFTMAX, {{"X"}, {"Out"}}}}; } // namespace paddle_mobile diff --git a/src/common/types.h b/src/common/types.h index c12e5b6a268f66f7fcf53d55d1f40a15093474e3..114424fe04add874affb42fe9fca8f0d86bcdd82 100644 --- a/src/common/types.h +++ b/src/common/types.h @@ -105,6 +105,8 @@ enum ActivationType { enum PoolingType { MAX = 0, AVG = 1, + SUM = 2, + FIRST = 3, }; extern const char *G_OP_TYPE_CONV; @@ -169,6 +171,10 @@ extern const char *G_OP_TYPE_FUSION_DECONV_RELU; extern const char *G_OP_TYPE_FUSION_DECONV_ADD; extern const char *G_OP_TYPE_FUSION_DECONV_ADD_RELU; +extern const char *G_OP_TYPE_SEQUENCE_EXPAND; +extern const char *G_OP_TYPE_SEQUENCE_POOL; +extern const char *G_OP_TYPE_SEQUENCE_SOFTMAX; + extern std::unordered_map< std::string, std::pair, std::vector>> op_input_output_key; diff --git a/src/framework/executor.cpp b/src/framework/executor.cpp index e82006be05e430fa46bd2ea8c372237ab9630f38..1d7933f2d69735c66b67ac49cc0922d7143edc5d 100644 --- a/src/framework/executor.cpp +++ b/src/framework/executor.cpp @@ -90,28 +90,28 @@ Executor::Executor(const Program &program, int batch_size, } } -template +template static void LoadMemInternal(void **data, LoDTensor *tensor, bool quant_uint8 = false) { char **data_buf = reinterpret_cast(data); int64_t size = tensor->numel(); - Device *tensor_data = tensor->mutable_data(); + T *tensor_data = tensor->mutable_data(); if (quant_uint8) { // should be moved into operator init function float min_value; float max_value; - memory::Copy(&min_value, data_buf, sizeof(float)); - memory::Copy(&max_value, data_buf + sizeof(float), sizeof(float)); - data_buf += 2 * sizeof(float); + memory::Copy(&min_value, *data_buf, sizeof(float)); + memory::Copy(&max_value, *data_buf + sizeof(float), sizeof(float)); + *data_buf += 2 * sizeof(float); const float factor = (max_value - min_value) / 255.0; - const uint8_t *uint8_data = reinterpret_cast(data_buf); + const uint8_t *uint8_data = reinterpret_cast(*data_buf); for (int k = 0; k < size; ++k) { tensor_data[k] = uint8_data[k] * factor + min_value; } - data_buf += size * sizeof(uint8_t); + *data_buf += size * sizeof(uint8_t); } else { - memory::Copy(tensor_data, *data_buf, size * sizeof(Device)); - *data_buf += size * sizeof(Device); + memory::Copy(tensor_data, *data_buf, size * sizeof(T)); + *data_buf += size * sizeof(T); } } diff --git a/src/framework/load_ops.h b/src/framework/load_ops.h index 72bc18cc521f96abf648306bfb455327b0f7cfa4..88fb360fcc4a135d39fee1e117ba1279d66acae5 100644 --- a/src/framework/load_ops.h +++ b/src/framework/load_ops.h @@ -264,3 +264,9 @@ LOAD_FUSION_MATCHER(fusion_dequant_add_bn_quant); LOAD_OP1(fusion_dequant_add_bn_relu_quant, CPU); LOAD_FUSION_MATCHER(fusion_dequant_add_bn_relu_quant); #endif +#ifdef SEQUENCE_EXPAND_OP +LOAD_OP1(sequence_expand, CPU); +#endif +#ifdef SEQUENCE_POOL_OP +LOAD_OP1(sequence_pool, CPU); +#endif diff --git a/src/operators/kernel/activation_kernel.h b/src/operators/kernel/activation_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..9dc8f307c0c988355575160a5f3bb4926537a679 --- /dev/null +++ b/src/operators/kernel/activation_kernel.h @@ -0,0 +1,46 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "framework/operator.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +#define DECLARE_KERNEL(KernelClass, KernelParam) \ + template \ + class KernelClass \ + : public framework::OpKernelBase> { \ + public: \ + bool Init(KernelParam *param); \ + void Compute(const KernelParam ¶m); \ + }; + +#ifdef RELU_OP +DECLARE_KERNEL(ReluKernel, ReluParam); +DECLARE_KERNEL(Relu6Kernel, ReluParam); +#endif + +#ifdef SIGMOID_OP +DECLARE_KERNEL(SigmoidKernel, SigmoidParam); +#endif + +#ifdef TANH_OP +DECLARE_KERNEL(TanhKernel, TanhParam); +#endif + +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/kernel/arm/relu_kernel.cpp b/src/operators/kernel/arm/activation_kernel.cpp similarity index 74% rename from src/operators/kernel/arm/relu_kernel.cpp rename to src/operators/kernel/arm/activation_kernel.cpp index 0333e9db4445aa68498671ed6472a2f8ff113e1c..5050d3b160addd8686849fa70a82cbf3274d5ff6 100644 --- a/src/operators/kernel/arm/relu_kernel.cpp +++ b/src/operators/kernel/arm/activation_kernel.cpp @@ -12,9 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef RELU_OP - -#include "operators/kernel/relu_kernel.h" +#include "operators/kernel/activation_kernel.h" #include "common/types.h" #include "operators/math/activation.h" #if defined(__ARM_NEON__) || defined(__ARM_NEON) @@ -25,12 +23,12 @@ namespace paddle_mobile { namespace operators { template -struct ReluCompute { +struct ActivationCompute { void operator()(const Tensor *input, Tensor *output) {} }; template -struct ReluCompute { +struct ActivationCompute { void operator()(const Tensor *input, Tensor *output) { const float *x = input->data(); float *y = output->mutable_data(); @@ -65,6 +63,7 @@ struct ReluCompute { } }; +#ifdef RELU_OP template <> bool ReluKernel::Init(ReluParam *param) { return true; @@ -74,7 +73,7 @@ template <> void ReluKernel::Compute(const ReluParam ¶m) { const Tensor *input = param.InputX(); Tensor *output = param.Out(); - ReluCompute()(input, output); + ActivationCompute()(input, output); } template <> @@ -86,10 +85,37 @@ template <> void Relu6Kernel::Compute(const ReluParam ¶m) { const Tensor *input = param.InputX(); Tensor *output = param.Out(); - ReluCompute()(input, output); + ActivationCompute()(input, output); } +#endif -} // namespace operators -} // namespace paddle_mobile +#ifdef SIGMOID_OP +template <> +bool SigmoidKernel::Init(SigmoidParam *param) { + return true; +} + +template <> +void SigmoidKernel::Compute(const SigmoidParam ¶m) { + const Tensor *input = param.InputX(); + Tensor *output = param.Out(); + ActivationCompute()(input, output); +} +#endif +#ifdef TANH_OP +template <> +void TanhKernel::Init(TanhParam *param) { + return true; +} + +template <> +void TanhKernel::Compute(const TanhParam ¶m) { + const Tensor *input = param.InputX(); + Tensor *output = param.Out(); + ActivationCompute()(input, output); +} #endif + +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/kernel/arm/sequence_expand_kernel.cpp b/src/operators/kernel/arm/sequence_expand_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f3fb01eed8fbee0af3caff152e7c46749694a43e --- /dev/null +++ b/src/operators/kernel/arm/sequence_expand_kernel.cpp @@ -0,0 +1,115 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef SEQUENCE_EXPAND_OP + +#include +#include "operators/kernel/sequence_kernels.h" + +namespace paddle_mobile { +namespace operators { + +typedef int (*LoDElementFunctor)(const std::vector &x_lod, int index); + +int element_with_lod(const std::vector &x_lod, int index) { + return x_lod[index]; +} + +int element_without_lod(const std::vector &x_lod, int index) { + return index; +} + +template +inline void SequenceExpandImpl(const framework::LoDTensor &x, + const std::vector &ref_lod, + framework::LoDTensor *output) { + const T *x_data = x.data(); + auto &x_lod = x.lod(); + LoDElementFunctor lod_element = element_without_lod; + if (x_lod.size() == 1) lod_element = element_with_lod; + + T *output_data = output->mutable_data(); + int x_item_length = x.numel() / x.dims()[0]; + int out_offset = 0; + + for (size_t i = 1; i < ref_lod.size(); ++i) { + int repeat_num = ref_lod[i] - ref_lod[i - 1]; + int x_start = lod_element(x_lod[0], i - 1); + int x_end = lod_element(x_lod[0], i); + int x_seq_len = x_end - x_start; + if (repeat_num > 0) { + int out_start = out_offset; + if (output->lod().size() == 1) { + out_start = output->lod()[0][out_offset]; + } + for (int j = 0; j < repeat_num; j++) { + for (int k = 0; k < x_seq_len; k++) { + memcpy(output_data + (out_start + j * x_seq_len + k) * x_item_length, + x_data + (x_start + k) * x_item_length, + x_item_length * sizeof(T)); + } + } + } + out_offset += repeat_num; + } +} + +template +class SequenceExpandKernel + : public framework::OpKernelBase> { + public: + bool Init(SequenceExpandParam *param) { return true; } + + void Compute(const SequenceExpandParam ¶m) { + const framework::LoDTensor *input_x = param.input_x_; + const framework::LoDTensor *input_y = param.input_y_; + framework::LoDTensor *output = param.output_; + output->mutable_data(); + + const auto &x_lod = input_x->lod(); + const auto &y_lod = input_y->lod(); + int ref_level = param.ref_level_; + if (ref_level == -1) ref_level = y_lod.size() - 1; + + if (y_lod[ref_level].size() <= 1) { + framework::TensorCopy(*input_x, output); + output->set_lod(input_x->lod()); + return; + } + + std::vector out_lod; + if (x_lod.size() == 1) { + out_lod.push_back(0); + for (size_t i = 1; i < y_lod[ref_level].size(); ++i) { + int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1]; + int x_start = x_lod[0][i - 1]; + int x_end = x_lod[0][i]; + int x_seq_len = x_end - x_start; + for (int j = 0; j < repeat_num; ++j) { + out_lod.push_back(out_lod.back() + x_seq_len); + } + } + } + output->set_lod({out_lod}); + SequenceExpandImpl(*input_x, y_lod[ref_level], output); + } +}; + +template class SequenceExpandKernel; +// template class SequenceExpandKernel; + +} // namespace operators +} // namespace paddle_mobile + +#endif // SEQUENCE_EXPAND_OP diff --git a/src/operators/kernel/arm/sequence_pool_kernel.cpp b/src/operators/kernel/arm/sequence_pool_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f4e28a0ffbbc13428bd8b4643aaae915f14539bc --- /dev/null +++ b/src/operators/kernel/arm/sequence_pool_kernel.cpp @@ -0,0 +1,195 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef SEQUENCE_POOL_OP + +#include +#include +#include +#include +#include "common/types.h" +#include "operators/kernel/sequence_kernels.h" +#include "operators/math/pooling.h" +#ifdef __ARM_NEON__ +#include +#endif // __ARM_NEON__ + +namespace paddle_mobile { +namespace operators { + +template +void SequencePoolImpl(const framework::LoDTensor &input, + framework::LoDTensor *output) { + const float *input_ptr = input.data(); + float *output_ptr = output->mutable_data(); + const auto &lod = input.lod()[0]; + int64_t width = input.numel() / input.dims()[0]; + + #pragma omp parallel for + for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { + const float *in_ptr = input_ptr + lod[i] * width; + float *out_ptr = output_ptr + i * width; + int64_t height = static_cast(lod[i + 1] - lod[i]); + if (width == 1) { + float max = -std::numeric_limits::max(); + int remain_h = height; +#ifdef __ARM_NEON__ + int loop = remain_h >> 2; + remain_h = remain_h & 0x3; + float32x4_t __max4 = math::vPoolInitq_f32(); + for (int h = 0; h < loop; ++h) { + float32x4_t r0 = vld1q_f32(in_ptr); + __max4 = vmaxq_f32(__max4, r0); + in_ptr += 4; + } + float32x2_t __max2 = + vpmax_f32(vget_low_f32(__max4), vget_high_f32(__max4)); + __max2 = vpmax_f32(__max2, __max2); + max = std::max(max, vget_lane_f32(__max2, 0)); +#endif // __ARM_NEON__ + for (int h = 0; h < remain_h; ++h) { + max = std::max(max, in_ptr[h]); + } + *out_ptr = max; + } else { + memcpy(out_ptr, in_ptr, width * sizeof(float)); + in_ptr += width; + int remain_h = height - 1; +#ifdef __ARM_NEON__ + int remain_w_start = width & 0xfffc; +#endif // __ARM_NEON__ + for (int h = 0; h < remain_h; ++h) { +#ifdef __ARM_NEON__ + for (int w = 0; w < width; w += 4) { + float32x4_t __in = vld1q_f32(in_ptr + w); + float32x4_t __out = vld1q_f32(out_ptr + w); + __out = vmaxq_f32(__out, __in); + vst1q_f32(out_ptr + w, __out); + } +#endif // __ARM_NEON__ + for (int w = remain_w_start; w < width; ++w) { + out_ptr[w] = std::max(out_ptr[w], in_ptr[w]); + } + in_ptr += width; + } + } + } +} + +template <> +void SequencePoolImpl(const framework::LoDTensor &input, + framework::LoDTensor *output) { + const float *input_ptr = input.data(); + float *output_ptr = output->mutable_data(); + const auto &lod = input.lod()[0]; + int64_t width = input.numel() / input.dims()[0]; + + #pragma omp parallel for + for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { + const float *in_ptr = input_ptr + lod[i] * width; + float *out_ptr = output_ptr + i * width; + int64_t height = static_cast(lod[i + 1] - lod[i]); + if (width == 1) { + float sum = 0.f; + int remain_h = height; +#ifdef __ARM_NEON__ + int loop = remain_h >> 2; + remain_h = remain_h & 0x3; + float32x4_t __sum4 = vdupq_n_f32(0.f); + for (int h = 0; h < loop; ++h) { + float32x4_t r0 = vld1q_f32(in_ptr); + __sum4 = vaddq_f32(__sum4, r0); + in_ptr += 4; + } + float32x2_t __sum2 = + vpadd_f32(vget_low_f32(__sum4), vget_high_f32(__sum4)); + sum += vget_lane_f32(__sum2, 0) + vget_lane_f32(__sum2, 1); +#endif // __ARM_NEON__ + for (int h = 0; h < remain_h; ++h) { + sum += in_ptr[h]; + } + *out_ptr = sum; + } else { + memcpy(out_ptr, in_ptr, width * sizeof(float)); + in_ptr += width; + int remain_h = height - 1; +#ifdef __ARM_NEON__ + int loop_w = width >> 2; + int remain_w_start = width & 0xfffc; +#endif // __ARM_NEON__ + for (int h = 0; h < remain_h; ++h) { +#ifdef __ARM_NEON__ + for (int w = 0; w < width - 3; w += 4) { + float32x4_t __in = vld1q_f32(in_ptr + w); + float32x4_t __out = vld1q_f32(out_ptr + w); + __out = vaddq_f32(__out, __in); + vst1q_f32(out_ptr + w, __out); + } +#endif // __ARM_NEON__ + for (int w = remain_w_start; w < width; ++w) { + out_ptr[w] += in_ptr[w]; + } + in_ptr += width; + } + } + } +} + +template <> +void SequencePoolImpl(const framework::LoDTensor &input, + framework::LoDTensor *output) { + const float *input_ptr = input.data(); + float *output_ptr = output->mutable_data(); + const auto &lod = input.lod()[0]; + int64_t width = input.numel() / input.dims()[0]; + + for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { + const float *in_ptr = input_ptr + lod[i] * width; + float *out_ptr = output_ptr + i * width; + memcpy(out_ptr, in_ptr, width * sizeof(float)); + } +} + +template +class SequencePoolKernel + : public framework::OpKernelBase> { + public: + bool Init(SequencePoolParam *param) { return true; } + + void Compute(const SequencePoolParam ¶m) { + const framework::LoDTensor *input = param.input_; + framework::LoDTensor *output = param.output_; + output->mutable_data(); + const std::string pooling_type = param.pool_type_; + + if (param.pool_type_ == "MAX") { + SequencePoolImpl(*input, output); + } else if (param.pool_type_ == "FIRST") { + SequencePoolImpl(*input, output); + } else if (param.pool_type_ == "SUM") { + SequencePoolImpl(*input, output); + } else { + PADDLE_MOBILE_THROW_EXCEPTION( + "pooling type `%s` has not been implemented.", + param.pool_type_.c_str()); + } + } +}; + +template class SequencePoolKernel; + +} // namespace operators +} // namespace paddle_mobile + +#endif // SEQUENCE_POOL_OP diff --git a/src/operators/kernel/arm/sigmoid_kernel.cpp b/src/operators/kernel/arm/sequence_softmax_kernel.cpp similarity index 51% rename from src/operators/kernel/arm/sigmoid_kernel.cpp rename to src/operators/kernel/arm/sequence_softmax_kernel.cpp index 3d6e14ffea80169172431229e34309cde331d588..ecbc39c4ccf4592308dc07d994535273d4a636f1 100644 --- a/src/operators/kernel/arm/sigmoid_kernel.cpp +++ b/src/operators/kernel/arm/sequence_softmax_kernel.cpp @@ -12,32 +12,32 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef SIGMOID_OP - -#include "../sigmoid_kernel.h" -#include "../central-arm-func/sigmoid_arm_func.h" -#ifdef __ARM_NEON -#include "../../math/math_func_neon.h" -#endif -#include +#ifdef SEQUENCE_SOFTMAX_OP + +#include "framework/lod_tensor.h" +#include "operators/kernel/sequence_kernels.h" +#include "operators/math/softmax.h" + namespace paddle_mobile { namespace operators { -using framework::DDim; -using framework::Tensor; +template +class SequenceSoftmaxKernel + : public framework::OpKernelBase> { + public: + bool Init(SoftmaxParam *param) { return true; } -template <> -bool SigmoidKernel::Init(SigmoidParam *param) { - return true; -} + void Compute(const SoftmaxParam ¶m) { + const framework::LoDTensor *input = param.InputX(); + framework::LoDTensor *output = param.Out(); + math::SequenceSoftmaxFuntor sequence_softmax; + sequence_softmax(input, output); + } +}; -template <> -void SigmoidKernel::Compute(const SigmoidParam ¶m) { - SigmoidCompute(param); -} +template class SequenceSoftmaxKernel; -template class SigmoidKernel; } // namespace operators } // namespace paddle_mobile -#endif +#endif // SEQUENCE_SOFTMAX_OP diff --git a/src/operators/kernel/central-arm-func/sigmoid_arm_func.h b/src/operators/kernel/central-arm-func/sigmoid_arm_func.h deleted file mode 100644 index c782171e59ca7077ebb5622ad550dd0906d9f441..0000000000000000000000000000000000000000 --- a/src/operators/kernel/central-arm-func/sigmoid_arm_func.h +++ /dev/null @@ -1,87 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -#ifdef SIGMOID_OP -#pragma once - -#include - -#include "operators/op_param.h" -#ifdef __ARM_NEON -#include -#include "operators/math/math_func_neon.h" -#endif - -namespace paddle_mobile { -namespace operators { - -using framework::DDim; - -void sigmoid(const Tensor *X, Tensor *Y) { -#ifdef __ARM_NEON - const float *input = X->data(); - float *output = Y->mutable_data(); - const DDim &dDim = X->dims(); - int axis_index = 1; - if (dDim.size() < 4) { - axis_index = 0; - } - DDim outer_ddim = - paddle_mobile::framework::slice_ddim(dDim, 0, axis_index + 1); - DDim inner_ddim = - paddle_mobile::framework::slice_ddim(dDim, axis_index + 1, dDim.size()); - int out_size = paddle_mobile::framework::product(outer_ddim); - int inner_size = paddle_mobile::framework::product(inner_ddim); - - DLOG << "outsize=" << out_size; - DLOG << "innersize=" << inner_size; - #pragma omp parallel for - for (int i = 0; i < out_size; ++i) { - const float *input_outer_ptr = input + i * inner_size; - float *output_outer_ptr = output + i * inner_size; - int nn = inner_size >> 2; - int remain = inner_size - (nn << 2); - float32x4_t _one = vdupq_n_f32(1.f); - for (; nn > 0; nn--) { - float32x4_t data = vld1q_f32(input_outer_ptr); - data = vnegq_f32(data); - data = exp_ps(data); - data = vaddq_f32(data, _one); - float32x4_t out_data = vrecpeq_f32(data); - out_data = vmulq_f32(vrecpsq_f32(data, out_data), out_data); - vst1q_f32(output_outer_ptr, out_data); - - input_outer_ptr += 4; - output_outer_ptr += 4; - } - for (; remain > 0; remain--) { - *output_outer_ptr = 1.f / (1.f + exp(-*input_outer_ptr)); - output_outer_ptr++; - input_outer_ptr++; - } - } -#else -#endif -} - -template -void SigmoidCompute(const SigmoidParam ¶m) { - const Tensor *in_x = param.InputX(); - Tensor *out = param.Out(); - auto x_dims = in_x->dims(); - out->Resize(x_dims); - sigmoid(in_x, out); -} -} // namespace operators -} // namespace paddle_mobile -#endif diff --git a/src/operators/kernel/sequence_kernels.h b/src/operators/kernel/sequence_kernels.h new file mode 100644 index 0000000000000000000000000000000000000000..7884d0d475949c8a54b0ecc08fb578807ca2e2d2 --- /dev/null +++ b/src/operators/kernel/sequence_kernels.h @@ -0,0 +1,45 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "framework/operator.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +#define DECLARE_KERNEL(KernelClass, KernelParam) \ + template \ + class KernelClass \ + : public framework::OpKernelBase> { \ + public: \ + bool Init(KernelParam *param); \ + void Compute(const KernelParam ¶m); \ + }; + +#ifdef SEQUENCE_EXPAND_OP +DECLARE_KERNEL(SequenceExpandKernel, SequenceExpandParam); +#endif // SEQUENCE_EXPAND_OP + +#ifdef SEQUENCE_POOL_OP +DECLARE_KERNEL(SequencePoolKernel, SequencePoolParam); +#endif // SEQUENCE_POOL_OP + +#ifdef SEQUENCE_SOFTMAX_OP +DECLARE_KERNEL(SequenceSoftmaxKernel, SoftmaxParam); +#endif // SEQUENCE_SOFTMAX_OP + +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/math/softmax.cpp b/src/operators/math/softmax.cpp index 48f8b35ea22f30ceb7e9bf8bcefa815a63b6e9dc..6b34f522ff6caf32c20971d9cf38f93730fdb727 100644 --- a/src/operators/math/softmax.cpp +++ b/src/operators/math/softmax.cpp @@ -60,6 +60,58 @@ float find_max(const float *input, const int num_classes) { return max; } +void SoftmaxBasic(const float *input, int num_classes, float *y) { + float *output = y; + // find max + float max = find_max(input, num_classes); + + // exp(x - max) and sum(exp(x - max)) + int remain = num_classes; + float sum = 0.f; +#if defined(__ARM_NEON) || defined(__ARM_NEON__) + int loop = num_classes >> 3; + remain = num_classes & 0x7; + float32x4_t __max = vdupq_n_f32(max); + float32x4_t __sum = vdupq_n_f32(0.f); + for (int i = 0; i < loop; ++i, input += 8, output += 8) { + float32x4_t x0 = vld1q_f32(input); + float32x4_t x1 = vld1q_f32(input + 4); + x0 = vsubq_f32(x0, __max); + x1 = vsubq_f32(x1, __max); + x0 = exp_ps(x0); + x1 = exp_ps(x1); + __sum = vaddq_f32(x0, __sum); + __sum = vaddq_f32(x1, __sum); + vst1q_f32(output, x0); + vst1q_f32(output + 4, x1); + } + sum += vaddvq_f32(__sum); +#endif // __ARM_NEON__ + for (int i = 0; i < remain; ++i) { + float out = expf(input[i] - max); + sum += out; + output[i] = out; + } + + // exp(x - max) / sum + float inv_sum = 1.f / sum; + output = y; +#if defined(__ARM_NEON) || defined(__ARM_NEON__) + float32x4_t __inv_sum = vdupq_n_f32(inv_sum); + for (int i = 0; i < loop; ++i, output += 8) { + float32x4_t x0 = vld1q_f32(output); + float32x4_t x1 = vld1q_f32(output + 4); + x0 = vmulq_f32(x0, __inv_sum); + x1 = vmulq_f32(x1, __inv_sum); + vst1q_f32(output, x0); + vst1q_f32(output + 4, x1); + } +#endif + for (int i = 0; i < remain; ++i) { + output[i] *= inv_sum; + } +} + template <> void SoftmaxFuntor::operator()(const framework::Tensor *X, framework::Tensor *Y) { @@ -76,65 +128,25 @@ void SoftmaxFuntor::operator()(const framework::Tensor *X, size_t offset = (batch * channels + channel) * num_classes; const float *input = x + offset; float *output = y + offset; - // find max - float max = find_max(input, num_classes); - - // exp(x - max) - int remain = num_classes; -#if defined(__ARM_NEON) || defined(__ARM_NEON__) - int loop = num_classes >> 3; - remain = num_classes & 0x7; - float32x4_t __max = vdupq_n_f32(max); - for (int i = 0; i < loop; ++i, input += 8, output += 8) { - float32x4_t x0 = vld1q_f32(input); - float32x4_t x1 = vld1q_f32(input + 4); - x0 = vsubq_f32(x0, __max); - x1 = vsubq_f32(x1, __max); - x0 = exp_ps(x0); - x1 = exp_ps(x1); - vst1q_f32(output, x0); - vst1q_f32(output + 4, x1); - } -#endif // __ARM_NEON__ - for (int i = 0; i < remain; ++i) { - output[i] = expf(input[i] - max); - } + SoftmaxBasic(input, num_classes, output); + } + } +} - // sum(exp(x - max)) - float sum = 0.f; - output = y + offset; -#if defined(__ARM_NEON) || defined(__ARM_NEON__) - float32x4_t __sum = vdupq_n_f32(0.f); - for (int i = 0; i < loop; ++i, output += 8) { - float32x4_t x0 = vld1q_f32(output); - float32x4_t x1 = vld1q_f32(output + 4); - __sum = vaddq_f32(x0, __sum); - __sum = vaddq_f32(x1, __sum); - } - sum += vaddvq_f32(__sum); -#endif // __ARM_NEON__ - for (int i = 0; i < remain; ++i) { - sum += output[i]; - } +template <> +void SequenceSoftmaxFuntor::operator()( + const framework::LoDTensor *X, framework::LoDTensor *Y) { + const float *x = X->data(); + const auto &lod = X->lod().back(); + float *y = Y->mutable_data(); - // exp(x - max) / sum - float inv_sum = 1.f / sum; - output = y + offset; -#if defined(__ARM_NEON) || defined(__ARM_NEON__) - float32x4_t __inv_sum = vdupq_n_f32(inv_sum); - for (int i = 0; i < loop; ++i, output += 8) { - float32x4_t x0 = vld1q_f32(output); - float32x4_t x1 = vld1q_f32(output + 4); - x0 = vmulq_f32(x0, __inv_sum); - x1 = vmulq_f32(x1, __inv_sum); - vst1q_f32(output, x0); - vst1q_f32(output + 4, x1); - } -#endif - for (int i = 0; i < remain; ++i) { - output[i] *= inv_sum; - } - } + #pragma omp parallel for + for (int batch = 0; batch < lod.size() - 1; ++batch) { + int num_classes = lod[batch + 1] - lod[batch]; + size_t offset = lod[batch]; + const float *input = x + offset; + float *output = y + offset; + SoftmaxBasic(input, num_classes, output); } } diff --git a/src/operators/math/softmax.h b/src/operators/math/softmax.h index 0de30a4ecaa2d58a4180203b7a27b23dc35446b5..dff25b9d0271db9f6d5704adaaf147629be56a32 100644 --- a/src/operators/math/softmax.h +++ b/src/operators/math/softmax.h @@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef SOFTMAX_OP +#if defined(SOFTMAX_OP) || defined(SEQUENCE_SOFTMAX_OP) #pragma once +#include "framework/lod_tensor.h" #include "framework/tensor.h" namespace paddle_mobile { @@ -28,7 +29,14 @@ class SoftmaxFuntor { void operator()(const framework::Tensor *X, framework::Tensor *Y); }; +template +class SequenceSoftmaxFuntor { + public: + void operator()(const framework::LoDTensor *X, framework::LoDTensor *Y); +}; + } // namespace math } // namespace operators } // namespace paddle_mobile + #endif diff --git a/src/operators/op_param.h b/src/operators/op_param.h index 9d7c213afa8277c421c0e6cce6cdaefa5ef58dd9..8976d8be8e0722fe6915af4786d55809a9f8ca7c 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -978,12 +978,12 @@ class SoftmaxParam : public OpParam { input_x_ = InputXFrom(inputs, scope); out_ = OutFrom(outputs, scope); } - const RType *InputX() const { return input_x_; } - RType *Out() const { return out_; } + const GType *InputX() const { return input_x_; } + GType *Out() const { return out_; } private: - RType *input_x_; - RType *out_; + GType *input_x_; + GType *out_; #ifdef PADDLE_MOBILE_FPGA @@ -2737,5 +2737,57 @@ class FusionDequantAddBNQuantParam : public FusionDequantAddBNParam { }; #endif +#ifdef SEQUENCE_EXPAND_OP +template +class SequenceExpandParam : public OpParam { + typedef typename DtypeTensorTrait::gtype GType; + typedef typename DtypeTensorTrait::rtype RType; + + public: + SequenceExpandParam(const VariableNameMap &inputs, + const VariableNameMap &outputs, const AttributeMap &attrs, + const Scope &scope) { + input_x_ = InputXFrom(inputs, scope); + input_y_ = InputYFrom(inputs, scope); + output_ = OutFrom(outputs, scope); + ref_level_ = -1; + if (OpParam::HasAttr("ref_level", attrs)) { + ref_level_ = OpParam::GetAttr("ref_level", attrs); + } + } + + public: + GType *input_x_; + GType *input_y_; + GType *output_; + int ref_level_; +}; +#endif // SEQUENCE_EXPAND_OP + +#ifdef SEQUENCE_POOL_OP +template +class SequencePoolParam : public OpParam { + typedef typename DtypeTensorTrait::gtype GType; + typedef typename DtypeTensorTrait::rtype RType; + + public: + SequencePoolParam(const VariableNameMap &inputs, + const VariableNameMap &outputs, const AttributeMap &attrs, + const Scope &scope) { + input_ = InputXFrom(inputs, scope); + output_ = OutFrom(outputs, scope); + pool_type_ = "MAX"; + if (OpParam::HasAttr("pooltype", attrs)) { + pool_type_ = OpParam::GetStringAttr("pooltype", attrs); + } + } + + public: + GType *input_; + GType *output_; + std::string pool_type_; +}; +#endif // SEQUENCE_EXPAND_OP + } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/relu_op.cpp b/src/operators/relu_op.cpp index 7ceaa815cfb554be9fd2feccb2cc05c6bfa1aa33..560b63058646b379a50184e98ac4b4d5dd43f9fa 100644 --- a/src/operators/relu_op.cpp +++ b/src/operators/relu_op.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #ifdef RELU_OP #include "operators/relu_op.h" + namespace paddle_mobile { namespace operators { @@ -47,4 +48,4 @@ REGISTER_OPERATOR_MALI_GPU(relu, ops::ReluOp); REGISTER_OPERATOR_CL(relu, ops::ReluOp); #endif -#endif +#endif // RELU_OP diff --git a/src/operators/relu_op.h b/src/operators/relu_op.h index 4bb67933db6ac1c174e267259df52b8eb79dbb35..9e3f109e74876b9f680dbac8aa2e67dd0bb83709 100644 --- a/src/operators/relu_op.h +++ b/src/operators/relu_op.h @@ -19,7 +19,7 @@ limitations under the License. */ #include #include "framework/operator.h" -#include "operators/kernel/relu_kernel.h" +#include "operators/kernel/activation_kernel.h" #include "operators/op_param.h" namespace paddle_mobile { diff --git a/src/operators/sequence_ops/sequence_expand_op.cpp b/src/operators/sequence_ops/sequence_expand_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a1ff83981328e2ffa8013e60a2019f1f87fc24ab --- /dev/null +++ b/src/operators/sequence_ops/sequence_expand_op.cpp @@ -0,0 +1,56 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef SEQUENCE_EXPAND_OP + +#include "operators/sequence_ops/sequence_expand_op.h" + +namespace paddle_mobile { +namespace operators { + +template +void SequenceExpandOp::InferShape() const { + const auto *input_x = this->param_.input_x_; + const auto *input_y = this->param_.input_y_; + const auto &x_lod = input_x->lod(); + const auto &y_lod = input_y->lod(); + int ref_level = this->param_.ref_level_; + if (ref_level == -1) ref_level = y_lod.size() - 1; + + auto out_dims = input_x->dims(); + int64_t out_first_dim = 0; + + if (y_lod[ref_level].size() > 1) { + for (size_t i = 1; i < y_lod[ref_level].size(); ++i) { + int x_seq_len = 1; + if (x_lod.size() == 1) { + x_seq_len = x_lod[0][i] - x_lod[0][i - 1]; + } + out_first_dim += + (y_lod[ref_level][i] - y_lod[ref_level][i - 1]) * x_seq_len; + } + out_dims[0] = out_first_dim; + } + this->param_.output_->Resize(out_dims); +} + +} // namespace operators +} // namespace paddle_mobile + +namespace ops = paddle_mobile::operators; +#ifdef PADDLE_MOBILE_CPU +REGISTER_OPERATOR_CPU(sequence_expand, ops::SequenceExpandOp); +#endif + +#endif // SEQUENCE_EXPAND_OP diff --git a/src/operators/sequence_ops/sequence_expand_op.h b/src/operators/sequence_ops/sequence_expand_op.h new file mode 100644 index 0000000000000000000000000000000000000000..cd62bbefc703ed2642e076913e2538c1621c1082 --- /dev/null +++ b/src/operators/sequence_ops/sequence_expand_op.h @@ -0,0 +1,47 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef SEQUENCE_EXPAND_OP + +#pragma once + +#include +#include "framework/operator.h" +#include "operators/kernel/sequence_kernels.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +template +class SequenceExpandOp : public framework::OperatorWithKernel< + DeviceType, SequenceExpandParam, + operators::SequenceExpandKernel> { + public: + SequenceExpandOp(const std::string &type, const VariableNameMap &inputs, + const VariableNameMap &outputs, + const framework::AttributeMap &attrs, + std::shared_ptr scope) + : framework::OperatorWithKernel< + DeviceType, SequenceExpandParam, + operators::SequenceExpandKernel>( + type, inputs, outputs, attrs, scope) {} + // inference output shape + void InferShape() const override; +}; + +} // namespace operators +} // namespace paddle_mobile + +#endif // SEQUENCE_EXPAND_OP diff --git a/src/operators/kernel/sigmoid_kernel.h b/src/operators/sequence_ops/sequence_pool_op.cpp similarity index 62% rename from src/operators/kernel/sigmoid_kernel.h rename to src/operators/sequence_ops/sequence_pool_op.cpp index db9fc3dd3cb1e6c0eb56cd5a14a173f5a031263c..4165d8ef60f5eb649c3acc9648d6cebe8e7f8d2c 100644 --- a/src/operators/kernel/sigmoid_kernel.h +++ b/src/operators/sequence_ops/sequence_pool_op.cpp @@ -12,27 +12,27 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#pragma once +#ifdef SEQUENCE_POOL_OP -#ifdef SIGMOID_OP - -#include "framework/operator.h" -#include "operators/op_param.h" +#include "operators/sequence_ops/sequence_pool_op.h" namespace paddle_mobile { namespace operators { -using framework::OpKernelBase; - template -class SigmoidKernel - : public OpKernelBase> { - public: - void Compute(const SigmoidParam& param); - bool Init(SigmoidParam* param); -}; +void SequencePoolOp::InferShape() const { + const auto *input = this->param_.input_; + auto out_dims = input->dims(); + out_dims[0] = input->lod()[0].size() - 1; + this->param_.output_->Resize(out_dims); +} } // namespace operators } // namespace paddle_mobile +namespace ops = paddle_mobile::operators; +#ifdef PADDLE_MOBILE_CPU +REGISTER_OPERATOR_CPU(sequence_pool, ops::SequencePoolOp); #endif + +#endif // SEQUENCE_POOL_OP diff --git a/src/operators/sequence_ops/sequence_pool_op.h b/src/operators/sequence_ops/sequence_pool_op.h new file mode 100644 index 0000000000000000000000000000000000000000..724572936643abe071147edbdbee0053a29f4c20 --- /dev/null +++ b/src/operators/sequence_ops/sequence_pool_op.h @@ -0,0 +1,47 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef SEQUENCE_POOL_OP + +#pragma once + +#include +#include "framework/operator.h" +#include "operators/kernel/sequence_kernels.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +template +class SequencePoolOp : public framework::OperatorWithKernel< + DeviceType, SequencePoolParam, + operators::SequencePoolKernel> { + public: + SequencePoolOp(const std::string &type, const VariableNameMap &inputs, + const VariableNameMap &outputs, + const framework::AttributeMap &attrs, + std::shared_ptr scope) + : framework::OperatorWithKernel< + DeviceType, SequencePoolParam, + operators::SequencePoolKernel>(type, inputs, outputs, + attrs, scope) {} + // inference output shape + void InferShape() const override; +}; + +} // namespace operators +} // namespace paddle_mobile + +#endif // SEQUENCE_POOL_OP diff --git a/src/operators/kernel/relu_kernel.h b/src/operators/sequence_ops/sequence_softmax_op.cpp similarity index 58% rename from src/operators/kernel/relu_kernel.h rename to src/operators/sequence_ops/sequence_softmax_op.cpp index e9473ee63bdc297d0789c15f2fcad79fb29c143f..602e0d2975adcdc7ff6c49dc4a6ed4de1de38d64 100644 --- a/src/operators/kernel/relu_kernel.h +++ b/src/operators/sequence_ops/sequence_softmax_op.cpp @@ -12,33 +12,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef RELU_OP +#ifdef SEQUENCE_SOFTMAX_OP -#pragma once - -#include "framework/operator.h" -#include "operators/op_param.h" +#include "operators/sequence_ops/sequence_softmax_op.h" namespace paddle_mobile { namespace operators { template -class ReluKernel - : public framework::OpKernelBase> { - public: - void Compute(const ReluParam& param); - bool Init(ReluParam* param); -}; +void SequenceSoftmaxOp::InferShape() const { + const auto *input_x = this->param_.InputX(); + const auto &x_lod = input_x->lod(); -template -class Relu6Kernel - : public framework::OpKernelBase> { - public: - void Compute(const ReluParam& param); - bool Init(ReluParam* param); -}; + this->param_.Out()->Resize(input_x->dims()); + this->param_.Out()->set_lod(input_x->lod()); +} } // namespace operators } // namespace paddle_mobile +namespace ops = paddle_mobile::operators; +#ifdef PADDLE_MOBILE_CPU +REGISTER_OPERATOR_CPU(sequence_softmax, ops::SequenceSoftmaxOp); #endif + +#endif // SEQUENCE_SOFTMAX_OP diff --git a/src/operators/sequence_ops/sequence_softmax_op.h b/src/operators/sequence_ops/sequence_softmax_op.h new file mode 100644 index 0000000000000000000000000000000000000000..92090ba802cc8ea97bc87f0fe9567b319c3d4948 --- /dev/null +++ b/src/operators/sequence_ops/sequence_softmax_op.h @@ -0,0 +1,47 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef SEQUENCE_SOFTMAX_OP + +#pragma once + +#include +#include "framework/operator.h" +#include "operators/kernel/sequence_kernels.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +template +class SequenceSoftmaxOp : public framework::OperatorWithKernel< + DeviceType, SoftmaxParam, + operators::SequenceSoftmaxKernel> { + public: + SequenceSoftmaxOp(const std::string &type, const VariableNameMap &inputs, + const VariableNameMap &outputs, + const framework::AttributeMap &attrs, + std::shared_ptr scope) + : framework::OperatorWithKernel< + DeviceType, SoftmaxParam, + operators::SequenceSoftmaxKernel>( + type, inputs, outputs, attrs, scope) {} + // inference output shape + void InferShape() const override; +}; + +} // namespace operators +} // namespace paddle_mobile + +#endif // SEQUENCE_SOFTMAX_OP diff --git a/src/operators/sigmoid_op.h b/src/operators/sigmoid_op.h index 7150a8a473e4cb1dba7230d63799bd263ef19812..f918b1a86f806bcba78f67900fe6bb2b56cd6a0f 100644 --- a/src/operators/sigmoid_op.h +++ b/src/operators/sigmoid_op.h @@ -18,7 +18,7 @@ limitations under the License. */ #include #include "framework/operator.h" -#include "operators/kernel/sigmoid_kernel.h" +#include "operators/kernel/activation_kernel.h" #include "operators/op_param.h" namespace paddle_mobile { diff --git a/src/operators/tanh_op.cpp b/src/operators/tanh_op.cpp index dd6f9083afd6919cfa3320e5e20275a785adf092..77cc980f468e2ec2a192a987496484d935690008 100644 --- a/src/operators/tanh_op.cpp +++ b/src/operators/tanh_op.cpp @@ -28,6 +28,9 @@ void TanhOp::InferShape() const { } // namespace paddle_mobile namespace ops = paddle_mobile::operators; +#ifdef PADDLE_MOBILE_CPU +REGISTER_OPERATOR_CPU(tanh, ops::TanhOp); +#endif #ifdef PADDLE_MOBILE_FPGA REGISTER_OPERATOR_FPGA(tanh, ops::TanhOp); #endif diff --git a/src/operators/tanh_op.h b/src/operators/tanh_op.h index 82b0e4e9a07ae4fd3e4885790d5832065ed3eb49..0e8226a1eb9074c00b6be87c3192a5f9f10f79bf 100644 --- a/src/operators/tanh_op.h +++ b/src/operators/tanh_op.h @@ -18,7 +18,7 @@ limitations under the License. */ #include #include "framework/operator.h" -#include "operators/kernel/tanh_kernel.h" +#include "operators/kernel/activation_kernel.h" #include "operators/op_param.h" namespace paddle_mobile { diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 5602d2d3ac38a7f9dfb175a0b03b1260960fc32e..6cab082d98f6247e5254b8bfa4c6a208c50fb42c 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -385,4 +385,13 @@ if (NOT FOUND_MATCH) # gen test ADD_EXECUTABLE(test-ocr net/test_ocr.cpp test_helper.h test_include.h) target_link_libraries(test-ocr paddle-mobile) + + ADD_EXECUTABLE(test-sequence-expand operators/test_sequence_expand_op.cpp test_helper.h test_include.h) + target_link_libraries(test-sequence-expand paddle-mobile) + + ADD_EXECUTABLE(test-sequence-pool operators/test_sequence_pool_op.cpp test_helper.h test_include.h) + target_link_libraries(test-sequence-pool paddle-mobile) + + ADD_EXECUTABLE(test-sequence-softmax operators/test_sequence_softmax_op.cpp test_helper.h test_include.h) + target_link_libraries(test-sequence-softmax paddle-mobile) endif () diff --git a/test/operators/test_sequence_expand_op.cpp b/test/operators/test_sequence_expand_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..72e8954f93f4e48524c6b78804237ece427dbae3 --- /dev/null +++ b/test/operators/test_sequence_expand_op.cpp @@ -0,0 +1,97 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "../test_include.h" +#include "operators/sequence_ops/sequence_expand_op.h" + +namespace paddle_mobile { + +int TestSequenceExpandOp(const framework::LoDTensor &input_x, + const framework::LoDTensor &input_y, int ref_level, + framework::LoDTensor *output) { + VariableNameMap inputs; + VariableNameMap outputs; + auto scope = std::make_shared(); + inputs["X"] = std::vector({"input_x"}); + inputs["Y"] = std::vector({"input_y"}); + outputs["Out"] = std::vector({"output"}); + + auto input_x_var = scope.get()->Var("input_x"); + auto *x = input_x_var->template GetMutable(); + x->Resize(input_x.dims()); + x->ShareDataWith(input_x); + x->set_lod(input_x.lod()); + auto input_y_var = scope.get()->Var("input_y"); + auto *y = input_y_var->template GetMutable(); + y->Resize(framework::make_ddim({0})); + y->mutable_data(); + y->set_lod(input_y.lod()); + + auto output_var = scope.get()->Var("output"); + + framework::AttributeMap attrs; + attrs["ref_level"].Set(0); + + auto *op = new operators::SequenceExpandOp( + "sequence_expand", inputs, outputs, attrs, scope); + + op->InferShape(); + op->Init(); + op->Run(); + + auto *out = output_var->template Get(); + output->Resize(out->dims()); + output->ShareDataWith(*out); + output->set_lod(out->lod()); + delete op; + return 0; +} + +} // namespace paddle_mobile + +// namespace framework = paddle_mobile::framework; + +int main(int argc, char *argv[]) { + framework::LoDTensor input_x, input_y, output; + // case 1 + { + std::vector data{1, 2, 3, 4}; + input_x.Resize(framework::make_ddim({4, 1})); + float *in_data = input_x.mutable_data(); + for (int i = 0; i < 4; ++i) in_data[i] = data[i]; + input_x.set_lod({{0, 2, 4}}); + input_y.set_lod({{0, 2, 4}, {0, 3, 6, 7, 8}}); + + TestSequenceExpandOp(input_x, input_y, 0, &output); + std::vector expect_data{1, 2, 1, 2, 3, 4, 3, 4}; + std::vector expect_lod{0, 2, 4, 6, 8}; + for (int i = 0; i < 5; ++i) { + if (output.lod()[0][i] != expect_lod[i]) { + std::cerr << "output_lod[" << i << "]: " << output.lod()[0][i] + << " != expect_lod[" << i << "]: " << expect_lod[i] + << std::endl; + return 1; + } + } + for (int i = 0; i < 8; ++i) { + if (output.data()[i] != expect_data[i]) { + std::cerr << "output[" << i << "]: " << output.data()[i] + << " != expect[" << i << "]: " << expect_data[i] << std::endl; + return 1; + } + } + } + return 0; +} diff --git a/test/operators/test_sequence_pool_op.cpp b/test/operators/test_sequence_pool_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a8518d630a6008c7cd1fa99d2b0df1d27ebfba32 --- /dev/null +++ b/test/operators/test_sequence_pool_op.cpp @@ -0,0 +1,293 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "../test_include.h" +#include "operators/sequence_ops/sequence_pool_op.h" + +namespace paddle_mobile { + +int TestSequencePoolOp(const framework::LoDTensor &input_x, + const std::string pool_type, + framework::LoDTensor *output) { + VariableNameMap inputs; + VariableNameMap outputs; + auto scope = std::make_shared(); + inputs["X"] = std::vector({"input_x"}); + outputs["Out"] = std::vector({"output"}); + + auto input_x_var = scope.get()->Var("input_x"); + auto *x = input_x_var->template GetMutable(); + x->Resize(input_x.dims()); + x->ShareDataWith(input_x); + x->set_lod(input_x.lod()); + + auto output_var = scope.get()->Var("output"); + + framework::AttributeMap attrs; + attrs["pooltype"].SetString(pool_type); + + auto *op = new operators::SequencePoolOp("sequence_pool", inputs, + outputs, attrs, scope); + + op->InferShape(); + op->Init(); + op->Run(); + + auto *out = output_var->template Get(); + output->Resize(out->dims()); + output->ShareDataWith(*out); + delete op; + return 0; +} + +} // namespace paddle_mobile + +// namespace framework = paddle_mobile::framework; + +int main(int argc, char *argv[]) { + framework::LoDTensor input_x, output; + // case 1 + std::cerr << "running max case 1" << std::endl; + { + std::vector data{1, 2, 3, 4}; + input_x.Resize(framework::make_ddim({4, 1})); + float *in_data = input_x.mutable_data(); + for (int i = 0; i < 4; ++i) in_data[i] = data[i]; + input_x.set_lod({{0, 2, 4}}); + + TestSequencePoolOp(input_x, "MAX", &output); + std::vector expect_data{2, 4}; + for (int i = 0; i < 2; ++i) { + if (output.data()[i] != expect_data[i]) { + std::cerr << "output[" << i << "]: " << output.data()[i] + << " != expect[" << i << "]: " << expect_data[i] << std::endl; + return 1; + } + } + } + // case 2 + std::cerr << "running max case 2" << std::endl; + { + std::vector data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + input_x.Resize(framework::make_ddim({data.size(), 1})); + float *in_data = input_x.mutable_data(); + for (int i = 0; i < data.size(); ++i) in_data[i] = data[i]; + input_x.set_lod({{0, 3, 10}}); + + TestSequencePoolOp(input_x, "MAX", &output); + std::vector expect_data{3, 10}; + for (int i = 0; i < 2; ++i) { + if (output.data()[i] != expect_data[i]) { + std::cerr << "output[" << i << "]: " << output.data()[i] + << " != expect[" << i << "]: " << expect_data[i] << std::endl; + return 1; + } + } + } + std::cerr << "running max case 3" << std::endl; + // case 3 + { + std::vector data{1, 2, 3, 4, 5, 6, 7, 8}; + input_x.Resize(framework::make_ddim({4, 2})); + float *in_data = input_x.mutable_data(); + for (int i = 0; i < data.size(); ++i) in_data[i] = data[i]; + input_x.set_lod({{0, 2, 4}}); + + TestSequencePoolOp(input_x, "MAX", &output); + std::vector expect_data{3, 4, 7, 8}; + for (int i = 0; i < 4; ++i) { + if (output.data()[i] != expect_data[i]) { + std::cerr << "output[" << i << "]: " << output.data()[i] + << " != expect[" << i << "]: " << expect_data[i] << std::endl; + return 1; + } + } + } + // case 4 + std::cerr << "running max case 4" << std::endl; + { + std::vector data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}; + input_x.Resize(framework::make_ddim({4, 5})); + float *in_data = input_x.mutable_data(); + for (int i = 0; i < data.size(); ++i) in_data[i] = data[i]; + input_x.set_lod({{0, 2, 4}}); + + TestSequencePoolOp(input_x, "MAX", &output); + std::vector expect_data{6, 7, 8, 9, 10, 16, 17, 18, 19, 20}; + for (int i = 0; i < 10; ++i) { + if (output.data()[i] != expect_data[i]) { + std::cerr << "output[" << i << "]: " << output.data()[i] + << " != expect[" << i << "]: " << expect_data[i] << std::endl; + return 1; + } + } + } + // case 1 + std::cerr << "running sum case 1" << std::endl; + { + std::vector data{1, 2, 3, 4}; + input_x.Resize(framework::make_ddim({4, 1})); + float *in_data = input_x.mutable_data(); + for (int i = 0; i < 4; ++i) in_data[i] = data[i]; + input_x.set_lod({{0, 2, 4}}); + + TestSequencePoolOp(input_x, "SUM", &output); + std::vector expect_data{3, 7}; + for (int i = 0; i < 2; ++i) { + if (output.data()[i] != expect_data[i]) { + std::cerr << "output[" << i << "]: " << output.data()[i] + << " != expect[" << i << "]: " << expect_data[i] << std::endl; + return 1; + } + } + } + // case 2 + std::cerr << "running sum case 2" << std::endl; + { + std::vector data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + input_x.Resize(framework::make_ddim({data.size(), 1})); + float *in_data = input_x.mutable_data(); + for (int i = 0; i < data.size(); ++i) in_data[i] = data[i]; + input_x.set_lod({{0, 3, 10}}); + + TestSequencePoolOp(input_x, "SUM", &output); + std::vector expect_data{6, 49}; + for (int i = 0; i < 2; ++i) { + if (output.data()[i] != expect_data[i]) { + std::cerr << "output[" << i << "]: " << output.data()[i] + << " != expect[" << i << "]: " << expect_data[i] << std::endl; + return 1; + } + } + } + // case 3 + std::cerr << "running sum case 3" << std::endl; + { + std::vector data{1, 2, 3, 4, 5, 6, 7, 8}; + input_x.Resize(framework::make_ddim({4, 2})); + float *in_data = input_x.mutable_data(); + for (int i = 0; i < data.size(); ++i) in_data[i] = data[i]; + input_x.set_lod({{0, 2, 4}}); + + TestSequencePoolOp(input_x, "SUM", &output); + std::vector expect_data{4, 6, 12, 14}; + for (int i = 0; i < 4; ++i) { + if (output.data()[i] != expect_data[i]) { + std::cerr << "output[" << i << "]: " << output.data()[i] + << " != expect[" << i << "]: " << expect_data[i] << std::endl; + return 1; + } + } + } + // case 4 + std::cerr << "running sum case 4" << std::endl; + { + std::vector data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}; + input_x.Resize(framework::make_ddim({4, 5})); + float *in_data = input_x.mutable_data(); + for (int i = 0; i < data.size(); ++i) in_data[i] = data[i]; + input_x.set_lod({{0, 2, 4}}); + + TestSequencePoolOp(input_x, "SUM", &output); + std::vector expect_data{7, 9, 11, 13, 15, 27, 29, 31, 33, 35}; + for (int i = 0; i < 10; ++i) { + if (output.data()[i] != expect_data[i]) { + std::cerr << "output[" << i << "]: " << output.data()[i] + << " != expect[" << i << "]: " << expect_data[i] << std::endl; + return 1; + } + } + } + // case 1 + std::cerr << "running first case 1" << std::endl; + { + std::vector data{1, 2, 3, 4}; + input_x.Resize(framework::make_ddim({4, 1})); + float *in_data = input_x.mutable_data(); + for (int i = 0; i < 4; ++i) in_data[i] = data[i]; + input_x.set_lod({{0, 2, 4}}); + + TestSequencePoolOp(input_x, "FIRST", &output); + std::vector expect_data{1, 3}; + for (int i = 0; i < 2; ++i) { + if (output.data()[i] != expect_data[i]) { + std::cerr << "output[" << i << "]: " << output.data()[i] + << " != expect[" << i << "]: " << expect_data[i] << std::endl; + return 1; + } + } + } + // case 2 + std::cerr << "running first case 2" << std::endl; + { + std::vector data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + input_x.Resize(framework::make_ddim({data.size(), 1})); + float *in_data = input_x.mutable_data(); + for (int i = 0; i < data.size(); ++i) in_data[i] = data[i]; + input_x.set_lod({{0, 3, 10}}); + + TestSequencePoolOp(input_x, "FIRST", &output); + std::vector expect_data{1, 4}; + for (int i = 0; i < 2; ++i) { + if (output.data()[i] != expect_data[i]) { + std::cerr << "output[" << i << "]: " << output.data()[i] + << " != expect[" << i << "]: " << expect_data[i] << std::endl; + return 1; + } + } + } + // case 3 + std::cerr << "running first case 3" << std::endl; + { + std::vector data{1, 2, 3, 4, 5, 6, 7, 8}; + input_x.Resize(framework::make_ddim({4, 2})); + float *in_data = input_x.mutable_data(); + for (int i = 0; i < data.size(); ++i) in_data[i] = data[i]; + input_x.set_lod({{0, 2, 4}}); + + TestSequencePoolOp(input_x, "FIRST", &output); + std::vector expect_data{1, 2, 5, 6}; + for (int i = 0; i < 4; ++i) { + if (output.data()[i] != expect_data[i]) { + std::cerr << "output[" << i << "]: " << output.data()[i] + << " != expect[" << i << "]: " << expect_data[i] << std::endl; + return 1; + } + } + } + // case 4 + std::cerr << "running first case 4" << std::endl; + { + std::vector data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}; + input_x.Resize(framework::make_ddim({4, 5})); + float *in_data = input_x.mutable_data(); + for (int i = 0; i < data.size(); ++i) in_data[i] = data[i]; + input_x.set_lod({{0, 2, 4}}); + + TestSequencePoolOp(input_x, "FIRST", &output); + std::vector expect_data{1, 2, 3, 4, 5, 11, 12, 13, 14, 15}; + for (int i = 0; i < 10; ++i) { + if (output.data()[i] != expect_data[i]) { + std::cerr << "output[" << i << "]: " << output.data()[i] + << " != expect[" << i << "]: " << expect_data[i] << std::endl; + return 1; + } + } + } + return 0; +} diff --git a/test/operators/test_sequence_softmax_op.cpp b/test/operators/test_sequence_softmax_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9e698d933a97525dac32948466e2600a9b217033 --- /dev/null +++ b/test/operators/test_sequence_softmax_op.cpp @@ -0,0 +1,100 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "../test_include.h" +#include "operators/sequence_ops/sequence_softmax_op.h" + +namespace paddle_mobile { + +void SequenceSoftmax(const framework::LoDTensor *X, framework::LoDTensor *Y) { + const float *x = X->data(); + const auto &lod = X->lod().back(); + float *y = Y->mutable_data(); + for (int batch = 0; batch < lod.size() - 1; ++batch) { + int num_classes = lod[batch + 1] - lod[batch]; + size_t offset = lod[batch]; + const float *input = x + offset; + float *output = y + offset; + float max = -std::numeric_limits::max(); + for (int j = 0; j < num_classes; ++j) { + max = (input[j] > max) ? input[j] : max; + } + float sum = 0.f; + for (int j = 0; j < num_classes; ++j) { + float tmp = std::expf(input[j] - max); + sum += tmp; + output[j] = tmp; + } + for (int j = 0; j < num_classes; ++j) { + output[j] /= sum; + } + } + Y->set_lod(X->lod()); +} + +int TestSequenceSoftmaxOp(const std::vector &input_shape, + const std::vector &input_lod) { + framework::DDim dims = framework::make_ddim(input_shape); + VariableNameMap inputs; + VariableNameMap outputs; + auto scope = std::make_shared(); + inputs["X"] = std::vector({"input"}); + outputs["Out"] = std::vector({"output"}); + + auto input_var = scope.get()->Var("input"); + auto input = input_var->template GetMutable(); + SetupTensor(input, dims, -100.0, 100.0); + input->set_lod({input_lod}); + + auto output_var = scope.get()->Var("output"); + + framework::AttributeMap attrs; + auto *op = new operators::SequenceSoftmaxOp( + "sequence_softmax", inputs, outputs, attrs, scope); + + op->InferShape(); + op->Init(); + op->Run(); + + auto output = output_var->template Get(); + + framework::LoDTensor output_cmp; + float *output_cmp_data = output_cmp.mutable_data(output->dims()); + SequenceSoftmax(input, &output_cmp); + + const float *output_data = output->data(); + for (int i = 0; i < output->numel(); ++i) { + float gap = output_data[i] - output_cmp_data[i]; + if (std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) { + LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i] + << ", output_cmp_data[" << i + << "] = " << output_cmp_data[i]; + delete op; + exit(1); + } + } + delete op; + return 0; +} + +} // namespace paddle_mobile + +int main(int argc, char *argv[]) { + TestSequenceSoftmaxOp({2, 1}, {0, 2}); + TestSequenceSoftmaxOp({100, 1}, {0, 3, 100}); + TestSequenceSoftmaxOp({100, 1}, {0, 50, 100}); + return 0; +} diff --git a/test/operators/test_sigmoid_op.cpp b/test/operators/test_sigmoid_op.cpp index df93da1529ae1e03561643ebeef4cb821f10d211..55ea43e1ea8e196c7397b7739cd83183ed7e8852 100644 --- a/test/operators/test_sigmoid_op.cpp +++ b/test/operators/test_sigmoid_op.cpp @@ -12,10 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "../../src/operators/kernel/central-arm-func/sigmoid_arm_func.h" -#include "../../src/operators/kernel/sigmoid_kernel.h" #include "../test_helper.h" -#include "framework/executor.h" int main() { paddle_mobile::framework::Tensor input; @@ -25,11 +22,5 @@ int main() { auto out_ddim = paddle_mobile::framework::make_ddim({1, 4, 60, 60}); output.Resize(out_ddim); - paddle_mobile::operators::sigmoid(&input, &output); - auto *output_ptr = output.data(); - for (int j = 0; j < output.numel(); ++j) { - DLOG << " value of output: " << output_ptr[j]; - } - DLOG << 5; return 0; } diff --git a/test/operators/test_softmax_op.cpp b/test/operators/test_softmax_op.cpp index d65cf4fea27343343d6c2a2a720a0e0ec7d45076..e94933eaa90f96982f90f713e04f05a999424697 100644 --- a/test/operators/test_softmax_op.cpp +++ b/test/operators/test_softmax_op.cpp @@ -62,7 +62,6 @@ int TestSoftmaxOp(const std::vector input_shape) { SetupTensor(input, dims, -100.0, 100.0); auto output_var = scope.get()->Var("output"); - auto output = output_var->template Get(); framework::AttributeMap attrs; auto *op = new operators::SoftmaxOp("softmax", inputs, outputs, @@ -71,6 +70,8 @@ int TestSoftmaxOp(const std::vector input_shape) { op->Init(); op->Run(); + auto output = output_var->template Get(); + framework::Tensor output_cmp; float *output_cmp_data = output_cmp.mutable_data(output->dims()); Softmax(input, &output_cmp); diff --git a/tools/op.cmake b/tools/op.cmake index 40bc1075baa1e7d81823a08efa56beaf58f3ae12..3563834e77fa5eba363fc64ba337dd30c95a3820 100644 --- a/tools/op.cmake +++ b/tools/op.cmake @@ -272,6 +272,9 @@ if(NOT FOUND_MATCH) set(FUSION_DEQUANT_ADD_BN_RELU_OP ON) set(FUSION_DEQUANT_ADD_BN_QUANT_OP ON) set(FUSION_DEQUANT_ADD_BN_RELU_QUANT_OP ON) + set(SEQUENCE_EXPAND_OP ON) + set(SEQUENCE_POOL_OP ON) + set(SEQUENCE_SOFTMAX_OP ON) endif() # option(BATCHNORM_OP "" ON) @@ -496,6 +499,15 @@ endif() if (FUSION_DEQUANT_ADD_BN_RELU_QUANT_OP) # add_definitions(-DFUSION_DEQUANT_ADD_BN_RELU_QUANT_OP) endif() +if (SEQUENCE_EXPAND_OP) + add_definitions(-DSEQUENCE_EXPAND_OP) +endif() +if (SEQUENCE_POOL_OP) + add_definitions(-DSEQUENCE_POOL_OP) +endif() +if (SEQUENCE_SOFTMAX_OP) + add_definitions(-DSEQUENCE_SOFTMAX_OP) +endif() if (TANH_OP) add_definitions(-DTANH_OP)