未验证 提交 1e88d1e8 编写于 作者: P Pei Yang 提交者: GitHub

add search_group_padding op and x86 kernel, test=develop (#2440)

add search_group_padding op and x86 kernel
上级 8599c042
...@@ -34,6 +34,7 @@ add_kernel(mul_compute_x86 X86 basic SRCS mul_compute.cc DEPS ${lite_kernel_deps ...@@ -34,6 +34,7 @@ add_kernel(mul_compute_x86 X86 basic SRCS mul_compute.cc DEPS ${lite_kernel_deps
add_kernel(concat_compute_x86 X86 basic SRCS concat_compute.cc DEPS ${lite_kernel_deps}) add_kernel(concat_compute_x86 X86 basic SRCS concat_compute.cc DEPS ${lite_kernel_deps})
add_kernel(shape_compute_x86 X86 basic SRCS shape_compute.cc DEPS ${lite_kernel_deps}) add_kernel(shape_compute_x86 X86 basic SRCS shape_compute.cc DEPS ${lite_kernel_deps})
add_kernel(sequence_pool_compute_x86 X86 basic SRCS sequence_pool_compute.cc DEPS ${lite_kernel_deps} sequence_pooling) add_kernel(sequence_pool_compute_x86 X86 basic SRCS sequence_pool_compute.cc DEPS ${lite_kernel_deps} sequence_pooling)
add_kernel(search_group_padding_compute_x86 X86 extra SRCS search_group_padding_compute.cc DEPS ${lite_kernel_deps})
add_kernel(sequence_reverse_compute_x86 X86 basic SRCS sequence_reverse_compute.cc DEPS ${lite_kernel_deps}) add_kernel(sequence_reverse_compute_x86 X86 basic SRCS sequence_reverse_compute.cc DEPS ${lite_kernel_deps})
add_kernel(softmax_compute_x86 X86 basic SRCS softmax_compute.cc DEPS ${lite_kernel_deps} softmax) add_kernel(softmax_compute_x86 X86 basic SRCS softmax_compute.cc DEPS ${lite_kernel_deps} softmax)
add_kernel(elementwise_compute_x86 X86 basic SRCS elementwise_compute.cc DEPS ${lite_kernel_deps}) add_kernel(elementwise_compute_x86 X86 basic SRCS elementwise_compute.cc DEPS ${lite_kernel_deps})
...@@ -71,6 +72,7 @@ lite_cc_test(test_batch_norm_compute_x86 SRCS batch_norm_compute_test.cc DEPS ba ...@@ -71,6 +72,7 @@ lite_cc_test(test_batch_norm_compute_x86 SRCS batch_norm_compute_test.cc DEPS ba
lite_cc_test(test_softmax_compute_x86 SRCS softmax_compute_test.cc DEPS softmax_compute_x86) lite_cc_test(test_softmax_compute_x86 SRCS softmax_compute_test.cc DEPS softmax_compute_x86)
lite_cc_test(test_elementwise_compute_x86 SRCS elementwise_compute_test.cc DEPS elementwise_compute_x86) lite_cc_test(test_elementwise_compute_x86 SRCS elementwise_compute_test.cc DEPS elementwise_compute_x86)
lite_cc_test(test_relu_compute_x86 SRCS relu_compute_test.cc DEPS activation_compute_x86) lite_cc_test(test_relu_compute_x86 SRCS relu_compute_test.cc DEPS activation_compute_x86)
lite_cc_test(test_search_group_padding_compute_x86 SRCS search_group_padding_compute_test.cc DEPS search_group_padding_compute_x86)
lite_cc_test(test_tanh_compute_x86 SRCS tanh_compute_test.cc DEPS activation_compute_x86) lite_cc_test(test_tanh_compute_x86 SRCS tanh_compute_test.cc DEPS activation_compute_x86)
lite_cc_test(test_gelu_compute_x86 SRCS gelu_compute_test.cc DEPS activation_compute_x86) lite_cc_test(test_gelu_compute_x86 SRCS gelu_compute_test.cc DEPS activation_compute_x86)
lite_cc_test(test_sequence_expand_as_compute_x86 SRCS sequence_expand_as_compute_test.cc DEPS sequence_expand_as_compute_x86) lite_cc_test(test_sequence_expand_as_compute_x86 SRCS sequence_expand_as_compute_test.cc DEPS sequence_expand_as_compute_x86)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/search_group_padding_compute.h"
REGISTER_LITE_KERNEL(
search_group_padding,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::SearchGroupPaddingCompute<float>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out_emb_padding", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out_new", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out_padding", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T>
class SearchGroupPaddingCompute
: public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::SearchGroupPaddingParam;
void Run() override {
auto& param = *param_.get_mutable<operators::SearchGroupPaddingParam>();
auto* bottom0 = param.x;
auto* top0 = param.out_emb_padding;
auto* top1 = param.out_new;
auto* top2 = param.out_padding;
int _pad_id = param.pad_id;
int batch = bottom0->lod()[0].size() - 1;
int dim0 = bottom0->dims()[0];
int dim1 = bottom0->dims()[1];
const auto offset = bottom0->lod()[0];
int max_seq = 0;
for (int i = 0; i < batch; ++i) {
if (offset[i + 1] - offset[i] > max_seq) {
max_seq = offset[i + 1] - offset[i];
}
}
std::vector<size_t> new_offset;
new_offset.resize(batch + 1);
for (int i = 0; i < batch + 1; ++i) {
new_offset[i] = i * max_seq;
}
// for padding data
lite::LoD top0_lod;
top0_lod.push_back(new_offset);
top0->set_lod(top0_lod);
top0->Resize({batch * max_seq, dim1});
// for origin input id
// already set by ShareLoD in InferShape
lite::LoD top1_lod;
top1_lod.push_back(offset);
top1->set_lod(top1_lod);
top1->Resize({dim0, 1});
memset(top1->mutable_data<T>(),
0,
top1->dims()[0] * top1->dims()[1] * sizeof(T));
// for padding input id
lite::LoD top2_lod;
top2_lod.push_back(new_offset);
top2->set_lod(top2_lod);
top2->Resize({batch * max_seq, 1});
// copy data
const auto* bottom_data = bottom0->data<T>();
auto* top_data = top0->mutable_data<T>();
auto* top_padding_input_data = top2->mutable_data<T>();
for (int i = 0; i < batch; i++) {
const int copy_step = offset[i + 1] - offset[i];
const int start = i * max_seq;
memcpy(top_data + start * dim1,
bottom_data + offset[i] * dim1,
copy_step * dim1 * sizeof(T));
memset(top_data + (start + copy_step) * dim1,
0,
(max_seq - copy_step) * dim1 * sizeof(T));
// for padding input id
memset(top_padding_input_data + start, 0, copy_step * sizeof(T));
for (int j = start + copy_step; j < start + max_seq; j++) {
top_padding_input_data[j] = static_cast<T>(_pad_id);
}
}
}
virtual ~SearchGroupPaddingCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/search_group_padding_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
TEST(search_group_padding_x86, retrieve_op) {
auto search_group_padding =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"search_group_padding");
ASSERT_FALSE(search_group_padding.empty());
ASSERT_TRUE(search_group_padding.front());
}
TEST(search_group_padding_x86, init) {
SearchGroupPaddingCompute<float> search_group_padding;
ASSERT_EQ(search_group_padding.precision(), PRECISION(kFloat));
ASSERT_EQ(search_group_padding.target(), TARGET(kX86));
}
TEST(search_group_padding_x86, run_test) {
lite::Tensor x, out_emb_padding, out_new, out_padding;
x.Resize({2, 3});
out_emb_padding.Resize({-1, 3});
out_new.Resize({2, 1});
out_padding.Resize({-1, 1});
LoD x_lod{};
x_lod.push_back({0, 1});
x.set_lod(x_lod);
auto* x_data = x.mutable_data<float>();
for (int64_t i = 0; i < x.dims().production(); i++) {
x_data[i] = static_cast<float>(i);
}
SearchGroupPaddingCompute<float> sgp_kernel;
operators::SearchGroupPaddingParam param;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
sgp_kernel.SetContext(std::move(ctx));
param.x = &x;
param.out_emb_padding = &out_emb_padding;
param.out_new = &out_new;
param.out_padding = &out_padding;
sgp_kernel.SetParam(param);
sgp_kernel.Run();
std::vector<float> out_emb_padding_ref = {0, 1, 2};
std::vector<float> out_new_ref = {0, 0};
std::vector<float> out_padding_ref = {0};
auto* out_emb_padding_data = out_emb_padding.mutable_data<float>();
auto* out_new_data = out_new.mutable_data<float>();
auto* out_padding_data = out_padding.mutable_data<float>();
for (int i = 0; i < out_emb_padding.dims().production(); i++) {
EXPECT_NEAR(out_emb_padding_data[i], out_emb_padding_ref[i], 1e-5);
}
for (int i = 0; i < out_new.dims().production(); i++) {
EXPECT_NEAR(out_new_data[i], out_new_ref[i], 1e-5);
}
for (int i = 0; i < out_padding.dims().production(); i++) {
EXPECT_NEAR(out_padding_data[i], out_padding_ref[i], 1e-5);
}
}
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(search_group_padding, kX86, kFloat, kNCHW, def);
...@@ -7,6 +7,7 @@ add_operator(pool_op basic SRCS pool_op.cc DEPS ${op_DEPS}) ...@@ -7,6 +7,7 @@ add_operator(pool_op basic SRCS pool_op.cc DEPS ${op_DEPS})
add_operator(fc_op basic SRCS fc_op.cc DEPS ${op_DEPS}) add_operator(fc_op basic SRCS fc_op.cc DEPS ${op_DEPS})
add_operator(assign_op extra SRCS assign_op.cc DEPS ${op_DEPS}) add_operator(assign_op extra SRCS assign_op.cc DEPS ${op_DEPS})
add_operator(relu_op basic SRCS relu_op.cc DEPS ${op_DEPS}) add_operator(relu_op basic SRCS relu_op.cc DEPS ${op_DEPS})
add_operator(search_group_padding extra SRCS search_group_padding_op.cc DEPS ${op_DEPS})
add_operator(mul_op basic SRCS mul_op.cc DEPS ${op_DEPS}) add_operator(mul_op basic SRCS mul_op.cc DEPS ${op_DEPS})
add_operator(matmul_op basic SRCS matmul_op.cc DEPS ${op_DEPS}) add_operator(matmul_op basic SRCS matmul_op.cc DEPS ${op_DEPS})
add_operator(scale_op basic SRCS scale_op.cc DEPS ${op_DEPS}) add_operator(scale_op basic SRCS scale_op.cc DEPS ${op_DEPS})
......
...@@ -740,6 +740,14 @@ struct SequencePoolParam { ...@@ -740,6 +740,14 @@ struct SequencePoolParam {
#endif #endif
}; };
struct SearchGroupPaddingParam {
lite::Tensor* x{};
lite::Tensor* out_emb_padding{};
lite::Tensor* out_new{};
lite::Tensor* out_padding{};
int pad_id;
};
struct SequenceReshapeParam { struct SequenceReshapeParam {
lite::Tensor* x{}; lite::Tensor* x{};
lite::Tensor* output{}; lite::Tensor* output{};
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/search_group_padding_op.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool SearchGroupPaddingOp::CheckShape() const {
CHECK_EQ(param_.x->dims().size(), 2) << "The rank of X(Input) should be 2.";
CHECK_EQ(param_.x->lod().empty(), false)
<< "Input Tensor of X does not contain LoD information.";
CHECK_GE(param_.x->lod()[0].size(), 2)
<< "The Input(X)'s lod info is corrupted.";
CHECK_EQ(param_.x->dims()[0], static_cast<int64_t>(param_.x->lod()[0].back()))
<< "The Input(X)'s lod info mismatches the actual tensor shape.";
return true;
}
bool SearchGroupPaddingOp::InferShape() const {
std::vector<int64_t> x_dims = param_.x->dims().Vectorize();
param_.out_emb_padding->Resize({-1, x_dims[1]});
param_.out_new->Resize({x_dims[0], 1});
param_.out_padding->Resize({-1, 1});
return true;
}
bool SearchGroupPaddingOp::AttachImpl(const cpp::OpDesc &op_desc,
lite::Scope *scope) {
auto x = op_desc.Input("X").front();
auto out_emb_padding = op_desc.Input("Out_emb_padding").front();
auto out_new = op_desc.Input("Out_new").front();
auto out_padding = op_desc.Input("Out_padding").front();
param_.x = scope->FindVar(x)->GetMutable<lite::Tensor>();
param_.out_emb_padding =
scope->FindVar(out_emb_padding)->GetMutable<lite::Tensor>();
param_.out_new = scope->FindVar(out_new)->GetMutable<lite::Tensor>();
param_.out_padding = scope->FindVar(out_padding)->GetMutable<lite::Tensor>();
param_.pad_id = op_desc.GetAttr<int>("pad_id");
CHECK(param_.out_emb_padding)
<< "Output(Out_emb_padding) of SearchGroupPadding Op should not be null.";
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(search_group_padding,
paddle::lite::operators::SearchGroupPaddingOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
namespace paddle {
namespace lite {
namespace operators {
class SearchGroupPaddingOp : public OpLite {
public:
SearchGroupPaddingOp() {}
explicit SearchGroupPaddingOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "search_group_padding"; }
private:
mutable SearchGroupPaddingParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册