From b6e709e2dd441c65f05f900f725b132322543240 Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Tue, 8 Jan 2019 18:56:44 +0800 Subject: [PATCH] Fix bugs and add sequence softmax op --- src/common/types.cpp | 4 +- src/common/types.h | 1 + .../kernel/arm/sequence_pool_kernel.cpp | 23 +- .../kernel/arm/sequence_softmax_kernel.cpp | 43 +++ src/operators/kernel/sequence_kernels.h | 4 + src/operators/math/softmax.cpp | 136 ++++---- src/operators/math/softmax.h | 10 +- src/operators/op_param.h | 10 +- .../sequence_ops/sequence_softmax_op.cpp | 39 +++ .../sequence_ops/sequence_softmax_op.h | 47 +++ test/CMakeLists.txt | 9 + test/operators/test_sequence_expand_op.cpp | 97 ++++++ test/operators/test_sequence_pool_op.cpp | 293 ++++++++++++++++++ test/operators/test_sequence_softmax_op.cpp | 100 ++++++ test/operators/test_softmax_op.cpp | 3 +- tools/op.cmake | 4 + 16 files changed, 750 insertions(+), 73 deletions(-) create mode 100644 src/operators/kernel/arm/sequence_softmax_kernel.cpp create mode 100644 src/operators/sequence_ops/sequence_softmax_op.cpp create mode 100644 src/operators/sequence_ops/sequence_softmax_op.h create mode 100644 test/operators/test_sequence_expand_op.cpp create mode 100644 test/operators/test_sequence_pool_op.cpp create mode 100644 test/operators/test_sequence_softmax_op.cpp diff --git a/src/common/types.cpp b/src/common/types.cpp index a7996e0960..c25c5db30c 100644 --- a/src/common/types.cpp +++ b/src/common/types.cpp @@ -91,6 +91,7 @@ const char *G_OP_TYPE_FUSION_DECONV_ADD_RELU = "fusion_deconv_add_relu"; const char *G_OP_TYPE_SEQUENCE_EXPAND = "sequence_expand"; const char *G_OP_TYPE_SEQUENCE_POOL = "sequence_pool"; +const char *G_OP_TYPE_SEQUENCE_SOFTMAX = "sequence_softmax"; std::unordered_map< std::string, std::pair, std::vector>> @@ -167,5 +168,6 @@ std::unordered_map< {G_OP_TYPE_FUSION_DECONV_ADD, {{"Input"}, {"Out"}}}, {G_OP_TYPE_FUSION_DECONV_ADD_RELU, {{"Input"}, {"Out"}}}, {G_OP_TYPE_SEQUENCE_EXPAND, {{"X", "Y"}, {"Out"}}}, - {G_OP_TYPE_SEQUENCE_POOL, {{"X"}, {"Out"}}}}; + {G_OP_TYPE_SEQUENCE_POOL, {{"X"}, {"Out"}}}, + {G_OP_TYPE_SEQUENCE_SOFTMAX, {{"X"}, {"Out"}}}}; } // namespace paddle_mobile diff --git a/src/common/types.h b/src/common/types.h index b1780cad31..114424fe04 100644 --- a/src/common/types.h +++ b/src/common/types.h @@ -173,6 +173,7 @@ extern const char *G_OP_TYPE_FUSION_DECONV_ADD_RELU; extern const char *G_OP_TYPE_SEQUENCE_EXPAND; extern const char *G_OP_TYPE_SEQUENCE_POOL; +extern const char *G_OP_TYPE_SEQUENCE_SOFTMAX; extern std::unordered_map< std::string, std::pair, std::vector>> diff --git a/src/operators/kernel/arm/sequence_pool_kernel.cpp b/src/operators/kernel/arm/sequence_pool_kernel.cpp index b76feb555a..f4e28a0ffb 100644 --- a/src/operators/kernel/arm/sequence_pool_kernel.cpp +++ b/src/operators/kernel/arm/sequence_pool_kernel.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #ifdef SEQUENCE_POOL_OP #include +#include #include #include #include "common/types.h" @@ -41,7 +42,7 @@ void SequencePoolImpl(const framework::LoDTensor &input, float *out_ptr = output_ptr + i * width; int64_t height = static_cast(lod[i + 1] - lod[i]); if (width == 1) { - float val = 0.f; + float max = -std::numeric_limits::max(); int remain_h = height; #ifdef __ARM_NEON__ int loop = remain_h >> 2; @@ -53,19 +54,19 @@ void SequencePoolImpl(const framework::LoDTensor &input, in_ptr += 4; } float32x2_t __max2 = - vpadd_f32(vget_low_f32(__max4), vget_high_f32(__max4)); - __max2 = vpadd_f32(__max2, __max2); - val = std::max(val, vget_lane_f32(__max2, 0)); + vpmax_f32(vget_low_f32(__max4), vget_high_f32(__max4)); + __max2 = vpmax_f32(__max2, __max2); + max = std::max(max, vget_lane_f32(__max2, 0)); #endif // __ARM_NEON__ for (int h = 0; h < remain_h; ++h) { - val = std::max(val, in_ptr[h]); + max = std::max(max, in_ptr[h]); } - *out_ptr = val; + *out_ptr = max; } else { memcpy(out_ptr, in_ptr, width * sizeof(float)); + in_ptr += width; int remain_h = height - 1; #ifdef __ARM_NEON__ - int loop_w = width >> 2; int remain_w_start = width & 0xfffc; #endif // __ARM_NEON__ for (int h = 0; h < remain_h; ++h) { @@ -121,6 +122,7 @@ void SequencePoolImpl(const framework::LoDTensor &input, *out_ptr = sum; } else { memcpy(out_ptr, in_ptr, width * sizeof(float)); + in_ptr += width; int remain_h = height - 1; #ifdef __ARM_NEON__ int loop_w = width >> 2; @@ -128,7 +130,7 @@ void SequencePoolImpl(const framework::LoDTensor &input, #endif // __ARM_NEON__ for (int h = 0; h < remain_h; ++h) { #ifdef __ARM_NEON__ - for (int w = 0; w < width; w += 4) { + for (int w = 0; w < width - 3; w += 4) { float32x4_t __in = vld1q_f32(in_ptr + w); float32x4_t __out = vld1q_f32(out_ptr + w); __out = vaddq_f32(__out, __in); @@ -169,6 +171,7 @@ class SequencePoolKernel const framework::LoDTensor *input = param.input_; framework::LoDTensor *output = param.output_; output->mutable_data(); + const std::string pooling_type = param.pool_type_; if (param.pool_type_ == "MAX") { SequencePoolImpl(*input, output); @@ -176,6 +179,10 @@ class SequencePoolKernel SequencePoolImpl(*input, output); } else if (param.pool_type_ == "SUM") { SequencePoolImpl(*input, output); + } else { + PADDLE_MOBILE_THROW_EXCEPTION( + "pooling type `%s` has not been implemented.", + param.pool_type_.c_str()); } } }; diff --git a/src/operators/kernel/arm/sequence_softmax_kernel.cpp b/src/operators/kernel/arm/sequence_softmax_kernel.cpp new file mode 100644 index 0000000000..ecbc39c4cc --- /dev/null +++ b/src/operators/kernel/arm/sequence_softmax_kernel.cpp @@ -0,0 +1,43 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef SEQUENCE_SOFTMAX_OP + +#include "framework/lod_tensor.h" +#include "operators/kernel/sequence_kernels.h" +#include "operators/math/softmax.h" + +namespace paddle_mobile { +namespace operators { + +template +class SequenceSoftmaxKernel + : public framework::OpKernelBase> { + public: + bool Init(SoftmaxParam *param) { return true; } + + void Compute(const SoftmaxParam ¶m) { + const framework::LoDTensor *input = param.InputX(); + framework::LoDTensor *output = param.Out(); + math::SequenceSoftmaxFuntor sequence_softmax; + sequence_softmax(input, output); + } +}; + +template class SequenceSoftmaxKernel; + +} // namespace operators +} // namespace paddle_mobile + +#endif // SEQUENCE_SOFTMAX_OP diff --git a/src/operators/kernel/sequence_kernels.h b/src/operators/kernel/sequence_kernels.h index 423e89e515..7884d0d475 100644 --- a/src/operators/kernel/sequence_kernels.h +++ b/src/operators/kernel/sequence_kernels.h @@ -37,5 +37,9 @@ DECLARE_KERNEL(SequenceExpandKernel, SequenceExpandParam); DECLARE_KERNEL(SequencePoolKernel, SequencePoolParam); #endif // SEQUENCE_POOL_OP +#ifdef SEQUENCE_SOFTMAX_OP +DECLARE_KERNEL(SequenceSoftmaxKernel, SoftmaxParam); +#endif // SEQUENCE_SOFTMAX_OP + } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/math/softmax.cpp b/src/operators/math/softmax.cpp index 48f8b35ea2..e905ff2564 100644 --- a/src/operators/math/softmax.cpp +++ b/src/operators/math/softmax.cpp @@ -60,6 +60,68 @@ float find_max(const float *input, const int num_classes) { return max; } +void SoftmaxBasic(const float *input, int num_classes, float *y) { + float *output = y; + // find max + float max = find_max(input, num_classes); + + // exp(x - max) + int remain = num_classes; +#if defined(__ARM_NEON) || defined(__ARM_NEON__) + int loop = num_classes >> 3; + remain = num_classes & 0x7; + float32x4_t __max = vdupq_n_f32(max); + for (int i = 0; i < loop; ++i, input += 8, output += 8) { + float32x4_t x0 = vld1q_f32(input); + float32x4_t x1 = vld1q_f32(input + 4); + x0 = vsubq_f32(x0, __max); + x1 = vsubq_f32(x1, __max); + x0 = exp_ps(x0); + x1 = exp_ps(x1); + vst1q_f32(output, x0); + vst1q_f32(output + 4, x1); + } +#endif // __ARM_NEON__ + for (int i = 0; i < remain; ++i) { + output[i] = expf(input[i] - max); + } + + // sum(exp(x - max)) + float sum = 0.f; + output = y; +#if defined(__ARM_NEON) || defined(__ARM_NEON__) + float32x4_t __sum = vdupq_n_f32(0.f); + for (int i = 0; i < loop; ++i, output += 8) { + float32x4_t x0 = vld1q_f32(output); + float32x4_t x1 = vld1q_f32(output + 4); + __sum = vaddq_f32(x0, __sum); + __sum = vaddq_f32(x1, __sum); + } + sum += vaddvq_f32(__sum); +#endif // __ARM_NEON__ + for (int i = 0; i < remain; ++i) { + sum += output[i]; + } + + // exp(x - max) / sum + float inv_sum = 1.f / sum; + output = y; +#if defined(__ARM_NEON) || defined(__ARM_NEON__) + float32x4_t __inv_sum = vdupq_n_f32(inv_sum); + for (int i = 0; i < loop; ++i, output += 8) { + float32x4_t x0 = vld1q_f32(output); + float32x4_t x1 = vld1q_f32(output + 4); + x0 = vmulq_f32(x0, __inv_sum); + x1 = vmulq_f32(x1, __inv_sum); + vst1q_f32(output, x0); + vst1q_f32(output + 4, x1); + } +#endif + for (int i = 0; i < remain; ++i) { + output[i] *= inv_sum; + } +} + template <> void SoftmaxFuntor::operator()(const framework::Tensor *X, framework::Tensor *Y) { @@ -76,65 +138,25 @@ void SoftmaxFuntor::operator()(const framework::Tensor *X, size_t offset = (batch * channels + channel) * num_classes; const float *input = x + offset; float *output = y + offset; - // find max - float max = find_max(input, num_classes); - - // exp(x - max) - int remain = num_classes; -#if defined(__ARM_NEON) || defined(__ARM_NEON__) - int loop = num_classes >> 3; - remain = num_classes & 0x7; - float32x4_t __max = vdupq_n_f32(max); - for (int i = 0; i < loop; ++i, input += 8, output += 8) { - float32x4_t x0 = vld1q_f32(input); - float32x4_t x1 = vld1q_f32(input + 4); - x0 = vsubq_f32(x0, __max); - x1 = vsubq_f32(x1, __max); - x0 = exp_ps(x0); - x1 = exp_ps(x1); - vst1q_f32(output, x0); - vst1q_f32(output + 4, x1); - } -#endif // __ARM_NEON__ - for (int i = 0; i < remain; ++i) { - output[i] = expf(input[i] - max); - } + SoftmaxBasic(input, num_classes, output); + } + } +} - // sum(exp(x - max)) - float sum = 0.f; - output = y + offset; -#if defined(__ARM_NEON) || defined(__ARM_NEON__) - float32x4_t __sum = vdupq_n_f32(0.f); - for (int i = 0; i < loop; ++i, output += 8) { - float32x4_t x0 = vld1q_f32(output); - float32x4_t x1 = vld1q_f32(output + 4); - __sum = vaddq_f32(x0, __sum); - __sum = vaddq_f32(x1, __sum); - } - sum += vaddvq_f32(__sum); -#endif // __ARM_NEON__ - for (int i = 0; i < remain; ++i) { - sum += output[i]; - } +template <> +void SequenceSoftmaxFuntor::operator()( + const framework::LoDTensor *X, framework::LoDTensor *Y) { + const float *x = X->data(); + const auto &lod = X->lod().back(); + float *y = Y->mutable_data(); - // exp(x - max) / sum - float inv_sum = 1.f / sum; - output = y + offset; -#if defined(__ARM_NEON) || defined(__ARM_NEON__) - float32x4_t __inv_sum = vdupq_n_f32(inv_sum); - for (int i = 0; i < loop; ++i, output += 8) { - float32x4_t x0 = vld1q_f32(output); - float32x4_t x1 = vld1q_f32(output + 4); - x0 = vmulq_f32(x0, __inv_sum); - x1 = vmulq_f32(x1, __inv_sum); - vst1q_f32(output, x0); - vst1q_f32(output + 4, x1); - } -#endif - for (int i = 0; i < remain; ++i) { - output[i] *= inv_sum; - } - } + #pragma omp parallel for + for (int batch = 0; batch < lod.size() - 1; ++batch) { + int num_classes = lod[batch + 1] - lod[batch]; + size_t offset = lod[batch]; + const float *input = x + offset; + float *output = y + offset; + SoftmaxBasic(input, num_classes, output); } } diff --git a/src/operators/math/softmax.h b/src/operators/math/softmax.h index 0de30a4eca..dff25b9d02 100644 --- a/src/operators/math/softmax.h +++ b/src/operators/math/softmax.h @@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef SOFTMAX_OP +#if defined(SOFTMAX_OP) || defined(SEQUENCE_SOFTMAX_OP) #pragma once +#include "framework/lod_tensor.h" #include "framework/tensor.h" namespace paddle_mobile { @@ -28,7 +29,14 @@ class SoftmaxFuntor { void operator()(const framework::Tensor *X, framework::Tensor *Y); }; +template +class SequenceSoftmaxFuntor { + public: + void operator()(const framework::LoDTensor *X, framework::LoDTensor *Y); +}; + } // namespace math } // namespace operators } // namespace paddle_mobile + #endif diff --git a/src/operators/op_param.h b/src/operators/op_param.h index 2f7c26b3b3..8976d8be8e 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -978,12 +978,12 @@ class SoftmaxParam : public OpParam { input_x_ = InputXFrom(inputs, scope); out_ = OutFrom(outputs, scope); } - const RType *InputX() const { return input_x_; } - RType *Out() const { return out_; } + const GType *InputX() const { return input_x_; } + GType *Out() const { return out_; } private: - RType *input_x_; - RType *out_; + GType *input_x_; + GType *out_; #ifdef PADDLE_MOBILE_FPGA @@ -2778,7 +2778,7 @@ class SequencePoolParam : public OpParam { output_ = OutFrom(outputs, scope); pool_type_ = "MAX"; if (OpParam::HasAttr("pooltype", attrs)) { - pool_type_ = OpParam::GetAttr("pooltype", attrs); + pool_type_ = OpParam::GetStringAttr("pooltype", attrs); } } diff --git a/src/operators/sequence_ops/sequence_softmax_op.cpp b/src/operators/sequence_ops/sequence_softmax_op.cpp new file mode 100644 index 0000000000..602e0d2975 --- /dev/null +++ b/src/operators/sequence_ops/sequence_softmax_op.cpp @@ -0,0 +1,39 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef SEQUENCE_SOFTMAX_OP + +#include "operators/sequence_ops/sequence_softmax_op.h" + +namespace paddle_mobile { +namespace operators { + +template +void SequenceSoftmaxOp::InferShape() const { + const auto *input_x = this->param_.InputX(); + const auto &x_lod = input_x->lod(); + + this->param_.Out()->Resize(input_x->dims()); + this->param_.Out()->set_lod(input_x->lod()); +} + +} // namespace operators +} // namespace paddle_mobile + +namespace ops = paddle_mobile::operators; +#ifdef PADDLE_MOBILE_CPU +REGISTER_OPERATOR_CPU(sequence_softmax, ops::SequenceSoftmaxOp); +#endif + +#endif // SEQUENCE_SOFTMAX_OP diff --git a/src/operators/sequence_ops/sequence_softmax_op.h b/src/operators/sequence_ops/sequence_softmax_op.h new file mode 100644 index 0000000000..92090ba802 --- /dev/null +++ b/src/operators/sequence_ops/sequence_softmax_op.h @@ -0,0 +1,47 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef SEQUENCE_SOFTMAX_OP + +#pragma once + +#include +#include "framework/operator.h" +#include "operators/kernel/sequence_kernels.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +template +class SequenceSoftmaxOp : public framework::OperatorWithKernel< + DeviceType, SoftmaxParam, + operators::SequenceSoftmaxKernel> { + public: + SequenceSoftmaxOp(const std::string &type, const VariableNameMap &inputs, + const VariableNameMap &outputs, + const framework::AttributeMap &attrs, + std::shared_ptr scope) + : framework::OperatorWithKernel< + DeviceType, SoftmaxParam, + operators::SequenceSoftmaxKernel>( + type, inputs, outputs, attrs, scope) {} + // inference output shape + void InferShape() const override; +}; + +} // namespace operators +} // namespace paddle_mobile + +#endif // SEQUENCE_SOFTMAX_OP diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 5602d2d3ac..6cab082d98 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -385,4 +385,13 @@ if (NOT FOUND_MATCH) # gen test ADD_EXECUTABLE(test-ocr net/test_ocr.cpp test_helper.h test_include.h) target_link_libraries(test-ocr paddle-mobile) + + ADD_EXECUTABLE(test-sequence-expand operators/test_sequence_expand_op.cpp test_helper.h test_include.h) + target_link_libraries(test-sequence-expand paddle-mobile) + + ADD_EXECUTABLE(test-sequence-pool operators/test_sequence_pool_op.cpp test_helper.h test_include.h) + target_link_libraries(test-sequence-pool paddle-mobile) + + ADD_EXECUTABLE(test-sequence-softmax operators/test_sequence_softmax_op.cpp test_helper.h test_include.h) + target_link_libraries(test-sequence-softmax paddle-mobile) endif () diff --git a/test/operators/test_sequence_expand_op.cpp b/test/operators/test_sequence_expand_op.cpp new file mode 100644 index 0000000000..72e8954f93 --- /dev/null +++ b/test/operators/test_sequence_expand_op.cpp @@ -0,0 +1,97 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "../test_include.h" +#include "operators/sequence_ops/sequence_expand_op.h" + +namespace paddle_mobile { + +int TestSequenceExpandOp(const framework::LoDTensor &input_x, + const framework::LoDTensor &input_y, int ref_level, + framework::LoDTensor *output) { + VariableNameMap inputs; + VariableNameMap outputs; + auto scope = std::make_shared(); + inputs["X"] = std::vector({"input_x"}); + inputs["Y"] = std::vector({"input_y"}); + outputs["Out"] = std::vector({"output"}); + + auto input_x_var = scope.get()->Var("input_x"); + auto *x = input_x_var->template GetMutable(); + x->Resize(input_x.dims()); + x->ShareDataWith(input_x); + x->set_lod(input_x.lod()); + auto input_y_var = scope.get()->Var("input_y"); + auto *y = input_y_var->template GetMutable(); + y->Resize(framework::make_ddim({0})); + y->mutable_data(); + y->set_lod(input_y.lod()); + + auto output_var = scope.get()->Var("output"); + + framework::AttributeMap attrs; + attrs["ref_level"].Set(0); + + auto *op = new operators::SequenceExpandOp( + "sequence_expand", inputs, outputs, attrs, scope); + + op->InferShape(); + op->Init(); + op->Run(); + + auto *out = output_var->template Get(); + output->Resize(out->dims()); + output->ShareDataWith(*out); + output->set_lod(out->lod()); + delete op; + return 0; +} + +} // namespace paddle_mobile + +// namespace framework = paddle_mobile::framework; + +int main(int argc, char *argv[]) { + framework::LoDTensor input_x, input_y, output; + // case 1 + { + std::vector data{1, 2, 3, 4}; + input_x.Resize(framework::make_ddim({4, 1})); + float *in_data = input_x.mutable_data(); + for (int i = 0; i < 4; ++i) in_data[i] = data[i]; + input_x.set_lod({{0, 2, 4}}); + input_y.set_lod({{0, 2, 4}, {0, 3, 6, 7, 8}}); + + TestSequenceExpandOp(input_x, input_y, 0, &output); + std::vector expect_data{1, 2, 1, 2, 3, 4, 3, 4}; + std::vector expect_lod{0, 2, 4, 6, 8}; + for (int i = 0; i < 5; ++i) { + if (output.lod()[0][i] != expect_lod[i]) { + std::cerr << "output_lod[" << i << "]: " << output.lod()[0][i] + << " != expect_lod[" << i << "]: " << expect_lod[i] + << std::endl; + return 1; + } + } + for (int i = 0; i < 8; ++i) { + if (output.data()[i] != expect_data[i]) { + std::cerr << "output[" << i << "]: " << output.data()[i] + << " != expect[" << i << "]: " << expect_data[i] << std::endl; + return 1; + } + } + } + return 0; +} diff --git a/test/operators/test_sequence_pool_op.cpp b/test/operators/test_sequence_pool_op.cpp new file mode 100644 index 0000000000..a8518d630a --- /dev/null +++ b/test/operators/test_sequence_pool_op.cpp @@ -0,0 +1,293 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "../test_include.h" +#include "operators/sequence_ops/sequence_pool_op.h" + +namespace paddle_mobile { + +int TestSequencePoolOp(const framework::LoDTensor &input_x, + const std::string pool_type, + framework::LoDTensor *output) { + VariableNameMap inputs; + VariableNameMap outputs; + auto scope = std::make_shared(); + inputs["X"] = std::vector({"input_x"}); + outputs["Out"] = std::vector({"output"}); + + auto input_x_var = scope.get()->Var("input_x"); + auto *x = input_x_var->template GetMutable(); + x->Resize(input_x.dims()); + x->ShareDataWith(input_x); + x->set_lod(input_x.lod()); + + auto output_var = scope.get()->Var("output"); + + framework::AttributeMap attrs; + attrs["pooltype"].SetString(pool_type); + + auto *op = new operators::SequencePoolOp("sequence_pool", inputs, + outputs, attrs, scope); + + op->InferShape(); + op->Init(); + op->Run(); + + auto *out = output_var->template Get(); + output->Resize(out->dims()); + output->ShareDataWith(*out); + delete op; + return 0; +} + +} // namespace paddle_mobile + +// namespace framework = paddle_mobile::framework; + +int main(int argc, char *argv[]) { + framework::LoDTensor input_x, output; + // case 1 + std::cerr << "running max case 1" << std::endl; + { + std::vector data{1, 2, 3, 4}; + input_x.Resize(framework::make_ddim({4, 1})); + float *in_data = input_x.mutable_data(); + for (int i = 0; i < 4; ++i) in_data[i] = data[i]; + input_x.set_lod({{0, 2, 4}}); + + TestSequencePoolOp(input_x, "MAX", &output); + std::vector expect_data{2, 4}; + for (int i = 0; i < 2; ++i) { + if (output.data()[i] != expect_data[i]) { + std::cerr << "output[" << i << "]: " << output.data()[i] + << " != expect[" << i << "]: " << expect_data[i] << std::endl; + return 1; + } + } + } + // case 2 + std::cerr << "running max case 2" << std::endl; + { + std::vector data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + input_x.Resize(framework::make_ddim({data.size(), 1})); + float *in_data = input_x.mutable_data(); + for (int i = 0; i < data.size(); ++i) in_data[i] = data[i]; + input_x.set_lod({{0, 3, 10}}); + + TestSequencePoolOp(input_x, "MAX", &output); + std::vector expect_data{3, 10}; + for (int i = 0; i < 2; ++i) { + if (output.data()[i] != expect_data[i]) { + std::cerr << "output[" << i << "]: " << output.data()[i] + << " != expect[" << i << "]: " << expect_data[i] << std::endl; + return 1; + } + } + } + std::cerr << "running max case 3" << std::endl; + // case 3 + { + std::vector data{1, 2, 3, 4, 5, 6, 7, 8}; + input_x.Resize(framework::make_ddim({4, 2})); + float *in_data = input_x.mutable_data(); + for (int i = 0; i < data.size(); ++i) in_data[i] = data[i]; + input_x.set_lod({{0, 2, 4}}); + + TestSequencePoolOp(input_x, "MAX", &output); + std::vector expect_data{3, 4, 7, 8}; + for (int i = 0; i < 4; ++i) { + if (output.data()[i] != expect_data[i]) { + std::cerr << "output[" << i << "]: " << output.data()[i] + << " != expect[" << i << "]: " << expect_data[i] << std::endl; + return 1; + } + } + } + // case 4 + std::cerr << "running max case 4" << std::endl; + { + std::vector data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}; + input_x.Resize(framework::make_ddim({4, 5})); + float *in_data = input_x.mutable_data(); + for (int i = 0; i < data.size(); ++i) in_data[i] = data[i]; + input_x.set_lod({{0, 2, 4}}); + + TestSequencePoolOp(input_x, "MAX", &output); + std::vector expect_data{6, 7, 8, 9, 10, 16, 17, 18, 19, 20}; + for (int i = 0; i < 10; ++i) { + if (output.data()[i] != expect_data[i]) { + std::cerr << "output[" << i << "]: " << output.data()[i] + << " != expect[" << i << "]: " << expect_data[i] << std::endl; + return 1; + } + } + } + // case 1 + std::cerr << "running sum case 1" << std::endl; + { + std::vector data{1, 2, 3, 4}; + input_x.Resize(framework::make_ddim({4, 1})); + float *in_data = input_x.mutable_data(); + for (int i = 0; i < 4; ++i) in_data[i] = data[i]; + input_x.set_lod({{0, 2, 4}}); + + TestSequencePoolOp(input_x, "SUM", &output); + std::vector expect_data{3, 7}; + for (int i = 0; i < 2; ++i) { + if (output.data()[i] != expect_data[i]) { + std::cerr << "output[" << i << "]: " << output.data()[i] + << " != expect[" << i << "]: " << expect_data[i] << std::endl; + return 1; + } + } + } + // case 2 + std::cerr << "running sum case 2" << std::endl; + { + std::vector data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + input_x.Resize(framework::make_ddim({data.size(), 1})); + float *in_data = input_x.mutable_data(); + for (int i = 0; i < data.size(); ++i) in_data[i] = data[i]; + input_x.set_lod({{0, 3, 10}}); + + TestSequencePoolOp(input_x, "SUM", &output); + std::vector expect_data{6, 49}; + for (int i = 0; i < 2; ++i) { + if (output.data()[i] != expect_data[i]) { + std::cerr << "output[" << i << "]: " << output.data()[i] + << " != expect[" << i << "]: " << expect_data[i] << std::endl; + return 1; + } + } + } + // case 3 + std::cerr << "running sum case 3" << std::endl; + { + std::vector data{1, 2, 3, 4, 5, 6, 7, 8}; + input_x.Resize(framework::make_ddim({4, 2})); + float *in_data = input_x.mutable_data(); + for (int i = 0; i < data.size(); ++i) in_data[i] = data[i]; + input_x.set_lod({{0, 2, 4}}); + + TestSequencePoolOp(input_x, "SUM", &output); + std::vector expect_data{4, 6, 12, 14}; + for (int i = 0; i < 4; ++i) { + if (output.data()[i] != expect_data[i]) { + std::cerr << "output[" << i << "]: " << output.data()[i] + << " != expect[" << i << "]: " << expect_data[i] << std::endl; + return 1; + } + } + } + // case 4 + std::cerr << "running sum case 4" << std::endl; + { + std::vector data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}; + input_x.Resize(framework::make_ddim({4, 5})); + float *in_data = input_x.mutable_data(); + for (int i = 0; i < data.size(); ++i) in_data[i] = data[i]; + input_x.set_lod({{0, 2, 4}}); + + TestSequencePoolOp(input_x, "SUM", &output); + std::vector expect_data{7, 9, 11, 13, 15, 27, 29, 31, 33, 35}; + for (int i = 0; i < 10; ++i) { + if (output.data()[i] != expect_data[i]) { + std::cerr << "output[" << i << "]: " << output.data()[i] + << " != expect[" << i << "]: " << expect_data[i] << std::endl; + return 1; + } + } + } + // case 1 + std::cerr << "running first case 1" << std::endl; + { + std::vector data{1, 2, 3, 4}; + input_x.Resize(framework::make_ddim({4, 1})); + float *in_data = input_x.mutable_data(); + for (int i = 0; i < 4; ++i) in_data[i] = data[i]; + input_x.set_lod({{0, 2, 4}}); + + TestSequencePoolOp(input_x, "FIRST", &output); + std::vector expect_data{1, 3}; + for (int i = 0; i < 2; ++i) { + if (output.data()[i] != expect_data[i]) { + std::cerr << "output[" << i << "]: " << output.data()[i] + << " != expect[" << i << "]: " << expect_data[i] << std::endl; + return 1; + } + } + } + // case 2 + std::cerr << "running first case 2" << std::endl; + { + std::vector data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + input_x.Resize(framework::make_ddim({data.size(), 1})); + float *in_data = input_x.mutable_data(); + for (int i = 0; i < data.size(); ++i) in_data[i] = data[i]; + input_x.set_lod({{0, 3, 10}}); + + TestSequencePoolOp(input_x, "FIRST", &output); + std::vector expect_data{1, 4}; + for (int i = 0; i < 2; ++i) { + if (output.data()[i] != expect_data[i]) { + std::cerr << "output[" << i << "]: " << output.data()[i] + << " != expect[" << i << "]: " << expect_data[i] << std::endl; + return 1; + } + } + } + // case 3 + std::cerr << "running first case 3" << std::endl; + { + std::vector data{1, 2, 3, 4, 5, 6, 7, 8}; + input_x.Resize(framework::make_ddim({4, 2})); + float *in_data = input_x.mutable_data(); + for (int i = 0; i < data.size(); ++i) in_data[i] = data[i]; + input_x.set_lod({{0, 2, 4}}); + + TestSequencePoolOp(input_x, "FIRST", &output); + std::vector expect_data{1, 2, 5, 6}; + for (int i = 0; i < 4; ++i) { + if (output.data()[i] != expect_data[i]) { + std::cerr << "output[" << i << "]: " << output.data()[i] + << " != expect[" << i << "]: " << expect_data[i] << std::endl; + return 1; + } + } + } + // case 4 + std::cerr << "running first case 4" << std::endl; + { + std::vector data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}; + input_x.Resize(framework::make_ddim({4, 5})); + float *in_data = input_x.mutable_data(); + for (int i = 0; i < data.size(); ++i) in_data[i] = data[i]; + input_x.set_lod({{0, 2, 4}}); + + TestSequencePoolOp(input_x, "FIRST", &output); + std::vector expect_data{1, 2, 3, 4, 5, 11, 12, 13, 14, 15}; + for (int i = 0; i < 10; ++i) { + if (output.data()[i] != expect_data[i]) { + std::cerr << "output[" << i << "]: " << output.data()[i] + << " != expect[" << i << "]: " << expect_data[i] << std::endl; + return 1; + } + } + } + return 0; +} diff --git a/test/operators/test_sequence_softmax_op.cpp b/test/operators/test_sequence_softmax_op.cpp new file mode 100644 index 0000000000..9e698d933a --- /dev/null +++ b/test/operators/test_sequence_softmax_op.cpp @@ -0,0 +1,100 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "../test_include.h" +#include "operators/sequence_ops/sequence_softmax_op.h" + +namespace paddle_mobile { + +void SequenceSoftmax(const framework::LoDTensor *X, framework::LoDTensor *Y) { + const float *x = X->data(); + const auto &lod = X->lod().back(); + float *y = Y->mutable_data(); + for (int batch = 0; batch < lod.size() - 1; ++batch) { + int num_classes = lod[batch + 1] - lod[batch]; + size_t offset = lod[batch]; + const float *input = x + offset; + float *output = y + offset; + float max = -std::numeric_limits::max(); + for (int j = 0; j < num_classes; ++j) { + max = (input[j] > max) ? input[j] : max; + } + float sum = 0.f; + for (int j = 0; j < num_classes; ++j) { + float tmp = std::expf(input[j] - max); + sum += tmp; + output[j] = tmp; + } + for (int j = 0; j < num_classes; ++j) { + output[j] /= sum; + } + } + Y->set_lod(X->lod()); +} + +int TestSequenceSoftmaxOp(const std::vector &input_shape, + const std::vector &input_lod) { + framework::DDim dims = framework::make_ddim(input_shape); + VariableNameMap inputs; + VariableNameMap outputs; + auto scope = std::make_shared(); + inputs["X"] = std::vector({"input"}); + outputs["Out"] = std::vector({"output"}); + + auto input_var = scope.get()->Var("input"); + auto input = input_var->template GetMutable(); + SetupTensor(input, dims, -100.0, 100.0); + input->set_lod({input_lod}); + + auto output_var = scope.get()->Var("output"); + + framework::AttributeMap attrs; + auto *op = new operators::SequenceSoftmaxOp( + "sequence_softmax", inputs, outputs, attrs, scope); + + op->InferShape(); + op->Init(); + op->Run(); + + auto output = output_var->template Get(); + + framework::LoDTensor output_cmp; + float *output_cmp_data = output_cmp.mutable_data(output->dims()); + SequenceSoftmax(input, &output_cmp); + + const float *output_data = output->data(); + for (int i = 0; i < output->numel(); ++i) { + float gap = output_data[i] - output_cmp_data[i]; + if (std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) { + LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i] + << ", output_cmp_data[" << i + << "] = " << output_cmp_data[i]; + delete op; + exit(1); + } + } + delete op; + return 0; +} + +} // namespace paddle_mobile + +int main(int argc, char *argv[]) { + TestSequenceSoftmaxOp({2, 1}, {0, 2}); + TestSequenceSoftmaxOp({100, 1}, {0, 3, 100}); + TestSequenceSoftmaxOp({100, 1}, {0, 50, 100}); + return 0; +} diff --git a/test/operators/test_softmax_op.cpp b/test/operators/test_softmax_op.cpp index d65cf4fea2..e94933eaa9 100644 --- a/test/operators/test_softmax_op.cpp +++ b/test/operators/test_softmax_op.cpp @@ -62,7 +62,6 @@ int TestSoftmaxOp(const std::vector input_shape) { SetupTensor(input, dims, -100.0, 100.0); auto output_var = scope.get()->Var("output"); - auto output = output_var->template Get(); framework::AttributeMap attrs; auto *op = new operators::SoftmaxOp("softmax", inputs, outputs, @@ -71,6 +70,8 @@ int TestSoftmaxOp(const std::vector input_shape) { op->Init(); op->Run(); + auto output = output_var->template Get(); + framework::Tensor output_cmp; float *output_cmp_data = output_cmp.mutable_data(output->dims()); Softmax(input, &output_cmp); diff --git a/tools/op.cmake b/tools/op.cmake index 60fdbacc2c..3563834e77 100644 --- a/tools/op.cmake +++ b/tools/op.cmake @@ -274,6 +274,7 @@ if(NOT FOUND_MATCH) set(FUSION_DEQUANT_ADD_BN_RELU_QUANT_OP ON) set(SEQUENCE_EXPAND_OP ON) set(SEQUENCE_POOL_OP ON) + set(SEQUENCE_SOFTMAX_OP ON) endif() # option(BATCHNORM_OP "" ON) @@ -504,6 +505,9 @@ endif() if (SEQUENCE_POOL_OP) add_definitions(-DSEQUENCE_POOL_OP) endif() +if (SEQUENCE_SOFTMAX_OP) + add_definitions(-DSEQUENCE_SOFTMAX_OP) +endif() if (TANH_OP) add_definitions(-DTANH_OP) -- GitLab