diff --git a/lite/CMakeLists.txt b/lite/CMakeLists.txt index 21ed772fa8c9d5d7e99ecd0b8fa25ec5a8fb9d79..fa55e27255fcd82a72ac1489741e9e69db1fe933 100644 --- a/lite/CMakeLists.txt +++ b/lite/CMakeLists.txt @@ -45,6 +45,7 @@ if (WITH_TESTING) lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "mobilenet_v2_relu.tar.gz") lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "resnet50.tar.gz") lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "inception_v4_simple.tar.gz") + lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "step_rnn.tar.gz") endif() endif() diff --git a/lite/api/CMakeLists.txt b/lite/api/CMakeLists.txt index 3336b020a0efb0352ed35cc838dd26b6f6510dc4..bf930ed0e20bf0c1a2e313fd33ad7d87b734c42c 100644 --- a/lite/api/CMakeLists.txt +++ b/lite/api/CMakeLists.txt @@ -143,6 +143,11 @@ if(WITH_TESTING) ${ops} ${host_kernels} ${x86_kernels} ARGS --model_dir=${LITE_MODEL_DIR}/resnet50) add_dependencies(test_resnet50_lite_x86 extern_lite_download_resnet50_tar_gz) + lite_cc_test(test_step_rnn_lite_x86 SRCS test_step_rnn_lite_x86.cc + DEPS mir_passes lite_api_test_helper paddle_api_full paddle_api_light gflags utils + ${ops} ${host_kernels} ${x86_kernels} + ARGS --model_dir=${LITE_MODEL_DIR}/step_rnn) + add_dependencies(test_step_rnn_lite_x86 extern_lite_download_step_rnn_tar_gz) endif() endif() diff --git a/lite/api/cxx_api.h b/lite/api/cxx_api.h index e1d34172ba578824228e6369a8e37d60972336e9..502ce812e1f4a7f520e89e6eaff020c5853f5308 100644 --- a/lite/api/cxx_api.h +++ b/lite/api/cxx_api.h @@ -15,6 +15,7 @@ #pragma once #include #include +#include //NOLINT #include #include #include @@ -126,6 +127,8 @@ class CxxPaddleApiImpl : public lite_api::PaddlePredictor { void Run() override; + std::shared_ptr Clone() override; + std::string GetVersion() const override; // get inputs names and get outputs names @@ -146,6 +149,8 @@ class CxxPaddleApiImpl : public lite_api::PaddlePredictor { private: Predictor raw_predictor_; + lite_api::CxxConfig config_; + std::mutex mutex_; }; /* diff --git a/lite/api/cxx_api_impl.cc b/lite/api/cxx_api_impl.cc index db225fb78497d5c8f31f90e59c755232adc53222..1a4f02fd9ed8b7a10b7878ccedf88b3f266124cd 100644 --- a/lite/api/cxx_api_impl.cc +++ b/lite/api/cxx_api_impl.cc @@ -13,6 +13,8 @@ // limitations under the License. #include "lite/api/cxx_api.h" +#include +#include //NOLINT #include #include "lite/api/paddle_api.h" #include "lite/core/device_info.h" @@ -22,6 +24,7 @@ namespace paddle { namespace lite { void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) { + config_ = config; #ifdef LITE_WITH_CUDA Env::Init(); #endif @@ -50,6 +53,13 @@ std::vector CxxPaddleApiImpl::GetOutputNames() { void CxxPaddleApiImpl::Run() { raw_predictor_.Run(); } +std::shared_ptr CxxPaddleApiImpl::Clone() { + std::lock_guard lock(mutex_); + auto predictor = std::make_shared(); + predictor->Init(config_); + return predictor; +} + std::string CxxPaddleApiImpl::GetVersion() const { return version(); } std::unique_ptr CxxPaddleApiImpl::GetTensor( diff --git a/lite/api/light_api.h b/lite/api/light_api.h index 13ef72d92cfd83954188516eb297ad23b31994df..3781bc4d674db5d2e8794edaf33f00627b9977bb 100644 --- a/lite/api/light_api.h +++ b/lite/api/light_api.h @@ -96,6 +96,8 @@ class LightPredictorImpl : public lite_api::PaddlePredictor { void Run() override; + std::shared_ptr Clone() override; + std::string GetVersion() const override; std::vector GetInputNames() override; std::vector GetOutputNames() override; diff --git a/lite/api/light_api_impl.cc b/lite/api/light_api_impl.cc index 90954187d2dbd211867232796dbe4ec556f9ba0c..59cf0af708d1186b30a0f582ba82a2edd3a234d5 100644 --- a/lite/api/light_api_impl.cc +++ b/lite/api/light_api_impl.cc @@ -44,6 +44,10 @@ std::unique_ptr LightPredictorImpl::GetOutput( void LightPredictorImpl::Run() { raw_predictor_->Run(); } +std::shared_ptr LightPredictorImpl::Clone() { + LOG(FATAL) << "The Clone API is not supported in LigthPredictor"; +} + std::string LightPredictorImpl::GetVersion() const { return lite::version(); } std::unique_ptr LightPredictorImpl::GetTensor( diff --git a/lite/api/paddle_api.cc b/lite/api/paddle_api.cc index db545de958743b89f99efcde131e1371fdb15409..f148096bb69a3a249521bcb847d5beae3f8297f9 100644 --- a/lite/api/paddle_api.cc +++ b/lite/api/paddle_api.cc @@ -46,6 +46,10 @@ template <> const int8_t *Tensor::data() const { return ctensor(raw_tensor_)->data(); } +template <> +const int64_t *Tensor::data() const { + return ctensor(raw_tensor_)->data(); +} template <> const int32_t *Tensor::data() const { @@ -64,6 +68,10 @@ template <> int8_t *Tensor::mutable_data(TargetType type) const { return tensor(raw_tensor_)->mutable_data(type); } +template <> +int64_t *Tensor::mutable_data(TargetType type) const { + return tensor(raw_tensor_)->mutable_data(type); +} template void Tensor::CopyFromCpu(const T *src_data) { diff --git a/lite/api/paddle_api.h b/lite/api/paddle_api.h index 3e911b62f785b2102685de94377804cf250f57e9..9b48ab4ab413ed57d73651137c6108cbb9a4e780 100644 --- a/lite/api/paddle_api.h +++ b/lite/api/paddle_api.h @@ -78,6 +78,7 @@ class LITE_API PaddlePredictor { virtual std::unique_ptr GetOutput(int i) const = 0; virtual void Run() = 0; + virtual std::shared_ptr Clone() = 0; virtual std::string GetVersion() const = 0; diff --git a/lite/api/paddle_place.h b/lite/api/paddle_place.h index 259887e2fb50cdee75f67d07db21ae73395fe700..07284be095c05e5dfa069b0973d5982cf1f07c8a 100644 --- a/lite/api/paddle_place.h +++ b/lite/api/paddle_place.h @@ -103,6 +103,8 @@ static size_t PrecisionTypeLength(PrecisionType type) { return 1; case PrecisionType::kInt32: return 4; + case PrecisionType::kInt64: + return 8; case PrecisionType::kFP16: return 2; default: diff --git a/lite/api/test_step_rnn_lite_x86.cc b/lite/api/test_step_rnn_lite_x86.cc new file mode 100644 index 0000000000000000000000000000000000000000..c483373dc745f6520d51ece3936448ada71990d3 --- /dev/null +++ b/lite/api/test_step_rnn_lite_x86.cc @@ -0,0 +1,111 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "lite/api/lite_api_test_helper.h" +#include "lite/api/paddle_api.h" +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/api/test_helper.h" +#include "lite/utils/cp_logging.h" + +namespace paddle { +namespace lite { + +TEST(Step_rnn, test_step_rnn_lite_x86) { + std::string model_dir = FLAGS_model_dir; + lite_api::CxxConfig config; + config.set_model_dir(model_dir); + config.set_valid_places({lite_api::Place{TARGET(kX86), PRECISION(kInt64)}, + lite_api::Place{TARGET(kX86), PRECISION(kFloat)}, + lite_api::Place{TARGET(kHost), PRECISION(kFloat)}}); + auto predictor = lite_api::CreatePaddlePredictor(config); + + std::vector target_names = {"item_type_id", + "mthid_id", + "source_id_id", + "layout_id", + "mark_id", + "category_id", + "subcategory_id", + "score_segment_id", + "item_attention_id", + "queue_num_id", + "micro_video_id", + "vertical_type_id"}; + + for (int i = 0; i < target_names.size(); ++i) { + auto input_tensor = predictor->GetInput(i); + int size = 0; + if (i == 6 || i == 8) { + input_tensor->Resize(std::vector{5, 1}); + input_tensor->SetLoD({{0, 5}}); + size = 5; + } else { + input_tensor->Resize(std::vector{1, 1}); + input_tensor->SetLoD({{0, 1}}); + size = 1; + } + auto* data = input_tensor->mutable_data(); + for (int i = 0; i < size; i++) data[i] = 1; + } + + for (int i = 0; i < FLAGS_warmup; ++i) { + predictor->Run(); + } + + auto start = GetCurrentUS(); + for (int i = 0; i < FLAGS_repeats; ++i) { + predictor->Run(); + } + + // LOG(INFO) << "================== Speed Report ==================="; + LOG(INFO) << ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats + << ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0 + << " ms in average."; + + std::vector> results; + // i = 1 + results.emplace_back(std::vector({0.5030127, 0.496987})); + auto out = predictor->GetOutput(0); + + std::vector out_shape = out->shape(); + + for (int i = 0; i < results.size(); ++i) { + for (int j = 0; j < results[i].size(); ++j) { + EXPECT_NEAR( + out->data()[j + (out_shape[1] * i)], results[i][j], 1e-6); + } + } +} + +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/fusion/quant_dequant_fuse_pass.cc b/lite/core/mir/fusion/quant_dequant_fuse_pass.cc index 5498c28922836cc16c4b765df03f68a4a7716e05..8ec50b8112b6b853e83abf5c491163fa4475f2f4 100644 --- a/lite/core/mir/fusion/quant_dequant_fuse_pass.cc +++ b/lite/core/mir/fusion/quant_dequant_fuse_pass.cc @@ -54,4 +54,5 @@ void QuantDequantFusePass::Apply(const std::unique_ptr& graph) { REGISTER_MIR_PASS(lite_quant_dequant_fuse_pass, paddle::lite::mir::QuantDequantFusePass) - .BindTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}) + .BindKernel("calib"); diff --git a/lite/core/op_registry.cc b/lite/core/op_registry.cc index ad974a781c7c899428015907a4166d8d0c351c76..3b8b350ad82f2cc1ce296b1ad74a6e322abec8ff 100644 --- a/lite/core/op_registry.cc +++ b/lite/core/op_registry.cc @@ -54,6 +54,8 @@ std::list> KernelRegistry::Create( CREATE_KERNEL1(target__, kFP16); \ case PRECISION(kAny): \ CREATE_KERNEL1(target__, kAny); \ + case PRECISION(kInt64): \ + CREATE_KERNEL1(target__, kInt64); \ default: \ CHECK(false) << "not supported kernel precision " \ << PrecisionToStr(precision); \ @@ -126,6 +128,7 @@ KernelRegistry::KernelRegistry() INIT_FOR(kX86, kFloat, kNCHW); INIT_FOR(kX86, kAny, kNCHW); INIT_FOR(kX86, kAny, kAny); + INIT_FOR(kX86, kInt64, kNCHW); INIT_FOR(kARM, kFloat, kNCHW); INIT_FOR(kARM, kInt8, kNCHW); diff --git a/lite/fluid/eigen.h b/lite/fluid/eigen.h index f5d5e4b5e516315b16369be2e2dd9c46281fc3d0..eac5332b53c857b05aacbfa95ee2e4b9fcd98a93 100644 --- a/lite/fluid/eigen.h +++ b/lite/fluid/eigen.h @@ -32,7 +32,7 @@ struct EigenDim { static Type From(const lite::DDim& dims) { PADDLE_ENFORCE(dims.size() == D, "D must match DDim::size"); Type ret; - for (int64_t d = 0; d < dims.size(); d++) { + for (size_t d = 0; d < dims.size(); d++) { ret[d] = dims[d]; } return ret; @@ -118,7 +118,9 @@ struct EigenScalar { using ConstType = Eigen::TensorMap< Eigen::TensorFixedSize, MajorType, IndexType>>; - static Type From(Tensor& tensor) { return Type(tensor.data()); } // NOLINT + static Type From(Tensor* tensor) { + return Type(const_cast(tensor->data())); + } // NOLINT static ConstType From(const Tensor& tensor) { return ConstType(tensor.data()); diff --git a/lite/kernels/x86/CMakeLists.txt b/lite/kernels/x86/CMakeLists.txt index 6d47c880c8daf1ec8981dfb4083324b79c25cec1..da955e4fd5902373cd881f85a8bc715eef7cec94 100644 --- a/lite/kernels/x86/CMakeLists.txt +++ b/lite/kernels/x86/CMakeLists.txt @@ -36,6 +36,9 @@ add_kernel(sequence_pool_compute_x86 X86 basic SRCS sequence_pool_compute.cc DEP add_kernel(softmax_compute_x86 X86 basic SRCS softmax_compute.cc DEPS ${lite_kernel_deps} softmax) add_kernel(elementwise_compute_x86 X86 basic SRCS elementwise_compute.cc DEPS ${lite_kernel_deps}) add_kernel(batch_norm_compute_x86 X86 basic SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(reduce_sum_compute_x86 X86 basic SRCS reduce_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(lookup_table_compute_x86 X86 basic SRCS lookup_table_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(sequence_reshape_compute_x86 X86 basic SRCS sequence_reshape_compute.cc DEPS ${lite_kernel_deps}) if(NOT LITE_WITH_X86) return() diff --git a/lite/kernels/x86/activation_compute.cc b/lite/kernels/x86/activation_compute.cc index b4a053419c5c6f04b4b053d7bf902a57e9562518..f2f911dd7d037a3f4e0f28592cff07383c8a49b6 100644 --- a/lite/kernels/x86/activation_compute.cc +++ b/lite/kernels/x86/activation_compute.cc @@ -57,3 +57,13 @@ REGISTER_LITE_KERNEL(gelu, .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .Finalize(); + +REGISTER_LITE_KERNEL(softsign, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::SoftsignCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/lite/kernels/x86/activation_compute.h b/lite/kernels/x86/activation_compute.h index 482684b0672c1ed7f0d571f852e134e92ddcaafa..14d0ffe000311c87dac513a65f731e9654042db2 100644 --- a/lite/kernels/x86/activation_compute.h +++ b/lite/kernels/x86/activation_compute.h @@ -187,6 +187,31 @@ class GeluCompute : public KernelLite { virtual ~GeluCompute() = default; }; +// softsign(x) = x / (1 + |x|) +template +struct SoftsignFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) { + out.device(d) = x / (static_cast(1) + x.abs()); + } +}; + +template +class SoftsignCompute : public KernelLite { + public: + using param_t = operators::ActivationParam; + + void Run() override { + // auto& context = ctx_->As(); + auto& param = *param_.get_mutable(); + param.Out->template mutable_data(); + + Activate>(param.X, param.Out); + } + + virtual ~SoftsignCompute() = default; +}; + } // namespace x86 } // namespace kernels } // namespace lite diff --git a/lite/kernels/x86/concat_compute.h b/lite/kernels/x86/concat_compute.h index 674f06461f211d18925d31e3a8ce6ec62d8b88df..3fd1e9f233d2022a1fa0735bd1bc849923e64745 100644 --- a/lite/kernels/x86/concat_compute.h +++ b/lite/kernels/x86/concat_compute.h @@ -42,7 +42,10 @@ class ConcatCompute : public KernelLite { int64_t axis = static_cast(param.axis); auto x_dims = param.x[0]->dims(); auto out = param.output; - if (param.x.size() == 1) return; + if (param.x.size() == 1) { + param.output->ShareDataWith(*param.x[0]); + return; + } auto output_data = param.output->template mutable_data(); int offset_concat_axis = 0; diff --git a/lite/kernels/x86/lookup_table_compute.cc b/lite/kernels/x86/lookup_table_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..364593251e17453011bad5b2c1057fc25d54d7c8 --- /dev/null +++ b/lite/kernels/x86/lookup_table_compute.cc @@ -0,0 +1,34 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/lookup_table_compute.h" + +// REGISTER_LITE_KERNEL(lookup_table, kX86, kFloat, kNCHW, +// paddle::lite::kernels::x86::LookupTableCompute, +// def) +// .BindInput("W", {LiteType::GetTensorTy(TARGET(kX86))}) +// .BindInput("Ids", {LiteType::GetTensorTy(TARGET(kX86))}) +// .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) +// .Finalize(); +//, +REGISTER_LITE_KERNEL(lookup_table, + kX86, + kInt64, + kNCHW, + paddle::lite::kernels::x86::LookupTableCompute, + def) + .BindInput("W", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Ids", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/lite/kernels/x86/lookup_table_compute.h b/lite/kernels/x86/lookup_table_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..e0d7752ca77c810700f57722c4186b4e02d6411f --- /dev/null +++ b/lite/kernels/x86/lookup_table_compute.h @@ -0,0 +1,66 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +#include "lite/fluid/eigen.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +class LookupTableCompute : public KernelLite { + public: + using param_t = operators::LookupTableParam; + + void Run() override { + auto ¶m = *param_.get_mutable(); + // auto& context = context_->As(); + auto *ids_t = param.Ids; + auto *output_t = param.Out; + int64_t padding_idx = param.padding_idx; + auto *ids = ids_t->data(); + int64_t ids_numel = ids_t->dims().production(); + + auto *table_t = param.W; + int64_t row_number = table_t->dims()[0]; + int64_t row_width = table_t->dims()[1]; + + auto *table = table_t->data(); + auto *output = output_t->mutable_data(); + memset(output, 0, output_t->dims().production() * sizeof(float)); + for (int64_t i = 0; i < ids_numel; ++i) { + if (padding_idx != -1 && ids[i] == padding_idx) { + memset(output + i * row_width, 0, row_width * sizeof(float)); + } else { + CHECK_LT(ids[i], row_number); + CHECK_GE(ids[i], 0); + memcpy(output + i * row_width, + table + ids[i] * row_width, + row_width * sizeof(float)); + } + } + } + + virtual ~LookupTableCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/reduce_compute.cc b/lite/kernels/x86/reduce_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..f95f4cfb881fef329ea940ca8b9fa6b4fd6ff7b6 --- /dev/null +++ b/lite/kernels/x86/reduce_compute.cc @@ -0,0 +1,25 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/reduce_compute.h" + +REGISTER_LITE_KERNEL(reduce_sum, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::ReduceSumCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/lite/kernels/x86/reduce_compute.h b/lite/kernels/x86/reduce_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..655f104ce65906f1904a7cf02d703069b0a7a2bf --- /dev/null +++ b/lite/kernels/x86/reduce_compute.h @@ -0,0 +1,83 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +#include "lite/fluid/eigen.h" +#include "lite/kernels/x86/reduce_op_function.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +struct SumFunctor { + template + void operator()(X* x, Y* y, const Dim& dim) { + y->device(lite::fluid::EigenDeviceType()) = x->sum(dim); + } +}; + +#define HANDLE_DIM(NDIM, RDIM) \ + if (ndim == NDIM && rdim == RDIM) { \ + paddle::lite::kernels::x86:: \ + ReduceFunctor( \ + *input, output, dims, keep_dim); \ + } + +template +class ReduceSumCompute : public KernelLite { + public: + using param_t = operators::ReduceParam; + + void Run() override { + auto& param = *param_.get_mutable(); + // auto& context = ctx_->As(); + bool reduce_all = param.reduce_all; + auto* input = param.x; + auto* output = param.output; + param.output->mutable_data(); + + auto dims = param.dim; + bool keep_dim = param.keep_dim; + if (reduce_all) { + // Flatten and reduce 1-D tensor + auto x = lite::fluid::EigenVector::Flatten(*input); + auto out = lite::fluid::EigenScalar::From(output); + // auto& place = *platform::CPUDeviceContext().eigen_device(); + auto reduce_dim = Eigen::array({{0}}); + SumFunctor functor; + functor(&x, &out, reduce_dim); + } else { + int ndim = input->dims().size(); + int rdim = dims.size(); + HANDLE_DIM(4, 3); + HANDLE_DIM(4, 2); + HANDLE_DIM(4, 1); + HANDLE_DIM(3, 2); + HANDLE_DIM(3, 1); + HANDLE_DIM(2, 1); + HANDLE_DIM(1, 1); + } + } + + virtual ~ReduceSumCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/reduce_op_function.h b/lite/kernels/x86/reduce_op_function.h new file mode 100644 index 0000000000000000000000000000000000000000..b3ddab64e4bf8dc72cec3b86398f42269c5a947c --- /dev/null +++ b/lite/kernels/x86/reduce_op_function.h @@ -0,0 +1,84 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/op_registry.h" +#include "lite/fluid/eigen.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +using EigenTensor = lite::fluid::EigenTensor; +template +using EigenScalar = lite::fluid::EigenScalar; +template +using EigenVector = lite::fluid::EigenVector; + +template +// const lite::Context& context, +void ReduceFunctor(const lite::Tensor& input, + lite::Tensor* output, + const std::vector& dims, + bool keep_dim) { + auto x = EigenTensor::From(input); + auto x_rank = static_cast(x.dimensions().size()); + auto reduce_dim = Eigen::array(); + std::vector dims_ref = dims; + for (size_t i = 0; i < dims_ref.size(); ++i) { + if (dims_ref[i] < 0) dims_ref[i] = x_rank + dims_ref[i]; + reduce_dim[i] = dims_ref[i]; + } + // construct the squeezed output tensor + lite::DDim out_dims = output->dims(); + if (keep_dim && x_rank > 1) { + const int kDelFlag = -2; + auto dims_vector = out_dims.Vectorize(); + for (size_t i = 0; i < dims_ref.size(); ++i) { + dims_vector[dims_ref[i]] = kDelFlag; + } + dims_vector.erase(remove(dims_vector.begin(), dims_vector.end(), kDelFlag), + dims_vector.end()); + out_dims = lite::DDim(dims_vector); + } + // auto& place = *context.eigen_device(); + Functor functor; + + if (D == 1) { + auto out = EigenScalar::From(output); + functor(&x, &out, reduce_dim); + } else { + auto out = EigenTensor::From(*output, out_dims); + functor(&x, &out, reduce_dim); + } +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/reshape_compute.cc b/lite/kernels/x86/reshape_compute.cc index abbb0f6af54106d06369c7a90cbdc2a554a42a3c..7afe4f6d8bc4740c00d3ed95fafc4e32f59b6d02 100644 --- a/lite/kernels/x86/reshape_compute.cc +++ b/lite/kernels/x86/reshape_compute.cc @@ -34,3 +34,14 @@ REGISTER_LITE_KERNEL(reshape2, .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kX86))}) .Finalize(); +REGISTER_LITE_KERNEL(reshape2, + kX86, + kInt64, + kNCHW, + paddle::lite::kernels::x86::Reshape2Compute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) + .BindOutput("XShape", + {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) + .Finalize(); diff --git a/lite/kernels/x86/sequence_reshape_compute.cc b/lite/kernels/x86/sequence_reshape_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..ccaeef27d7439b739b298f3b0756e2a2eddef2c1 --- /dev/null +++ b/lite/kernels/x86/sequence_reshape_compute.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/sequence_reshape_compute.h" + +REGISTER_LITE_KERNEL( + sequence_reshape, + kX86, + kInt64, + kNCHW, + paddle::lite::kernels::x86::SequenceReshapeCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) + .Finalize(); diff --git a/lite/kernels/x86/sequence_reshape_compute.h b/lite/kernels/x86/sequence_reshape_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..68a573c2f674edcf0a09cccec730a8d7dbcea844 --- /dev/null +++ b/lite/kernels/x86/sequence_reshape_compute.h @@ -0,0 +1,81 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +#include "lite/fluid/eigen.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +class SequenceReshapeCompute + : public KernelLite { + public: + using param_t = operators::SequenceReshapeParam; + + void Run() override { + auto& param = *param_.get_mutable(); + // auto& context = context_->As(); + auto* in = param.x; + auto* out = param.output; + int out_width = param.new_dim; + + auto in_dims = in->dims(); + int64_t in_width = in_dims[1]; + // LOG(INFO)<<"sequence_reshape in tensor:"<<*in; + auto& in_lod = in->lod(); + + CHECK_EQ(in_lod.size(), 1UL); + CHECK_EQ((uint64_t)in_dims[0], in_lod[0].back()); + + auto in_lod_l0 = in_lod[0]; + int seq_num = in_lod_l0.size() - 1; + + if (in_width == out_width) { + out->set_lod(in->lod()); + } else { + auto& out_lod = *out->mutable_lod(); + out_lod.resize(1); + out_lod[0].resize(seq_num + 1); + out_lod[0][0] = 0; + for (int i = 0; i < seq_num; ++i) { + size_t seq_len = in_lod_l0[i + 1] - in_lod_l0[i]; + size_t offset = 0; + offset = (seq_len * in_width) / out_width; + CHECK_EQ(offset * out_width, seq_len * in_width); + out_lod[0][i + 1] = out_lod[0][i] + offset; + } + } + + out->Resize(in_dims); + auto* dst_ptr = out->mutable_data(); + auto size = in->numel() * sizeof(T); + std::memcpy(dst_ptr, in->data(), size); + std::vector out_shape{static_cast(out->lod()[0].back()), + out_width}; + out->Resize(lite::DDim(out_shape)); + } + + virtual ~SequenceReshapeCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index 9d5d1952ae9e17826db24b5f37d77a1980232ffd..21b8ec278a6df16711bef3d1b3be34f77c52c9b3 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -76,6 +76,8 @@ add_operator(sequence_expand_as_op_lite extra SRCS sequence_expand_as_op.cc DEPS add_operator(range_op extra SRCS range_op.cc DEPS ${op_DEPS}) add_operator(assign_value_op extra SRCS assign_value_op.cc DEPS ${op_DEPS}) add_operator(fake_quantize_dequantize_moving_avg_abs_max_op extra SRCS fake_quantize_dequantize_moving_avg_max_abs.cc DEPS ${op_DEPS}) +add_operator(sequence_reshape_op_lite extra SRCS sequence_reshape_op.cc DEPS ${op_DEPS}) +add_operator(reduce_sum_op_lite extra SRCS reduce_ops.cc DEPS ${op_DEPS}) # for OCR specific add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS}) diff --git a/lite/operators/activation_ops.cc b/lite/operators/activation_ops.cc index a7f2d28cc908a0bda5027c297eeb9f59e989540a..c3c5de311f41f88fbeed4b03f9bfd618cf51c3b3 100644 --- a/lite/operators/activation_ops.cc +++ b/lite/operators/activation_ops.cc @@ -118,6 +118,7 @@ REGISTER_LITE_OP(exp, paddle::lite::operators::ActivationOp); REGISTER_LITE_OP(floor, paddle::lite::operators::ActivationOp); REGISTER_LITE_OP(hard_sigmoid, paddle::lite::operators::ActivationOp); REGISTER_LITE_OP(rsqrt, paddle::lite::operators::ActivationOp); +REGISTER_LITE_OP(softsign, paddle::lite::operators::ActivationOp); #ifdef LITE_WITH_TRAIN REGISTER_LITE_OP(square_grad, paddle::lite::operators::ActivationGradOp); diff --git a/lite/operators/concat_op.cc b/lite/operators/concat_op.cc index f073faf6b9d98a92d195f2004fc98760a45af9ba..dfd95e4658ddbfe244659887e9c738722be439ec 100644 --- a/lite/operators/concat_op.cc +++ b/lite/operators/concat_op.cc @@ -21,7 +21,7 @@ namespace lite { namespace operators { bool ConcatOpLite::CheckShape() const { - CHECK_GT_OR_FALSE(param_.x.size(), 1UL); + CHECK_GE_OR_FALSE(param_.x.size(), 1UL); CHECK_OR_FALSE(param_.output); return true; } diff --git a/lite/operators/lookup_table_op.cc b/lite/operators/lookup_table_op.cc index 192de2ecf85d5dda9bbf42b4fb1dccd28d8b02d5..3d5a71cee96adb520aeafc83156e5f37638912ad 100644 --- a/lite/operators/lookup_table_op.cc +++ b/lite/operators/lookup_table_op.cc @@ -50,6 +50,7 @@ bool LookupTableOpLite::InferShape() const { } out_dims.push_back(table_dims[1]); param_.Out->Resize(lite::DDim{out_dims}); + param_.Out->set_lod(param_.Ids->lod()); return true; } diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 097dd91163357d9fa43818c68687a48de06fe8aa..5ae22e6039bf55bb57f4e90a49b4eca835b879ea 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -721,6 +721,12 @@ struct SequencePoolParam { #endif }; +struct SequenceReshapeParam { + lite::Tensor* x{}; + lite::Tensor* output{}; + int new_dim; +}; + struct SequenceExpandParam { const lite::Tensor* X{}; const lite::Tensor* Y{}; @@ -753,6 +759,15 @@ struct IsEmptyParam { const lite::Tensor* X{}; lite::Tensor* Out{}; }; + +struct ReduceParam { + lite::Tensor* x{}; + lite::Tensor* output{}; + std::vector dim{0}; + bool keep_dim{false}; + bool reduce_all{false}; +}; + /// ----------------------- shape operators ---------------------- struct ShapeParam { const lite::Tensor* X{}; diff --git a/lite/operators/reduce_ops.cc b/lite/operators/reduce_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..e986b0ca5412f8380cccc9f981e5e4069ffcdabc --- /dev/null +++ b/lite/operators/reduce_ops.cc @@ -0,0 +1,89 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/reduce_ops.h" +#include +#include "lite/core/op_registry.h" +namespace paddle { +namespace lite { +namespace operators { + +bool ReduceOp::CheckShape() const { + CHECK_OR_FALSE(param_.x); + CHECK_OR_FALSE(param_.output); + auto x_dims = param_.x->dims(); + auto x_rank = x_dims.size(); + CHECK_LE(x_rank, 6UL) << "Tensors with rank at most 6 are supported."; + return true; +} + +bool ReduceOp::InferShape() const { + auto x_dims = param_.x->dims(); + auto x_rank = x_dims.size(); + auto dims = param_.dim; + for (size_t i = 0; i < dims.size(); ++i) { + if (dims[i] < 0) dims[i] = x_rank + dims[i]; + CHECK_LT(dims[i], x_rank) + << "The dim should be in the range [-rank(input), rank(input)."; + } + sort(dims.begin(), dims.end()); + bool reduce_all = param_.reduce_all; + bool keep_dim = param_.keep_dim; + + if (reduce_all) { + if (keep_dim) + param_.output->Resize(lite::DDim(std::vector(x_rank, 1))); + else + param_.output->Resize(lite::DDim(std::vector{1})); + } else { + auto dims_vector = x_dims.Vectorize(); + if (keep_dim) { + for (size_t i = 0; i < dims.size(); ++i) { + dims_vector[dims[i]] = 1; + } + } else { + const int kDelFlag = -2; + for (size_t i = 0; i < dims.size(); ++i) { + dims_vector[dims[i]] = kDelFlag; + } + dims_vector.erase( + remove(dims_vector.begin(), dims_vector.end(), kDelFlag), + dims_vector.end()); + } + auto out_dims = lite::DDim(dims_vector); + param_.output->Resize(out_dims); + if (dims[0] != 0) { + param_.output->set_lod(param_.x->lod()); + } + } + return true; +} + +bool ReduceOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { + param_.x = + scope->FindVar(opdesc.Input("X").front())->GetMutable(); + param_.output = + scope->FindVar(opdesc.Output("Out").front())->GetMutable(); + + param_.dim = opdesc.GetAttr>("dim"); + param_.reduce_all = opdesc.GetAttr("reduce_all"); + param_.keep_dim = opdesc.GetAttr("keep_dim"); + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(reduce_sum, paddle::lite::operators::ReduceOp); diff --git a/lite/operators/reduce_ops.h b/lite/operators/reduce_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..0063aba1fa606c6228e7dcb1197bfb36f57aa33c --- /dev/null +++ b/lite/operators/reduce_ops.h @@ -0,0 +1,46 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class ReduceOp : public OpLite { + public: + ReduceOp() {} + explicit ReduceOp(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "reduce"; } + + private: + mutable ReduceParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/sequence_reshape_op.cc b/lite/operators/sequence_reshape_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..c7e86af65033205bcb389cecff8db14721507142 --- /dev/null +++ b/lite/operators/sequence_reshape_op.cc @@ -0,0 +1,54 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/sequence_reshape_op.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool SequenceReshapeOp::CheckShape() const { + CHECK_OR_FALSE(param_.x); + CHECK_OR_FALSE(param_.output); + auto x_dims = param_.x->dims(); + CHECK_EQ_OR_FALSE(x_dims.size(), 2U); + return true; +} + +bool SequenceReshapeOp::InferShape() const { + int new_dim = param_.new_dim; + auto x_numel = param_.x->dims().production(); + std::vector out_shape{x_numel / new_dim, + static_cast(new_dim)}; + param_.output->Resize(lite::DDim(out_shape)); + return true; +} + +bool SequenceReshapeOp::AttachImpl(const cpp::OpDesc &opdesc, + lite::Scope *scope) { + param_.x = + scope->FindVar(opdesc.Input("X").front())->GetMutable(); + param_.output = + scope->FindVar(opdesc.Output("Out").front())->GetMutable(); + + param_.new_dim = opdesc.GetAttr("new_dim"); + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(sequence_reshape, paddle::lite::operators::SequenceReshapeOp); diff --git a/lite/operators/sequence_reshape_op.h b/lite/operators/sequence_reshape_op.h new file mode 100644 index 0000000000000000000000000000000000000000..c8378aebc44acf22017eee17f5b58d6ff4dd65bf --- /dev/null +++ b/lite/operators/sequence_reshape_op.h @@ -0,0 +1,47 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class SequenceReshapeOp : public OpLite { + public: + SequenceReshapeOp() {} + explicit SequenceReshapeOp(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "sequence_reshape"; } + + private: + mutable SequenceReshapeParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle