提交 7c6a9495 编写于 作者: H huzhiqiang 提交者: GitHub

add x86 kernels: search_fc and sequence_topk_ave_pooling (#2443)

* add x86 op and kernel : search_fc and sequence_topk_avg_pooling   for content-dnn model test=develop
上级 4048c261
......@@ -50,7 +50,8 @@ math_library(unpooling)
math_library(vol2col)
## math_library(prelu)
math_library(tree2col DEPS math_function)
math_library(sequence_topk_avg_pooling)
math_library(search_fc DEPS blas dynload_mklml)
# cc_test(math_function_test SRCS math_function_test.cc DEPS math_function)
# cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selected_rows_functor)
# cc_test(im2col_test SRCS im2col_test.cc DEPS im2col)
......
/* Copyright (c) 2018 paddlepaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "lite/backends/x86/math/search_fc.h"
#include <algorithm>
#include <vector>
namespace paddle {
namespace lite {
namespace x86 {
namespace math {
/*
* All tensors' dimension should be the same and the values of
* each dimension must be the same, except the axis dimension.
*/
template <typename T>
class SearchFcFunctor<lite::TargetType::kX86, T> {
public:
void operator()(const lite::X86Context& context,
const lite::Tensor& bottom,
const lite::Tensor& w,
const lite::Tensor& b,
lite::Tensor* top,
int out_size) {
int batch = bottom.dims()[0];
int _out = w.dims()[0]; // 100
int _in = w.dims()[1]; // 228
lite::DDim dims(std::vector<int64_t>({bottom.dims()[0], out_size}));
const auto bottom_data = bottom.data<T>();
auto top_data = top->mutable_data<T>(lite::TargetType::kX86);
const auto weights = w.data<T>();
auto blas = math::GetBlas<lite::TargetType::kX86, T>(context);
call_gemm<lite::X86Context, T>(blas,
CblasNoTrans,
CblasTrans,
batch,
_out,
_in,
1.0f,
bottom_data,
weights,
0.0f,
top_data);
if (true) {
const auto* bias_data = b.data<T>();
for (int i = 0; i < batch; ++i) {
// add bias here
sse_eltadd(top_data + i * _out, bias_data, top_data + i * _out, _out);
}
}
}
// private:
};
#define DEFINE_FUNCTOR(type) \
template class SearchFcFunctor<lite::TargetType::kX86, type>;
FOR_ALL_TYPES(DEFINE_FUNCTOR);
} // namespace math
} // namespace x86
} // namespace lite
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "lite/backends/x86/math/blas.h"
#include "lite/backends/x86/mklml.h"
#include "lite/core/context.h"
#include "lite/core/tensor.h"
#include "lite/fluid/data_type.h"
namespace paddle {
namespace lite {
namespace x86 {
namespace math {
template <typename DeviceContext, typename T>
void call_gemm(const BlasT<lite::TargetType::kX86, T> blas,
const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB,
const int M,
const int N,
const int K,
const T alpha,
const T* A,
const T* B,
const T beta,
T* C) {
#ifndef __NAIVE_GEMM__
int lda = (TransA == CblasNoTrans) ? K : M;
int ldb = (TransB == CblasNoTrans) ? N : K;
blas.GEMM(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, N);
#else
naive::gemm((TransA == CblasTrans),
(TransB == CblasTrans),
M,
N,
K,
alpha,
A,
B,
beta,
C);
#endif // !__NAIVE_GEMM__
}
// To align with Lego
#ifndef LEGO_USE_FLOAT
#define LEGO_USE_FLOAT
#endif
#ifndef LEGO_SSE
#define LEGO_SSE
#endif
#if defined(LEGO_USE_FLOAT)
#define __m256x __m256
#define __m128x __m128
static const unsigned int AVX_STEP_SIZE = 8;
static const unsigned int SSE_STEP_SIZE = 4;
static const unsigned int AVX_CUT_LEN_MASK = 7U;
static const unsigned int SSE_CUT_LEN_MASK = 3U;
#define _mm256_setzero_px _mm256_setzero_ps
#define _mm256_mul_px _mm256_mul_ps
#define _mm256_add_px _mm256_add_ps
#define _mm256_load_px _mm256_loadu_ps
#define _mm256_hadd_px _mm256_hadd_ps
#define _mm256_permute2f128_px _mm256_permute2f128_ps
#define _mm256_store_px _mm256_storeu_ps
#define _mm256_broadcast_sx _mm256_broadcast_ss
#define _mm256_castpx256_px128 _mm256_castps256_ps128
#define _mm256_max_px _mm256_max_ps
#define _mm256_sub_px _mm256_sub_ps
#define _mm256_set1_px _mm256_set1_ps
#define _mm256_sqrt_px _mm256_sqrt_ps
#define _mm256_div_px _mm256_div_ps
#define _mm_setzero_px _mm_setzero_ps
#define _mm_add_px _mm_add_ps
#define _mm_mul_px _mm_mul_ps
#define _mm_load_px _mm_loadu_ps
#define _mm_hadd_px _mm_hadd_ps
#define _mm_store_sx _mm_store_ss
#define _mm_store_px _mm_storeu_ps
#define _mm_load1_px _mm_load1_ps
#define _mm_max_px _mm_max_ps
#define _mm_sub_px _mm_sub_ps
#define _mm_set1_px _mm_set1_ps
#define _mm_sqrt_px _mm_sqrt_ps
#define _mm_div_px _mm_div_ps
#elif defined(LEGO_USE_DOUBLE)
#define __m256x __m256d
#define __m128x __m128d
static const unsigned int AVX_STEP_SIZE = 4;
static const unsigned int SSE_STEP_SIZE = 2;
static const unsigned int AVX_CUT_LEN_MASK = 3U;
static const unsigned int SSE_CUT_LEN_MASK = 1U;
#define _mm256_setzero_px _mm256_setzero_pd
#define _mm256_mul_px _mm256_mul_pd
#define _mm256_add_px _mm256_add_pd
#define _mm256_load_px _mm256_loadu_pd
#define _mm256_hadd_px _mm256_hadd_pd
#define _mm256_permute2f128_px _mm256_permute2f128_pd
#define _mm256_store_px _mm256_storeu_pd
#define _mm256_broadcast_sx _mm256_broadcast_sd
#define _mm256_castpx256_px128 _mm256_castpd256_pd128
#define _mm256_max_px _mm256_max_pd
#define _mm256_sub_px _mm256_sub_pd
#define _mm256_set1_px _mm256_set1_pd
#define _mm256_sqrt_px _mm256_sqrt_pd
#define _mm256_div_px _mm256_div_pd
#define _mm_setzero_px _mm_setzero_pd
#define _mm_add_px _mm_add_pd
#define _mm_mul_px _mm_mul_pd
#define _mm_load_px _mm_loadu_pd
#define _mm_hadd_px _mm_hadd_pd
#define _mm_store_sx _mm_store_sd
#define _mm_store_px _mm_storeu_pd
#define _mm_load1_px _mm_load1_pd
#define _mm_max_px _mm_max_pd
#define _mm_sub_px _mm_sub_pd
#define _mm_set1_px _mm_set1_pd
#define _mm_sqrt_px _mm_sqrt_pd
#define _mm_div_px _mm_div_pd
#endif
template <typename T>
inline void sse_eltadd(const T* x, const T* y, T* z, size_t len) {
unsigned int jjj, lll;
jjj = lll = 0;
#if defined(LEGO_AVX)
lll = len & ~AVX_CUT_LEN_MASK;
for (jjj = 0; jjj < lll; jjj += AVX_STEP_SIZE) {
_mm256_store_px(
z + jjj,
_mm256_add_px(_mm256_load_px(x + jjj), _mm256_load_px(y + jjj)));
}
#elif defined(LEGO_SSE)
lll = len & ~SSE_CUT_LEN_MASK;
for (jjj = 0; jjj < lll; jjj += SSE_STEP_SIZE) {
_mm_store_px(z + jjj,
_mm_add_px(_mm_load_px(x + jjj), _mm_load_px(y + jjj)));
}
#endif
for (; jjj < len; jjj++) {
z[jjj] = x[jjj] + y[jjj];
}
}
template <lite::TargetType Target, typename T>
class SearchFcFunctor {
public:
void operator()(const lite::Context<Target>& context,
const lite::Tensor& X,
const lite::Tensor& W,
const lite::Tensor& b,
lite::Tensor* Out,
int out_size);
};
} // namespace math
} // namespace x86
} // namespace lite
} // namespace paddle
#define FOR_ALL_TYPES(macro) macro(float);
/* Copyright (c) 2018 paddlepaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "lite/backends/x86/math/sequence_topk_avg_pooling.h"
#include <algorithm>
#include <vector>
namespace paddle {
namespace lite {
namespace x86 {
namespace math {
template <typename T>
void get_topk_pos(const T* data, int length, int k, int* pos, bool debug) {
size_t real_k = k < length ? k : length;
std::vector<T> v(data, data + length);
std::vector<int> topk_pos;
T min_val = -10000000.0;
while (topk_pos.size() < real_k) {
T max_val = min_val;
int max_pos = -1;
for (int i = 0; i < length; ++i) {
if (v[i] > max_val) {
max_pos = i;
max_val = v[i];
}
}
assert(max_pos >= 0);
topk_pos.push_back(max_pos);
v[max_pos] = min_val;
}
assert(topk_pos.size() > 0);
while (topk_pos.size() < (size_t)k) {
topk_pos.push_back(-1);
}
for (size_t i = 0; i < topk_pos.size(); ++i) {
pos[i] = topk_pos[i];
}
}
/*
* All tensors' dimension should be the same and the values of
* each dimension must be the same, except the axis dimension.
*/
template <typename T>
class SequenceTopkAvgPoolingFunctor<lite::TargetType::kX86, T> {
public:
void operator()(const lite::Tensor& in,
const lite::Tensor& row,
const lite::Tensor& col,
lite::Tensor* out,
lite::Tensor* pos,
int channel_num,
std::vector<int> topks) {
auto k_num = topks.size();
auto max_k = topks[topks.size() - 1];
std::vector<int64_t> vec_pos_shape;
auto in_lod = in.lod()[0];
auto row_lod = row.lod()[0];
auto col_lod = col.lod()[0];
int batch_size = row_lod.size() - 1;
int pos_total_size = row_lod[batch_size] * channel_num * max_k;
vec_pos_shape.push_back(pos_total_size);
lite::DDim dims(vec_pos_shape);
pos->Resize(dims);
auto pos_data = pos->mutable_data<int>(lite::TargetType::kX86);
int offset = 0;
std::vector<size_t> vec_out_lod;
vec_out_lod.reserve(batch_size + 1);
for (int i = 0; i <= batch_size; ++i) {
offset = row_lod[i];
vec_out_lod.push_back(offset);
}
lite::LoD lod_temp;
lod_temp.push_back(vec_out_lod);
out->set_lod(lod_temp);
auto in_data = in.data<T>();
auto out_data = out->mutable_data<T>(lite::TargetType::kX86);
T* sum_data = new T[max_k];
for (int i = 0; i < batch_size; ++i) {
int total_size = in_lod[i + 1] - in_lod[i];
int row_size = row_lod[i + 1] - row_lod[i];
int col_size = col_lod[i + 1] - col_lod[i];
CHECK_EQ(total_size, channel_num * row_size * col_size)
<< "size wrong in sequence_topk_avg_pooling_op!";
int feature_num = row_size * col_size;
for (int j = 0; j < channel_num; ++j) {
auto input_offset_feature_data = in_data + in_lod[i] + j * feature_num;
for (int r = 0; r < row_size; ++r) {
auto row_data = input_offset_feature_data + r * col_size;
auto pos_slice_data = pos_data + row_lod[i] * channel_num * max_k +
r * channel_num * max_k + j * max_k;
auto out_slice_data = out_data + row_lod[i] * channel_num * k_num +
r * channel_num * k_num + j * k_num;
get_topk_pos<T>(row_data, col_size, max_k, pos_slice_data);
if (pos_slice_data[0] == -1) {
sum_data[0] = 0.0;
} else {
sum_data[0] = row_data[pos_slice_data[0]];
}
for (int k = 1; k < max_k; ++k) {
if (pos_slice_data[k] == -1) {
sum_data[k] = sum_data[k - 1];
} else {
sum_data[k] = sum_data[k - 1] + row_data[pos_slice_data[k]];
}
}
for (size_t k = 0; k < k_num; ++k) {
out_slice_data[k] = sum_data[topks[k] - 1] / topks[k];
}
}
}
}
delete[] sum_data;
}
};
#define DEFINE_FUNCTOR(type) \
template class SequenceTopkAvgPoolingFunctor<lite::TargetType::kX86, type>;
FOR_ALL_TYPES(DEFINE_FUNCTOR);
} // namespace math
} // namespace x86
} // namespace lite
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "lite/core/context.h"
#include "lite/core/tensor.h"
#include "lite/fluid/data_type.h"
namespace paddle {
namespace lite {
namespace x86 {
namespace math {
template <typename T>
void get_topk_pos(
const T* data, int length, int k, int* pos, bool debug = false);
template <lite::TargetType Target, typename T>
class SequenceTopkAvgPoolingFunctor {
public:
void operator()(const lite::Tensor& X,
const lite::Tensor& ROW,
const lite::Tensor& COLUMN,
lite::Tensor* Out,
lite::Tensor* pos,
int channel_num,
std::vector<int> topks);
};
} // namespace math
} // namespace x86
} // namespace lite
} // namespace paddle
#define FOR_ALL_TYPES(macro) macro(float);
......@@ -53,6 +53,8 @@ add_kernel(sequence_arithmetic_compute_x86 X86 basic SRCS sequence_arithmetic_co
# for content-dnn specific
add_kernel(search_aligned_mat_mul_compute_x86 X86 extra SRCS search_aligned_mat_mul_compute.cc DEPS ${lite_kernel_deps} blas)
add_kernel(search_seq_fc_compute_x86 X86 extra SRCS search_seq_fc_compute.cc DEPS ${lite_kernel_deps} blas)
add_kernel(sequence_topk_avg_pooling_compute_x86 X86 basic SRCS sequence_topk_avg_pooling_compute.cc DEPS ${lite_kernel_deps} sequence_topk_avg_pooling)
add_kernel(search_fc_compute_x86 X86 basic SRCS search_fc_compute.cc DEPS ${lite_kernel_deps} search_fc)
if(NOT LITE_WITH_X86)
return()
......@@ -83,6 +85,7 @@ lite_cc_test(test_cast_compute_x86 SRCS cast_compute_test.cc DEPS cast_compute_x
lite_cc_test(test_pool2d_compute_x86 SRCS pool_compute_test.cc DEPS pool_compute_x86)
lite_cc_test(test_dropout_compute_x86 SRCS dropout_compute_test.cc DEPS dropout_compute_x86)
lite_cc_test(test_transpose_compute_x86 SRCS transpose_compute_test.cc DEPS transpose_compute_x86)
lite_cc_test(test_search_fc_compute_x86 SRCS search_fc_compute_test.cc DEPS search_fc_compute_x86)
lite_cc_test(test_search_seq_depadding_compute_x86 SRCS search_seq_depadding_compute_test.cc DEPS search_seq_depadding_compute_x86)
if(LITE_BUILD_EXTRA)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/search_fc_compute.h"
REGISTER_LITE_KERNEL(search_fc,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::SearchFcCompute<float>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("W", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("b", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/backends/x86/math/search_fc.h"
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/core/types.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T>
class SearchFcCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::SearchFcParam;
void Run() override {
auto& context = ctx_->As<X86Context>();
auto& param = *param_.get_mutable<param_t>();
lite::x86::math::SearchFcFunctor<lite::TargetType::kX86, T> search_fc;
search_fc(context, *param.X, *param.W, *param.b, param.Out, param.out_size);
}
virtual ~SearchFcCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/search_fc_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
void fc_cpu_base(const lite::Tensor* X,
const lite::Tensor* W,
const lite::Tensor* b,
int out_size,
lite::Tensor* Out) {
const float* data_in = X->data<float>();
const float* bias = b->data<float>();
const float* weights = W->data<float>();
float* data_out = Out->mutable_data<float>();
int out_rows = X->dims()[0];
int in_cols = X->numel() / out_rows;
int out_cols = W->numel() / in_cols;
int index_out;
for (int i = 0; i < out_rows; i++) {
for (int j = 0; j < out_cols; j++) {
index_out = i * out_cols + j;
data_out[index_out] = bias ? bias[j] : 0;
for (int k = 0; k < in_cols; k++) {
data_out[index_out] +=
data_in[i * in_cols + k] * weights[j * in_cols + k];
}
}
}
}
TEST(search_fc_x86, retrive_op) {
auto search_fc =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"search_fc");
ASSERT_FALSE(search_fc.empty());
ASSERT_TRUE(search_fc.front());
}
TEST(search_fc_x86, init) {
SearchFcCompute<float> search_fc;
ASSERT_EQ(search_fc.precision(), PRECISION(kFloat));
ASSERT_EQ(search_fc.target(), TARGET(kX86));
}
TEST(search_fc_x86, run_test) {
lite::Tensor x, w, b, out;
lite::Tensor out_ref;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
std::vector<int64_t> x_shape{1, 4};
x.Resize(lite::DDim(x_shape));
std::vector<int64_t> w_shape{3, 4};
w.Resize(lite::DDim(w_shape));
std::vector<int64_t> b_shape{3};
b.Resize(lite::DDim(b_shape));
std::vector<int64_t> out_shape{1, 4};
out.Resize(lite::DDim(out_shape));
out_ref.Resize(lite::DDim(out_shape));
auto x_data = x.mutable_data<float>();
auto w_data = w.mutable_data<float>();
auto b_data = b.mutable_data<float>();
auto out_data = out.mutable_data<float>();
auto out_data_ref = out_ref.mutable_data<float>();
for (int64_t i = 0; i < x.dims().production(); i++) {
x_data[i] = static_cast<float>(i);
}
for (int64_t i = 0; i < w.dims().production(); i++) {
w_data[i] = static_cast<float>(i);
}
for (int64_t i = 0; i < b.dims().production(); i++) {
b_data[i] = static_cast<float>(i);
}
fc_cpu_base(&x, &w, &b, 4, &out_ref);
SearchFcCompute<float> fc;
operators::SearchFcParam param;
param.X = &x;
param.W = &w;
param.b = &b;
param.Out = &out;
param.out_size = 4;
fc.SetParam(param);
fc.SetContext(std::move(ctx));
fc.Run();
VLOG(3) << "output vs ref";
for (int i = 0; i < out.dims().production(); i++) {
EXPECT_NEAR(out_data[i], out_data_ref[i], 1e-5);
}
}
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(search_fc, kX86, kFloat, kNCHW, def);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/sequence_topk_avg_pooling_compute.h"
REGISTER_LITE_KERNEL(
sequence_topk_avg_pooling,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::SequenceTopkAvgPoolingCompute<float>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("ROW", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("COLUMN", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("pos", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/backends/x86/math/sequence_topk_avg_pooling.h"
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/core/types.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T>
class SequenceTopkAvgPoolingCompute
: public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::SequenceTopkAvgPoolingParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
lite::x86::math::SequenceTopkAvgPoolingFunctor<lite::TargetType::kX86, T>
sequence_topk_avg_pooling;
sequence_topk_avg_pooling(*param.X,
*param.ROW,
*param.COLUMN,
param.Out,
param.pos,
param.channel_num,
param.topks);
};
virtual ~SequenceTopkAvgPoolingCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -116,10 +116,11 @@ add_operator(topk_op extra SRCS topk_op.cc DEPS ${op_DEPS})
add_operator(increment_op extra SRCS increment_op.cc DEPS ${op_DEPS})
add_operator(layer_norm_op extra SRCS layer_norm_op.cc DEPS ${op_DEPS})
add_operator(sequence_softmax_op extra SRCS sequence_softmax_op.cc DEPS ${op_DEPS})
# for content-dnn specific
add_operator(search_aligned_mat_mul_op extra SRCS search_aligned_mat_mul_op.cc DEPS ${op_DEPS})
add_operator(search_seq_fc_op extra SRCS search_seq_fc_op.cc DEPS ${op_DEPS})
add_operator(sequence_topk_avg_pooling_op basic SRCS sequence_topk_avg_pooling_op.cc DEPS ${op_DEPS})
add_operator(search_fc_op basic SRCS search_fc_op.cc DEPS ${op_DEPS})
if (NOT LITE_WITH_X86)
lite_cc_test(test_fc_op SRCS fc_op_test.cc
......
......@@ -983,6 +983,25 @@ struct AssignValueParam {
lite::Tensor* Out{};
};
/// --------------- sequence_topk_avg_pooling operators ------------------
struct SequenceTopkAvgPoolingParam {
const lite::Tensor* X{};
const lite::Tensor* ROW{};
const lite::Tensor* COLUMN{};
lite::Tensor* Out{};
lite::Tensor* pos{};
int channel_num{};
std::vector<int> topks{};
};
/// --------------- search_fc operators ------------------
struct SearchFcParam {
const lite::Tensor* X{};
const lite::Tensor* W{};
const lite::Tensor* b{};
lite::Tensor* Out{};
int out_size{};
};
/// --------------------- match_matrix_tensor operators --------------------
struct MatchMatrixTensorParam {
const lite::Tensor* x{};
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/search_fc_op.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool SearchFcOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.W);
CHECK_OR_FALSE(param_.b);
CHECK_OR_FALSE(param_.Out);
auto x_dims = param_.X->dims();
CHECK_EQ(x_dims.size(), 2) << "The rank of X(Input) should be 2.";
auto w_dims = param_.W->dims();
CHECK_EQ(w_dims.size(), 2) << "W should be 2-D tensor.";
auto b_dims = param_.b->dims();
CHECK_EQ(b_dims.size(), 1) << "b should be 1-D tensor.";
CHECK_EQ(w_dims[1], x_dims[1]) << "wrong shape: w_dims[1] != x_dims[1]";
return true;
}
bool SearchFcOpLite::InferShape() const {
auto out_size = param_.out_size;
lite::DDim dims(std::vector<int64_t>({-1, out_size}));
param_.Out->Resize(dims);
return true;
}
bool SearchFcOpLite::AttachImpl(const cpp::OpDesc &op_desc,
lite::Scope *scope) {
auto X = op_desc.Input("X").front();
auto W = op_desc.Input("W").front();
auto b = op_desc.Input("b").front();
auto Out = op_desc.Output("Out").front();
param_.X = scope->FindVar(X)->GetMutable<lite::Tensor>();
param_.W = scope->FindVar(W)->GetMutable<lite::Tensor>();
param_.b = scope->FindVar(b)->GetMutable<lite::Tensor>();
param_.Out = scope->FindVar(Out)->GetMutable<lite::Tensor>();
param_.out_size = op_desc.GetAttr<int>("out_size");
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(SearchFc, paddle::lite::operators::SearchFcOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class SearchFcOpLite : public OpLite {
public:
SearchFcOpLite() {}
explicit SearchFcOpLite(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "search_fc"; }
private:
mutable SearchFcParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/sequence_topk_avg_pooling_op.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool SequenceTopkAvgPoolingOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.ROW);
CHECK_OR_FALSE(param_.COLUMN);
CHECK_OR_FALSE(param_.Out);
CHECK_OR_FALSE(param_.pos);
return true;
}
bool SequenceTopkAvgPoolingOpLite::InferShape() const {
int channel_num = param_.channel_num;
std::vector<int> topks = param_.topks;
auto row_dim = param_.ROW->dims();
auto num_k = topks.size();
auto row_shape_0 = row_dim[0];
std::vector<int64_t> vec_out_shape;
vec_out_shape.push_back(row_shape_0);
vec_out_shape.push_back(channel_num * num_k);
param_.Out->Resize(lite::DDim(vec_out_shape));
auto out_lod = param_.Out->mutable_lod();
*out_lod = param_.X->lod();
return true;
}
bool SequenceTopkAvgPoolingOpLite::AttachImpl(const cpp::OpDesc &op_desc,
lite::Scope *scope) {
auto X = op_desc.Input("X").front();
auto ROW = op_desc.Input("ROW").front();
auto COLUMN = op_desc.Input("COLUMN").front();
auto Out = op_desc.Output("Out").front();
auto pos = op_desc.Output("pos").front();
param_.X = scope->FindVar(X)->GetMutable<lite::Tensor>();
param_.ROW = scope->FindVar(ROW)->GetMutable<lite::Tensor>();
param_.COLUMN = scope->FindVar(COLUMN)->GetMutable<lite::Tensor>();
param_.Out = scope->FindVar(Out)->GetMutable<lite::Tensor>();
param_.pos = scope->FindVar(pos)->GetMutable<lite::Tensor>();
param_.channel_num = op_desc.GetAttr<int>("channel_num");
param_.topks = op_desc.GetAttr<std::vector<int>>("topks");
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(SequenceTopkAvgPooling,
paddle::lite::operators::SequenceTopkAvgPoolingOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class SequenceTopkAvgPoolingOpLite : public OpLite {
public:
SequenceTopkAvgPoolingOpLite() {}
explicit SequenceTopkAvgPoolingOpLite(const std::string &op_type)
: OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override {
return "sequence_topk_avg_pooling";
}
private:
mutable SequenceTopkAvgPoolingParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册