提交 f9930fc1 编写于 作者: Z zhupengyang 提交者: 石晓伟

add search_seq_softmax op; regist search_seq_softmax x86 kernel and cuda kernel (#2445)

test=develop
上级 68fe5b5c
......@@ -244,3 +244,18 @@ REGISTER_LITE_KERNEL(softmax,
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.Finalize();
REGISTER_LITE_KERNEL(search_seq_softmax,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::SoftmaxCompute,
def)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.Finalize();
......@@ -23,3 +23,12 @@ REGISTER_LITE_KERNEL(softmax,
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
REGISTER_LITE_KERNEL(search_seq_softmax,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::SoftmaxCompute<float>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
......@@ -84,6 +84,7 @@ add_operator(reduce_sum_op_lite extra SRCS reduce_ops.cc DEPS ${op_DEPS})
add_operator(match_matrix_tensor_op_lite extra SRCS match_matrix_tensor_op.cc DEPS ${op_DEPS})
add_operator(search_seq_depadding_op_lite extra SRCS search_seq_depadding_op.cc DEPS ${op_DEPS})
add_operator(search_grnn_op_lite extra SRCS search_grnn_op.cc DEPS ${op_DEPS})
add_operator(search_seq_softmax_op_lite extra SRCS search_seq_softmax_op.cc DEPS ${op_DEPS})
add_operator(sequence_concat_op_lite extra SRCS sequence_concat_op.cc DEPS ${op_DEPS})
add_operator(var_conv_2d_op_lite extra SRCS var_conv_2d_op.cc DEPS ${op_DEPS})
add_operator(attention_padding_mask_op_lite extra SRCS attention_padding_mask_op.cc DEPS ${op_DEPS})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/search_seq_softmax_op.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool SearchSeqSoftmaxOp::CheckShape() const {
CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.output);
return true;
}
bool SearchSeqSoftmaxOp::InferShape() const {
param_.output->Resize(param_.x->dims());
param_.output->set_lod(param_.x->lod());
return true;
}
bool SearchSeqSoftmaxOp::AttachImpl(const cpp::OpDesc &opdesc,
lite::Scope *scope) {
param_.x = const_cast<lite::Tensor *>(
&scope->FindVar(opdesc.Input("X").front())->Get<lite::Tensor>());
param_.output =
scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>();
param_.axis = 1;
CHECK(param_.x);
CHECK(param_.output);
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(search_seq_softmax,
paddle::lite::operators::SearchSeqSoftmaxOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class SearchSeqSoftmaxOp : public OpLite {
public:
SearchSeqSoftmaxOp() {}
explicit SearchSeqSoftmaxOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "search_seq_softmax_op"; }
private:
mutable SoftmaxParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册