diff --git a/lite/backends/x86/math/CMakeLists.txt b/lite/backends/x86/math/CMakeLists.txt index 2dea4364d5ee2d11d6d266935fad2a1180954369..a89107632341cf063ac3166aa9890ff383e3383f 100644 --- a/lite/backends/x86/math/CMakeLists.txt +++ b/lite/backends/x86/math/CMakeLists.txt @@ -50,7 +50,8 @@ math_library(unpooling) math_library(vol2col) ## math_library(prelu) math_library(tree2col DEPS math_function) - +math_library(sequence_topk_avg_pooling) +math_library(search_fc DEPS blas dynload_mklml) # cc_test(math_function_test SRCS math_function_test.cc DEPS math_function) # cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selected_rows_functor) # cc_test(im2col_test SRCS im2col_test.cc DEPS im2col) diff --git a/lite/backends/x86/math/search_fc.cc b/lite/backends/x86/math/search_fc.cc new file mode 100644 index 0000000000000000000000000000000000000000..56fc363cb48ec5c58f4a7ee3e62a2e6bd7355021 --- /dev/null +++ b/lite/backends/x86/math/search_fc.cc @@ -0,0 +1,79 @@ +/* Copyright (c) 2018 paddlepaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "lite/backends/x86/math/search_fc.h" +#include +#include + +namespace paddle { +namespace lite { +namespace x86 { +namespace math { + +/* + * All tensors' dimension should be the same and the values of + * each dimension must be the same, except the axis dimension. + */ +template +class SearchFcFunctor { + public: + void operator()(const lite::X86Context& context, + const lite::Tensor& bottom, + const lite::Tensor& w, + const lite::Tensor& b, + lite::Tensor* top, + int out_size) { + int batch = bottom.dims()[0]; + + int _out = w.dims()[0]; // 100 + int _in = w.dims()[1]; // 228 + + lite::DDim dims(std::vector({bottom.dims()[0], out_size})); + + const auto bottom_data = bottom.data(); + auto top_data = top->mutable_data(lite::TargetType::kX86); + const auto weights = w.data(); + auto blas = math::GetBlas(context); + call_gemm(blas, + CblasNoTrans, + CblasTrans, + batch, + _out, + _in, + 1.0f, + bottom_data, + weights, + 0.0f, + top_data); + if (true) { + const auto* bias_data = b.data(); + for (int i = 0; i < batch; ++i) { + // add bias here + sse_eltadd(top_data + i * _out, bias_data, top_data + i * _out, _out); + } + } + } + + // private: +}; + +#define DEFINE_FUNCTOR(type) \ + template class SearchFcFunctor; + +FOR_ALL_TYPES(DEFINE_FUNCTOR); + +} // namespace math +} // namespace x86 +} // namespace lite +} // namespace paddle diff --git a/lite/backends/x86/math/search_fc.h b/lite/backends/x86/math/search_fc.h new file mode 100644 index 0000000000000000000000000000000000000000..e415c396023dbc10358992012197f4cfebac554f --- /dev/null +++ b/lite/backends/x86/math/search_fc.h @@ -0,0 +1,184 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "lite/backends/x86/math/blas.h" +#include "lite/backends/x86/mklml.h" +#include "lite/core/context.h" +#include "lite/core/tensor.h" +#include "lite/fluid/data_type.h" + +namespace paddle { +namespace lite { +namespace x86 { +namespace math { + +template +void call_gemm(const BlasT blas, + const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, + const int M, + const int N, + const int K, + const T alpha, + const T* A, + const T* B, + const T beta, + T* C) { +#ifndef __NAIVE_GEMM__ + int lda = (TransA == CblasNoTrans) ? K : M; + int ldb = (TransB == CblasNoTrans) ? N : K; + blas.GEMM(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, N); +#else + naive::gemm((TransA == CblasTrans), + (TransB == CblasTrans), + M, + N, + K, + alpha, + A, + B, + beta, + C); +#endif // !__NAIVE_GEMM__ +} + +// To align with Lego +#ifndef LEGO_USE_FLOAT +#define LEGO_USE_FLOAT +#endif +#ifndef LEGO_SSE +#define LEGO_SSE +#endif + +#if defined(LEGO_USE_FLOAT) + +#define __m256x __m256 +#define __m128x __m128 + +static const unsigned int AVX_STEP_SIZE = 8; +static const unsigned int SSE_STEP_SIZE = 4; +static const unsigned int AVX_CUT_LEN_MASK = 7U; +static const unsigned int SSE_CUT_LEN_MASK = 3U; + +#define _mm256_setzero_px _mm256_setzero_ps +#define _mm256_mul_px _mm256_mul_ps +#define _mm256_add_px _mm256_add_ps +#define _mm256_load_px _mm256_loadu_ps +#define _mm256_hadd_px _mm256_hadd_ps +#define _mm256_permute2f128_px _mm256_permute2f128_ps +#define _mm256_store_px _mm256_storeu_ps +#define _mm256_broadcast_sx _mm256_broadcast_ss +#define _mm256_castpx256_px128 _mm256_castps256_ps128 +#define _mm256_max_px _mm256_max_ps +#define _mm256_sub_px _mm256_sub_ps +#define _mm256_set1_px _mm256_set1_ps +#define _mm256_sqrt_px _mm256_sqrt_ps +#define _mm256_div_px _mm256_div_ps +#define _mm_setzero_px _mm_setzero_ps +#define _mm_add_px _mm_add_ps +#define _mm_mul_px _mm_mul_ps +#define _mm_load_px _mm_loadu_ps +#define _mm_hadd_px _mm_hadd_ps +#define _mm_store_sx _mm_store_ss +#define _mm_store_px _mm_storeu_ps +#define _mm_load1_px _mm_load1_ps +#define _mm_max_px _mm_max_ps +#define _mm_sub_px _mm_sub_ps +#define _mm_set1_px _mm_set1_ps +#define _mm_sqrt_px _mm_sqrt_ps +#define _mm_div_px _mm_div_ps + +#elif defined(LEGO_USE_DOUBLE) + +#define __m256x __m256d +#define __m128x __m128d + +static const unsigned int AVX_STEP_SIZE = 4; +static const unsigned int SSE_STEP_SIZE = 2; +static const unsigned int AVX_CUT_LEN_MASK = 3U; +static const unsigned int SSE_CUT_LEN_MASK = 1U; + +#define _mm256_setzero_px _mm256_setzero_pd +#define _mm256_mul_px _mm256_mul_pd +#define _mm256_add_px _mm256_add_pd +#define _mm256_load_px _mm256_loadu_pd +#define _mm256_hadd_px _mm256_hadd_pd +#define _mm256_permute2f128_px _mm256_permute2f128_pd +#define _mm256_store_px _mm256_storeu_pd +#define _mm256_broadcast_sx _mm256_broadcast_sd +#define _mm256_castpx256_px128 _mm256_castpd256_pd128 +#define _mm256_max_px _mm256_max_pd +#define _mm256_sub_px _mm256_sub_pd +#define _mm256_set1_px _mm256_set1_pd +#define _mm256_sqrt_px _mm256_sqrt_pd +#define _mm256_div_px _mm256_div_pd +#define _mm_setzero_px _mm_setzero_pd +#define _mm_add_px _mm_add_pd +#define _mm_mul_px _mm_mul_pd +#define _mm_load_px _mm_loadu_pd +#define _mm_hadd_px _mm_hadd_pd +#define _mm_store_sx _mm_store_sd +#define _mm_store_px _mm_storeu_pd +#define _mm_load1_px _mm_load1_pd +#define _mm_max_px _mm_max_pd +#define _mm_sub_px _mm_sub_pd +#define _mm_set1_px _mm_set1_pd +#define _mm_sqrt_px _mm_sqrt_pd +#define _mm_div_px _mm_div_pd +#endif + +template +inline void sse_eltadd(const T* x, const T* y, T* z, size_t len) { + unsigned int jjj, lll; + jjj = lll = 0; + +#if defined(LEGO_AVX) + lll = len & ~AVX_CUT_LEN_MASK; + for (jjj = 0; jjj < lll; jjj += AVX_STEP_SIZE) { + _mm256_store_px( + z + jjj, + _mm256_add_px(_mm256_load_px(x + jjj), _mm256_load_px(y + jjj))); + } +#elif defined(LEGO_SSE) + lll = len & ~SSE_CUT_LEN_MASK; + + for (jjj = 0; jjj < lll; jjj += SSE_STEP_SIZE) { + _mm_store_px(z + jjj, + _mm_add_px(_mm_load_px(x + jjj), _mm_load_px(y + jjj))); + } +#endif + for (; jjj < len; jjj++) { + z[jjj] = x[jjj] + y[jjj]; + } +} + +template +class SearchFcFunctor { + public: + void operator()(const lite::Context& context, + const lite::Tensor& X, + const lite::Tensor& W, + const lite::Tensor& b, + lite::Tensor* Out, + int out_size); +}; + +} // namespace math +} // namespace x86 +} // namespace lite +} // namespace paddle + +#define FOR_ALL_TYPES(macro) macro(float); diff --git a/lite/backends/x86/math/sequence_topk_avg_pooling.cc b/lite/backends/x86/math/sequence_topk_avg_pooling.cc new file mode 100644 index 0000000000000000000000000000000000000000..035a7923c70f91cf27f1d845f68110f8f33cb73d --- /dev/null +++ b/lite/backends/x86/math/sequence_topk_avg_pooling.cc @@ -0,0 +1,151 @@ +/* Copyright (c) 2018 paddlepaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "lite/backends/x86/math/sequence_topk_avg_pooling.h" +#include +#include + +namespace paddle { +namespace lite { +namespace x86 { +namespace math { + +template +void get_topk_pos(const T* data, int length, int k, int* pos, bool debug) { + size_t real_k = k < length ? k : length; + + std::vector v(data, data + length); + + std::vector topk_pos; + T min_val = -10000000.0; + while (topk_pos.size() < real_k) { + T max_val = min_val; + int max_pos = -1; + for (int i = 0; i < length; ++i) { + if (v[i] > max_val) { + max_pos = i; + max_val = v[i]; + } + } + + assert(max_pos >= 0); + + topk_pos.push_back(max_pos); + v[max_pos] = min_val; + } + + assert(topk_pos.size() > 0); + while (topk_pos.size() < (size_t)k) { + topk_pos.push_back(-1); + } + + for (size_t i = 0; i < topk_pos.size(); ++i) { + pos[i] = topk_pos[i]; + } +} + +/* + * All tensors' dimension should be the same and the values of + * each dimension must be the same, except the axis dimension. + */ +template +class SequenceTopkAvgPoolingFunctor { + public: + void operator()(const lite::Tensor& in, + const lite::Tensor& row, + const lite::Tensor& col, + lite::Tensor* out, + lite::Tensor* pos, + int channel_num, + std::vector topks) { + auto k_num = topks.size(); + auto max_k = topks[topks.size() - 1]; + std::vector vec_pos_shape; + auto in_lod = in.lod()[0]; + auto row_lod = row.lod()[0]; + auto col_lod = col.lod()[0]; + int batch_size = row_lod.size() - 1; + int pos_total_size = row_lod[batch_size] * channel_num * max_k; + vec_pos_shape.push_back(pos_total_size); + lite::DDim dims(vec_pos_shape); + pos->Resize(dims); + auto pos_data = pos->mutable_data(lite::TargetType::kX86); + + int offset = 0; + std::vector vec_out_lod; + vec_out_lod.reserve(batch_size + 1); + for (int i = 0; i <= batch_size; ++i) { + offset = row_lod[i]; + vec_out_lod.push_back(offset); + } + + lite::LoD lod_temp; + lod_temp.push_back(vec_out_lod); + out->set_lod(lod_temp); + + auto in_data = in.data(); + auto out_data = out->mutable_data(lite::TargetType::kX86); + + T* sum_data = new T[max_k]; + for (int i = 0; i < batch_size; ++i) { + int total_size = in_lod[i + 1] - in_lod[i]; + int row_size = row_lod[i + 1] - row_lod[i]; + int col_size = col_lod[i + 1] - col_lod[i]; + + CHECK_EQ(total_size, channel_num * row_size * col_size) + << "size wrong in sequence_topk_avg_pooling_op!"; + + int feature_num = row_size * col_size; + for (int j = 0; j < channel_num; ++j) { + auto input_offset_feature_data = in_data + in_lod[i] + j * feature_num; + + for (int r = 0; r < row_size; ++r) { + auto row_data = input_offset_feature_data + r * col_size; + auto pos_slice_data = pos_data + row_lod[i] * channel_num * max_k + + r * channel_num * max_k + j * max_k; + auto out_slice_data = out_data + row_lod[i] * channel_num * k_num + + r * channel_num * k_num + j * k_num; + + get_topk_pos(row_data, col_size, max_k, pos_slice_data); + if (pos_slice_data[0] == -1) { + sum_data[0] = 0.0; + } else { + sum_data[0] = row_data[pos_slice_data[0]]; + } + for (int k = 1; k < max_k; ++k) { + if (pos_slice_data[k] == -1) { + sum_data[k] = sum_data[k - 1]; + } else { + sum_data[k] = sum_data[k - 1] + row_data[pos_slice_data[k]]; + } + } + for (size_t k = 0; k < k_num; ++k) { + out_slice_data[k] = sum_data[topks[k] - 1] / topks[k]; + } + } + } + } + delete[] sum_data; + } +}; + +#define DEFINE_FUNCTOR(type) \ + template class SequenceTopkAvgPoolingFunctor; + +FOR_ALL_TYPES(DEFINE_FUNCTOR); + +} // namespace math +} // namespace x86 +} // namespace lite +} // namespace paddle diff --git a/lite/backends/x86/math/sequence_topk_avg_pooling.h b/lite/backends/x86/math/sequence_topk_avg_pooling.h new file mode 100644 index 0000000000000000000000000000000000000000..78d458c4d8fe0bf5a117cb5ad23d44bf0b7f3471 --- /dev/null +++ b/lite/backends/x86/math/sequence_topk_avg_pooling.h @@ -0,0 +1,46 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "lite/core/context.h" +#include "lite/core/tensor.h" +#include "lite/fluid/data_type.h" + +namespace paddle { +namespace lite { +namespace x86 { +namespace math { +template +void get_topk_pos( + const T* data, int length, int k, int* pos, bool debug = false); + +template +class SequenceTopkAvgPoolingFunctor { + public: + void operator()(const lite::Tensor& X, + const lite::Tensor& ROW, + const lite::Tensor& COLUMN, + lite::Tensor* Out, + lite::Tensor* pos, + int channel_num, + std::vector topks); +}; + +} // namespace math +} // namespace x86 +} // namespace lite +} // namespace paddle + +#define FOR_ALL_TYPES(macro) macro(float); diff --git a/lite/kernels/x86/CMakeLists.txt b/lite/kernels/x86/CMakeLists.txt index c21081d4b75b110e45e925691c7703ef085088c0..6976a64c171c35b6c443b7d4382bba810ad5109e 100644 --- a/lite/kernels/x86/CMakeLists.txt +++ b/lite/kernels/x86/CMakeLists.txt @@ -53,6 +53,8 @@ add_kernel(sequence_arithmetic_compute_x86 X86 basic SRCS sequence_arithmetic_co # for content-dnn specific add_kernel(search_aligned_mat_mul_compute_x86 X86 extra SRCS search_aligned_mat_mul_compute.cc DEPS ${lite_kernel_deps} blas) add_kernel(search_seq_fc_compute_x86 X86 extra SRCS search_seq_fc_compute.cc DEPS ${lite_kernel_deps} blas) +add_kernel(sequence_topk_avg_pooling_compute_x86 X86 basic SRCS sequence_topk_avg_pooling_compute.cc DEPS ${lite_kernel_deps} sequence_topk_avg_pooling) +add_kernel(search_fc_compute_x86 X86 basic SRCS search_fc_compute.cc DEPS ${lite_kernel_deps} search_fc) if(NOT LITE_WITH_X86) return() @@ -83,6 +85,7 @@ lite_cc_test(test_cast_compute_x86 SRCS cast_compute_test.cc DEPS cast_compute_x lite_cc_test(test_pool2d_compute_x86 SRCS pool_compute_test.cc DEPS pool_compute_x86) lite_cc_test(test_dropout_compute_x86 SRCS dropout_compute_test.cc DEPS dropout_compute_x86) lite_cc_test(test_transpose_compute_x86 SRCS transpose_compute_test.cc DEPS transpose_compute_x86) +lite_cc_test(test_search_fc_compute_x86 SRCS search_fc_compute_test.cc DEPS search_fc_compute_x86) lite_cc_test(test_search_seq_depadding_compute_x86 SRCS search_seq_depadding_compute_test.cc DEPS search_seq_depadding_compute_x86) if(LITE_BUILD_EXTRA) diff --git a/lite/kernels/x86/search_fc_compute.cc b/lite/kernels/x86/search_fc_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..cf76113e01d81e899250a60203680cd984746f19 --- /dev/null +++ b/lite/kernels/x86/search_fc_compute.cc @@ -0,0 +1,27 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/search_fc_compute.h" + +REGISTER_LITE_KERNEL(search_fc, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::SearchFcCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("W", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("b", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/lite/kernels/x86/search_fc_compute.h b/lite/kernels/x86/search_fc_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..0e61924151dd9a67ea23dbbd9d35187b458ec638 --- /dev/null +++ b/lite/kernels/x86/search_fc_compute.h @@ -0,0 +1,43 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "lite/backends/x86/math/search_fc.h" +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +#include "lite/core/types.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +class SearchFcCompute : public KernelLite { + public: + using param_t = operators::SearchFcParam; + void Run() override { + auto& context = ctx_->As(); + auto& param = *param_.get_mutable(); + + lite::x86::math::SearchFcFunctor search_fc; + search_fc(context, *param.X, *param.W, *param.b, param.Out, param.out_size); + } + virtual ~SearchFcCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/search_fc_compute_test.cc b/lite/kernels/x86/search_fc_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..425df2a0f0544d7345923cb2efdce96074845311 --- /dev/null +++ b/lite/kernels/x86/search_fc_compute_test.cc @@ -0,0 +1,122 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/search_fc_compute.h" +#include +#include +#include +#include +#include +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +void fc_cpu_base(const lite::Tensor* X, + const lite::Tensor* W, + const lite::Tensor* b, + int out_size, + lite::Tensor* Out) { + const float* data_in = X->data(); + const float* bias = b->data(); + const float* weights = W->data(); + float* data_out = Out->mutable_data(); + int out_rows = X->dims()[0]; + int in_cols = X->numel() / out_rows; + int out_cols = W->numel() / in_cols; + int index_out; + + for (int i = 0; i < out_rows; i++) { + for (int j = 0; j < out_cols; j++) { + index_out = i * out_cols + j; + data_out[index_out] = bias ? bias[j] : 0; + + for (int k = 0; k < in_cols; k++) { + data_out[index_out] += + data_in[i * in_cols + k] * weights[j * in_cols + k]; + } + } + } +} + +TEST(search_fc_x86, retrive_op) { + auto search_fc = + KernelRegistry::Global().Create( + "search_fc"); + ASSERT_FALSE(search_fc.empty()); + ASSERT_TRUE(search_fc.front()); +} + +TEST(search_fc_x86, init) { + SearchFcCompute search_fc; + ASSERT_EQ(search_fc.precision(), PRECISION(kFloat)); + ASSERT_EQ(search_fc.target(), TARGET(kX86)); +} + +TEST(search_fc_x86, run_test) { + lite::Tensor x, w, b, out; + lite::Tensor out_ref; + std::unique_ptr ctx(new KernelContext); + ctx->As(); + std::vector x_shape{1, 4}; + x.Resize(lite::DDim(x_shape)); + std::vector w_shape{3, 4}; + w.Resize(lite::DDim(w_shape)); + std::vector b_shape{3}; + b.Resize(lite::DDim(b_shape)); + std::vector out_shape{1, 4}; + out.Resize(lite::DDim(out_shape)); + out_ref.Resize(lite::DDim(out_shape)); + auto x_data = x.mutable_data(); + auto w_data = w.mutable_data(); + auto b_data = b.mutable_data(); + auto out_data = out.mutable_data(); + auto out_data_ref = out_ref.mutable_data(); + for (int64_t i = 0; i < x.dims().production(); i++) { + x_data[i] = static_cast(i); + } + for (int64_t i = 0; i < w.dims().production(); i++) { + w_data[i] = static_cast(i); + } + for (int64_t i = 0; i < b.dims().production(); i++) { + b_data[i] = static_cast(i); + } + + fc_cpu_base(&x, &w, &b, 4, &out_ref); + + SearchFcCompute fc; + operators::SearchFcParam param; + param.X = &x; + param.W = &w; + param.b = &b; + param.Out = &out; + param.out_size = 4; + fc.SetParam(param); + fc.SetContext(std::move(ctx)); + fc.Run(); + + VLOG(3) << "output vs ref"; + for (int i = 0; i < out.dims().production(); i++) { + EXPECT_NEAR(out_data[i], out_data_ref[i], 1e-5); + } +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(search_fc, kX86, kFloat, kNCHW, def); diff --git a/lite/kernels/x86/sequence_topk_avg_pooling_compute.cc b/lite/kernels/x86/sequence_topk_avg_pooling_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..9bd8b287507426798e0ec24f8854e812016b0054 --- /dev/null +++ b/lite/kernels/x86/sequence_topk_avg_pooling_compute.cc @@ -0,0 +1,29 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/sequence_topk_avg_pooling_compute.h" + +REGISTER_LITE_KERNEL( + sequence_topk_avg_pooling, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::SequenceTopkAvgPoolingCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("ROW", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("COLUMN", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("pos", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/lite/kernels/x86/sequence_topk_avg_pooling_compute.h b/lite/kernels/x86/sequence_topk_avg_pooling_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..724415288a72932392d5726778830095c8810e15 --- /dev/null +++ b/lite/kernels/x86/sequence_topk_avg_pooling_compute.h @@ -0,0 +1,50 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "lite/backends/x86/math/sequence_topk_avg_pooling.h" +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +#include "lite/core/types.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +class SequenceTopkAvgPoolingCompute + : public KernelLite { + public: + using param_t = operators::SequenceTopkAvgPoolingParam; + + void Run() override { + auto& param = *param_.get_mutable(); + lite::x86::math::SequenceTopkAvgPoolingFunctor + sequence_topk_avg_pooling; + sequence_topk_avg_pooling(*param.X, + *param.ROW, + *param.COLUMN, + param.Out, + param.pos, + param.channel_num, + param.topks); + }; + virtual ~SequenceTopkAvgPoolingCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index 1fe7ec726611b8bbbec55c8d87e0af2ff6e8285c..4f4b13f9318060660e14e587559db5a679e93cb9 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -116,10 +116,11 @@ add_operator(topk_op extra SRCS topk_op.cc DEPS ${op_DEPS}) add_operator(increment_op extra SRCS increment_op.cc DEPS ${op_DEPS}) add_operator(layer_norm_op extra SRCS layer_norm_op.cc DEPS ${op_DEPS}) add_operator(sequence_softmax_op extra SRCS sequence_softmax_op.cc DEPS ${op_DEPS}) - # for content-dnn specific add_operator(search_aligned_mat_mul_op extra SRCS search_aligned_mat_mul_op.cc DEPS ${op_DEPS}) add_operator(search_seq_fc_op extra SRCS search_seq_fc_op.cc DEPS ${op_DEPS}) +add_operator(sequence_topk_avg_pooling_op basic SRCS sequence_topk_avg_pooling_op.cc DEPS ${op_DEPS}) +add_operator(search_fc_op basic SRCS search_fc_op.cc DEPS ${op_DEPS}) if (NOT LITE_WITH_X86) lite_cc_test(test_fc_op SRCS fc_op_test.cc diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 73e65e42033c9ed35bd8d9969250ef48ed6b36be..e29bc5921697e9af7c9c495d471231c4e9aee0c6 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -983,6 +983,25 @@ struct AssignValueParam { lite::Tensor* Out{}; }; +/// --------------- sequence_topk_avg_pooling operators ------------------ +struct SequenceTopkAvgPoolingParam { + const lite::Tensor* X{}; + const lite::Tensor* ROW{}; + const lite::Tensor* COLUMN{}; + lite::Tensor* Out{}; + lite::Tensor* pos{}; + int channel_num{}; + std::vector topks{}; +}; + +/// --------------- search_fc operators ------------------ +struct SearchFcParam { + const lite::Tensor* X{}; + const lite::Tensor* W{}; + const lite::Tensor* b{}; + lite::Tensor* Out{}; + int out_size{}; +}; /// --------------------- match_matrix_tensor operators -------------------- struct MatchMatrixTensorParam { const lite::Tensor* x{}; diff --git a/lite/operators/search_fc_op.cc b/lite/operators/search_fc_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..50d09f602b1e42366ad598c3805c9d5726d2ab78 --- /dev/null +++ b/lite/operators/search_fc_op.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/search_fc_op.h" +#include "lite/core/op_lite.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool SearchFcOpLite::CheckShape() const { + CHECK_OR_FALSE(param_.X); + CHECK_OR_FALSE(param_.W); + CHECK_OR_FALSE(param_.b); + CHECK_OR_FALSE(param_.Out); + + auto x_dims = param_.X->dims(); + CHECK_EQ(x_dims.size(), 2) << "The rank of X(Input) should be 2."; + auto w_dims = param_.W->dims(); + CHECK_EQ(w_dims.size(), 2) << "W should be 2-D tensor."; + auto b_dims = param_.b->dims(); + CHECK_EQ(b_dims.size(), 1) << "b should be 1-D tensor."; + CHECK_EQ(w_dims[1], x_dims[1]) << "wrong shape: w_dims[1] != x_dims[1]"; + return true; +} + +bool SearchFcOpLite::InferShape() const { + auto out_size = param_.out_size; + lite::DDim dims(std::vector({-1, out_size})); + param_.Out->Resize(dims); + return true; +} + +bool SearchFcOpLite::AttachImpl(const cpp::OpDesc &op_desc, + lite::Scope *scope) { + auto X = op_desc.Input("X").front(); + auto W = op_desc.Input("W").front(); + auto b = op_desc.Input("b").front(); + auto Out = op_desc.Output("Out").front(); + + param_.X = scope->FindVar(X)->GetMutable(); + param_.W = scope->FindVar(W)->GetMutable(); + param_.b = scope->FindVar(b)->GetMutable(); + param_.Out = scope->FindVar(Out)->GetMutable(); + param_.out_size = op_desc.GetAttr("out_size"); + + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(SearchFc, paddle::lite::operators::SearchFcOpLite); diff --git a/lite/operators/search_fc_op.h b/lite/operators/search_fc_op.h new file mode 100644 index 0000000000000000000000000000000000000000..a871cadd33b4f7d4b6130a0b8ac2974a738ac0c3 --- /dev/null +++ b/lite/operators/search_fc_op.h @@ -0,0 +1,46 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class SearchFcOpLite : public OpLite { + public: + SearchFcOpLite() {} + explicit SearchFcOpLite(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "search_fc"; } + + private: + mutable SearchFcParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/sequence_topk_avg_pooling_op.cc b/lite/operators/sequence_topk_avg_pooling_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..384d13711285566bf99fcc43b81e5e81d86dc35e --- /dev/null +++ b/lite/operators/sequence_topk_avg_pooling_op.cc @@ -0,0 +1,86 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/sequence_topk_avg_pooling_op.h" +#include "lite/core/op_lite.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool SequenceTopkAvgPoolingOpLite::CheckShape() const { + CHECK_OR_FALSE(param_.X); + CHECK_OR_FALSE(param_.ROW); + CHECK_OR_FALSE(param_.COLUMN); + CHECK_OR_FALSE(param_.Out); + CHECK_OR_FALSE(param_.pos); + return true; +} + +bool SequenceTopkAvgPoolingOpLite::InferShape() const { + int channel_num = param_.channel_num; + std::vector topks = param_.topks; + auto row_dim = param_.ROW->dims(); + auto num_k = topks.size(); + auto row_shape_0 = row_dim[0]; + std::vector vec_out_shape; + vec_out_shape.push_back(row_shape_0); + vec_out_shape.push_back(channel_num * num_k); + + param_.Out->Resize(lite::DDim(vec_out_shape)); + auto out_lod = param_.Out->mutable_lod(); + *out_lod = param_.X->lod(); + return true; +} + +bool SequenceTopkAvgPoolingOpLite::AttachImpl(const cpp::OpDesc &op_desc, + lite::Scope *scope) { + auto X = op_desc.Input("X").front(); + auto ROW = op_desc.Input("ROW").front(); + auto COLUMN = op_desc.Input("COLUMN").front(); + auto Out = op_desc.Output("Out").front(); + auto pos = op_desc.Output("pos").front(); + + param_.X = scope->FindVar(X)->GetMutable(); + param_.ROW = scope->FindVar(ROW)->GetMutable(); + param_.COLUMN = scope->FindVar(COLUMN)->GetMutable(); + param_.Out = scope->FindVar(Out)->GetMutable(); + param_.pos = scope->FindVar(pos)->GetMutable(); + param_.channel_num = op_desc.GetAttr("channel_num"); + param_.topks = op_desc.GetAttr>("topks"); + + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(SequenceTopkAvgPooling, + paddle::lite::operators::SequenceTopkAvgPoolingOpLite); diff --git a/lite/operators/sequence_topk_avg_pooling_op.h b/lite/operators/sequence_topk_avg_pooling_op.h new file mode 100644 index 0000000000000000000000000000000000000000..1c1cfe3a9c7bc82c3e79fc372b98293183509dca --- /dev/null +++ b/lite/operators/sequence_topk_avg_pooling_op.h @@ -0,0 +1,49 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class SequenceTopkAvgPoolingOpLite : public OpLite { + public: + SequenceTopkAvgPoolingOpLite() {} + explicit SequenceTopkAvgPoolingOpLite(const std::string &op_type) + : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { + return "sequence_topk_avg_pooling"; + } + + private: + mutable SequenceTopkAvgPoolingParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle