未验证 提交 47869a59 编写于 作者: C cc 提交者: GitHub

Add hard_swish, ctc_align and reciprocal op (#3354)

* Add hard_swish, ctc_align and reciprocal op, test=develop
* Move some activation ops to extra, test=develop
上级 99deb7d9
......@@ -100,7 +100,9 @@ enum class ActivationType : int {
kSwish = 7,
kExp = 8,
kAbs = 9,
NUM = 10,
kHardSwish = 10,
kReciprocal = 11,
NUM = 12,
};
static size_t PrecisionTypeLength(PrecisionType type) {
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "lite/backends/arm/math/activation.h"
#include <algorithm>
#include <string>
#include "lite/backends/arm/math/funcs.h"
......@@ -711,6 +712,38 @@ void act_square<float>(const float* din, float* dout, int size, int threads) {
}
}
template <>
void act_hard_swish<float>(const float* din,
float* dout,
int size,
float threshold,
float scale,
float offset,
int threads) {
const float* ptr_in = din;
float* ptr_out = dout;
for (int i = 0; i < size; ++i) {
ptr_out[0] = std::min(std::max(0.f, ptr_in[0] + offset), threshold) *
ptr_in[0] / scale;
ptr_in++;
ptr_out++;
}
}
template <>
void act_reciprocal<float>(const float* din,
float* dout,
int size,
int threads) {
const float* ptr_in = din;
float* ptr_out = dout;
for (int i = 0; i < size; ++i) {
ptr_out[0] = 1.0 / ptr_in[0];
ptr_in++;
ptr_out++;
}
}
#ifdef LITE_WITH_TRAIN
template <>
void act_square_grad(const float* din,
......
......@@ -72,6 +72,17 @@ void act_rsqrt(const T* din, T* dout, int size, int threads);
template <typename T>
void act_square(const T* din, T* dout, int size, int threads);
template <typename T>
void act_hard_swish(const T* din,
T* dout,
int size,
float threshold,
float scale,
float offset,
int threads);
template <typename T>
void act_reciprocal(const T* din, T* dout, int size, int threads);
#ifdef LITE_WITH_TRAIN
template <typename T>
void act_square_grad(
......
......@@ -157,5 +157,33 @@ Tensor *OpLite::GetMutableTensor(lite::Scope *scope,
return var->GetMutable<lite::Tensor>();
}
void OpLite::AttachInput(const cpp::OpDesc &op_desc,
lite::Scope *scope,
const std::string &input_name,
bool is_dispensable,
lite::Tensor **input_var) {
bool is_have_input =
op_desc.HasInput(input_name) && op_desc.Input(input_name).size() > 0;
CHECK(is_dispensable || is_have_input);
if (is_have_input) {
std::string input_var_name = op_desc.Input(input_name).front();
*input_var = scope->FindVar(input_var_name)->GetMutable<lite::Tensor>();
}
}
void OpLite::AttachOutput(const cpp::OpDesc &op_desc,
lite::Scope *scope,
const std::string &output_name,
bool is_dispensable,
lite::Tensor **output_var) {
bool is_have_output =
op_desc.HasOutput(output_name) && op_desc.Output(output_name).size() > 0;
CHECK(is_dispensable || is_have_output);
if (is_have_output) {
std::string output_var_name = op_desc.Output(output_name).front();
*output_var = scope->FindVar(output_var_name)->GetMutable<lite::Tensor>();
}
}
} // namespace lite
} // namespace paddle
......@@ -105,6 +105,20 @@ class OpLite : public Registry {
return kernel_.get();
}
// Attach input variable from scope by op_desc and input name
void AttachInput(const cpp::OpDesc &op_desc,
lite::Scope *scope,
const std::string &input_name,
bool is_dispensable,
lite::Tensor **input_var);
// Attach output variable from scope by op_desc and output name
void AttachOutput(const cpp::OpDesc &op_desc,
lite::Scope *scope,
const std::string &output_name,
bool is_dispensable,
lite::Tensor **output_var);
virtual ~OpLite() = default;
protected:
......
......@@ -152,6 +152,8 @@ KernelRegistry::KernelRegistry()
INIT_FOR(kMLU, kInt16, kNCHW);
INIT_FOR(kHost, kFloat, kNCHW);
INIT_FOR(kHost, kInt32, kNCHW);
INIT_FOR(kHost, kInt64, kNCHW);
INIT_FOR(kHost, kAny, kNCHW);
INIT_FOR(kHost, kFloat, kNHWC);
INIT_FOR(kHost, kFloat, kAny);
......
......@@ -135,6 +135,12 @@ class KernelRegistry final {
KernelRegistryForTarget<TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny)> *, //
KernelRegistryForTarget<TARGET(kHost),
PRECISION(kInt32),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kHost),
PRECISION(kInt64),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kCUDA),
PRECISION(kAny),
DATALAYOUT(kAny)> *, //
......
......@@ -179,6 +179,34 @@ void SquareCompute::Run() {
x_data, output_data, x_dims.production(), ctx.threads());
}
void HardSwishCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<ARMContext>();
auto x_dims = param.X->dims();
auto x_data = param.X->data<float>();
auto output_data = param.Out->mutable_data<float>();
float threshold = param.hard_swish_threshold;
float scale = param.hard_swish_scale;
float offset = param.hard_swish_offset;
lite::arm::math::act_hard_swish<float>(x_data,
output_data,
x_dims.production(),
threshold,
scale,
offset,
ctx.threads());
}
void ReciprocalCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<ARMContext>();
auto x_dims = param.X->dims();
auto x_data = param.X->data<float>();
auto output_data = param.Out->mutable_data<float>();
lite::arm::math::act_reciprocal<float>(
x_data, output_data, x_dims.production(), ctx.threads());
}
} // namespace arm
} // namespace kernels
} // namespace lite
......@@ -275,3 +303,21 @@ REGISTER_LITE_KERNEL(
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
REGISTER_LITE_KERNEL(hard_swish,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::HardSwishCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
REGISTER_LITE_KERNEL(reciprocal,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::ReciprocalCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
......@@ -148,6 +148,24 @@ class SquareCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
virtual ~SquareCompute() = default;
};
class HardSwishCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::ActivationParam;
void Run() override;
virtual ~HardSwishCompute() = default;
};
class ReciprocalCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::ActivationParam;
void Run() override;
virtual ~ReciprocalCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
......
......@@ -5,3 +5,4 @@ add_kernel(fetch_compute_host Host basic SRCS fetch_compute.cc DEPS ${lite_kerne
add_kernel(reshape_compute_host Host basic SRCS reshape_compute.cc DEPS ${lite_kernel_deps} reshape_op)
add_kernel(multiclass_nms_compute_host Host basic SRCS multiclass_nms_compute.cc DEPS ${lite_kernel_deps})
add_kernel(crf_decoding_compute_host Host extra SRCS crf_decoding_compute.cc DEPS ${lite_kernel_deps})
add_kernel(ctc_align_compute_host Host extra SRCS ctc_align_compute.cc DEPS ${lite_kernel_deps})
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/host/ctc_align_compute.h"
#include <algorithm>
#include <cstring>
#include <map>
#include <utility>
#include <vector>
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
LoD ToAbs(const LoD& in) {
if (in.empty()) return in;
LoD result;
for (auto& src : in) {
std::vector<uint64_t> dest(src.size() + 1, 0);
for (int i = 0; i < src.size(); i++) {
dest[i + 1] = dest[i] + src[i];
}
result.emplace_back(dest);
}
return result;
}
LoD ToNorm(const LoD& in) {
if (in.empty()) return in;
LoD result;
for (auto& src : in) {
std::vector<uint64_t> dest(src.size() - 1, 0);
for (int i = 0; i < dest.size(); i++) {
dest[i] = src[i + 1] - src[i];
}
result.emplace_back(dest);
}
return result;
}
LoD ToAbsOffset(const LoD& in) {
// the lowest level stores relative offsets
if (in.empty() || in.size() == 1) return in;
LoD result = in;
for (auto level = static_cast<int>(in.size() - 2); level >= 0; level--) {
for (size_t i = 0; i < in[level].size(); ++i) {
size_t index = in[level][i];
result[level][i] = result[level + 1][index];
}
}
return result;
}
template <typename T, PrecisionType PT>
void CtcAlignCompute<T, PT>::Run() {
auto& param = this->template Param<operators::CtcAlignParam>();
auto* input = param.input;
auto* output = param.output;
size_t blank = static_cast<size_t>(param.blank);
bool merge_repeated = param.merge_repeated;
size_t padding_value = static_cast<size_t>(param.padding_value);
const auto* input_data = input->template data<T>();
auto input_dims = input->dims();
auto* output_data = output->template mutable_data<T>();
if (input->lod().empty()) {
auto* input_length = param.input_length;
auto* output_length = param.output_length;
CHECK(input_length != nullptr);
CHECK(output_length != nullptr);
const auto* input_length_data = input_length->template data<T>();
auto* output_length_data = output_length->template mutable_data<T>();
for (size_t batch_id = 0; batch_id < (unsigned)input_dims[0]; batch_id++) {
T prev_token = -1;
size_t output_idx = 0;
for (size_t i = 0; i < (unsigned)input_length_data[batch_id]; i++) {
size_t input_ind = batch_id * input_dims[1] + i;
if ((unsigned)input_data[input_ind] != blank &&
!(merge_repeated && input_data[input_ind] == prev_token)) {
output_data[batch_id * input_dims[1] + output_idx] =
input_data[input_ind];
++output_idx;
}
prev_token = input_data[input_ind];
}
output_length_data[batch_id] = output_idx;
for (size_t j = output_idx; j < (unsigned)input_dims[1]; j++)
output_data[batch_id * input_dims[1] + j] = padding_value;
}
} else {
const size_t level = 0;
auto input_lod = input->lod();
input_lod = ToAbs(input->lod());
input_lod = ToAbsOffset(input_lod);
CHECK_EQ(input_dims[0], static_cast<int64_t>(input_lod[level].back()));
const size_t num_sequences = input_lod[level].size() - 1;
// merge repeated tokens and delete blank
size_t output_idx = 0;
std::vector<uint64_t> output_lod0(1, 0);
for (size_t seq_idx = 0; seq_idx < num_sequences; ++seq_idx) {
T prev_token = -1;
for (size_t i = input_lod[level][seq_idx];
i < input_lod[level][seq_idx + 1];
++i) {
if ((unsigned)input_data[i] != blank &&
!(merge_repeated && input_data[i] == prev_token)) {
output_data[output_idx] = input_data[i];
++output_idx;
}
prev_token = input_data[i];
}
output_lod0.push_back(static_cast<uint64_t>(output_idx));
}
LoD output_lod;
output_lod.push_back(output_lod0);
output_lod = ToNorm(output_lod);
output->set_lod(output_lod);
output->Resize({static_cast<int64_t>(output_lod0.back()), 1});
if (output_lod0.back() == 0) {
output->Resize({1, 1});
output_data = output->template mutable_data<T>();
output_data[0] = -1;
}
}
}
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
using ctc_align_int64 =
paddle::lite::kernels::host::CtcAlignCompute<int64_t, PRECISION(kInt64)>;
REGISTER_LITE_KERNEL(ctc_align, kHost, kInt64, kNCHW, ctc_align_int64, def)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt64))})
.BindInput("InputLength",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt64))})
.BindOutput("Output",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt64))})
.BindOutput("OutputLength",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt64))})
.Finalize();
using ctc_align_int32 =
paddle::lite::kernels::host::CtcAlignCompute<int32_t, PRECISION(kInt32)>;
REGISTER_LITE_KERNEL(ctc_align, kHost, kInt32, kNCHW, ctc_align_int32, def)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
.BindInput("InputLength",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
.BindOutput("Output",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
.BindOutput("OutputLength",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
template <typename T, PrecisionType PT>
class CtcAlignCompute : public KernelLite<TARGET(kHost), PT> {
public:
void Run() override;
virtual ~CtcAlignCompute() = default;
};
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -2,7 +2,7 @@ if(NOT LITE_WITH_X86)
return()
endif()
add_kernel(activation_compute_x86 X86 basic SRCS activation_compute.cc DEPS ${lite_kernel_deps} activation_ops math_function)
add_kernel(activation_compute_x86 X86 basic SRCS activation_compute.cc DEPS ${lite_kernel_deps} math_function)
# lite_cc_library(mean_compute_x86 SRCS mean_compute.cc DEPS ${lite_kernel_deps})
# lite_cc_library(fill_constant_compute_x86 SRCS fill_constant_compute.cc DEPS ${lite_kernel_deps})
# lite_cc_library(sgd_compute_x86 SRCS sgd_compute.cc DEPS ${lite_kernel_deps})
......
......@@ -21,7 +21,7 @@
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
#include "lite/fluid/eigen.h"
#include "lite/operators/activation_ops.h"
#include "lite/operators/op_params.h"
namespace paddle {
namespace lite {
......
......@@ -14,7 +14,7 @@ add_operator(reshape_op basic SRCS reshape_op.cc DEPS ${op_DEPS} )
add_operator(batch_norm_op basic SRCS batch_norm_op.cc DEPS ${op_DEPS})
add_operator(feed_op basic SRCS feed_op.cc DEPS ${op_DEPS})
add_operator(fetch_op basic SRCS fetch_op.cc DEPS ${op_DEPS})
add_operator(activation_ops basic SRCS activation_ops.cc DEPS ${op_DEPS})
add_operator(activation_basic_ops basic SRCS activation_ops.cc DEPS ${op_DEPS})
add_operator(elementwise_ops basic SRCS elementwise_ops.cc DEPS ${op_DEPS})
add_operator(box_coder_op_lite basic SRCS box_coder_op.cc DEPS ${op_DEPS})
add_operator(multiclass_nms_op_lite basic SRCS multiclass_nms_op.cc DEPS ${op_DEPS})
......@@ -60,6 +60,7 @@ add_operator(power_op extra SRCS power_op.cc DEPS ${op_DEPS})
add_operator(norm_op extra SRCS norm_op.cc DEPS ${op_DEPS})
# 3.extra ops
add_operator(activation_extra_ops extra SRCS activation_extra_ops.cc DEPS ${op_DEPS})
add_operator(search_group_padding extra SRCS search_group_padding_op.cc DEPS ${op_DEPS})
add_operator(lrn_op_lite extra SRCS lrn_op.cc DEPS ${op_DEPS})
add_operator(decode_bboxes_op_lite extra SRCS decode_bboxes_op.cc DEPS ${op_DEPS})
......@@ -106,6 +107,7 @@ add_operator(conditional_block_op_lite extra SRCS conditional_block_op.cc DEPS $
add_operator(collect_fpn_proposals_op_lite extra SRCS collect_fpn_proposals_op.cc DEPS ${op_DEPS})
add_operator(distribute_fpn_proposals_op_lite extra SRCS distribute_fpn_proposals_op.cc DEPS ${op_DEPS})
add_operator(crf_decoding_op_lite extra SRCS crf_decoding_op.cc DEPS ${op_DEPS})
add_operator(ctc_align_op_lite extra SRCS ctc_align_op.cc DEPS ${op_DEPS})
# for OCR specific
add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.i
#include "lite/core/op_registry.h"
#include "lite/operators/activation_ops.h"
// Extra activation ops
REGISTER_LITE_OP(square, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(relu_clipped, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(swish, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(log, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(exp, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(abs, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(floor, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(hard_sigmoid, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(sqrt, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(rsqrt, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(softsign, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(gelu, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(hard_swish, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(reciprocal, paddle::lite::operators::ActivationOp);
......@@ -74,6 +74,14 @@ bool ActivationOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
} else if (opdesc.Type() == "abs") {
// abs
param_.active_type = lite_api::ActivationType::kAbs;
} else if (opdesc.Type() == "hard_swish") {
// hard_swish
param_.active_type = lite_api::ActivationType::kHardSwish;
param_.hard_swish_threshold = opdesc.GetAttr<float>("threshold");
param_.hard_swish_scale = opdesc.GetAttr<float>("scale");
param_.hard_swish_offset = opdesc.GetAttr<float>("offset");
} else if (opdesc.Type() == "reciprocal") {
param_.active_type = lite_api::ActivationType::kReciprocal;
}
VLOG(4) << "opdesc.Type():" << opdesc.Type();
......@@ -84,21 +92,11 @@ bool ActivationOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(square, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(relu, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(leaky_relu, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(relu_clipped, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(prelu, paddle::lite::operators::ActivationOp);
// Baisc activation ops
REGISTER_LITE_OP(sigmoid, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(tanh, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(swish, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(relu, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(leaky_relu, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(relu6, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(log, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(exp, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(abs, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(floor, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(hard_sigmoid, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(sqrt, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(rsqrt, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(softsign, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(gelu, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(prelu, paddle::lite::operators::ActivationOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/ctc_align_op.h"
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool CtcAlignOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.input != nullptr);
CHECK_OR_FALSE(param_.output != nullptr);
auto* input = param_.input;
auto* input_length = param_.input_length;
auto input_lod = input->lod();
CHECK_OR_FALSE(!input_lod.empty() || input_length != nullptr);
return true;
}
bool CtcAlignOpLite::InferShapeImpl() const {
auto input_dims = param_.input->dims();
// It is tricky to set the wrong dimension here.
param_.output->Resize(input_dims);
if (param_.input_length != nullptr && param_.output_length != nullptr) {
param_.output_length->Resize({input_dims[0], 1});
}
return true;
}
bool CtcAlignOpLite::AttachImpl(const cpp::OpDesc& op_desc,
lite::Scope* scope) {
AttachInput(op_desc, scope, "Input", false, &param_.input);
AttachInput(op_desc, scope, "InputLength", true, &param_.input_length);
AttachOutput(op_desc, scope, "Output", false, &param_.output);
AttachOutput(op_desc, scope, "OutputLength", true, &param_.output_length);
param_.blank = op_desc.GetAttr<int>("blank");
param_.merge_repeated = op_desc.GetAttr<bool>("merge_repeated");
param_.padding_value = op_desc.GetAttr<int>("padding_value");
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(ctc_align, paddle::lite::operators::CtcAlignOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/operators/op_params.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class CtcAlignOpLite : public OpLite {
public:
CtcAlignOpLite() {}
explicit CtcAlignOpLite(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "ctc_align"; }
private:
mutable CtcAlignParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -336,17 +336,22 @@ struct ConcatParam : ParamBase {
/// ----------------------- activation operators ----------------------
struct ActivationParam : ParamBase {
const lite::Tensor* X{};
lite::Tensor* Out{};
lite_api::ActivationType active_type;
bool has_active{false};
float Leaky_relu_alpha{0}; // leaky_relu param
float Relu_clipped_coef{6}; // relu_clipped param
std::string Prelu_mode{
"channel"}; // prelu param, can be "all", "channel" or "element"
lite::Tensor* Prelu_alpha{}; // prelu param
float Swish_beta; // swish param
// hard_sigmoid param
float hard_sigmoid_slope{0.2};
float hard_sigmoid_offset{0.5};
lite::Tensor* Out{};
bool has_active{false};
lite_api::ActivationType active_type;
// hard_swish param
float hard_swish_threshold{6.0};
float hard_swish_scale{6.0};
float hard_swish_offset{3.0};
};
struct ActivationGradParam : ParamBase {
......@@ -1444,6 +1449,16 @@ struct CrfDecodingParam : ParamBase {
lite::Tensor* viterbi_path{};
};
struct CtcAlignParam : ParamBase {
lite::Tensor* input{};
lite::Tensor* input_length{};
lite::Tensor* output{};
lite::Tensor* output_length{};
int blank{0};
bool merge_repeated{true};
int padding_value{0};
};
struct XPUResNet50Param : ParamBase {
lite::Tensor* input{};
std::vector<lite::Tensor*> filter;
......
......@@ -61,6 +61,7 @@ if(LITE_BUILD_EXTRA)
lite_cc_test(test_kernel_lookup_table_compute SRCS lookup_table_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_lookup_table_dequant_compute SRCS lookup_table_dequant_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_gather_compute SRCS gather_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_ctc_align_compute SRCS ctc_align_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
# for training kernel
if (LITE_WITH_TRAIN)
......
......@@ -36,7 +36,9 @@ enum activation_type_test {
FLOOR,
RSQRT,
GELU,
SQUARE
SQUARE,
HARD_SWISH,
RECIPROCAL
};
class ActivationComputeTester : public arena::TestCase {
......@@ -49,6 +51,9 @@ class ActivationComputeTester : public arena::TestCase {
float relu_clipped_coef_ = 6.;
std::string prelu_mode_ = "";
float swish_beta_ = 0.;
float hard_swish_threshold = 6.0;
float hard_swish_scale = 6.0;
float hard_swish_offset = 3.0;
DDim dims_{{1}};
std::string type_ = "";
activation_type_test act_type_ = RELU;
......@@ -199,6 +204,20 @@ class ActivationComputeTester : public arena::TestCase {
}
break;
}
case HARD_SWISH: {
for (int i = 0; i < dims_.production(); i++) {
float max_value = std::max(0.f, x_data[i] + hard_swish_offset);
float min_value = std::min(max_value, hard_swish_threshold);
output_data[i] = min_value * x_data[i] / hard_swish_scale;
}
break;
}
case RECIPROCAL: {
for (int i = 0; i < dims_.production(); i++) {
output_data[i] = 1.0 / x_data[i];
}
break;
}
default:
LOG(INFO) << "the type of activation is unknow.";
}
......@@ -221,6 +240,11 @@ class ActivationComputeTester : public arena::TestCase {
if (act_type_ == SWISH) {
op_desc->SetAttr("beta", swish_beta_);
}
if (act_type_ == HARD_SWISH) {
op_desc->SetAttr("threshold", hard_swish_threshold);
op_desc->SetAttr("scale", hard_swish_scale);
op_desc->SetAttr("offset", hard_swish_offset);
}
}
void PrepareData() override {
......@@ -552,5 +576,61 @@ TEST(Activation_gelu, precision) {
}
}
TEST(activation_hard_swish, precision) {
LOG(INFO) << "test hard_swish op";
Place place;
float abs_error = 2e-5;
#if defined(LITE_WITH_ARM)
place = TARGET(kARM);
#else
return;
#endif
for (auto dims : std::vector<std::vector<int64_t>>{
{1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) {
std::unique_ptr<arena::TestCase> tester(
new ActivationComputeTester(place,
"def",
0.01,
6.,
"all",
0.,
DDim(dims),
"hard_swish",
HARD_SWISH));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
TEST(activation_reciprocal, precision) {
LOG(INFO) << "test reciprocal op";
Place place;
float abs_error = 2e-5;
#if defined(LITE_WITH_ARM)
place = TARGET(kARM);
#else
return;
#endif
for (auto dims : std::vector<std::vector<int64_t>>{
{1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) {
std::unique_ptr<arena::TestCase> tester(
new ActivationComputeTester(place,
"def",
0.01,
6.,
"all",
0.,
DDim(dims),
"reciprocal",
RECIPROCAL));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
namespace paddle {
namespace lite {
class CtcAlignComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string input_ = "input";
std::string input_length_ = "input_length";
std::string output_ = "output";
std::string output_length_ = "output_length";
std::vector<int> input_data_;
std::vector<int64_t> input_shape_;
std::vector<std::vector<uint64_t>> input_lod_;
std::vector<int> input_length_data_;
std::vector<int64_t> input_length_shape_;
std::vector<int> output_data_;
std::vector<int64_t> output_shape_;
std::vector<std::vector<uint64_t>> output_lod_;
std::vector<int> output_length_data_;
std::vector<int64_t> output_length_shape_;
int blank_;
bool merge_repeated_;
int padding_value_;
public:
CtcAlignComputeTester(const Place& place,
const std::string& alias,
const std::vector<int>& input_data,
const std::vector<int64_t> input_shape,
const std::vector<std::vector<uint64_t>>& input_lod,
const std::vector<int>& input_length_data,
const std::vector<int64_t> input_length_shape,
const int blank,
const bool merge_repeated,
const int padding_value,
const std::vector<int>& output_data,
const std::vector<int64_t>& output_shape,
const std::vector<std::vector<uint64_t>>& output_lod,
const std::vector<int>& output_length_data,
const std::vector<int64_t>& output_length_shape)
: TestCase(place, alias) {
input_data_ = input_data;
input_shape_ = input_shape;
input_lod_ = input_lod;
input_length_data_ = input_length_data;
input_length_shape_ = input_length_shape;
blank_ = blank;
merge_repeated_ = merge_repeated;
padding_value_ = padding_value;
output_data_ = output_data;
output_shape_ = output_shape;
output_lod_ = output_lod;
output_length_data_ = output_length_data;
output_length_shape_ = output_length_shape;
}
void RunBaseline(Scope* scope) override {
auto* output_tensor = scope->NewTensor(output_);
output_tensor->Resize(output_shape_);
if (!output_lod_.empty()) {
output_tensor->set_lod(output_lod_);
}
auto* output_data = output_tensor->mutable_data<int>();
int64_t output_num = 1;
for (auto e : output_shape_) {
output_num *= e;
}
for (int i = 0; i < output_num; i++) {
output_data[i] = output_data_[i];
}
if (!input_length_data_.empty() && !output_length_data_.empty()) {
auto* output_length_tensor = scope->NewTensor(output_length_);
output_length_tensor->Resize(output_length_shape_);
auto* output_length_data = output_length_tensor->mutable_data<int>();
int64_t num = 1;
for (auto e : output_length_shape_) {
num *= e;
}
for (int i = 0; i < num; i++) {
output_length_data[i] = output_length_data_[i];
}
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("ctc_align");
op_desc->SetInput("Input", {input_});
op_desc->SetOutput("Output", {output_});
if (!input_length_data_.empty()) {
op_desc->SetInput("InputLength", {input_length_});
op_desc->SetOutput("OutputLength", {output_length_});
}
op_desc->SetAttr("blank", blank_);
op_desc->SetAttr("merge_repeated", merge_repeated_);
op_desc->SetAttr("padding_value", padding_value_);
}
void PrepareData() override {
SetCommonTensor(input_, DDim(input_shape_), input_data_.data(), input_lod_);
if (!input_length_data_.empty()) {
SetCommonTensor(
input_length_, DDim(input_length_shape_), input_length_data_.data());
}
}
};
TEST(CtcAlign1, precision) {
LOG(INFO) << "test ctc_align op";
#ifdef LITE_WITH_ARM
// Define variable
const std::vector<int>& input_data = {
0, 1, 2, 2, 0, 4, 0, 4, 5, 0, 6, 6, 0, 0, 7, 7, 7, 0};
const std::vector<int64_t> input_shape = {18, 1};
const std::vector<std::vector<uint64_t>> input_lod = {{11, 7}};
const std::vector<int> input_length_data = {};
const std::vector<int64_t> input_length_shape = {};
const int blank = 0;
const bool merge_repeated = false;
const int padding_value = 0;
const std::vector<int> output_data = {1, 2, 2, 4, 4, 5, 6, 6, 7, 7, 7};
const std::vector<int64_t> output_shape = {11, 1};
const std::vector<std::vector<uint64_t>> output_lod = {{7, 4}};
const std::vector<int> output_length_data = {};
const std::vector<int64_t> output_length_shape = {};
// Test
Place place(TARGET(kHost), PRECISION(kInt32));
std::unique_ptr<arena::TestCase> tester(
new CtcAlignComputeTester(place,
"def",
input_data,
input_shape,
input_lod,
input_length_data,
input_length_shape,
blank,
merge_repeated,
padding_value,
output_data,
output_shape,
output_lod,
output_length_data,
output_length_shape));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
#endif
}
TEST(CtcAlign2, precision) {
LOG(INFO) << "test ctc_align op";
#ifdef LITE_WITH_ARM
// Define variable
const std::vector<int>& input_data = {
0, 1, 2, 2, 0, 4, 0, 4, 5, 0, 6, 0, 0, 7, 7, 7, 0, 0};
const std::vector<int64_t> input_shape = {3, 6};
const std::vector<std::vector<uint64_t>> input_lod = {};
const std::vector<int> input_length_data = {6, 5, 4};
const std::vector<int64_t> input_length_shape = {3, 1};
const int blank = 0;
const bool merge_repeated = true;
const int padding_value = 0;
const std::vector<int> output_data = {
1, 2, 4, 0, 0, 0, 4, 5, 6, 0, 0, 0, 7, 0, 0, 0, 0, 0};
const std::vector<int64_t> output_shape = {3, 6};
const std::vector<std::vector<uint64_t>> output_lod = {};
const std::vector<int> output_length_data = {3, 3, 1};
const std::vector<int64_t> output_length_shape = {3, 1};
// Test
Place place(TARGET(kHost), PRECISION(kInt32));
std::unique_ptr<arena::TestCase> tester(
new CtcAlignComputeTester(place,
"def",
input_data,
input_shape,
input_lod,
input_length_data,
input_length_shape,
blank,
merge_repeated,
padding_value,
output_data,
output_shape,
output_lod,
output_length_data,
output_length_shape));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
#endif
}
TEST(CtcAlign3, precision) {
LOG(INFO) << "test ctc_align op";
#ifdef LITE_WITH_ARM
// Define variable
const std::vector<int>& input_data = {
0, 1, 2, 2, 0, 4, 0, 4, 5, 0, 6, 0, 0, 7, 7, 7, 0, 0};
const std::vector<int64_t> input_shape = {3, 6};
const std::vector<std::vector<uint64_t>> input_lod = {};
const std::vector<int> input_length_data = {6, 5, 4};
const std::vector<int64_t> input_length_shape = {3, 1};
const int blank = 0;
const bool merge_repeated = false;
const int padding_value = 0;
const std::vector<int> output_data = {
1, 2, 2, 4, 0, 0, 4, 5, 6, 0, 0, 0, 7, 7, 7, 0, 0, 0};
const std::vector<int64_t> output_shape = {3, 6};
const std::vector<std::vector<uint64_t>> output_lod = {};
const std::vector<int> output_length_data = {4, 3, 3};
const std::vector<int64_t> output_length_shape = {3, 1};
// Test
Place place(TARGET(kHost), PRECISION(kInt32));
std::unique_ptr<arena::TestCase> tester(
new CtcAlignComputeTester(place,
"def",
input_data,
input_shape,
input_lod,
input_length_data,
input_length_shape,
blank,
merge_repeated,
padding_value,
output_data,
output_shape,
output_lod,
output_length_data,
output_length_shape));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
#endif
}
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册