From 7014a76b0aac93cf2d463d978ba3d9a1f945f81f Mon Sep 17 00:00:00 2001 From: lijianshe02 <48898730+lijianshe02@users.noreply.github.com> Date: Sat, 7 Sep 2019 16:04:10 +0800 Subject: [PATCH] add lite x86 ops for ASR test=develop (#1981) * add lite x86 ops for ASR test=develop * add lite x86 ops for ASR test=develop * fix x86 ci run test problems test=develop * fix mkl path for CI test=develop --- lite/backends/x86/CMakeLists.txt | 2 +- lite/backends/x86/dynamic_loader.cc | 8 +- lite/backends/x86/math/CMakeLists.txt | 2 +- lite/backends/x86/math/blas_impl.h | 6 +- lite/core/tensor.h | 26 +++--- lite/kernels/x86/CMakeLists.txt | 11 ++- lite/kernels/x86/concat_compute.h | 82 ++++++----------- lite/kernels/x86/concat_compute_test.cc | 7 +- lite/kernels/x86/mul_compute.cc | 2 + lite/kernels/x86/mul_compute.h | 34 ++++--- lite/kernels/x86/mul_compute_test.cc | 8 +- lite/kernels/x86/sequence_pool_compute.cc | 25 ++++++ lite/kernels/x86/sequence_pool_compute.h | 59 +++++++++++++ .../kernels/x86/sequence_pool_compute_test.cc | 88 +++++++++++++++++++ lite/kernels/x86/shape_compute.cc | 25 ++++++ lite/kernels/x86/shape_compute.h | 45 ++++++++++ lite/kernels/x86/shape_compute_test.cc | 73 +++++++++++++++ lite/operators/CMakeLists.txt | 2 +- lite/operators/op_params.h | 5 +- lite/tools/ci_build.sh | 4 +- .../cmake_tools/parse_kernel_registry.py | 1 + 21 files changed, 413 insertions(+), 102 deletions(-) create mode 100644 lite/kernels/x86/sequence_pool_compute.cc create mode 100644 lite/kernels/x86/sequence_pool_compute.h create mode 100644 lite/kernels/x86/sequence_pool_compute_test.cc create mode 100644 lite/kernels/x86/shape_compute.cc create mode 100644 lite/kernels/x86/shape_compute.h create mode 100644 lite/kernels/x86/shape_compute_test.cc diff --git a/lite/backends/x86/CMakeLists.txt b/lite/backends/x86/CMakeLists.txt index 992bf5536e..34e0800130 100644 --- a/lite/backends/x86/CMakeLists.txt +++ b/lite/backends/x86/CMakeLists.txt @@ -6,7 +6,7 @@ configure_file(cupti_lib_path.h.in ${CMAKE_CURRENT_BINARY_DIR}/cupti_lib_path.h) configure_file(warpctc_lib_path.h.in ${CMAKE_CURRENT_BINARY_DIR}/warpctc_lib_path.h) lite_cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags) -#lite_cc_library(dynload_mklml SRCS mklml.cc DEPS dynamic_loader mklml) +lite_cc_library(dynload_mklml SRCS mklml.cc DEPS dynamic_loader mklml) lite_cc_library(target_wrapper_x86 SRCS target_wrapper.cc) lite_cc_library(x86_cpu_info SRCS cpu_info.cc DEPS xbyak) diff --git a/lite/backends/x86/dynamic_loader.cc b/lite/backends/x86/dynamic_loader.cc index 3a3e0e1dd4..0f27a19cf5 100644 --- a/lite/backends/x86/dynamic_loader.cc +++ b/lite/backends/x86/dynamic_loader.cc @@ -54,8 +54,8 @@ DEFINE_string( DEFINE_string(mklml_dir, "", "Specify path for loading libmklml_intel.so."); namespace paddle { -namespace platform { -namespace dynload { +namespace lite { +namespace x86 { static constexpr char cupti_lib_path[] = CUPTI_LIB_PATH; static constexpr char warpctc_lib_path[] = WARPCTC_LIB_PATH; @@ -258,6 +258,6 @@ void* GetMKLMLDsoHandle() { #endif } -} // namespace dynload -} // namespace platform +} // namespace x86 +} // namespace lite } // namespace paddle diff --git a/lite/backends/x86/math/CMakeLists.txt b/lite/backends/x86/math/CMakeLists.txt index 35208c2c7a..5f440947fe 100644 --- a/lite/backends/x86/math/CMakeLists.txt +++ b/lite/backends/x86/math/CMakeLists.txt @@ -16,7 +16,7 @@ function(math_library TARGET) endif() list(LENGTH cc_srcs cc_srcs_len) - lite_cc_library(${TARGET} SRCS ${cc_srcs} DEPS ${math_library_DEPS} ${math_common_deps} eigen3) + lite_cc_library(${TARGET} SRCS ${cc_srcs} DEPS ${math_library_DEPS} ${math_common_deps} eigen3 dynload_mklml) endfunction() # please add new math_library in alphabetical order diff --git a/lite/backends/x86/math/blas_impl.h b/lite/backends/x86/math/blas_impl.h index 36d76c783c..c4844a4df3 100644 --- a/lite/backends/x86/math/blas_impl.h +++ b/lite/backends/x86/math/blas_impl.h @@ -483,7 +483,7 @@ void Blas::MatMul(const lite::Tensor &mat_a, mat_a.data(), mat_b.data(), beta, - mat_out->data()); + mat_out->mutable_data()); } template <> @@ -759,7 +759,7 @@ void Blas::MatMul(const lite::Tensor &mat_a, mat_a.data(), mat_b.data(), beta, - mat_out->data()); + mat_out->mutable_data()); } else { PADDLE_ENFORCE(dim_a.batch_size_ == dim_b.batch_size_ || dim_a.batch_size_ == 0 || dim_b.batch_size_ == 0); @@ -773,7 +773,7 @@ void Blas::MatMul(const lite::Tensor &mat_a, mat_a.data(), mat_b.data(), beta, - mat_out->data(), + mat_out->mutable_data(), dim_a.batch_size_ == 0 ? dim_b.batch_size_ : dim_a.batch_size_, dim_a.stride_, dim_b.stride_); diff --git a/lite/core/tensor.h b/lite/core/tensor.h index 205e586ab3..aa4cb1b3c5 100644 --- a/lite/core/tensor.h +++ b/lite/core/tensor.h @@ -218,16 +218,22 @@ R *TensorLite::mutable_data(TargetType target) { template TensorLite TensorLite::Slice(int64_t begin, int64_t end) const { - int64_t base = numel() / dims_[0]; - - TensorLite dst; - dst.buffer_ = buffer_; - dst.target_ = target_; - auto dst_dims = dims_; - dst_dims[0] = end - begin; - dst.Resize(dst_dims); - dst.offset_ = offset_ + static_cast(begin * base) * sizeof(T); - return dst; + CHECK_GE(begin, 0); + CHECK_LE(end, dims_[0]); + CHECK_LT(begin, end); + if (dims_[0] == 1) { + return *this; + } else { + int64_t base = numel() / dims_[0]; + TensorLite dst; + dst.buffer_ = buffer_; + dst.target_ = target_; + auto dst_dims = dims_; + dst_dims[0] = end - begin; + dst.Resize(dst_dims); + dst.offset_ = offset_ + static_cast(begin * base) * sizeof(T); + return dst; + } } template diff --git a/lite/kernels/x86/CMakeLists.txt b/lite/kernels/x86/CMakeLists.txt index 7080cc8c55..810c753c6d 100644 --- a/lite/kernels/x86/CMakeLists.txt +++ b/lite/kernels/x86/CMakeLists.txt @@ -23,11 +23,18 @@ add_kernel(scale_compute_x86 X86 basic SRCS scale_compute.cc DEPS ${lite_kernel_ # lite_cc_test(test_fc_compute_x86 SRCS fc_compute_test.cc DEPS fc_compute_x86) # lite_cc_test(test_conv2d_compute_x86 SRCS conv_compute_test.cc DEPS conv_compute_x86) # lite_cc_test(test_pool2d_compute_x86 SRCS pool_compute_test.cc DEPS pool_compute_x86) -# lite_cc_test(test_concat_compute_x86 SRCS concat_compute_test.cc DEPS concat_compute_x86) # lite_cc_test(test_softmax_compute_x86 SRCS softmax_compute_test.cc DEPS softmax_compute_x86) # lite_cc_test(test_elementwise_compute_x86 SRCS elementwise_compute_test.cc DEPS elementwise_compute_x86) # lite_cc_test(test_relu_compute_x86 SRCS relu_compute_test.cc DEPS relu_compute_x86) -# lite_cc_test(test_mul_compute_x86 SRCS mul_compute_test.cc DEPS mul_compute_x86 operator) # lite_cc_test(test_scale_compute_x86 SRCS scale_compute_test.cc DEPS scale_compute_x86) # lite_cc_test(test_dropout_compute_x86 SRCS dropout_compute_test.cc DEPS dropout_compute_x86) # lite_cc_test(test_batch_norm_compute_x86 SRCS batch_norm_compute_test.cc DEPS batch_norm_compute_x86) +add_kernel(mul_compute_x86 X86 basic SRCS mul_compute.cc DEPS ${lite_kernel_deps} blas) +add_kernel(concat_compute_x86 X86 basic SRCS concat_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(shape_compute_x86 X86 basic SRCS shape_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(sequence_pool_compute_x86 X86 basic SRCS sequence_pool_compute.cc DEPS ${lite_kernel_deps} sequence_pooling) + +lite_cc_test(test_mul_compute_x86 SRCS mul_compute_test.cc DEPS mul_compute_x86) +lite_cc_test(test_concat_compute_x86 SRCS concat_compute_test.cc DEPS concat_compute_x86) +lite_cc_test(test_sequence_pool_compute_x86 SRCS sequence_pool_compute_test.cc DEPS sequence_pool_compute_x86) +lite_cc_test(test_shape_compute_x86 SRCS shape_compute_test.cc DEPS shape_compute_x86) diff --git a/lite/kernels/x86/concat_compute.h b/lite/kernels/x86/concat_compute.h index 280320867d..674f06461f 100644 --- a/lite/kernels/x86/concat_compute.h +++ b/lite/kernels/x86/concat_compute.h @@ -18,13 +18,20 @@ #include "lite/core/kernel.h" #include "lite/core/op_registry.h" #include "lite/core/types.h" -#include "paddle/fluid/operators/strided_memcpy.h" namespace paddle { namespace lite { namespace kernels { namespace x86 { +inline int count(int start_axis, int end_axis, const lite::DDim& dim) { + int count = 1; + for (int i = start_axis; i < end_axis; ++i) { + count *= dim[i]; + } + return count; +} + template class ConcatCompute : public KernelLite { public: @@ -33,67 +40,28 @@ class ConcatCompute : public KernelLite { void Run() override { auto& param = *param_.get_mutable(); int64_t axis = static_cast(param.axis); + auto x_dims = param.x[0]->dims(); auto out = param.output; + if (param.x.size() == 1) return; - if (axis == 0 && param.x.size() < 10) { - size_t output_offset = 0; - for (auto* in : param.x) { - if (!in || in->dims().production() == 0UL) { - continue; - } - auto in_stride = framework::stride_numel(in->dims().data()); - auto out_stride = framework::stride_numel(out->dims().data()); - paddle::operators::StridedNumelCopyWithAxis( - platform::CPUDeviceContext(), - axis, - out->mutable_data() + output_offset, - out_stride, - in->data(), - in_stride, - in_stride[axis]); - - output_offset += in_stride[axis]; - } - } else { - std::vector inputs; - for (size_t j = 0; j < param.x.size(); ++j) { - if (param.x[j] && param.x[j]->dims().production() > 0) { - inputs.push_back(*param.x[j]); - } else { - continue; - } - } - - int num = inputs.size(); - int rows = 1; - auto dim_0 = inputs[0].dims(); - for (int i = 0; i < axis; ++i) { - rows *= dim_0[i]; - } - int out_rows = rows, out_cols = 0; - - std::vector input_cols(inputs.size()); - for (int i = 0; i < num; ++i) { - int t_cols = inputs[i].dims().production() / rows; - out_cols += t_cols; - input_cols[i] = t_cols; - } - // computation - auto output_data = param.output->template mutable_data(); - int col_idx = 0; - for (int j = 0; j < num; ++j) { - int col_len = input_cols[j]; - auto input_data = inputs[j].data(); - for (int k = 0; k < out_rows; ++k) { - std::memcpy(output_data + k * out_cols + col_idx, - input_data + k * col_len, - sizeof(T) * col_len); - } - col_idx += col_len; + auto output_data = param.output->template mutable_data(); + int offset_concat_axis = 0; + int num_concat = count(0, axis, x_dims); + int concat_input_size = count(axis + 1, x_dims.size(), x_dims); + const int top_concat_axis = out->dims()[axis]; + for (size_t i = 0; i < param.x.size(); ++i) { + auto bottom_data = param.x[i]->data(); + const int64_t bottom_concat_axis = param.x[i]->dims()[axis]; + for (int n = 0; n < num_concat; ++n) { + std::memcpy( + output_data + + (n * top_concat_axis + offset_concat_axis) * concat_input_size, + bottom_data + n * bottom_concat_axis * concat_input_size, + (bottom_concat_axis * concat_input_size) * sizeof(T)); } + offset_concat_axis += bottom_concat_axis; } } - virtual ~ConcatCompute() = default; }; diff --git a/lite/kernels/x86/concat_compute_test.cc b/lite/kernels/x86/concat_compute_test.cc index 5a08903f82..468e942275 100644 --- a/lite/kernels/x86/concat_compute_test.cc +++ b/lite/kernels/x86/concat_compute_test.cc @@ -14,7 +14,6 @@ #include "lite/kernels/x86/concat_compute.h" #include -#include #include #include "lite/core/op_registry.h" @@ -68,11 +67,11 @@ TEST(concat_x86, run_test) { concat.SetParam(param); concat.Run(); - std::cout << "output: "; + std::vector ref_results = { + 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2}; for (int i = 0; i < out.dims().production(); i++) { - std::cout << out_data[i] << " "; + EXPECT_NEAR(out_data[i], ref_results[i], 1e-3); } - std::cout << std::endl; } } // namespace x86 diff --git a/lite/kernels/x86/mul_compute.cc b/lite/kernels/x86/mul_compute.cc index d021a73532..3e5fccfc3a 100644 --- a/lite/kernels/x86/mul_compute.cc +++ b/lite/kernels/x86/mul_compute.cc @@ -25,6 +25,7 @@ REGISTER_LITE_KERNEL(mul, .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .Finalize(); +#ifdef LITE_WITH_TRAIN REGISTER_LITE_KERNEL(mul_grad, kX86, kFloat, @@ -40,3 +41,4 @@ REGISTER_LITE_KERNEL(mul_grad, .BindOutput(paddle::framework::GradVarName("Y"), {LiteType::GetTensorTy(TARGET(kX86))}) .Finalize(); +#endif diff --git a/lite/kernels/x86/mul_compute.h b/lite/kernels/x86/mul_compute.h index ae47d4a59e..e204fc81f2 100644 --- a/lite/kernels/x86/mul_compute.h +++ b/lite/kernels/x86/mul_compute.h @@ -13,17 +13,26 @@ // limitations under the License. #pragma once +#include "lite/backends/x86/math/blas.h" #include "lite/core/kernel.h" #include "lite/core/op_registry.h" #include "lite/core/types.h" -#include "paddle/fluid/operators/math/blas.h" - namespace paddle { namespace lite { namespace kernels { namespace x86 { -using Tensor = framework::Tensor; +// using Tensor = framework::Tensor; +inline lite::Tensor ReshapeToMatrix(const lite::Tensor& src, int num_col_dims) { + int rank = src.dims().size(); + if (rank == 2) { + return src; + } + lite::Tensor res; + res.ShareDataWith(src); + res.Resize(src.dims().Flatten2D(num_col_dims)); + return res; +} template class MulCompute : public KernelLite { @@ -33,36 +42,35 @@ class MulCompute : public KernelLite { void Run() override { auto& context = ctx_->As(); auto& param = *param_.get_mutable(); - CHECK(context.x86_device_context()); + // CHECK(context.x86_device_context()); - param.output->template mutable_data(); + auto* z = param.output; - auto* x = ¶m.x->raw_tensor(); - auto* y = ¶m.y->raw_tensor(); + auto* x = param.x; + auto* y = param.y; Tensor x_matrix, y_matrix; if (x->dims().size() > 2) { - x_matrix = framework::ReshapeToMatrix(*x, param.x_num_col_dims); + x_matrix = ReshapeToMatrix(*x, param.x_num_col_dims); } else { x_matrix = *x; } if (y->dims().size() > 2) { - y_matrix = framework::ReshapeToMatrix(*y, param.y_num_col_dims); + y_matrix = ReshapeToMatrix(*y, param.y_num_col_dims); } else { y_matrix = *y; } - auto* z = ¶m.output->raw_tensor(); + z->mutable_data(); auto z_dim = z->dims(); if (z_dim.size() != 2) { z->Resize({x_matrix.dims()[0], y_matrix.dims()[1]}); } - auto blas = paddle::operators::math::GetBlas( - *context.x86_device_context()); + auto blas = lite::x86::math::GetBlas(context); blas.MatMul(x_matrix, y_matrix, z); if (z_dim.size() != 2) { @@ -73,6 +81,7 @@ class MulCompute : public KernelLite { virtual ~MulCompute() = default; }; +#ifdef LITE_WITH_TRAIN template class MulGradCompute : public KernelLite { public: @@ -142,6 +151,7 @@ class MulGradCompute : public KernelLite { virtual ~MulGradCompute() = default; }; +#endif } // namespace x86 } // namespace kernels diff --git a/lite/kernels/x86/mul_compute_test.cc b/lite/kernels/x86/mul_compute_test.cc index 6737b75041..32d82cbb77 100644 --- a/lite/kernels/x86/mul_compute_test.cc +++ b/lite/kernels/x86/mul_compute_test.cc @@ -19,7 +19,6 @@ #include #include #include "lite/core/op_registry.h" - namespace paddle { namespace lite { namespace kernels { @@ -33,7 +32,7 @@ TEST(mul_x86, retrive_op) { } TEST(mul_x86, init) { - MulCompute mul; + lite::kernels::x86::MulCompute mul; ASSERT_EQ(mul.precision(), PRECISION(kFloat)); ASSERT_EQ(mul.target(), TARGET(kX86)); } @@ -72,9 +71,10 @@ TEST(mul_x86, run_test) { mul.SetParam(param); mul.Run(); - LOG(INFO) << "output: "; + std::vector ref_result = {20, 23, 26, 29}; + for (int i = 0; i < out.dims().production(); i++) { - LOG(INFO) << out_data[i]; + EXPECT_NEAR(out_data[i], ref_result[i], 1e-3); } } diff --git a/lite/kernels/x86/sequence_pool_compute.cc b/lite/kernels/x86/sequence_pool_compute.cc new file mode 100644 index 0000000000..f158392556 --- /dev/null +++ b/lite/kernels/x86/sequence_pool_compute.cc @@ -0,0 +1,25 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/sequence_pool_compute.h" + +REGISTER_LITE_KERNEL(sequence_pool, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::SequencePoolCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/lite/kernels/x86/sequence_pool_compute.h b/lite/kernels/x86/sequence_pool_compute.h new file mode 100644 index 0000000000..329a76658d --- /dev/null +++ b/lite/kernels/x86/sequence_pool_compute.h @@ -0,0 +1,59 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include "lite/backends/x86/math/math_function.h" +#include "lite/backends/x86/math/sequence_pooling.h" +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +class SequencePoolCompute : public KernelLite { + public: + using param_t = operators::SequencePoolParam; + + void Run() override { + auto& param = *param_.get_mutable(); + auto& context = ctx_->As(); + auto* out = param.Out; + auto dims = param.X->dims(); + auto lod = param.X->lod(); + CHECK_EQ(lod.size(), 1UL); + CHECK_GE(dims[0], static_cast(lod[0].size() - 1)); + + dims[0] = lod[0].size() - 1; + out->Resize({dims}); + out->mutable_data(); + lite::Tensor* index = nullptr; + + const bool is_test = true; + float pad_value = 0.0; + + lite::x86::math::SequencePoolFunctor pool; + pool(context, param.pool_type, pad_value, *param.X, out, is_test, index); + } + + virtual ~SequencePoolCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/sequence_pool_compute_test.cc b/lite/kernels/x86/sequence_pool_compute_test.cc new file mode 100644 index 0000000000..93cc122f7a --- /dev/null +++ b/lite/kernels/x86/sequence_pool_compute_test.cc @@ -0,0 +1,88 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/sequence_pool_compute.h" +#include +#include +#include +#include +#include "lite/core/op_registry.h" +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +TEST(sequence_pool_x86, retrive_op) { + auto sequence_pool = + KernelRegistry::Global().Create( + "sequence_pool"); + ASSERT_FALSE(sequence_pool.empty()); + ASSERT_TRUE(sequence_pool.front()); +} + +TEST(sequence_pool_x86, init) { + SequencePoolCompute sequence_pool; + ASSERT_EQ(sequence_pool.precision(), PRECISION(kFloat)); + ASSERT_EQ(sequence_pool.target(), TARGET(kX86)); +} + +TEST(sequence_pool_x86, run_test) { + lite::Tensor x, out; + lite::LoD lod; + lod.push_back(std::vector{0, 10}); + + x.set_lod(lod); + const size_t second_dim = 8u; + std::vector input_shape{static_cast(lod[0].back()), + static_cast(second_dim)}; + lite::DDim in_dims(input_shape); + x.Resize(in_dims); + + const size_t out_first_dim = lod[0].size() - 1; + std::vector output_shape{static_cast(out_first_dim), + static_cast(second_dim)}; + lite::DDim out_dims(output_shape); + out.Resize(out_dims); + + auto x_data = x.mutable_data(); + auto out_data = out.mutable_data(); + + for (int64_t i = 0; i < x.dims().production(); i++) { + x_data[i] = 1.1f * i; + } + + SequencePoolCompute sequence_pool; + operators::SequencePoolParam param; + param.X = &x; + param.Out = &out; + + std::unique_ptr ctx(new KernelContext); + ctx->As(); + sequence_pool.SetContext(std::move(ctx)); + sequence_pool.SetParam(param); + sequence_pool.Run(); + + std::vector ref_results = { + 39.6, 40.7, 41.8, 42.9, 44, 45.1, 46.2, 47.3}; + for (int i = 0; i < out.dims().production(); i++) { + EXPECT_NEAR(out_data[i], ref_results[i], 1e-3); + } +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(sequence_pool, kX86, kFloat, kNCHW, def); diff --git a/lite/kernels/x86/shape_compute.cc b/lite/kernels/x86/shape_compute.cc new file mode 100644 index 0000000000..565379eb06 --- /dev/null +++ b/lite/kernels/x86/shape_compute.cc @@ -0,0 +1,25 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/shape_compute.h" + +REGISTER_LITE_KERNEL(shape, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::ShapeCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/lite/kernels/x86/shape_compute.h b/lite/kernels/x86/shape_compute.h new file mode 100644 index 0000000000..ee3678a7f1 --- /dev/null +++ b/lite/kernels/x86/shape_compute.h @@ -0,0 +1,45 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +class ShapeCompute : public KernelLite { + public: + using param_t = operators::ShapeParam; + + void Run() override { + auto& param = *param_.get_mutable(); + // auto& context = context_->As(); + auto out_data = param.Out->mutable_data(); + auto in_dims = param.X->dims(); + for (int i = 0; i < in_dims.size(); ++i) { + out_data[i] = in_dims[i]; + } + } + + virtual ~ShapeCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/shape_compute_test.cc b/lite/kernels/x86/shape_compute_test.cc new file mode 100644 index 0000000000..88bd98f33f --- /dev/null +++ b/lite/kernels/x86/shape_compute_test.cc @@ -0,0 +1,73 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/shape_compute.h" +#include +#include +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +TEST(shape_x86, retrive_op) { + auto shape = + KernelRegistry::Global().Create("shape"); + ASSERT_FALSE(shape.empty()); + ASSERT_TRUE(shape.front()); +} + +TEST(shape_x86, init) { + ShapeCompute shape; + ASSERT_EQ(shape.precision(), PRECISION(kFloat)); + ASSERT_EQ(shape.target(), TARGET(kX86)); +} + +TEST(shape_x86, run_test) { + lite::Tensor x, out; + constexpr int batch_size = 1; + std::vector x_shape{batch_size, 1, 3, 3}; + x.Resize(lite::DDim(x_shape)); + + std::vector out_shape{4}; + out.Resize(lite::DDim(out_shape)); + + auto x_data = x.mutable_data(); + auto out_data = out.mutable_data(); + + for (int64_t i = 0; i < x.dims().production(); i++) { + x_data[i] = 1; + } + + ShapeCompute shape; + operators::ShapeParam param; + param.X = &x; + param.Out = &out; + + shape.SetParam(param); + shape.Run(); + + std::vector ref_results = {1, 1, 3, 3}; + for (int i = 0; i < out.dims().production(); i++) { + EXPECT_NEAR(out_data[i], ref_results[i], 1e-3); + } +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(shape, kX86, kFloat, kNCHW, def); diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index 12f96121db..dd7b751cb9 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -88,7 +88,7 @@ add_operator(greater_than extra SRCS compare_op.cc DEPS ${op_DEPS}) add_operator(greater_equal extra SRCS compare_op.cc DEPS ${op_DEPS}) add_operator(read_from_array_op extra SRCS read_from_array_op.cc DEPS ${op_DEPS}) add_operator(beam_search_op extra SRCS beam_search_op.cc DEPS ${op_DEPS}) -add_operator(sequence_pool_op_lite extra SRCS sequence_pool_op.cc DEPS ${op_DEPS}) +add_operator(sequence_pool extra SRCS sequence_pool_op.cc DEPS ${op_DEPS}) add_operator(lod_reset_op extra SRCS lod_reset_op.cc DEPS ${op_DEPS}) add_operator(is_empty extra SRCS is_empty_op.cc DEPS ${op_DEPS}) add_operator(slice_op_lite extra SRCS slice_op.cc DEPS ${op_DEPS}) diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 6885b05946..ae0f796a0f 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -666,7 +666,10 @@ struct BeamSearchParam { struct SequencePoolParam { const lite::Tensor* X{}; lite::Tensor* Out{}; - std::string pool_type; + std::string pool_type{"AVERAGE"}; +#ifdef LITE_WITH_X86 + float pad_value{0.0}; +#endif }; struct SequenceExpandParam { diff --git a/lite/tools/ci_build.sh b/lite/tools/ci_build.sh index eb91e15a6f..c04bbb7c62 100755 --- a/lite/tools/ci_build.sh +++ b/lite/tools/ci_build.sh @@ -151,7 +151,7 @@ function build_opencl { # This method is only called in CI. function cmake_x86_for_CI { prepare_workspace # fake an empty __generated_code__.cc to pass cmake. - cmake .. -DWITH_GPU=OFF -DWITH_MKLDNN=OFF -DLITE_WITH_X86=ON ${common_flags} -DLITE_WITH_PROFILE=ON -DWITH_MKL=OFF \ + cmake .. -DWITH_GPU=OFF -DWITH_MKLDNN=OFF -DLITE_WITH_X86=ON ${common_flags} -DLITE_WITH_PROFILE=ON -DWITH_MKL=ON \ -DLITE_BUILD_EXTRA=ON \ # Compile and execute the gen_code related test, so it will generate some code, and make the compilation reasonable. @@ -219,7 +219,7 @@ function test_server { function build_test_server { mkdir -p ./build cd ./build - export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/paddle/build/third_party/install/mklml/lib" + export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$PWD/third_party/install/mklml/lib" cmake_x86_for_CI build diff --git a/lite/tools/cmake_tools/parse_kernel_registry.py b/lite/tools/cmake_tools/parse_kernel_registry.py index a0a123898b..623d58190a 100644 --- a/lite/tools/cmake_tools/parse_kernel_registry.py +++ b/lite/tools/cmake_tools/parse_kernel_registry.py @@ -71,6 +71,7 @@ for line in lines: alias = fields[-1] key = "USE_LITE_KERNEL(%s, %s, %s, %s, %s);" % ( op, target, precision, layout, alias) + if "_grad" in key: continue out_lines.append(key) -- GitLab