未验证 提交 78f76834 编写于 作者: H hong19860320 提交者: GitHub

[LITE][X86] Add search_aligned_mat_mul and search_seq_fc op for X86 (#2428)

上级 603b810f
...@@ -47,6 +47,10 @@ add_kernel(search_grnn_compute_x86 X86 basic SRCS search_grnn_compute.cc DEPS ${ ...@@ -47,6 +47,10 @@ add_kernel(search_grnn_compute_x86 X86 basic SRCS search_grnn_compute.cc DEPS ${
add_kernel(sequence_concat_compute_x86 X86 basic SRCS sequence_concat_compute.cc DEPS ${lite_kernel_deps}) add_kernel(sequence_concat_compute_x86 X86 basic SRCS sequence_concat_compute.cc DEPS ${lite_kernel_deps})
add_kernel(var_conv_2d_compute_x86 X86 basic SRCS var_conv_2d_compute.cc DEPS ${lite_kernel_deps} blas fluid_data_type) add_kernel(var_conv_2d_compute_x86 X86 basic SRCS var_conv_2d_compute.cc DEPS ${lite_kernel_deps} blas fluid_data_type)
# for content-dnn specific
add_kernel(search_aligned_mat_mul_compute_x86 X86 extra SRCS search_aligned_mat_mul_compute.cc DEPS ${lite_kernel_deps} blas)
add_kernel(search_seq_fc_compute_x86 X86 extra SRCS search_seq_fc_compute.cc DEPS ${lite_kernel_deps} blas)
if(NOT LITE_WITH_X86) if(NOT LITE_WITH_X86)
return() return()
endif() endif()
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/search_aligned_mat_mul_compute.h"
REGISTER_LITE_KERNEL(
search_aligned_mat_mul,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::SearchAlignedMatMulCompute<float>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/backends/x86/math/blas.h"
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/core/types.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T>
class SearchAlignedMatMulCompute
: public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::MatMulParam;
void Run() override {
auto& context = ctx_->As<X86Context>();
auto& param = *param_.get_mutable<operators::MatMulParam>();
auto x = param.X;
auto y = param.Y;
auto out = param.Out;
bool x_transpose = param.transpose_X;
bool y_transpose = param.transpose_Y;
float alpha = param.alpha;
const auto x_dims = x->dims();
const auto y_dims = y->dims();
const auto& x_lod = x->lod();
const auto& y_lod = y->lod();
const auto& x_lod_0 = x_lod[0];
const auto& y_lod_0 = y_lod[0];
int seq_num = x_lod_0.size() - 1;
int x_inner_size = x_dims[1];
int y_inner_size = y_dims[1];
int x_batch_size = x_lod_0[1];
int y_batch_size = y_lod_0[1];
int M = x_transpose ? x_inner_size : x_batch_size;
int N = y_transpose ? y_batch_size : y_inner_size;
int X_K = x_transpose ? x_batch_size : x_inner_size;
int Y_K = y_transpose ? y_inner_size : y_batch_size;
CHECK_EQ(X_K, Y_K) << "K of Input(X) and Input(Y) is not equal";
int K = X_K;
lite::x86::math::MatDescriptor mat_dim_a;
mat_dim_a.height_ = M;
mat_dim_a.width_ = K;
mat_dim_a.stride_ = x_batch_size * x_inner_size;
mat_dim_a.batch_size_ = seq_num;
mat_dim_a.trans_ = x_transpose;
lite::x86::math::MatDescriptor mat_dim_b;
mat_dim_b.height_ = K;
mat_dim_b.width_ = N;
mat_dim_b.stride_ = y_batch_size * y_inner_size;
mat_dim_b.batch_size_ = seq_num;
mat_dim_b.trans_ = y_transpose;
auto blas = lite::x86::math::GetBlas<lite::TargetType::kX86, T>(context);
blas.MatMul(*x, mat_dim_a, *y, mat_dim_b, static_cast<T>(alpha), out, T(0));
}
virtual ~SearchAlignedMatMulCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/search_seq_fc_compute.h"
REGISTER_LITE_KERNEL(search_seq_fc,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::SearchSeqFcCompute<float>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("W", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("b", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/backends/x86/math/blas.h"
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/core/types.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T>
class SearchSeqFcCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::SearchSeqFcParam;
void Run() override {
auto& context = ctx_->As<X86Context>();
auto& param = *param_.get_mutable<operators::SearchSeqFcParam>();
auto x = param.x;
auto w = param.w;
auto b = param.b;
auto out = param.out;
auto out_size = param.out_size;
const auto x_dims = x->dims();
const auto w_dims = w->dims();
const auto out_dims = out->dims();
CHECK_EQ(x_dims.size(), 2) << "The Input(X) should be 2-D tensor.";
CHECK_EQ(w_dims.size(), 2) << "W should be 2-D tensor.";
CHECK_EQ(out_dims.size(), 2) << "The Output(Out) should be 2-D tensor.";
CHECK_EQ(x_dims[1], w_dims[1]) << "Wrong shape: x_dims[1] != w_dims[1]";
CHECK_EQ(w_dims[0], out_size) << "Wrong shape: w_dims[0] != out_size";
CHECK_EQ(out_dims[0], x_dims[0]) << "Wrong shape: out_dims[0] != x_dims[0]";
CHECK_EQ(out_dims[1], out_size) << "Wrong shape: out_dims[1] != out_size";
auto blas = lite::x86::math::GetBlas<lite::TargetType::kX86, T>(context);
blas.MatMul(*x, false, *w, true, out);
if (b != nullptr) {
auto b_dims = b->dims();
CHECK_EQ(b_dims.size(), 1) << "b should be 1-D tensor.";
CHECK_EQ(b_dims[0], w_dims[0]) << "Wrong shape: b_dims[0] != w_dims[0]";
int M = x_dims[0];
int N = w_dims[0];
for (int i = 0; i < M; i++) {
blas.AXPY(
N, static_cast<T>(1), b->data<T>(), out->mutable_data<T>() + i * N);
}
}
}
virtual ~SearchSeqFcCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
...@@ -114,6 +114,9 @@ add_operator(increment_op extra SRCS increment_op.cc DEPS ${op_DEPS}) ...@@ -114,6 +114,9 @@ add_operator(increment_op extra SRCS increment_op.cc DEPS ${op_DEPS})
add_operator(layer_norm_op extra SRCS layer_norm_op.cc DEPS ${op_DEPS}) add_operator(layer_norm_op extra SRCS layer_norm_op.cc DEPS ${op_DEPS})
add_operator(sequence_softmax_op extra SRCS sequence_softmax_op.cc DEPS ${op_DEPS}) add_operator(sequence_softmax_op extra SRCS sequence_softmax_op.cc DEPS ${op_DEPS})
# for content-dnn specific
add_operator(search_aligned_mat_mul_op extra SRCS search_aligned_mat_mul_op.cc DEPS ${op_DEPS})
add_operator(search_seq_fc_op extra SRCS search_seq_fc_op.cc DEPS ${op_DEPS})
if (NOT LITE_WITH_X86) if (NOT LITE_WITH_X86)
lite_cc_test(test_fc_op SRCS fc_op_test.cc lite_cc_test(test_fc_op SRCS fc_op_test.cc
......
...@@ -89,6 +89,14 @@ struct FcParam { ...@@ -89,6 +89,14 @@ struct FcParam {
WITH_INT8_CONFIG WITH_INT8_CONFIG
}; };
struct SearchSeqFcParam {
lite::Tensor* x{nullptr};
lite::Tensor* w{nullptr};
lite::Tensor* b{nullptr};
lite::Tensor* out{nullptr};
int out_size;
};
// For Interpolate Op // For Interpolate Op
struct InterpolateParam { struct InterpolateParam {
lite::Tensor* X{}; lite::Tensor* X{};
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/search_aligned_mat_mul_op.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool SearchAlignedMatMulOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.Y);
CHECK_OR_FALSE(param_.Out);
return true;
}
bool SearchAlignedMatMulOpLite::InferShape() const {
const auto x_dims = param_.X->dims();
const auto y_dims = param_.Y->dims();
const auto& x_lod = param_.X->lod();
const auto& y_lod = param_.Y->lod();
bool x_transpose = param_.transpose_X;
bool y_transpose = param_.transpose_Y;
CHECK_EQ(x_dims.size(), 2) << "X should be 2-D tensor";
CHECK_EQ(y_dims.size(), 2) << "Y should be 2-D tensor";
CHECK(!x_lod.empty()) << "The Input(X) must hold lod info.";
CHECK(!y_lod.empty()) << "The Input(Y) must hold lod info.";
const auto& x_lod_0 = x_lod[0];
const auto& y_lod_0 = y_lod[0];
CHECK_GE(x_lod_0.size(), 2) << "The Input(X)'s lod info is corrupted.";
CHECK_GE(y_lod_0.size(), 2) << "The Input(Y)'s lod info is corrupted.";
CHECK_EQ(x_dims[0], static_cast<int64_t>(x_lod_0.back()))
<< "The Input(X)'s lod info mismatches the actual tensor shape.";
CHECK_EQ(y_dims[0], static_cast<int64_t>(y_lod_0.back()))
<< "The Input(Y)'s lod info mismatches the actual tensor shape.";
CHECK_EQ(x_lod_0.size(), y_lod_0.size())
<< "The Length of X and Y must be equal.";
int seq_num = x_lod_0.size() - 1;
int x_inner_size = x_dims[1];
int y_inner_size = y_dims[1];
int x_batch_size = x_lod_0[1];
int y_batch_size = y_lod_0[1];
int M = x_transpose ? x_inner_size : x_batch_size;
int N = y_transpose ? y_batch_size : y_inner_size;
int X_K = x_transpose ? x_batch_size : x_inner_size;
int Y_K = y_transpose ? y_inner_size : y_batch_size;
CHECK_EQ(X_K, Y_K) << "K of Input(X) and Input(Y) is not equal";
LoD out_lod;
std::vector<uint64_t> out_lod_0(seq_num + 1);
out_lod_0[0] = 0;
for (int i = 0; i < seq_num; i++) {
out_lod_0[i + 1] = out_lod_0[i] + M;
}
out_lod.push_back(out_lod_0);
DDim out_dims(
{static_cast<int64_t>(out_lod_0.back()), static_cast<int64_t>(N)});
param_.Out->set_lod(out_lod);
param_.Out->Resize(out_dims);
return true;
}
bool SearchAlignedMatMulOpLite::AttachImpl(const cpp::OpDesc& op_desc,
lite::Scope* scope) {
CHECK(!op_desc.Input("X").empty());
CHECK(!op_desc.Input("Y").empty());
CHECK(!op_desc.Output("Out").empty());
auto X = op_desc.Input("X").front();
auto Y = op_desc.Input("Y").front();
auto Out = op_desc.Output("Out").front();
param_.X = GetVar<lite::Tensor>(scope, X);
param_.Y = GetVar<lite::Tensor>(scope, Y);
param_.Out = GetMutableVar<lite::Tensor>(scope, Out);
param_.transpose_X = op_desc.GetAttr<bool>("transpose_X");
param_.transpose_Y = op_desc.GetAttr<bool>("transpose_Y");
param_.alpha = op_desc.GetAttr<float>("alpha");
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(search_aligned_mat_mul,
paddle::lite::operators::SearchAlignedMatMulOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class SearchAlignedMatMulOpLite : public OpLite {
public:
SearchAlignedMatMulOpLite() {}
explicit SearchAlignedMatMulOpLite(const std::string &type) : OpLite(type) {}
bool CheckShape() const override;
bool InferShape() const override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override;
std::string DebugString() const override { return "search_aligned_mat_mul"; }
private:
mutable MatMulParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/search_seq_fc_op.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool SearchSeqFcOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.w);
CHECK_OR_FALSE(param_.out);
return true;
}
bool SearchSeqFcOpLite::InferShape() const {
const auto x_dims = param_.x->dims();
const auto w_dims = param_.w->dims();
const auto& x_lod = param_.x->lod();
auto out_size = param_.out_size;
CHECK_EQ(x_dims.size(), 2) << "The Input(X) should be 2-D tensor.";
CHECK(!x_lod.empty()) << "The Input(X) must hold lod info.";
const auto& x_lod_0 = x_lod[0];
CHECK_GE(x_lod_0.size(), 2) << "The Input(X)'s lod info is corrupted.";
CHECK_EQ(x_dims[0], static_cast<int64_t>(x_lod_0.back()))
<< "The Input(X)'s lod info mismatches the actual tensor shape.";
CHECK_EQ(w_dims.size(), 2) << "W should be 2-D tensor.";
CHECK_EQ(x_dims[1], w_dims[1]) << "Wrong shape: x_dims[1] != w_dims[1]";
CHECK_EQ(w_dims[0], out_size) << "Wrong shape: w_dims[0] != out_size";
if (param_.b != nullptr) {
const auto b_dims = param_.b->dims();
CHECK_EQ(b_dims.size(), 1) << "b should be 1-D tensor.";
CHECK_EQ(b_dims[0], w_dims[0]) << "Wrong shape: b_dims[0] != w_dims[0]";
}
param_.out->set_lod(x_lod);
param_.out->Resize({x_dims[0], w_dims[0]});
return true;
}
bool SearchSeqFcOpLite::AttachImpl(const cpp::OpDesc& op_desc,
lite::Scope* scope) {
CHECK(!op_desc.Input("X").empty());
CHECK(!op_desc.Input("W").empty());
CHECK(!op_desc.Output("Out").empty());
auto x = op_desc.Input("X").front();
auto w = op_desc.Input("W").front();
auto out = op_desc.Output("Out").front();
param_.x = scope->FindVar(x)->GetMutable<lite::Tensor>();
param_.w = scope->FindVar(w)->GetMutable<lite::Tensor>();
param_.out = scope->FindVar(out)->GetMutable<lite::Tensor>();
param_.out_size = op_desc.GetAttr<int>("out_size");
bool has_bias = op_desc.GetAttr<bool>("has_bias");
if (has_bias) {
CHECK(!op_desc.Input("b").empty());
auto b = op_desc.Input("b").front();
param_.b = scope->FindVar(b)->GetMutable<lite::Tensor>();
}
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(search_seq_fc, paddle::lite::operators::SearchSeqFcOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class SearchSeqFcOpLite : public OpLite {
public:
SearchSeqFcOpLite() {}
explicit SearchSeqFcOpLite(const std::string &type) : OpLite(type) {}
bool CheckShape() const override;
bool InferShape() const override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override;
std::string DebugString() const override { return "search_seq_fc"; }
private:
mutable SearchSeqFcParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
...@@ -39,6 +39,8 @@ if(LITE_BUILD_EXTRA) ...@@ -39,6 +39,8 @@ if(LITE_BUILD_EXTRA)
lite_cc_test(test_kernel_anchor_generator_compute SRCS anchor_generator_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_anchor_generator_compute SRCS anchor_generator_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_generate_proposals_compute SRCS generate_proposals_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_generate_proposals_compute SRCS generate_proposals_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_roi_align_compute SRCS roi_align_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_roi_align_compute SRCS roi_align_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_search_aligned_mat_mul_compute SRCS search_aligned_mat_mul_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_search_seq_fc_compute SRCS search_seq_fc_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
endif() endif()
lite_cc_test(test_kernel_pad2d_compute SRCS pad2d_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_pad2d_compute SRCS pad2d_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_prior_box_compute SRCS prior_box_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_prior_box_compute SRCS prior_box_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
#include "lite/tests/utils/fill_data.h"
#include "lite/tests/utils/naive_math_impl.h"
namespace paddle {
namespace lite {
class SearchAlignedMatMulComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string x_ = "X";
std::string y_ = "Y";
bool x_transpose_;
bool y_transpose_;
float alpha_;
std::string out_ = "Out";
DDim x_dims_;
DDim y_dims_;
LoD x_lod_;
LoD y_lod_;
public:
SearchAlignedMatMulComputeTester(const Place& place,
const std::string& alias,
bool x_transpose,
bool y_transpose,
float alpha,
const DDim& x_dims,
const DDim& y_dims,
const LoD& x_lod,
const LoD& y_lod)
: TestCase(place, alias),
x_transpose_(x_transpose),
y_transpose_(y_transpose),
alpha_(alpha),
x_dims_(x_dims),
y_dims_(y_dims),
x_lod_(x_lod),
y_lod_(y_lod) {}
void RunBaseline(Scope* scope) override {
auto x = scope->FindTensor(x_);
auto y = scope->FindTensor(y_);
CHECK(x);
CHECK(y);
const auto x_data = x->data<float>();
const auto y_data = y->data<float>();
auto out = scope->NewTensor(out_);
CHECK(out);
const auto x_dims = x->dims();
const auto y_dims = y->dims();
const auto& x_lod = x->lod();
const auto& y_lod = y->lod();
const auto& x_lod_0 = x_lod[0];
const auto& y_lod_0 = y_lod[0];
int seq_num = x_lod_0.size() - 1;
int x_inner_size = x_dims[1];
int y_inner_size = y_dims[1];
int x_batch_size = x_lod_0[1];
int y_batch_size = y_lod_0[1];
int M = x_transpose_ ? x_inner_size : x_batch_size;
int N = y_transpose_ ? y_batch_size : y_inner_size;
int X_K = x_transpose_ ? x_batch_size : x_inner_size;
int Y_K = y_transpose_ ? y_inner_size : y_batch_size;
CHECK_EQ(X_K, Y_K) << "K of Input(X) and Input(Y) is not equal";
int K = X_K;
int x_stride = x_batch_size * x_inner_size;
int y_stride = y_batch_size * y_inner_size;
int out_stride = M * N;
int lda = x_transpose_ ? M : K;
int ldb = y_transpose_ ? K : N;
int ldc = N;
LoD out_lod;
std::vector<uint64_t> out_lod_0(seq_num + 1);
out_lod_0[0] = 0;
for (int i = 0; i < seq_num; i++) {
out_lod_0[i + 1] = out_lod_0[i] + M;
}
out_lod.push_back(out_lod_0);
DDim out_dims(
{static_cast<int64_t>(out_lod_0.back()), static_cast<int64_t>(N)});
out->set_lod(out_lod);
out->Resize(out_dims);
auto out_data = out->mutable_data<float>();
for (int i = 0; i < seq_num; i++) {
basic_gemm<float, float>(x_transpose_,
y_transpose_,
M,
N,
K,
alpha_,
x_data + i * x_stride,
lda,
y_data + i * y_stride,
ldb,
0,
out_data + i * out_stride,
ldc,
nullptr,
false,
false);
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("search_aligned_mat_mul");
op_desc->SetInput("X", {x_});
op_desc->SetInput("Y", {y_});
op_desc->SetOutput("Out", {out_});
op_desc->SetAttr("transpose_X", x_transpose_);
op_desc->SetAttr("transpose_Y", y_transpose_);
op_desc->SetAttr("alpha", alpha_);
}
void PrepareData() override {
std::vector<float> x_data(x_dims_.production());
std::vector<float> y_data(y_dims_.production());
fill_data_rand(x_data.data(), -1.f, 1.f, x_dims_.production());
fill_data_rand(y_data.data(), -1.f, 1.f, y_dims_.production());
SetCommonTensor(x_, x_dims_, x_data.data(), x_lod_);
SetCommonTensor(y_, y_dims_, y_data.data(), y_lod_);
}
};
void test_search_aligned_mat_mul(Place place) {
for (int seq_num : {1, 2}) {
for (int x_batch_size : {1, 3}) {
for (int x_inner_size : {1, 5}) {
for (int out_inner_size : {1, 4}) {
for (bool x_transpose : {true, false}) {
for (bool y_transpose : {true, false}) {
for (float alpha : {1., 2.}) {
// infer x_dims and y_dims
int y_batch_size;
int y_inner_size;
if (x_transpose) {
if (y_transpose) {
y_batch_size = out_inner_size;
y_inner_size = x_batch_size;
} else {
y_batch_size = x_batch_size;
y_inner_size = out_inner_size;
}
} else {
if (y_transpose) {
y_batch_size = out_inner_size;
y_inner_size = x_inner_size;
} else {
y_batch_size = x_inner_size;
y_inner_size = out_inner_size;
}
}
std::vector<uint64_t> x_lod_0(seq_num + 1);
std::vector<uint64_t> y_lod_0(seq_num + 1);
x_lod_0[0] = 0;
y_lod_0[0] = 0;
for (int i = 0; i < seq_num; i++) {
x_lod_0[i + 1] = x_lod_0[i] + x_batch_size;
y_lod_0[i + 1] = y_lod_0[i] + y_batch_size;
}
LoD x_lod;
LoD y_lod;
x_lod.push_back(x_lod_0);
y_lod.push_back(y_lod_0);
DDim x_dims({static_cast<int64_t>(x_lod_0.back()),
static_cast<int64_t>(x_inner_size)});
DDim y_dims({static_cast<int64_t>(y_lod_0.back()),
static_cast<int64_t>(y_inner_size)});
std::unique_ptr<arena::TestCase> tester(
new SearchAlignedMatMulComputeTester(place,
"def",
x_transpose,
y_transpose,
alpha,
x_dims,
y_dims,
x_lod,
y_lod));
arena::Arena arena(std::move(tester), place, 5e-4);
arena.TestPrecision();
}
}
}
}
}
}
}
}
TEST(SearchAlignedMatMul, precision) {
#ifdef LITE_WITH_X86
Place place(TARGET(kX86));
test_search_aligned_mat_mul(place);
#endif
}
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
#include "lite/tests/utils/fill_data.h"
#include "lite/tests/utils/naive_math_impl.h"
namespace paddle {
namespace lite {
class SearchSeqFcOPTest : public arena::TestCase {
protected:
// common attributes for this op.
std::string x_ = "x";
std::string w_ = "w";
std::string b_ = "b";
std::string out_ = "out";
DDim x_dims_;
DDim w_dims_;
DDim b_dims_;
LoD x_lod_;
bool has_bias_;
int out_size_;
public:
SearchSeqFcOPTest(const Place& place,
const std::string& alias,
DDim x_dims,
DDim w_dims,
DDim b_dims,
LoD x_lod,
bool has_bias,
int out_size)
: TestCase(place, alias),
x_dims_(x_dims),
w_dims_(w_dims),
b_dims_(b_dims),
x_lod_(x_lod),
has_bias_(has_bias),
out_size_(out_size) {}
void RunBaseline(Scope* scope) override {
auto x = scope->FindTensor(x_);
auto w = scope->FindTensor(w_);
CHECK(x);
CHECK(w);
auto out = scope->NewTensor(out_);
CHECK(out);
const auto x_data = x->data<float>();
const auto w_data = w->data<float>();
const auto x_dims = x->dims();
const auto w_dims = w->dims();
const auto& x_lod = x->lod();
CHECK_EQ(x_dims.size(), 2) << "The Input(X) should be 2-D tensor.";
CHECK(!x_lod.empty()) << "The Input(X) must hold lod info.";
const auto& x_lod_0 = x_lod[0];
CHECK_GE(x_lod_0.size(), 2) << "The Input(X)'s lod info is corrupted.";
CHECK_EQ(x_dims[0], static_cast<int64_t>(x_lod_0.back()))
<< "The Input(X)'s lod info mismatches the actual tensor shape.";
CHECK_EQ(w_dims.size(), 2) << "W should be 2-D tensor.";
CHECK_EQ(x_dims[1], w_dims[1]) << "Wrong shape: x_dims[1] != w_dims[1]";
CHECK_EQ(w_dims[0], out_size_) << "Wrong shape: w_dims[0] != out_size";
const float* b_data = nullptr;
if (has_bias_) {
auto b = scope->FindTensor(b_);
CHECK(b);
auto b_dims = b->dims();
CHECK_EQ(b_dims.size(), 1) << "b should be 1-D tensor.";
CHECK_EQ(b_dims[0], w_dims[0]) << "Wrong shape: b_dims[0] != w_dims[0]";
b_data = b->data<float>();
}
out->set_lod(x_lod);
out->Resize({x_dims[0], w_dims[0]});
int M = x_dims[0];
int K = x_dims[1];
int N = w_dims[0];
auto out_data = out->mutable_data<float>();
basic_gemm<float, float>(false,
true,
M,
N,
K,
1.f,
x_data,
K,
w_data,
K,
0,
out_data,
N,
nullptr,
false,
false);
if (b_data != nullptr) {
for (int i = 0; i < M; i++) {
for (int j = 0; j < N; j++) {
out_data[i * N + j] += b_data[j];
}
}
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("search_seq_fc");
op_desc->SetInput("X", {x_});
op_desc->SetInput("W", {w_});
if (has_bias_) {
op_desc->SetInput("b", {b_});
}
op_desc->SetAttr<bool>("has_bias", has_bias_);
op_desc->SetAttr<int>("out_size", out_size_);
op_desc->SetOutput("Out", {out_});
}
void PrepareData() override {
std::vector<float> x_data(x_dims_.production());
std::vector<float> w_data(w_dims_.production());
fill_data_rand(x_data.data(), -1.f, 1.f, x_dims_.production());
fill_data_rand(w_data.data(), -1.f, 1.f, w_dims_.production());
SetCommonTensor(x_, x_dims_, x_data.data(), x_lod_);
SetCommonTensor(w_, w_dims_, w_data.data());
if (has_bias_) {
std::vector<float> b_data(b_dims_.production());
fill_data_rand(b_data.data(), -1.f, 1.f, b_dims_.production());
SetCommonTensor(b_, b_dims_, b_data.data());
}
}
};
void test_search_seq_fc(Place place) {
for (auto x_lod_0 : {std::vector<uint64_t>({0, 1, 3}),
std::vector<uint64_t>({0, 3, 4, 5})}) {
for (auto feature_size : {2, 9}) {
for (auto out_size : {3, 5}) {
for (auto has_bias : {true, false}) {
DDim x_dims({static_cast<int64_t>(x_lod_0.back()), feature_size});
DDim w_dims({out_size, feature_size});
DDim b_dims({has_bias ? out_size : 0});
LoD x_lod;
x_lod.push_back(x_lod_0);
std::unique_ptr<arena::TestCase> tester(new SearchSeqFcOPTest(
place, "def", x_dims, w_dims, b_dims, x_lod, has_bias, out_size));
arena::Arena arena(std::move(tester), place, 6e-5);
arena.TestPrecision();
}
}
}
}
}
TEST(SearchSeqFcOP, precision) {
#ifdef LITE_WITH_X86
Place place(TARGET(kX86));
test_search_seq_fc(place);
#endif
}
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册