From f9930fc191f5289ad963fbd8b21ae0d711418fcd Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Tue, 19 Nov 2019 13:20:06 +0800 Subject: [PATCH] add search_seq_softmax op; regist search_seq_softmax x86 kernel and cuda kernel (#2445) test=develop --- lite/kernels/cuda/softmax_compute.cu | 15 +++++++ lite/kernels/x86/softmax_compute.cc | 9 +++++ lite/operators/CMakeLists.txt | 1 + lite/operators/search_seq_softmax_op.cc | 52 +++++++++++++++++++++++++ lite/operators/search_seq_softmax_op.h | 47 ++++++++++++++++++++++ 5 files changed, 124 insertions(+) create mode 100644 lite/operators/search_seq_softmax_op.cc create mode 100644 lite/operators/search_seq_softmax_op.h diff --git a/lite/kernels/cuda/softmax_compute.cu b/lite/kernels/cuda/softmax_compute.cu index d8d2987524..14ed391f7f 100644 --- a/lite/kernels/cuda/softmax_compute.cu +++ b/lite/kernels/cuda/softmax_compute.cu @@ -244,3 +244,18 @@ REGISTER_LITE_KERNEL(softmax, PRECISION(kFloat), DATALAYOUT(kNCHW))}) .Finalize(); +REGISTER_LITE_KERNEL(search_seq_softmax, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::SoftmaxCompute, + def) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .Finalize(); diff --git a/lite/kernels/x86/softmax_compute.cc b/lite/kernels/x86/softmax_compute.cc index a00aa6d566..3fe7b162a3 100644 --- a/lite/kernels/x86/softmax_compute.cc +++ b/lite/kernels/x86/softmax_compute.cc @@ -23,3 +23,12 @@ REGISTER_LITE_KERNEL(softmax, .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .Finalize(); +REGISTER_LITE_KERNEL(search_seq_softmax, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::SoftmaxCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index 4f4b13f931..30637c00b5 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -84,6 +84,7 @@ add_operator(reduce_sum_op_lite extra SRCS reduce_ops.cc DEPS ${op_DEPS}) add_operator(match_matrix_tensor_op_lite extra SRCS match_matrix_tensor_op.cc DEPS ${op_DEPS}) add_operator(search_seq_depadding_op_lite extra SRCS search_seq_depadding_op.cc DEPS ${op_DEPS}) add_operator(search_grnn_op_lite extra SRCS search_grnn_op.cc DEPS ${op_DEPS}) +add_operator(search_seq_softmax_op_lite extra SRCS search_seq_softmax_op.cc DEPS ${op_DEPS}) add_operator(sequence_concat_op_lite extra SRCS sequence_concat_op.cc DEPS ${op_DEPS}) add_operator(var_conv_2d_op_lite extra SRCS var_conv_2d_op.cc DEPS ${op_DEPS}) add_operator(attention_padding_mask_op_lite extra SRCS attention_padding_mask_op.cc DEPS ${op_DEPS}) diff --git a/lite/operators/search_seq_softmax_op.cc b/lite/operators/search_seq_softmax_op.cc new file mode 100644 index 0000000000..973ffa04c4 --- /dev/null +++ b/lite/operators/search_seq_softmax_op.cc @@ -0,0 +1,52 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/search_seq_softmax_op.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool SearchSeqSoftmaxOp::CheckShape() const { + CHECK_OR_FALSE(param_.x); + CHECK_OR_FALSE(param_.output); + return true; +} + +bool SearchSeqSoftmaxOp::InferShape() const { + param_.output->Resize(param_.x->dims()); + param_.output->set_lod(param_.x->lod()); + return true; +} + +bool SearchSeqSoftmaxOp::AttachImpl(const cpp::OpDesc &opdesc, + lite::Scope *scope) { + param_.x = const_cast( + &scope->FindVar(opdesc.Input("X").front())->Get()); + param_.output = + scope->FindVar(opdesc.Output("Out").front())->GetMutable(); + param_.axis = 1; + + CHECK(param_.x); + CHECK(param_.output); + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(search_seq_softmax, + paddle::lite::operators::SearchSeqSoftmaxOp); diff --git a/lite/operators/search_seq_softmax_op.h b/lite/operators/search_seq_softmax_op.h new file mode 100644 index 0000000000..f97e8ddd3a --- /dev/null +++ b/lite/operators/search_seq_softmax_op.h @@ -0,0 +1,47 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class SearchSeqSoftmaxOp : public OpLite { + public: + SearchSeqSoftmaxOp() {} + explicit SearchSeqSoftmaxOp(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "search_seq_softmax_op"; } + + private: + mutable SoftmaxParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle -- GitLab