提交 9ebaaa1b 编写于 作者: M mapingshuo 提交者: GitHub

add lookup_dequant_op (#3108)

* add lookup_dequant_op
上级 0e77cd63
......@@ -88,6 +88,7 @@ add_kernel(gru_unit_compute_arm ARM extra SRCS gru_unit_compute.cc DEPS ${lite_k
add_kernel(gru_compute_arm ARM extra SRCS gru_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(beam_search_decode_compute_arm ARM extra SRCS beam_search_decode_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(lookup_table_compute_arm ARM extra SRCS lookup_table_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(lookup_table_dequant_compute_arm ARM extra SRCS lookup_table_dequant_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(logical_compute_arm ARM extra SRCS logical_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(sequence_softmax_compute_arm ARM extra SRCS sequence_softmax_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(less_than_arm ARM extra SRCS compare_compute.cc DEPS ${lite_kernel_deps} math_arm)
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/lookup_table_dequant_compute.h"
#include <string>
#include <vector>
#include "lite/api/paddle_place.h"
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
#include "lite/core/type_system.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void dequant(const unsigned char *in,
float *out,
float min,
float max,
int emb_size,
int pow_2_bits) {
float scale = (max - min) / pow_2_bits;
for (int i = 0; i < emb_size; ++i) {
float x = scale * static_cast<int>(in[i]) + min;
out[i] = x;
}
}
void LookupTableDequantCompute::Run() {
auto &param = this->Param<param_t>();
// inputs
auto w = param.W;
auto ids = param.Ids;
// outputs
auto out = param.Out;
auto table_dim = w->dims();
int64_t ids_numel = ids->numel();
auto ids_data = ids->data<int64_t>();
int64_t row_number = table_dim[0];
int64_t quant_number = table_dim[1];
int64_t row_width = (quant_number - 2) * 4;
auto table_data = w->data<float>();
auto dout = out->mutable_data<float>();
int pow_2_bits = static_cast<int>(pow(2, 8));
for (int64_t i = 0; i < ids_numel; ++i) {
int ids_int = ids_data[i];
if (param.padding_idx != -1 && ids_data[i] == param.padding_idx) {
memset(dout + i * row_width, 0, row_width * sizeof(float));
} else {
CHECK_LT(ids_data[i], row_number)
<< "look uptable ids[i] < row_number check failed";
CHECK_GE(ids_data[i], 0) << "lookuptable ids[i] >= 0 check failed";
float min = *(table_data + ids_data[i] * quant_number);
float max = *(table_data + ids_data[i] * quant_number + 1);
int offset = ids_data[i] * quant_number + 2;
const unsigned char *tensor_buf =
reinterpret_cast<const unsigned char *>(table_data + offset);
dequant(
tensor_buf, dout + i * row_width, min, max, row_width, pow_2_bits);
// memcpy(dout + i * row_width,
// table_data + ids_int * row_width,
// row_width * sizeof(float));
}
}
*(out->mutable_lod()) = ids->lod();
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(lookup_table_dequant,
kARM,
kAny,
kNCHW,
paddle::lite::kernels::arm::LookupTableDequantCompute,
def)
.BindInput("W", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Ids", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class LookupTableDequantCompute
: public KernelLite<TARGET(kARM), PRECISION(kAny)> {
public:
using param_t = operators::LookupTableDequantParam;
LookupTableDequantCompute() = default;
void Run() override;
virtual ~LookupTableDequantCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -109,6 +109,7 @@ add_operator(distribute_fpn_proposals_op_lite extra SRCS distribute_fpn_proposal
# for OCR specific
add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS})
add_operator(lookup_table_op extra SRCS lookup_table_op.cc DEPS ${op_DEPS})
add_operator(lookup_table_dequant_op extra SRCS lookup_table_dequant_op.cc DEPS ${op_DEPS})
add_operator(lookup_table_v2_op extra SRCS lookup_table_v2_op.cc DEPS ${op_DEPS})
add_operator(beam_search_decode_op extra SRCS beam_search_decode_op.cc DEPS ${op_DEPS})
add_operator(logical_xor extra SRCS logical_op.cc DEPS ${op_DEPS})
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/lookup_table_dequant_op.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool LookupTableDequantOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.W)
CHECK_OR_FALSE(param_.Ids)
CHECK_OR_FALSE(param_.Out)
const auto& table_dims = param_.W->dims();
const auto& ids_dims = param_.Ids->dims();
int ids_rank = ids_dims.size();
CHECK_EQ_OR_FALSE(table_dims.size(), 2);
CHECK_EQ_OR_FALSE(ids_dims[ids_rank - 1], 1);
CHECK_GT_OR_FALSE(table_dims[1], 2);
return true;
}
bool LookupTableDequantOpLite::InferShape() const {
const auto& table_dims = param_.W->dims();
const auto& ids_dims = param_.Ids->dims();
auto out_dims = ids_dims;
int ids_rank = ids_dims.size();
out_dims[ids_rank - 1] = (table_dims[1] - 2) * 4;
param_.Out->Resize(out_dims);
param_.Out->set_lod(param_.Ids->lod());
return true;
}
bool LookupTableDequantOpLite::AttachImpl(const cpp::OpDesc& op_desc,
lite::Scope* scope) {
auto input = op_desc.Input("W").front();
auto ids = op_desc.Input("Ids").front();
auto out = op_desc.Output("Out").front();
param_.W = scope->FindVar(input)->GetMutable<lite::Tensor>();
param_.Ids = scope->FindVar(ids)->GetMutable<lite::Tensor>();
param_.Out = scope->FindVar(out)->GetMutable<lite::Tensor>();
param_.padding_idx = op_desc.GetAttr<int64_t>("padding_idx");
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(lookup_table_dequant,
paddle::lite::operators::LookupTableDequantOpLite)
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class LookupTableDequantOpLite : public OpLite {
public:
LookupTableDequantOpLite() {}
explicit LookupTableDequantOpLite(const std::string &op_type)
: OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "LookupTableDequant"; }
private:
mutable LookupTableDequantParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -661,6 +661,13 @@ struct LookupTableParam {
int64_t padding_idx{-1};
};
struct LookupTableDequantParam {
lite::Tensor* W{nullptr};
lite::Tensor* Ids{nullptr};
lite::Tensor* Out{nullptr};
int64_t padding_idx{-1};
};
struct Im2SequenceParam {
const lite::Tensor* X{};
const lite::Tensor* Y{};
......
......@@ -58,6 +58,7 @@ if(LITE_BUILD_EXTRA)
lite_cc_test(test_kernel_search_aligned_mat_mul_compute SRCS search_aligned_mat_mul_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_search_seq_fc_compute SRCS search_seq_fc_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_lookup_table_compute SRCS lookup_table_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_lookup_table_dequant_compute SRCS lookup_table_dequant_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_gather_compute SRCS gather_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
endif()
lite_cc_test(test_kernel_pad2d_compute SRCS pad2d_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
#include "lite/tests/utils/fill_data.h"
namespace paddle {
namespace lite {
void dequant(const unsigned char* in,
float* out,
float min,
float max,
int emb_size,
int pow_2_bits) {
float scale = (max - min) / pow_2_bits;
for (int i = 0; i < emb_size; ++i) {
float x = scale * static_cast<int>(in[i]) + min;
out[i] = x;
}
}
class LookupTableDequantComputeTest : public arena::TestCase {
protected:
// common attributes for this op.
std::string op_type_ = "lookup_table_dequant";
std::string ids_ = "ids";
std::string w_ = "w";
std::string out_ = "out";
DDim ids_dims_{{2, 1}};
DDim w_dims_{{8, 4}};
int64_t padding_idx_ = -1;
public:
LookupTableDequantComputeTest(const Place& place,
const std::string& alias,
const DDim& ids_dims,
const DDim& w_dims,
int64_t padding_idx)
: TestCase(place, alias),
ids_dims_(ids_dims),
w_dims_(w_dims),
padding_idx_(padding_idx) {}
void RunBaseline(Scope* scope) override {
auto ids = scope->FindTensor(ids_);
auto w = scope->FindTensor(w_);
auto ids_dims = ids->dims();
auto w_dims = w->dims();
auto out = scope->NewTensor(out_);
CHECK(out);
int ids_rank = ids_dims.size();
CHECK_EQ(ids_dims[ids_rank - 1], 1);
CHECK_EQ(w_dims.size(), 2);
std::vector<int64_t> out_dims;
for (int i = 0; i < ids_rank - 1; ++i) {
out_dims.push_back(ids_dims[i]);
}
out_dims.push_back((w_dims[1] - 2) * 4);
out->Resize(out_dims);
out->set_lod(ids->lod());
auto ids_data = ids->data<int64_t>();
auto ids_size = ids_dims.production();
auto w_data = w->data<float>();
auto w_rows = w_dims[0];
auto quant_number = w_dims[1];
auto w_cols = (quant_number - 2) * 4;
auto out_data = out->mutable_data<float>();
int pow_2_bits = static_cast<int>(pow(2, 8));
for (int64_t i = 0; i < ids_size; i++) {
auto id = ids_data[i];
if (padding_idx_ != -1 && id == padding_idx_) {
memset(out_data + i * w_cols, 0, w_cols * sizeof(float));
} else {
CHECK_LT(id, w_rows) << "lookup_table ids[i] expected < " << w_rows
<< " but got " << id;
CHECK_GE(id, 0) << "lookup_table ids[i] expected >= 0 but got " << id;
float min = *(w_data + ids_data[i] * quant_number);
float max = *(w_data + ids_data[i] * quant_number + 1);
int offset = ids_data[i] * quant_number + 2;
const unsigned char* tensor_buf =
reinterpret_cast<const unsigned char*>(w_data + offset);
dequant(
tensor_buf, out_data + i * w_cols, min, max, w_cols, pow_2_bits);
}
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType(op_type_);
op_desc->SetInput("Ids", {ids_});
op_desc->SetInput("W", {w_});
op_desc->SetOutput("Out", {out_});
op_desc->SetAttr<int64_t>("padding_idx", padding_idx_);
}
void PrepareData() override {
std::vector<int64_t> ids(ids_dims_.production());
fill_data_rand<int64_t>(
ids.data(), 0, w_dims_[0] - 1, ids_dims_.production());
std::vector<float> w(w_dims_.production());
fill_data_rand(w.data(), -1.f, 1.f, w_dims_.production());
SetCommonTensor(ids_, ids_dims_, ids.data());
SetCommonTensor(w_, w_dims_, w.data());
}
};
TEST(LookupTableDequant, precision) {
#ifdef LITE_WITH_ARM
float abs_error = 2e-5;
Place place = {TARGET(kARM), PRECISION(kAny)};
for (auto ids_dims :
std::vector<std::vector<int64_t>>{{5, 2, 3, 1}, {2, 3, 1}, {3, 1}}) {
for (auto w_dims :
std::vector<std::vector<int64_t>>{{4, 3}, {6, 8}, {12, 15}}) {
for (auto padding_idx : std::vector<int64_t>{-1}) {
std::unique_ptr<arena::TestCase> tester(
new LookupTableDequantComputeTest(
place, "def", DDim(ids_dims), DDim(w_dims), padding_idx));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
}
#endif
}
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册