From d5e58a988acde3a96a2e4dfaa3680fa5f7c2ae77 Mon Sep 17 00:00:00 2001 From: lijianshe02 Date: Tue, 22 Oct 2019 09:20:26 +0000 Subject: [PATCH] add asr related kernel test=develop --- lite/api/paddle_place.h | 2 + lite/core/op_registry.cc | 3 + lite/fluid/eigen.h | 6 +- lite/kernels/x86/CMakeLists.txt | 3 + lite/kernels/x86/activation_compute.cc | 10 ++ lite/kernels/x86/activation_compute.h | 25 +++++ lite/kernels/x86/lookup_table_compute.cc | 34 +++++++ lite/kernels/x86/lookup_table_compute.h | 97 ++++++++++++++++++++ lite/kernels/x86/reduce_compute.cc | 25 +++++ lite/kernels/x86/reduce_compute.h | 83 +++++++++++++++++ lite/kernels/x86/reduce_op_function.h | 84 +++++++++++++++++ lite/kernels/x86/sequence_reshape_compute.cc | 25 +++++ lite/kernels/x86/sequence_reshape_compute.h | 81 ++++++++++++++++ lite/operators/CMakeLists.txt | 2 + lite/operators/activation_ops.cc | 1 + lite/operators/concat_op.cc | 2 +- lite/operators/op_params.h | 15 +++ lite/operators/reduce_ops.cc | 89 ++++++++++++++++++ lite/operators/reduce_ops.h | 46 ++++++++++ lite/operators/sequence_reshape_op.cc | 54 +++++++++++ lite/operators/sequence_reshape_op.h | 46 ++++++++++ 21 files changed, 730 insertions(+), 3 deletions(-) create mode 100644 lite/kernels/x86/lookup_table_compute.cc create mode 100644 lite/kernels/x86/lookup_table_compute.h create mode 100644 lite/kernels/x86/reduce_compute.cc create mode 100644 lite/kernels/x86/reduce_compute.h create mode 100644 lite/kernels/x86/reduce_op_function.h create mode 100644 lite/kernels/x86/sequence_reshape_compute.cc create mode 100644 lite/kernels/x86/sequence_reshape_compute.h create mode 100644 lite/operators/reduce_ops.cc create mode 100644 lite/operators/reduce_ops.h create mode 100644 lite/operators/sequence_reshape_op.cc create mode 100644 lite/operators/sequence_reshape_op.h diff --git a/lite/api/paddle_place.h b/lite/api/paddle_place.h index 5e4f2ed21c..d8da20ed1e 100644 --- a/lite/api/paddle_place.h +++ b/lite/api/paddle_place.h @@ -101,6 +101,8 @@ static size_t PrecisionTypeLength(PrecisionType type) { return 1; case PrecisionType::kInt32: return 4; + case PrecisionType::kInt64: + return 8; case PrecisionType::kFP16: return 2; default: diff --git a/lite/core/op_registry.cc b/lite/core/op_registry.cc index 0fdce27e3b..80b1757da5 100644 --- a/lite/core/op_registry.cc +++ b/lite/core/op_registry.cc @@ -54,6 +54,8 @@ std::list> KernelRegistry::Create( CREATE_KERNEL1(target__, kFP16); \ case PRECISION(kAny): \ CREATE_KERNEL1(target__, kAny); \ + case PRECISION(kInt64): \ + CREATE_KERNEL1(target__, kInt64); \ default: \ CHECK(false) << "not supported kernel precision " \ << PrecisionToStr(precision); \ @@ -123,6 +125,7 @@ KernelRegistry::KernelRegistry() INIT_FOR(kX86, kFloat, kNCHW); INIT_FOR(kX86, kAny, kNCHW); INIT_FOR(kX86, kAny, kAny); + INIT_FOR(kX86, kInt64, kNCHW); INIT_FOR(kARM, kFloat, kNCHW); INIT_FOR(kARM, kInt8, kNCHW); diff --git a/lite/fluid/eigen.h b/lite/fluid/eigen.h index f5d5e4b5e5..4314a6c492 100644 --- a/lite/fluid/eigen.h +++ b/lite/fluid/eigen.h @@ -32,7 +32,7 @@ struct EigenDim { static Type From(const lite::DDim& dims) { PADDLE_ENFORCE(dims.size() == D, "D must match DDim::size"); Type ret; - for (int64_t d = 0; d < dims.size(); d++) { + for (size_t d = 0; d < dims.size(); d++) { ret[d] = dims[d]; } return ret; @@ -118,7 +118,9 @@ struct EigenScalar { using ConstType = Eigen::TensorMap< Eigen::TensorFixedSize, MajorType, IndexType>>; - static Type From(Tensor& tensor) { return Type(tensor.data()); } // NOLINT + static Type From(const Tensor& tensor) { + return Type(const_cast(tensor.data())); + } // NOLINT static ConstType From(const Tensor& tensor) { return ConstType(tensor.data()); diff --git a/lite/kernels/x86/CMakeLists.txt b/lite/kernels/x86/CMakeLists.txt index 7442a7be8b..2af3f46bac 100644 --- a/lite/kernels/x86/CMakeLists.txt +++ b/lite/kernels/x86/CMakeLists.txt @@ -35,6 +35,9 @@ add_kernel(sequence_pool_compute_x86 X86 basic SRCS sequence_pool_compute.cc DEP add_kernel(softmax_compute_x86 X86 basic SRCS softmax_compute.cc DEPS ${lite_kernel_deps} softmax) add_kernel(elementwise_compute_x86 X86 basic SRCS elementwise_compute.cc DEPS ${lite_kernel_deps}) add_kernel(batch_norm_compute_x86 X86 basic SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(reduce_sum_compute_x86 X86 basic SRCS reduce_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(lookup_table_compute_x86 X86 basic SRCS lookup_table_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(sequence_reshape_compute_x86 X86 basic SRCS sequence_reshape_compute.cc DEPS ${lite_kernel_deps}) if(NOT LITE_WITH_X86) return() diff --git a/lite/kernels/x86/activation_compute.cc b/lite/kernels/x86/activation_compute.cc index 0ed09c43a5..212f83a5f6 100644 --- a/lite/kernels/x86/activation_compute.cc +++ b/lite/kernels/x86/activation_compute.cc @@ -35,3 +35,13 @@ REGISTER_LITE_KERNEL(relu, .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .Finalize(); + +REGISTER_LITE_KERNEL(softsign, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::SoftsignCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/lite/kernels/x86/activation_compute.h b/lite/kernels/x86/activation_compute.h index 2775240194..902c95422f 100644 --- a/lite/kernels/x86/activation_compute.h +++ b/lite/kernels/x86/activation_compute.h @@ -115,6 +115,31 @@ class ReluCompute : public KernelLite { virtual ~ReluCompute() = default; }; +// softsign(x) = x / (1 + |x|) +template +struct SoftsignFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) { + out.device(d) = x / (static_cast(1) + x.abs()); + } +}; + +template +class SoftsignCompute : public KernelLite { + public: + using param_t = operators::ActivationParam; + + void Run() override { + // auto& context = ctx_->As(); + auto& param = *param_.get_mutable(); + param.Out->template mutable_data(); + + Activate>(param.X, param.Out); + } + + virtual ~SoftsignCompute() = default; +}; + } // namespace x86 } // namespace kernels } // namespace lite diff --git a/lite/kernels/x86/lookup_table_compute.cc b/lite/kernels/x86/lookup_table_compute.cc new file mode 100644 index 0000000000..cd61b6cc15 --- /dev/null +++ b/lite/kernels/x86/lookup_table_compute.cc @@ -0,0 +1,34 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/lookup_table_compute.h" + +// REGISTER_LITE_KERNEL(lookup_table, kX86, kFloat, kNCHW, +// paddle::lite::kernels::x86::LookupTableCompute, +// def) +// .BindInput("W", {LiteType::GetTensorTy(TARGET(kX86))}) +// .BindInput("Ids", {LiteType::GetTensorTy(TARGET(kX86))}) +// .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) +// .Finalize(); +//, +REGISTER_LITE_KERNEL(lookup_table, + kX86, + kInt64, + kNCHW, + paddle::lite::kernels::x86::LookupTableCompute, + def) + .BindInput("W", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Ids", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/lite/kernels/x86/lookup_table_compute.h b/lite/kernels/x86/lookup_table_compute.h new file mode 100644 index 0000000000..d0fa299cd6 --- /dev/null +++ b/lite/kernels/x86/lookup_table_compute.h @@ -0,0 +1,97 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +#include "lite/fluid/eigen.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +/*struct LookupTableTimer { + std::chrono::time_point timer_{}; + uint64_t total_{}; + + void Start() { timer_ = std::chrono::high_resolution_clock::now(); } + void Stop() { + auto duration = std::chrono::duration_cast( + std::chrono::high_resolution_clock::now() - timer_); + Log(duration.count()); + } + void Log(uint32_t timespan) { total_ += timespan; } + ~LookupTableTimer() { + LOG(INFO) << "lookup table timer: [" << total_ << "us]"; + } +};*/ + +template +class LookupTableCompute : public KernelLite { + public: + using param_t = operators::LookupTableParam; + + void Run() override { + auto ¶m = *param_.get_mutable(); + // auto& context = context_->As(); + auto *ids_t = param.Ids; + auto *output_t = param.Out; + LOG(INFO) << "lookup_table input ids tensor address: " << ids_t; + int64_t padding_idx = param.padding_idx; + auto *ids = ids_t->data(); + LOG(INFO) << "ids data address: " << ids; + int64_t ids_numel = ids_t->dims().production(); + std::cout << "ids_numel: " << ids_numel << std::endl; + std::cout << "ids tensor info: ["; + for (size_t i = 0; i < ids_numel; ++i) { + std::cout << ids[i] << ","; + } + std::cout << "]" << std::endl; + + auto *table_t = param.W; + int64_t row_number = table_t->dims()[0]; + int64_t row_width = table_t->dims()[1]; + std::cout << "row_number: " << row_number << std::endl; + std::cout << "row_width: " << row_width << std::endl; + + auto *table = table_t->data(); + auto *output = output_t->mutable_data(); + memset(output, 0, output_t->dims().production() * sizeof(T)); + for (int64_t i = 0; i < ids_numel; ++i) { + if (padding_idx != -1 && ids[i] == padding_idx) { + memset(output + i * row_width, 0, row_width * sizeof(float)); + } else { + std::cout << "*************************" << std::endl; + std::cout << "ids[i]: " << ids[i] << std::endl; + std::cout << "row_number: " << row_number << std::endl; + std::cout << "*************************" << std::endl; + CHECK_LT(ids[i], row_number); + CHECK_GE(ids[i], 0); + memcpy(output + i * row_width, + table + ids[i] * row_width, + row_width * sizeof(float)); + } + } + } + + virtual ~LookupTableCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/reduce_compute.cc b/lite/kernels/x86/reduce_compute.cc new file mode 100644 index 0000000000..f95f4cfb88 --- /dev/null +++ b/lite/kernels/x86/reduce_compute.cc @@ -0,0 +1,25 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/reduce_compute.h" + +REGISTER_LITE_KERNEL(reduce_sum, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::ReduceSumCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/lite/kernels/x86/reduce_compute.h b/lite/kernels/x86/reduce_compute.h new file mode 100644 index 0000000000..faace5e24e --- /dev/null +++ b/lite/kernels/x86/reduce_compute.h @@ -0,0 +1,83 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +#include "lite/fluid/eigen.h" +#include "lite/kernels/x86/reduce_op_function.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +struct SumFunctor { + template + void operator()(X* x, Y* y, const Dim& dim) { + y->device(lite::fluid::EigenDeviceType()) = x->sum(dim); + } +}; + +#define HANDLE_DIM(NDIM, RDIM) \ + if (ndim == NDIM && rdim == RDIM) { \ + paddle::lite::kernels::x86:: \ + ReduceFunctor( \ + *input, output, dims, keep_dim); \ + } + +template +class ReduceSumCompute : public KernelLite { + public: + using param_t = operators::ReduceParam; + + void Run() override { + auto& param = *param_.get_mutable(); + // auto& context = ctx_->As(); + bool reduce_all = param.reduce_all; + auto* input = param.x; + auto* output = param.output; + param.output->mutable_data(); + + auto dims = param.dim; + bool keep_dim = param.keep_dim; + if (reduce_all) { + // Flatten and reduce 1-D tensor + auto x = lite::fluid::EigenVector::Flatten(*input); + auto out = lite::fluid::EigenScalar::From(*output); + // auto& place = *platform::CPUDeviceContext().eigen_device(); + auto reduce_dim = Eigen::array({{0}}); + SumFunctor functor; + functor(&x, &out, reduce_dim); + } else { + int ndim = input->dims().size(); + int rdim = dims.size(); + HANDLE_DIM(4, 3); + HANDLE_DIM(4, 2); + HANDLE_DIM(4, 1); + HANDLE_DIM(3, 2); + HANDLE_DIM(3, 1); + HANDLE_DIM(2, 1); + HANDLE_DIM(1, 1); + } + } + + virtual ~ReduceSumCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/reduce_op_function.h b/lite/kernels/x86/reduce_op_function.h new file mode 100644 index 0000000000..be3ef6ed0d --- /dev/null +++ b/lite/kernels/x86/reduce_op_function.h @@ -0,0 +1,84 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/op_registry.h" +#include "lite/fluid/eigen.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +using EigenTensor = lite::fluid::EigenTensor; +template +using EigenScalar = lite::fluid::EigenScalar; +template +using EigenVector = lite::fluid::EigenVector; + +template +// const lite::Context& context, +void ReduceFunctor(const lite::Tensor& input, + lite::Tensor* output, + const std::vector& dims, + bool keep_dim) { + auto x = EigenTensor::From(input); + auto x_rank = static_cast(x.dimensions().size()); + auto reduce_dim = Eigen::array(); + std::vector dims_ref = dims; + for (size_t i = 0; i < dims_ref.size(); ++i) { + if (dims_ref[i] < 0) dims_ref[i] = x_rank + dims_ref[i]; + reduce_dim[i] = dims_ref[i]; + } + // construct the squeezed output tensor + lite::DDim out_dims = output->dims(); + if (keep_dim && x_rank > 1) { + const int kDelFlag = -2; + auto dims_vector = out_dims.Vectorize(); + for (size_t i = 0; i < dims_ref.size(); ++i) { + dims_vector[dims_ref[i]] = kDelFlag; + } + dims_vector.erase(remove(dims_vector.begin(), dims_vector.end(), kDelFlag), + dims_vector.end()); + out_dims = lite::DDim(dims_vector); + } + // auto& place = *context.eigen_device(); + Functor functor; + + if (D == 1) { + auto out = EigenScalar::From(*output); + functor(&x, &out, reduce_dim); + } else { + auto out = EigenTensor::From(*output, out_dims); + functor(&x, &out, reduce_dim); + } +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/sequence_reshape_compute.cc b/lite/kernels/x86/sequence_reshape_compute.cc new file mode 100644 index 0000000000..62cd35b4ee --- /dev/null +++ b/lite/kernels/x86/sequence_reshape_compute.cc @@ -0,0 +1,25 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/sequence_reshape_compute.h" + +REGISTER_LITE_KERNEL(sequence_reshape, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::SequenceReshapeCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) + .Finalize(); diff --git a/lite/kernels/x86/sequence_reshape_compute.h b/lite/kernels/x86/sequence_reshape_compute.h new file mode 100644 index 0000000000..3e2108c1d2 --- /dev/null +++ b/lite/kernels/x86/sequence_reshape_compute.h @@ -0,0 +1,81 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +#include "lite/fluid/eigen.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +class SequenceReshapeCompute + : public KernelLite { + public: + using param_t = operators::SequenceReshapeParam; + + void Run() override { + auto& param = *param_.get_mutable(); + // auto& context = context_->As(); + auto* in = param.x; + auto* out = param.output; + int out_width = param.new_dim; + + auto in_dims = in->dims(); + int64_t in_width = in_dims[1]; + // LOG(INFO)<<"sequence_reshape in tensor:"<<*in; + auto& in_lod = in->lod(); + + CHECK_EQ(in_lod.size(), 1UL); + CHECK_EQ((uint64_t)in_dims[0], in_lod[0].back()); + + auto in_lod_l0 = in_lod[0]; + int seq_num = in_lod_l0.size() - 1; + + if (in_width == out_width) { + out->set_lod(in->lod()); + } else { + auto& out_lod = *out->mutable_lod(); + out_lod.resize(1); + out_lod[0].resize(seq_num + 1); + out_lod[0][0] = 0; + for (int i = 0; i < seq_num; ++i) { + size_t seq_len = in_lod_l0[i + 1] - in_lod_l0[i]; + size_t offset = 0; + offset = (seq_len * in_width) / out_width; + CHECK_EQ(offset * out_width, seq_len * in_width); + out_lod[0][i + 1] = out_lod[0][i] + offset; + } + } + + out->Resize(in_dims); + auto* dst_ptr = out->mutable_data(); + auto size = in->numel() * sizeof(T); + std::memcpy(dst_ptr, in->data(), size); + std::vector out_shape{static_cast(out->lod()[0].back()), + out_width}; + out->Resize(lite::DDim(out_shape)); + } + + virtual ~SequenceReshapeCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index 2d23d8bb06..7d6964345b 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -76,6 +76,8 @@ add_operator(sequence_expand_as_op_lite basic SRCS sequence_expand_as_op.cc DEPS add_operator(range_op basic SRCS range_op.cc DEPS ${op_DEPS}) add_operator(assign_value_op basic SRCS assign_value_op.cc DEPS ${op_DEPS}) add_operator(fake_quantize_dequantize_moving_avg_abs_max_op basic SRCS fake_quantize_dequantize_moving_avg_max_abs.cc DEPS ${op_DEPS}) +add_operator(sequence_reshape_op_lite basic SRCS sequence_reshape_op.cc DEPS ${op_DEPS}) +add_operator(reduce_sum_op_lite basic SRCS reduce_ops.cc DEPS ${op_DEPS}) # for OCR specific add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS}) diff --git a/lite/operators/activation_ops.cc b/lite/operators/activation_ops.cc index a7f2d28cc9..c3c5de311f 100644 --- a/lite/operators/activation_ops.cc +++ b/lite/operators/activation_ops.cc @@ -118,6 +118,7 @@ REGISTER_LITE_OP(exp, paddle::lite::operators::ActivationOp); REGISTER_LITE_OP(floor, paddle::lite::operators::ActivationOp); REGISTER_LITE_OP(hard_sigmoid, paddle::lite::operators::ActivationOp); REGISTER_LITE_OP(rsqrt, paddle::lite::operators::ActivationOp); +REGISTER_LITE_OP(softsign, paddle::lite::operators::ActivationOp); #ifdef LITE_WITH_TRAIN REGISTER_LITE_OP(square_grad, paddle::lite::operators::ActivationGradOp); diff --git a/lite/operators/concat_op.cc b/lite/operators/concat_op.cc index f073faf6b9..dfd95e4658 100644 --- a/lite/operators/concat_op.cc +++ b/lite/operators/concat_op.cc @@ -21,7 +21,7 @@ namespace lite { namespace operators { bool ConcatOpLite::CheckShape() const { - CHECK_GT_OR_FALSE(param_.x.size(), 1UL); + CHECK_GE_OR_FALSE(param_.x.size(), 1UL); CHECK_OR_FALSE(param_.output); return true; } diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 3071f6f907..5f23415e83 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -717,6 +717,12 @@ struct SequencePoolParam { #endif }; +struct SequenceReshapeParam { + lite::Tensor* x{}; + lite::Tensor* output{}; + int new_dim; +}; + struct SequenceExpandParam { const lite::Tensor* X{}; const lite::Tensor* Y{}; @@ -749,6 +755,15 @@ struct IsEmptyParam { const lite::Tensor* X{}; lite::Tensor* Out{}; }; + +struct ReduceParam { + lite::Tensor* x{}; + lite::Tensor* output{}; + std::vector dim{0}; + bool keep_dim{false}; + bool reduce_all{false}; +}; + /// ----------------------- shape operators ---------------------- struct ShapeParam { const lite::Tensor* X{}; diff --git a/lite/operators/reduce_ops.cc b/lite/operators/reduce_ops.cc new file mode 100644 index 0000000000..e986b0ca54 --- /dev/null +++ b/lite/operators/reduce_ops.cc @@ -0,0 +1,89 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/reduce_ops.h" +#include +#include "lite/core/op_registry.h" +namespace paddle { +namespace lite { +namespace operators { + +bool ReduceOp::CheckShape() const { + CHECK_OR_FALSE(param_.x); + CHECK_OR_FALSE(param_.output); + auto x_dims = param_.x->dims(); + auto x_rank = x_dims.size(); + CHECK_LE(x_rank, 6UL) << "Tensors with rank at most 6 are supported."; + return true; +} + +bool ReduceOp::InferShape() const { + auto x_dims = param_.x->dims(); + auto x_rank = x_dims.size(); + auto dims = param_.dim; + for (size_t i = 0; i < dims.size(); ++i) { + if (dims[i] < 0) dims[i] = x_rank + dims[i]; + CHECK_LT(dims[i], x_rank) + << "The dim should be in the range [-rank(input), rank(input)."; + } + sort(dims.begin(), dims.end()); + bool reduce_all = param_.reduce_all; + bool keep_dim = param_.keep_dim; + + if (reduce_all) { + if (keep_dim) + param_.output->Resize(lite::DDim(std::vector(x_rank, 1))); + else + param_.output->Resize(lite::DDim(std::vector{1})); + } else { + auto dims_vector = x_dims.Vectorize(); + if (keep_dim) { + for (size_t i = 0; i < dims.size(); ++i) { + dims_vector[dims[i]] = 1; + } + } else { + const int kDelFlag = -2; + for (size_t i = 0; i < dims.size(); ++i) { + dims_vector[dims[i]] = kDelFlag; + } + dims_vector.erase( + remove(dims_vector.begin(), dims_vector.end(), kDelFlag), + dims_vector.end()); + } + auto out_dims = lite::DDim(dims_vector); + param_.output->Resize(out_dims); + if (dims[0] != 0) { + param_.output->set_lod(param_.x->lod()); + } + } + return true; +} + +bool ReduceOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { + param_.x = + scope->FindVar(opdesc.Input("X").front())->GetMutable(); + param_.output = + scope->FindVar(opdesc.Output("Out").front())->GetMutable(); + + param_.dim = opdesc.GetAttr>("dim"); + param_.reduce_all = opdesc.GetAttr("reduce_all"); + param_.keep_dim = opdesc.GetAttr("keep_dim"); + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(reduce_sum, paddle::lite::operators::ReduceOp); diff --git a/lite/operators/reduce_ops.h b/lite/operators/reduce_ops.h new file mode 100644 index 0000000000..0063aba1fa --- /dev/null +++ b/lite/operators/reduce_ops.h @@ -0,0 +1,46 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class ReduceOp : public OpLite { + public: + ReduceOp() {} + explicit ReduceOp(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "reduce"; } + + private: + mutable ReduceParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/sequence_reshape_op.cc b/lite/operators/sequence_reshape_op.cc new file mode 100644 index 0000000000..c7e86af650 --- /dev/null +++ b/lite/operators/sequence_reshape_op.cc @@ -0,0 +1,54 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/sequence_reshape_op.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool SequenceReshapeOp::CheckShape() const { + CHECK_OR_FALSE(param_.x); + CHECK_OR_FALSE(param_.output); + auto x_dims = param_.x->dims(); + CHECK_EQ_OR_FALSE(x_dims.size(), 2U); + return true; +} + +bool SequenceReshapeOp::InferShape() const { + int new_dim = param_.new_dim; + auto x_numel = param_.x->dims().production(); + std::vector out_shape{x_numel / new_dim, + static_cast(new_dim)}; + param_.output->Resize(lite::DDim(out_shape)); + return true; +} + +bool SequenceReshapeOp::AttachImpl(const cpp::OpDesc &opdesc, + lite::Scope *scope) { + param_.x = + scope->FindVar(opdesc.Input("X").front())->GetMutable(); + param_.output = + scope->FindVar(opdesc.Output("Out").front())->GetMutable(); + + param_.new_dim = opdesc.GetAttr("new_dim"); + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(sequence_reshape, paddle::lite::operators::SequenceReshapeOp); diff --git a/lite/operators/sequence_reshape_op.h b/lite/operators/sequence_reshape_op.h new file mode 100644 index 0000000000..6f06680bb0 --- /dev/null +++ b/lite/operators/sequence_reshape_op.h @@ -0,0 +1,46 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class SequenceReshapeOp : public OpLite { + public: + SequenceReshapeOp() {} + explicit SequenceReshapeOp(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "sequence_reshape"; } + + private: + mutable SequenceReshapeParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle -- GitLab