未验证 提交 79f8f42d 编写于 作者: L liu zhengxi 提交者: GitHub

Add stack op on Lite x86 platform and fix extra cmake error (#2458)

* add stack op and its unit tests, test=develop
上级 7635d699
......@@ -16,6 +16,7 @@ add_kernel(conv_compute_x86 X86 basic SRCS conv_compute.cc DEPS ${lite_kernel_de
# lite_cc_library(dropout_compute_x86 SRCS dropout_compute.cc DEPS ${lite_kernel_deps} )
# lite_cc_library(conv_compute_x86 SRCS conv_compute.cc DEPS ${lite_kernel_deps} blas im2col vol2col)
add_kernel(pool_compute_x86 X86 basic SRCS pool_compute.cc DEPS ${lite_kernel_deps} pooling)
add_kernel(stack_compute_x86 X86 basic SRCS stack_compute.cc DEPS ${lite_kernel_deps})
add_kernel(dropout_compute_x86 X86 basic SRCS dropout_compute.cc DEPS ${lite_kernel_deps})
add_kernel(transpose_compute_x86 X86 basic SRCS transpose_compute.cc DEPS ${lite_kernel_deps} math_function)
# add_kernel(fc_compute_x86 X86 basic SRCS fc_compute.cc DEPS ${lite_kernel_deps})
......@@ -34,13 +35,13 @@ add_kernel(mul_compute_x86 X86 basic SRCS mul_compute.cc DEPS ${lite_kernel_deps
add_kernel(concat_compute_x86 X86 basic SRCS concat_compute.cc DEPS ${lite_kernel_deps})
add_kernel(shape_compute_x86 X86 basic SRCS shape_compute.cc DEPS ${lite_kernel_deps})
add_kernel(sequence_pool_compute_x86 X86 basic SRCS sequence_pool_compute.cc DEPS ${lite_kernel_deps} sequence_pooling)
add_kernel(search_group_padding_compute_x86 X86 extra SRCS search_group_padding_compute.cc DEPS ${lite_kernel_deps})
add_kernel(search_group_padding_compute_x86 X86 basic SRCS search_group_padding_compute.cc DEPS ${lite_kernel_deps})
add_kernel(sequence_reverse_compute_x86 X86 basic SRCS sequence_reverse_compute.cc DEPS ${lite_kernel_deps})
add_kernel(softmax_compute_x86 X86 basic SRCS softmax_compute.cc DEPS ${lite_kernel_deps} softmax)
add_kernel(elementwise_compute_x86 X86 basic SRCS elementwise_compute.cc DEPS ${lite_kernel_deps})
add_kernel(batch_norm_compute_x86 X86 basic SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps})
add_kernel(reduce_sum_compute_x86 X86 basic SRCS reduce_compute.cc DEPS ${lite_kernel_deps})
add_kernel(lookup_table_compute_x86 X86 extra SRCS lookup_table_compute.cc DEPS ${lite_kernel_deps})
add_kernel(lookup_table_compute_x86 X86 basic SRCS lookup_table_compute.cc DEPS ${lite_kernel_deps})
add_kernel(sequence_reshape_compute_x86 X86 basic SRCS sequence_reshape_compute.cc DEPS ${lite_kernel_deps})
add_kernel(match_matrix_tensor_compute_x86 X86 basic SRCS match_matrix_tensor_compute.cc DEPS ${lite_kernel_deps} blas math_function)
add_kernel(search_seq_depadding_compute_x86 X86 basic SRCS search_seq_depadding_compute.cc DEPS ${lite_kernel_deps})
......@@ -75,7 +76,6 @@ lite_cc_test(test_batch_norm_compute_x86 SRCS batch_norm_compute_test.cc DEPS ba
lite_cc_test(test_softmax_compute_x86 SRCS softmax_compute_test.cc DEPS softmax_compute_x86)
lite_cc_test(test_elementwise_compute_x86 SRCS elementwise_compute_test.cc DEPS elementwise_compute_x86)
lite_cc_test(test_relu_compute_x86 SRCS relu_compute_test.cc DEPS activation_compute_x86)
lite_cc_test(test_search_group_padding_compute_x86 SRCS search_group_padding_compute_test.cc DEPS search_group_padding_compute_x86)
lite_cc_test(test_tanh_compute_x86 SRCS tanh_compute_test.cc DEPS activation_compute_x86)
lite_cc_test(test_gelu_compute_x86 SRCS gelu_compute_test.cc DEPS activation_compute_x86)
lite_cc_test(test_sequence_expand_as_compute_x86 SRCS sequence_expand_as_compute_test.cc DEPS sequence_expand_as_compute_x86)
......@@ -88,9 +88,9 @@ lite_cc_test(test_transpose_compute_x86 SRCS transpose_compute_test.cc DEPS tran
lite_cc_test(test_search_fc_compute_x86 SRCS search_fc_compute_test.cc DEPS search_fc_compute_x86)
lite_cc_test(test_search_seq_depadding_compute_x86 SRCS search_seq_depadding_compute_test.cc DEPS search_seq_depadding_compute_x86)
if(LITE_BUILD_EXTRA)
lite_cc_test(test_lookup_table_compute_x86 SRCS lookup_table_compute_test.cc DEPS lookup_table_compute_x86)
endif()
lite_cc_test(test_lookup_table_compute_x86 SRCS lookup_table_compute_test.cc DEPS lookup_table_compute_x86)
lite_cc_test(test_stack_compute_x86 SRCS stack_compute_test.cc DEPS stack_compute_x86)
lite_cc_test(test_search_group_padding_compute_x86 SRCS search_group_padding_compute_test.cc DEPS search_group_padding_compute_x86)
lite_cc_test(test_sequence_concat_compute_x86 SRCS sequence_concat_compute_test.cc DEPS sequence_concat_compute_x86)
lite_cc_test(test_match_matrix_compute_x86 SRCS match_matrix_tensor_compute_test.cc DEPS match_matrix_tensor_compute_x86)
lite_cc_test(test_var_conv_2d_compute_x86 SRCS var_conv_2d_compute_test.cc DEPS var_conv_2d_compute_x86)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/stack_compute.h"
REGISTER_LITE_KERNEL(stack,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::StackCompute<float>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "lite/core/kernel.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
#include "lite/operators/stack_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T>
class StackCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::StackParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto x = param.X;
auto y = param.Out;
int axis = param.axis;
if (axis < 0) axis += (x[0]->dims().size() + 1);
int n = static_cast<int>(x.size());
auto y_data = y->mutable_data<T>();
std::vector<const T*> x_datas(n);
for (int i = 0; i < n; ++i) x_datas[i] = x[i]->data<T>();
int pre = 1, post = 1;
auto dim = x[0]->dims();
for (int i = 0; i < axis; ++i) pre *= dim[i];
for (int i = axis; i < dim.size(); ++i) post *= dim[i];
auto x_data_arr = x_datas.data();
size_t x_offset = 0;
size_t y_offset = 0;
for (int i = 0; i < pre; i++) {
for (int j = 0; j < n; j++) {
std::memcpy(
y_data + y_offset, x_data_arr[j] + x_offset, post * sizeof(T));
y_offset += post;
}
x_offset += post;
}
}
virtual ~StackCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/stack_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
// stack
TEST(stack_x86, retrive_op) {
auto stack =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>("stack");
ASSERT_FALSE(stack.empty());
ASSERT_TRUE(stack.front());
}
TEST(stack_x86, init) {
lite::kernels::x86::StackCompute<float> stack;
ASSERT_EQ(stack.precision(), PRECISION(kFloat));
ASSERT_EQ(stack.target(), TARGET(kX86));
}
TEST(stack_x86, run_test) {
lite::Tensor x;
lite::Tensor out;
int num_input = 5;
std::vector<int64_t> x_shape({10, 20, 10});
x.Resize(lite::DDim(x_shape));
std::vector<int64_t> out_shape({5, 10, 20, 10});
out.Resize(lite::DDim(out_shape));
auto x_data = x.mutable_data<float>();
auto out_data = out.mutable_data<float>();
for (int64_t i = 0; i < x.dims().production(); ++i) {
x_data[i] = static_cast<float>(i);
}
std::vector<lite::Tensor*> input;
for (int i = 0; i < num_input; ++i) {
input.emplace_back(&x);
}
// StackCompute stack;
StackCompute<float> stack;
operators::StackParam param;
param.X = input;
param.Out = &out;
int axis = 0;
param.axis = axis;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
stack.SetContext(std::move(ctx));
stack.SetParam(param);
stack.Run();
int ref_data = 0;
for (int j = 0; j < out.dims().production(); ++j) {
EXPECT_NEAR(out_data[j], ref_data, 1e-5);
ref_data++;
ref_data = (ref_data >= 2000) ? (ref_data - 2000) : ref_data;
}
}
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(stack, kX86, kFloat, kNCHW, def);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册