未验证 提交 98292030 编写于 作者: X xiaogang 提交者: GitHub

[cherry-pick]Release/v2.3 (#3310)

* add lookup_dequant_op (#3108)

* add lookup_dequant_op

* add sequence_conv op and arm kernel (#3016)

* add sequence_conv op and arm kernel

* add test, test=develop

* test=develop

* test=develop

* test=develop

* test=develop

* test=develop

* test=develop

* test=develop

* test=develop

* test=develop

* test=develop

* test=develop

* test=develop

* test=develop

* test=develop

* modify code style. test=develop

* fix ut, test=develop

* delete unused code, test=develop

* fix sgemm bug, test=develop (#3053)
Co-authored-by: Nmapingshuo <mps2012@yeah.net>
上级 7fa2a776
......@@ -573,6 +573,22 @@ template void conv_im2col_gemm_int8<float>(const int8_t* i_data,
ARMContext* ctx,
const float* scale);
template void im2col<float>(const float* data_im,
int channels,
int height,
int width,
int kernel_h,
int kernel_w,
int pad_top,
int pad_bottom,
int pad_left,
int pad_right,
int stride_h,
int stride_w,
int dilation_h,
int dilation_w,
float* data_col);
void conv_depthwise_3x3_fp32(const void* din,
void* dout,
int num,
......
......@@ -359,6 +359,24 @@ void conv_compute_2x2_3x3_small(const float* input,
const float* bias,
const operators::ConvParam& param,
ARMContext* ctx);
template <typename Dtype>
void im2col(const Dtype* data_im,
int channels,
int height,
int width,
int kernel_h,
int kernel_w,
int pad_top,
int pad_bottom,
int pad_left,
int pad_right,
int stride_h,
int stride_w,
int dilation_h,
int dilation_w,
Dtype* data_col);
} // namespace math
} // namespace arm
} // namespace lite
......
......@@ -68,6 +68,7 @@ add_kernel(reduce_max_compute_arm ARM extra SRCS reduce_max_compute.cc DEPS ${li
add_kernel(sequence_expand_compute_arm ARM extra SRCS sequence_expand_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(im2sequence_compute_arm ARM extra SRCS im2sequence_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(sequence_pool_compute_arm ARM extra SRCS sequence_pool_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(sequence_conv_compute_arm ARM extra SRCS sequence_conv_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(layer_norm_compute_arm ARM extra SRCS layer_norm_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(gather_compute_arm ARM extra SRCS gather_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(reduce_prod_compute_arm ARM extra SRCS reduce_prod_compute.cc DEPS ${lite_kernel_deps} math_arm)
......@@ -88,6 +89,7 @@ add_kernel(gru_unit_compute_arm ARM extra SRCS gru_unit_compute.cc DEPS ${lite_k
add_kernel(gru_compute_arm ARM extra SRCS gru_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(beam_search_decode_compute_arm ARM extra SRCS beam_search_decode_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(lookup_table_compute_arm ARM extra SRCS lookup_table_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(lookup_table_dequant_compute_arm ARM extra SRCS lookup_table_dequant_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(logical_compute_arm ARM extra SRCS logical_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(sequence_softmax_compute_arm ARM extra SRCS sequence_softmax_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(less_than_arm ARM extra SRCS compare_compute.cc DEPS ${lite_kernel_deps} math_arm)
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/lookup_table_dequant_compute.h"
#include <string>
#include <vector>
#include "lite/api/paddle_place.h"
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
#include "lite/core/type_system.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void dequant(const unsigned char *in,
float *out,
float min,
float max,
int emb_size,
int pow_2_bits) {
float scale = (max - min) / pow_2_bits;
for (int i = 0; i < emb_size; ++i) {
float x = scale * static_cast<int>(in[i]) + min;
out[i] = x;
}
}
void LookupTableDequantCompute::Run() {
auto &param = this->Param<param_t>();
// inputs
auto w = param.W;
auto ids = param.Ids;
// outputs
auto out = param.Out;
auto table_dim = w->dims();
int64_t ids_numel = ids->numel();
auto ids_data = ids->data<int64_t>();
int64_t row_number = table_dim[0];
int64_t quant_number = table_dim[1];
int64_t row_width = (quant_number - 2) * 4;
auto table_data = w->data<float>();
auto dout = out->mutable_data<float>();
int pow_2_bits = static_cast<int>(pow(2, 8));
for (int64_t i = 0; i < ids_numel; ++i) {
int ids_int = ids_data[i];
if (param.padding_idx != -1 && ids_data[i] == param.padding_idx) {
memset(dout + i * row_width, 0, row_width * sizeof(float));
} else {
CHECK_LT(ids_data[i], row_number)
<< "look uptable ids[i] < row_number check failed";
CHECK_GE(ids_data[i], 0) << "lookuptable ids[i] >= 0 check failed";
float min = *(table_data + ids_data[i] * quant_number);
float max = *(table_data + ids_data[i] * quant_number + 1);
int offset = ids_data[i] * quant_number + 2;
const unsigned char *tensor_buf =
reinterpret_cast<const unsigned char *>(table_data + offset);
dequant(
tensor_buf, dout + i * row_width, min, max, row_width, pow_2_bits);
// memcpy(dout + i * row_width,
// table_data + ids_int * row_width,
// row_width * sizeof(float));
}
}
*(out->mutable_lod()) = ids->lod();
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(lookup_table_dequant,
kARM,
kAny,
kNCHW,
paddle::lite::kernels::arm::LookupTableDequantCompute,
def)
.BindInput("W", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Ids", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class LookupTableDequantCompute
: public KernelLite<TARGET(kARM), PRECISION(kAny)> {
public:
using param_t = operators::LookupTableDequantParam;
LookupTableDequantCompute() = default;
void Run() override;
virtual ~LookupTableDequantCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "lite/kernels/arm/sequence_conv_compute.h"
#include <algorithm>
#include <cstddef>
#include <string>
#include <vector>
#include "lite/backends/arm/math/conv_impl.h"
#include "lite/backends/arm/math/sgemm.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
#include "lite/core/type_system.h"
#include "lite/operators/op_params.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
template <typename Dtype>
void local_naive_transpose(const Dtype* din, Dtype* dout, int m, int n) {
int k = 0;
for (int i = 0; i < n; ++i) {
for (int j = 0; j < m; ++j) {
dout[k++] = din[j * n + i];
}
}
}
void SequenceConvCompute::PrepareForRun() {}
void SequenceConvCompute::Run() {
// param.X is in shape: [sequence_len, hidden_dim];
// param.Filter is in shape: [kernel_size * hidden_dim, kernel_num]
// param.contextLength : kernel_size
// param.contextStart: for padding idx
// param.Out is in shape [new_sequence_len, kernel_num]
auto& param = this->Param<operators::SequenceConvParam>();
auto& ctx = this->ctx_->template As<ARMContext>();
const auto* in_data = param.X->data<float>();
const auto* filter_data = param.Filter->data<float>();
float* out_data = param.Out->mutable_data<float>();
int pad_start = param.contextStart;
int kernel_size = param.contextLength;
int kernel_num = param.Filter->dims()[1];
int up_pad = std::max(0, -pad_start);
int down_pad = std::max(0, pad_start + kernel_size - 1);
auto hidden_dim = static_cast<int64_t>(param.X->dims()[1]);
auto sequence_len = static_cast<int64_t>(param.X->dims()[0]);
auto lod = param.X->lod();
// Im2Col
lite::Tensor col;
lite::Tensor tmp;
col.Resize({sequence_len, kernel_size * hidden_dim});
auto* col_data = col.mutable_data<float>();
auto lod_level_0 = lod[0];
int input_row_begin, input_row_end;
for (int i = 0; i < static_cast<int>(lod_level_0.size()) - 1; i++) {
if (lod_level_0[i] == lod_level_0[i + 1]) continue;
input_row_begin = (pad_start > 0)
? static_cast<int>(lod_level_0[i]) + pad_start
: static_cast<int>(lod_level_0[i]);
input_row_end = static_cast<int>(lod_level_0[i + 1]);
if (input_row_begin < input_row_end) {
// do im2col
auto* sub_in_data = in_data + input_row_begin * hidden_dim;
auto* sub_col_data =
col_data + input_row_begin * kernel_size * hidden_dim;
tmp.Resize({kernel_size * hidden_dim, input_row_end - input_row_begin});
auto* tmp_data = tmp.mutable_data<float>();
// Image Col: [input_channels, filter_height, filter_width, output_height,
// output_width]
// sequence Col: [1, kernel_size, hidden_dim, sequence_len, 1]
paddle::lite::arm::math::im2col(
sub_in_data,
1,
sequence_len,
hidden_dim, // C H W -> 1, seq_len, hidden_dim
kernel_size,
hidden_dim, // kernel_h, kernel_w
up_pad,
down_pad,
0,
0, // pad_top, pad_bottom, pad_left, pad_right
1,
1,
1,
1, // stride_h, stride_w, dilation_h, dilation_w
tmp_data);
local_naive_transpose(tmp_data,
sub_col_data,
kernel_size * hidden_dim,
input_row_end - input_row_begin);
}
}
// SGDMM C := alpha * A * B + beta * C
// matmul: col * filter_data
// [sequence_len, kernel_size * hidden_dim] * [kernel_size * hidden_dim,
// kernel_num]
// = [sequence_len, kernel_num]
paddle::lite::operators::ActivationParam act_param;
paddle::lite::arm::math::sgemm(false,
false, // is_transB,
sequence_len, // M
kernel_num, // N
kernel_size * hidden_dim, // K
1.0f, // alpha
col_data, // A
kernel_size * hidden_dim, // lda: k
filter_data, // B
kernel_num, // ldb: n
0.f, // beta
out_data, // C
kernel_num, // ldc: n
NULL, // bias
false, // is_bias
act_param, // act_param
&ctx); // ctx
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(sequence_conv,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::SequenceConvCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdint.h>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class SequenceConvCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
void PrepareForRun() override;
void Run() override;
virtual ~SequenceConvCompute() = default;
private:
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -90,6 +90,7 @@ add_operator(reduce_prod_op_lite extra SRCS reduce_prod_op.cc DEPS ${op_DEPS})
add_operator(sequence_reshape_op_lite extra SRCS sequence_reshape_op.cc DEPS ${op_DEPS})
add_operator(sequence_reverse_op_lite extra SRCS sequence_reverse_op.cc DEPS ${op_DEPS})
add_operator(sequence_pool extra SRCS sequence_pool_op.cc DEPS ${op_DEPS})
add_operator(sequence_conv extra SRCS sequence_conv_op.cc DEPS ${op_DEPS})
add_operator(sequence_pool_concat extra SRCS sequence_pool_concat_op.cc DEPS ${op_DEPS})
add_operator(reduce_sum_op_lite extra SRCS reduce_ops.cc DEPS ${op_DEPS})
add_operator(match_matrix_tensor_op_lite extra SRCS match_matrix_tensor_op.cc DEPS ${op_DEPS})
......@@ -107,6 +108,7 @@ add_operator(distribute_fpn_proposals_op_lite extra SRCS distribute_fpn_proposal
# for OCR specific
add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS})
add_operator(lookup_table_op extra SRCS lookup_table_op.cc DEPS ${op_DEPS})
add_operator(lookup_table_dequant_op extra SRCS lookup_table_dequant_op.cc DEPS ${op_DEPS})
add_operator(lookup_table_v2_op extra SRCS lookup_table_v2_op.cc DEPS ${op_DEPS})
add_operator(beam_search_decode_op extra SRCS beam_search_decode_op.cc DEPS ${op_DEPS})
add_operator(logical_xor extra SRCS logical_op.cc DEPS ${op_DEPS})
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/lookup_table_dequant_op.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool LookupTableDequantOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.W)
CHECK_OR_FALSE(param_.Ids)
CHECK_OR_FALSE(param_.Out)
const auto& table_dims = param_.W->dims();
const auto& ids_dims = param_.Ids->dims();
int ids_rank = ids_dims.size();
CHECK_EQ_OR_FALSE(table_dims.size(), 2);
CHECK_EQ_OR_FALSE(ids_dims[ids_rank - 1], 1);
CHECK_GT_OR_FALSE(table_dims[1], 2);
return true;
}
bool LookupTableDequantOpLite::InferShape() const {
const auto& table_dims = param_.W->dims();
const auto& ids_dims = param_.Ids->dims();
auto out_dims = ids_dims;
int ids_rank = ids_dims.size();
out_dims[ids_rank - 1] = (table_dims[1] - 2) * 4;
param_.Out->Resize(out_dims);
param_.Out->set_lod(param_.Ids->lod());
return true;
}
bool LookupTableDequantOpLite::AttachImpl(const cpp::OpDesc& op_desc,
lite::Scope* scope) {
auto input = op_desc.Input("W").front();
auto ids = op_desc.Input("Ids").front();
auto out = op_desc.Output("Out").front();
param_.W = scope->FindVar(input)->GetMutable<lite::Tensor>();
param_.Ids = scope->FindVar(ids)->GetMutable<lite::Tensor>();
param_.Out = scope->FindVar(out)->GetMutable<lite::Tensor>();
param_.padding_idx = op_desc.GetAttr<int64_t>("padding_idx");
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(lookup_table_dequant,
paddle::lite::operators::LookupTableDequantOpLite)
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class LookupTableDequantOpLite : public OpLite {
public:
LookupTableDequantOpLite() {}
explicit LookupTableDequantOpLite(const std::string &op_type)
: OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "LookupTableDequant"; }
private:
mutable LookupTableDequantParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -668,6 +668,13 @@ struct LookupTableParam {
int64_t padding_idx{-1};
};
struct LookupTableDequantParam {
lite::Tensor* W{nullptr};
lite::Tensor* Ids{nullptr};
lite::Tensor* Out{nullptr};
int64_t padding_idx{-1};
};
struct Im2SequenceParam {
const lite::Tensor* X{};
const lite::Tensor* Y{};
......@@ -772,6 +779,15 @@ struct SequencePoolParam {
#endif
};
struct SequenceConvParam {
const lite::Tensor* X{};
const lite::Tensor* Filter{};
lite::Tensor* Out{};
int contextStart{0};
int contextStride{1};
int contextLength;
};
struct SequencePoolConcatParam {
std::vector<lite::Tensor*> X{};
lite::Tensor* Out{};
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/sequence_conv_op.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool SequenceConvOp::CheckShape() const {
CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.Filter);
CHECK_OR_FALSE(param_.Out);
// currently we only support the case that
// the contextStride is equal to 1
int context_length = param_.contextLength;
int context_start = param_.contextStart;
CHECK_EQ_OR_FALSE(param_.contextStride, 1UL);
CHECK_GT_OR_FALSE(context_start, -context_length);
CHECK_GE_OR_FALSE(0, context_start);
const auto *filter = param_.Filter;
auto lod = param_.X->lod();
auto filter_dims = filter->dims();
auto in_dims = param_.X->dims();
CHECK_EQ_OR_FALSE(in_dims.size(), 2UL);
CHECK_EQ_OR_FALSE(filter_dims.size(), 2UL);
CHECK_EQ_OR_FALSE(lod.size(), 1UL);
CHECK_EQ_OR_FALSE(filter_dims[0], context_length * in_dims[1]);
CHECK_GE_OR_FALSE(in_dims[0], (static_cast<int64_t>(lod[0].size()) - 1));
return true;
}
bool SequenceConvOp::InferShape() const {
const auto *input = param_.X;
const auto *filter = param_.Filter;
auto in_dims = input->dims();
auto filter_dims = filter->dims();
auto out_dims = in_dims;
out_dims[1] = filter_dims[1];
param_.Out->Resize(out_dims);
param_.Out->set_lod(param_.X->lod());
return true;
}
bool SequenceConvOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
// required params
param_.X = const_cast<lite::Tensor *>(
&scope->FindVar(opdesc.Input("X").front())->Get<lite::Tensor>());
param_.Filter = const_cast<lite::Tensor *>(
&scope->FindVar(opdesc.Input("Filter").front())->Get<lite::Tensor>());
param_.Out =
scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>();
param_.contextStart = opdesc.GetAttr<int>("contextStart");
param_.contextStride = opdesc.GetAttr<int>("contextStride");
param_.contextLength = opdesc.GetAttr<int>("contextLength");
// PaddingData is not supported for now
std::vector<std::string> input_arg_names = opdesc.InputArgumentNames();
if (std::find(input_arg_names.begin(),
input_arg_names.end(),
"PaddingData") != input_arg_names.end()) {
auto padding_data_arguments = opdesc.Input("PaddingData");
CHECK_EQ_OR_FALSE(padding_data_arguments.size(), 0);
}
// paddingTrainable == True is not supported for now.
if (opdesc.HasAttr("paddingTrainable")) {
CHECK_OR_FALSE(!opdesc.GetAttr<bool>("paddingTrainable"));
}
CHECK(param_.X);
CHECK(param_.Filter);
CHECK(param_.Out);
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(sequence_conv, paddle::lite::operators::SequenceConvOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class SequenceConvOp : public OpLite {
public:
SequenceConvOp() {}
explicit SequenceConvOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "sequence_conv"; }
private:
mutable SequenceConvParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -38,6 +38,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_BM) AND (LITE_
if(LITE_BUILD_EXTRA)
lite_cc_test(test_gru_unit SRCS gru_unit_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_sequence_pool_compute SRCS sequence_pool_compute_test.cc DEPS ${bm_kernels} arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_sequence_conv_compute SRCS sequence_conv_compute_test.cc DEPS ${bm_kernels} arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_reduce_max_compute SRCS reduce_max_compute_test.cc DEPS arena_framework ${bm_kernels} ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_unsqueeze_compute SRCS unsqueeze_compute_test.cc DEPS arena_framework ${bm_kernels} ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_assign_compute SRCS assign_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......@@ -54,6 +55,7 @@ if(LITE_BUILD_EXTRA)
lite_cc_test(test_kernel_search_aligned_mat_mul_compute SRCS search_aligned_mat_mul_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_search_seq_fc_compute SRCS search_seq_fc_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_lookup_table_compute SRCS lookup_table_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_lookup_table_dequant_compute SRCS lookup_table_dequant_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_gather_compute SRCS gather_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
endif()
lite_cc_test(test_kernel_pad2d_compute SRCS pad2d_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
#include "lite/tests/utils/fill_data.h"
namespace paddle {
namespace lite {
void dequant(const unsigned char* in,
float* out,
float min,
float max,
int emb_size,
int pow_2_bits) {
float scale = (max - min) / pow_2_bits;
for (int i = 0; i < emb_size; ++i) {
float x = scale * static_cast<int>(in[i]) + min;
out[i] = x;
}
}
class LookupTableDequantComputeTest : public arena::TestCase {
protected:
// common attributes for this op.
std::string op_type_ = "lookup_table_dequant";
std::string ids_ = "ids";
std::string w_ = "w";
std::string out_ = "out";
DDim ids_dims_{{2, 1}};
DDim w_dims_{{8, 4}};
int64_t padding_idx_ = -1;
public:
LookupTableDequantComputeTest(const Place& place,
const std::string& alias,
const DDim& ids_dims,
const DDim& w_dims,
int64_t padding_idx)
: TestCase(place, alias),
ids_dims_(ids_dims),
w_dims_(w_dims),
padding_idx_(padding_idx) {}
void RunBaseline(Scope* scope) override {
auto ids = scope->FindTensor(ids_);
auto w = scope->FindTensor(w_);
auto ids_dims = ids->dims();
auto w_dims = w->dims();
auto out = scope->NewTensor(out_);
CHECK(out);
int ids_rank = ids_dims.size();
CHECK_EQ(ids_dims[ids_rank - 1], 1);
CHECK_EQ(w_dims.size(), 2);
std::vector<int64_t> out_dims;
for (int i = 0; i < ids_rank - 1; ++i) {
out_dims.push_back(ids_dims[i]);
}
out_dims.push_back((w_dims[1] - 2) * 4);
out->Resize(out_dims);
out->set_lod(ids->lod());
auto ids_data = ids->data<int64_t>();
auto ids_size = ids_dims.production();
auto w_data = w->data<float>();
auto w_rows = w_dims[0];
auto quant_number = w_dims[1];
auto w_cols = (quant_number - 2) * 4;
auto out_data = out->mutable_data<float>();
int pow_2_bits = static_cast<int>(pow(2, 8));
for (int64_t i = 0; i < ids_size; i++) {
auto id = ids_data[i];
if (padding_idx_ != -1 && id == padding_idx_) {
memset(out_data + i * w_cols, 0, w_cols * sizeof(float));
} else {
CHECK_LT(id, w_rows) << "lookup_table ids[i] expected < " << w_rows
<< " but got " << id;
CHECK_GE(id, 0) << "lookup_table ids[i] expected >= 0 but got " << id;
float min = *(w_data + ids_data[i] * quant_number);
float max = *(w_data + ids_data[i] * quant_number + 1);
int offset = ids_data[i] * quant_number + 2;
const unsigned char* tensor_buf =
reinterpret_cast<const unsigned char*>(w_data + offset);
dequant(
tensor_buf, out_data + i * w_cols, min, max, w_cols, pow_2_bits);
}
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType(op_type_);
op_desc->SetInput("Ids", {ids_});
op_desc->SetInput("W", {w_});
op_desc->SetOutput("Out", {out_});
op_desc->SetAttr<int64_t>("padding_idx", padding_idx_);
}
void PrepareData() override {
std::vector<int64_t> ids(ids_dims_.production());
fill_data_rand<int64_t>(
ids.data(), 0, w_dims_[0] - 1, ids_dims_.production());
std::vector<float> w(w_dims_.production());
fill_data_rand(w.data(), -1.f, 1.f, w_dims_.production());
SetCommonTensor(ids_, ids_dims_, ids.data());
SetCommonTensor(w_, w_dims_, w.data());
}
};
TEST(LookupTableDequant, precision) {
#ifdef LITE_WITH_ARM
float abs_error = 2e-5;
Place place = {TARGET(kARM), PRECISION(kAny)};
for (auto ids_dims :
std::vector<std::vector<int64_t>>{{5, 2, 3, 1}, {2, 3, 1}, {3, 1}}) {
for (auto w_dims :
std::vector<std::vector<int64_t>>{{4, 3}, {6, 8}, {12, 15}}) {
for (auto padding_idx : std::vector<int64_t>{-1}) {
std::unique_ptr<arena::TestCase> tester(
new LookupTableDequantComputeTest(
place, "def", DDim(ids_dims), DDim(w_dims), padding_idx));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
}
#endif
}
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <stdio.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
namespace paddle {
namespace lite {
class SequenceConvComputeTester : public arena::TestCase {
public:
SequenceConvComputeTester(const Place& place,
const std::string& alias,
LoD lod,
DDim dims,
const int& contextStart,
const int& contextStride,
const int& contextLength,
const int& kernel_num)
: TestCase(place, alias),
lod_(lod),
dims_(dims),
contextStart_(contextStart),
contextStride_(contextStride),
contextLength_(contextLength),
kernel_num_(kernel_num) {}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("sequence_conv");
op_desc->SetInput("X", {input_name_});
op_desc->SetInput("Filter", {filter_name_});
op_desc->SetOutput("Out", {output_name_});
op_desc->SetAttr("contextStart", contextStart_);
op_desc->SetAttr("contextStride", contextStride_);
op_desc->SetAttr("contextLength", contextLength_);
}
void PrepareData() override {
DDim filter_dims(
std::vector<int64_t>{contextLength_ * dims_[1], kernel_num_});
std::vector<float> din(dims_.production());
for (int i = 0; i < dims_[0]; i++) {
for (int j = 0; j < dims_[1]; j++) {
din[i * dims_[1] + j] =
(2.0 * i + 3.0 * j) / (2.0 * dims_[0] + 3.0 * dims_[1]) - 0.5;
}
}
SetCommonTensor(input_name_, dims_, din.data(), lod_);
std::vector<float> dfilter(filter_dims.production());
for (int i = 0; i < filter_dims[0]; i++) {
for (int j = 0; j < filter_dims[1]; j++) {
dfilter[i * filter_dims[1] + j] =
(1.5 * i + 2.0 * j) /
(1.5 * filter_dims[0] + 2.0 * filter_dims[1]) -
0.5;
}
}
SetCommonTensor(filter_name_, filter_dims, dfilter.data(), lod_);
}
void RunBaseline(Scope* scope) override {
// calculate res the output in this scope
// to compare with the Paddle-Lite calculated one
auto* output = scope->NewTensor(output_name_);
CHECK(output);
std::vector<int64_t> output_shape({4, 3});
output->Resize(DDim(output_shape));
auto output_dims = output->dims();
auto output_data = output->mutable_data<float>();
std::vector<std::vector<float>> res;
if (contextStart_ == -2) {
res = {{-0.08867277, -0.17257819, -0.2564836},
{0.194508, 0.05720823, -0.08009153},
{0.73512584, 0.5749428, 0.41475973},
{0.5635012, 0.49485126, 0.42620137}};
} else if (contextStart_ == -1) {
res = {{0.194508, 0.05720823, -0.08009153},
{0.73512584, 0.5749428, 0.41475973},
{0.5635012, 0.49485126, 0.42620137},
{0.2517162, 0.23646072, 0.22120519}};
} else if (contextStart_ == 0) {
res = {{0.73512584, 0.5749428, 0.41475973},
{0.5635012, 0.49485126, 0.42620137},
{0.2517162, 0.23646072, 0.22120519},
{0.02574372, 0.03337148, 0.04099924}};
} else {
fprintf(stderr, "not supported contextStart_\n");
exit(-1);
}
for (int i = 0; i < output_shape[0]; i++) {
for (int j = 0; j < output_shape[1]; j++) {
output_data[i * output_shape[1] + j] = res[i][j];
}
}
(output->mutable_lod())->push_back(lod_[0]);
}
protected:
std::string input_name_ = "x";
std::string filter_name_ = "filter";
std::string output_name_ = "out";
LoD lod_;
DDim dims_;
int contextStart_;
int contextStride_;
int contextLength_;
int kernel_num_;
};
void TestNormalCase(Place place, float abs_error = 2e-5) {
std::vector<std::vector<uint64_t>> lod{{0, 4}};
std::vector<int64_t> dims{4, 5};
std::vector<int> candidate_pad_idx{-2, -1, 0};
for (int pad_idx : candidate_pad_idx) {
std::unique_ptr<arena::TestCase> tester(new SequenceConvComputeTester(
place, "def", lod, DDim(dims), pad_idx, 1, 3, 3));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
TEST(sequence_conv, precision) {
#ifdef LITE_WITH_ARM
float abs_error = 2e-5;
Place place(TARGET(kARM));
TestNormalCase(place, abs_error);
#endif
}
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册