未验证 提交 28481458 编写于 作者: Y Yiqun Liu 提交者: GitHub

[X86] Polish the implementation of fc and imporve the unittest (#2656)

* Remove GEMM padding in fc_compute.
test=develop

* Write a common ParallelFor function to run the for loop in parallel.

* Add the codes of padding GEMM back in fc.

* Refine the code of fc when padding_weight is false to avoid the definition of temporary Tensor.

* Refine the unit test of fc and add testing case of padding and parallel.
test=develop

* Enable more test cases in common fc unittest, including padding and parallel for x86 target.

* Remove the fc test under kernels/x86.
test=develop

* Disable relu in test of fc for non-x86 target.
test=develop

* Change the eps of arm.
test=develop
上级 d345a7fc
...@@ -49,7 +49,7 @@ TEST(Step_rnn, test_step_rnn_lite_x86) { ...@@ -49,7 +49,7 @@ TEST(Step_rnn, test_step_rnn_lite_x86) {
"micro_video_id", "micro_video_id",
"vertical_type_id"}; "vertical_type_id"};
for (int i = 0; i < target_names.size(); ++i) { for (size_t i = 0; i < target_names.size(); ++i) {
auto input_tensor = predictor->GetInput(i); auto input_tensor = predictor->GetInput(i);
int size = 0; int size = 0;
if (i == 6 || i == 8) { if (i == 6 || i == 8) {
...@@ -74,8 +74,7 @@ TEST(Step_rnn, test_step_rnn_lite_x86) { ...@@ -74,8 +74,7 @@ TEST(Step_rnn, test_step_rnn_lite_x86) {
predictor->Run(); predictor->Run();
} }
// LOG(INFO) << "================== Speed Report ==================="; LOG(INFO) << "warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats
LOG(INFO) << ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats
<< ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0 << ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0
<< " ms in average."; << " ms in average.";
...@@ -86,8 +85,8 @@ TEST(Step_rnn, test_step_rnn_lite_x86) { ...@@ -86,8 +85,8 @@ TEST(Step_rnn, test_step_rnn_lite_x86) {
std::vector<int64_t> out_shape = out->shape(); std::vector<int64_t> out_shape = out->shape();
for (int i = 0; i < results.size(); ++i) { for (size_t i = 0; i < results.size(); ++i) {
for (int j = 0; j < results[i].size(); ++j) { for (size_t j = 0; j < results[i].size(); ++j) {
EXPECT_NEAR( EXPECT_NEAR(
out->data<float>()[j + (out_shape[1] * i)], results[i][j], 1e-6); out->data<float>()[j + (out_shape[1] * i)], results[i][j], 1e-6);
} }
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#ifdef PADDLE_WITH_MKLML
#include <omp.h>
#include "lite/backends/x86/mklml.h"
#endif
namespace paddle {
namespace lite {
namespace x86 {
static void SetNumThreads(int num_threads) {
#ifdef PADDLE_WITH_MKLML
int real_num_threads = std::max(num_threads, 1);
x86::MKL_Set_Num_Threads(real_num_threads);
omp_set_num_threads(real_num_threads);
#endif
}
static inline int64_t GetMaxThreads() {
int64_t num_threads = 1;
#ifdef PADDLE_WITH_MKLML
// Do not support nested omp parallem.
num_threads = omp_in_parallel() ? 1 : omp_get_max_threads();
#endif
return std::max(num_threads, 1L);
}
using ThreadHandler =
std::function<void(const int64_t begin, const int64_t end)>;
static inline void RunParallelFor(const int64_t begin,
const int64_t end,
const ThreadHandler& f) {
if (begin >= end) {
return;
}
#ifdef PADDLE_WITH_MKLML
int64_t num_threads = std::min(GetMaxThreads(), end - begin);
if (num_threads > 1) {
#pragma omp parallel num_threads(num_threads)
{
int64_t tid = omp_get_thread_num();
int64_t chunk_size = (end - begin + num_threads - 1) / num_threads;
int64_t begin_tid = begin + tid * chunk_size;
f(begin_tid, std::min(end, chunk_size + begin_tid));
}
return;
}
#endif
f(begin, end);
}
} // namespace x86
} // namespace lite
} // namespace paddle
...@@ -100,4 +100,3 @@ lite_cc_test(test_sequence_concat_compute_x86 SRCS sequence_concat_compute_test. ...@@ -100,4 +100,3 @@ lite_cc_test(test_sequence_concat_compute_x86 SRCS sequence_concat_compute_test.
lite_cc_test(test_var_conv_2d_compute_x86 SRCS var_conv_2d_compute_test.cc DEPS var_conv_2d_compute_x86) lite_cc_test(test_var_conv_2d_compute_x86 SRCS var_conv_2d_compute_test.cc DEPS var_conv_2d_compute_x86)
#lite_cc_test(test_attention_padding_mask_compute_x86 SRCS attention_padding_mask_compute_test.cc DEPS attention_padding_mask_compute_x86) #lite_cc_test(test_attention_padding_mask_compute_x86 SRCS attention_padding_mask_compute_test.cc DEPS attention_padding_mask_compute_x86)
lite_cc_test(test_sequence_arithmetic_compute_x86 SRCS sequence_arithmetic_compute_test.cc DEPS sequence_arithmetic_compute_x86) lite_cc_test(test_sequence_arithmetic_compute_x86 SRCS sequence_arithmetic_compute_test.cc DEPS sequence_arithmetic_compute_x86)
lite_cc_test(test_fc_compute_x86 SRCS fc_compute_test.cc DEPS fc_compute_x86)
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <vector> #include <vector>
...@@ -18,6 +19,7 @@ ...@@ -18,6 +19,7 @@
#include "lite/backends/x86/jit/kernel_base.h" #include "lite/backends/x86/jit/kernel_base.h"
#include "lite/backends/x86/jit/kernels.h" #include "lite/backends/x86/jit/kernels.h"
#include "lite/backends/x86/math/blas.h" #include "lite/backends/x86/math/blas.h"
#include "lite/backends/x86/parallel.h"
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
#include "lite/core/op_lite.h" #include "lite/core/op_lite.h"
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
...@@ -57,34 +59,45 @@ class FCFunctor { ...@@ -57,34 +59,45 @@ class FCFunctor {
bool relu = false, bool relu = false,
bool padding_weights = false) { bool padding_weights = false) {
auto blas = lite::x86::math::GetBlas<lite::TargetType::kX86, T>(context); auto blas = lite::x86::math::GetBlas<lite::TargetType::kX86, T>(context);
lite::Tensor Y1;
T* Y1_data = nullptr; T* Y1_data = nullptr;
if (N % 128 == 0 && K % 128 == 0) {
auto compute =
relu
? jit::KernelFuncs<jit::VAddReluTuple<T>, fluid::CPUPlace>::Cache()
.At(N)
: jit::KernelFuncs<jit::VAddTuple<T>, fluid::CPUPlace>::Cache().At(
N);
auto parallel_compute = [&](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; i++) {
T* dst = Y + i * N;
T* src = Y1_data ? Y1_data + i * (N + 4) : dst;
compute(B, src, dst, N);
}
};
// Because of the overhead of memcpy, we only do padding for GEMM
// when weights is already padded in fc_fuse_pass.
if (padding_weights) {
const int NN = N + 4; const int NN = N + 4;
const int KK = K + 4; const int KK = K + 4;
// NOTE: here need to mutable_data for temporary Tensor X1 and Y1,
// the overhead is unmeasured.
lite::Tensor X1; lite::Tensor X1;
X1.Resize({M * KK}); X1.Resize({M * KK});
Y1.Resize({M * (N + 4)});
T* X1_data = X1.mutable_data<T>(); T* X1_data = X1.mutable_data<T>();
lite::Tensor Y1;
Y1.Resize({M * (N + 4)});
Y1_data = Y1.mutable_data<T>(); Y1_data = Y1.mutable_data<T>();
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for auto parallel_memcpy_x = [&](int64_t begin, int64_t end) {
#endif for (int64_t i = begin; i < end; i++) {
for (int i = 0; i < M; i++) { memcpy(X1_data + i * KK, X + i * K, K * sizeof(T));
memcpy(X1_data + i * KK, X + i * K, K * sizeof(X[0]));
}
lite::Tensor W1;
T* W1_data = nullptr;
if (!padding_weights) {
W1.Resize({(K + 4) * (N + 4)});
W1_data = W1.mutable_data<T>();
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int i = 0; i < K; i++) {
memcpy(W1_data + i * NN, W + i * N, N * sizeof(W[0]));
} }
} };
lite::x86::RunParallelFor(0, M, parallel_memcpy_x);
blas.GEMM(false, blas.GEMM(false,
false, false,
M, M,
...@@ -93,48 +106,30 @@ class FCFunctor { ...@@ -93,48 +106,30 @@ class FCFunctor {
static_cast<T>(1.0), static_cast<T>(1.0),
X1_data, X1_data,
KK, KK,
(padding_weights ? W : W1_data), W,
NN, NN,
static_cast<T>(0.0), static_cast<T>(0.0),
Y1_data, Y1_data,
NN); NN);
} else {
blas.MatMul(M, N, K, X, W, Y); if (!B) {
} auto parallel_memcpy_y = [&](int64_t begin, int64_t end) {
if (B == NULL) { for (int64_t i = begin; i < end; i++) {
if (N % 128 == 0 && K % 128 == 0) { memcpy(Y + i * N, Y1_data + i * (N + 4), N * sizeof(T));
#ifdef PADDLE_WITH_MKLML }
#pragma omp parallel for };
#endif lite::x86::RunParallelFor(0, M, parallel_memcpy_y);
for (int i = 0; i < M; i++) { return;
memcpy(Y + i * N, Y1_data + i * (N + 4), N * sizeof(Y[0]));
}
}
return;
}
if (relu) {
auto compute =
paddle::lite::jit::KernelFuncs<paddle::lite::jit::VAddReluTuple<T>,
lite::fluid::CPUPlace>::Cache()
.At(N);
for (int i = 0; i < M; i++) {
T* dst = Y + i * N;
T* src = (N % 128 == 0 && K % 128 == 0) ? Y1_data + i * (N + 4) : dst;
compute(B, src, dst, N);
} }
lite::x86::RunParallelFor(0, M, parallel_compute);
} else { } else {
auto compute = blas.MatMul(M, N, K, X, W, Y);
paddle::lite::jit::KernelFuncs<paddle::lite::jit::VAddTuple<T>, if (!B) {
lite::fluid::CPUPlace>::Cache() return;
.At(N);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int i = 0; i < M; i++) {
T* dst = Y + i * N;
T* src = (N % 128 == 0 && K % 128 == 0) ? Y1_data + i * (N + 4) : dst;
compute(B, src, dst, N);
} }
lite::x86::RunParallelFor(0, M, parallel_compute);
} }
} }
}; };
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/fc_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
TEST(fc_x86, retrive_op) {
auto fc =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>("fc");
ASSERT_FALSE(fc.empty());
ASSERT_TRUE(fc.front());
}
TEST(fc_x86, init) {
FcCompute<float> fc;
ASSERT_EQ(fc.precision(), PRECISION(kFloat));
ASSERT_EQ(fc.target(), TARGET(kX86));
}
TEST(fc_x86, run_test) {
lite::Tensor x, w, b, out;
constexpr int batch_size = 2;
std::vector<int64_t> x_shape{batch_size, 3};
x.Resize(lite::DDim(x_shape));
std::vector<int64_t> w_shape{3, 4};
w.Resize(lite::DDim(w_shape));
std::vector<int64_t> b_shape{1, 4};
b.Resize(lite::DDim(b_shape));
std::vector<int64_t> out_shape{batch_size, 4};
out.Resize(lite::DDim(out_shape));
auto x_data = x.mutable_data<float>();
auto w_data = w.mutable_data<float>();
auto b_data = b.mutable_data<float>();
auto out_data = out.mutable_data<float>();
for (int64_t i = 0; i < x.dims().production(); i++) {
x_data[i] = static_cast<float>(i);
}
for (int64_t i = 0; i < w.dims().production(); i++) {
w_data[i] = static_cast<float>(2);
}
for (int64_t i = 0; i < b.dims().production(); i++) {
b_data[i] = static_cast<float>(2);
}
// FcCompute fc;
FcCompute<float> fc;
operators::FcParam param;
param.in_num_col_dims = 1;
param.input = &x;
param.w = &w;
param.bias = &b;
param.output = &out;
param.in_mat_dims = x.dims();
param.activation_type = "relu";
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
fc.SetParam(param);
fc.SetContext(std::move(ctx));
fc.Run();
std::vector<float> ref_data({8, 8, 8, 8, 26, 26, 26, 26});
for (int i = 0; i < out.dims().production(); i++) {
EXPECT_NEAR(out_data[i], ref_data[i], 1e-5);
}
}
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(fc, kX86, kFloat, kNCHW, def);
...@@ -27,21 +27,21 @@ bool FcOpLite::CheckShape() const { ...@@ -27,21 +27,21 @@ bool FcOpLite::CheckShape() const {
const auto input_dims = param_.input->dims(); const auto input_dims = param_.input->dims();
const auto w_dims = param_.w->dims(); const auto w_dims = param_.w->dims();
CHECK_EQ_OR_FALSE(w_dims.size(), 2UL);
int64_t w_dims_1 = param_.padding_weights ? w_dims[1] - 4 : w_dims[1];
if (param_.bias) { if (param_.bias) {
const auto bias_dims = param_.bias->dims(); const auto bias_dims = param_.bias->dims();
if (bias_dims.size() == 2) { if (bias_dims.size() == 2) {
CHECK_EQ_OR_FALSE(bias_dims[0], 1); CHECK_EQ_OR_FALSE(bias_dims[0], 1);
CHECK_EQ_OR_FALSE(bias_dims[1], w_dims[1]); CHECK_EQ_OR_FALSE(bias_dims[1], w_dims_1);
} else if (bias_dims.size() == 1) { } else if (bias_dims.size() == 1) {
CHECK_EQ_OR_FALSE(bias_dims[0], w_dims[1]); CHECK_EQ_OR_FALSE(bias_dims[0], w_dims_1);
} }
} }
CHECK_EQ_OR_FALSE(w_dims.size(), 2UL);
CHECK_GT_OR_FALSE(input_dims.size(), CHECK_GT_OR_FALSE(input_dims.size(),
static_cast<size_t>(param_.in_num_col_dims)); static_cast<size_t>(param_.in_num_col_dims));
param_.in_mat_dims = input_dims.Flatten2D(param_.in_num_col_dims); param_.in_mat_dims = input_dims.Flatten2D(param_.in_num_col_dims);
// CHECK_EQ_OR_FALSE(param_.in_mat_dims[1], w_dims[0]); // CHECK_EQ_OR_FALSE(param_.in_mat_dims[1], w_dims[0]);
...@@ -91,7 +91,9 @@ bool FcOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { ...@@ -91,7 +91,9 @@ bool FcOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
param_.activation_type = op_desc.GetAttr<std::string>("activation_type"); param_.activation_type = op_desc.GetAttr<std::string>("activation_type");
} }
if (op_desc.HasAttr("padding_weights")) { if (op_desc.HasAttr("padding_weights")) {
param_.activation_type = op_desc.GetAttr<bool>("padding_weights"); param_.padding_weights = op_desc.GetAttr<bool>("padding_weights");
} else {
param_.padding_weights = false;
} }
// For Int8 // For Int8
......
...@@ -18,11 +18,14 @@ ...@@ -18,11 +18,14 @@
#include "lite/core/arena/framework.h" #include "lite/core/arena/framework.h"
#include "lite/tests/utils/fill_data.h" #include "lite/tests/utils/fill_data.h"
#include "lite/tests/utils/naive_math_impl.h" #include "lite/tests/utils/naive_math_impl.h"
#ifdef LITE_WITH_X86
#include "lite/backends/x86/parallel.h"
#endif
namespace paddle { namespace paddle {
namespace lite { namespace lite {
void fill_bias_fc(float* out, const float* bias, int num, int channel) { void AddBias(float* out, const float* bias, int num, int channel) {
int remain = channel; int remain = channel;
for (int j = 0; j < num; ++j) { for (int j = 0; j < num; ++j) {
const float* ptr_bias = bias; const float* ptr_bias = bias;
...@@ -33,7 +36,15 @@ void fill_bias_fc(float* out, const float* bias, int num, int channel) { ...@@ -33,7 +36,15 @@ void fill_bias_fc(float* out, const float* bias, int num, int channel) {
} }
} }
DDim compute_out_dim(const DDim& dim_in, const DDim& wdim, int in_num_col_dim) { void Relu(float* out, int num, int channel) {
for (int i = 0; i < num * channel; ++i) {
if (out[i] < 0) {
out[i] = 0;
}
}
}
DDim ComputeOutDim(const DDim& dim_in, const DDim& wdim, int in_num_col_dim) {
std::vector<int64_t> out_dim; std::vector<int64_t> out_dim;
out_dim.resize(in_num_col_dim + 1); out_dim.resize(in_num_col_dim + 1);
auto in_mat_dims = dim_in.Flatten2D(in_num_col_dim); auto in_mat_dims = dim_in.Flatten2D(in_num_col_dim);
...@@ -49,12 +60,16 @@ class FcOPTest : public arena::TestCase { ...@@ -49,12 +60,16 @@ class FcOPTest : public arena::TestCase {
// common attributes for this op. // common attributes for this op.
std::string input_ = "x"; std::string input_ = "x";
std::string weight_ = "w"; std::string weight_ = "w";
std::string weight_padding_ = "w_padding";
std::string bias_ = "b"; std::string bias_ = "b";
std::string out_ = "out"; std::string out_ = "out";
DDim dims_{{1, 128}}; DDim dims_{{1, 128}};
DDim wdims_{{128, 4}}; DDim wdims_{{128, 4}};
DDim wdims_padding_;
DDim bdims_{{4}}; DDim bdims_{{4}};
int in_num_col_dims_{1}; int in_num_col_dims_{1};
bool with_relu_{false};
bool padding_weights_{false};
public: public:
FcOPTest(const Place& place, FcOPTest(const Place& place,
...@@ -62,12 +77,22 @@ class FcOPTest : public arena::TestCase { ...@@ -62,12 +77,22 @@ class FcOPTest : public arena::TestCase {
DDim dim_in, DDim dim_in,
DDim dim_w, DDim dim_w,
DDim dim_b, DDim dim_b,
int in_num_col_dims) int in_num_col_dims,
bool with_relu,
bool padding)
: TestCase(place, alias), : TestCase(place, alias),
dims_(std::move(dim_in)), dims_(std::move(dim_in)),
wdims_(std::move(dim_w)), wdims_(std::move(dim_w)),
bdims_(dim_b), bdims_(dim_b),
in_num_col_dims_(in_num_col_dims) {} in_num_col_dims_(in_num_col_dims),
with_relu_(with_relu) {
#ifdef LITE_WITH_X86
if (padding && wdims_[0] % 128 == 0 && wdims_[1] % 128 == 0) {
padding_weights_ = true;
wdims_padding_ = DDim({wdims_[0] + 4, wdims_[1] + 4});
}
#endif
}
void RunBaseline(Scope* scope) override { void RunBaseline(Scope* scope) override {
auto x = scope->FindTensor(input_); auto x = scope->FindTensor(input_);
...@@ -76,11 +101,9 @@ class FcOPTest : public arena::TestCase { ...@@ -76,11 +101,9 @@ class FcOPTest : public arena::TestCase {
bool flag_bias = b; bool flag_bias = b;
auto out = scope->NewTensor(out_); auto out = scope->NewTensor(out_);
CHECK(out); CHECK(out);
DDim out_dim = compute_out_dim(x->dims(), w->dims(), in_num_col_dims_); DDim out_dim = ComputeOutDim(x->dims(), w->dims(), in_num_col_dims_);
out->Resize(out_dim); out->Resize(out_dim);
LOG(INFO) << "out dims: " << out_dim;
auto x_data = x->data<float>(); auto x_data = x->data<float>();
auto w_data = w->data<float>(); auto w_data = w->data<float>();
const float* b_data = nullptr; const float* b_data = nullptr;
...@@ -94,7 +117,9 @@ class FcOPTest : public arena::TestCase { ...@@ -94,7 +117,9 @@ class FcOPTest : public arena::TestCase {
int k = wdims_[0]; int k = wdims_[0];
int n = wdims_[1]; int n = wdims_[1];
LOG(INFO) << "m: " << m << ", n: " << n << ", k: " << k; LOG(INFO) << "M=" << m << ", N=" << n << ", K=" << k
<< ", bias=" << flag_bias << ", with_relu=" << with_relu_
<< ", padding_weights=" << padding_weights_;
if (m == 1) { if (m == 1) {
basic_gemv(n, basic_gemv(n,
...@@ -126,20 +151,34 @@ class FcOPTest : public arena::TestCase { ...@@ -126,20 +151,34 @@ class FcOPTest : public arena::TestCase {
false, false,
false); false);
if (flag_bias) { if (flag_bias) {
fill_bias_fc(out_data, b_data, m, n); AddBias(out_data, b_data, m, n);
} }
} }
#ifdef LITE_WITH_X86
if (flag_bias && with_relu_) {
Relu(out_data, m, n);
}
#endif
} }
void PrepareOpDesc(cpp::OpDesc* op_desc) { void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("fc"); op_desc->SetType("fc");
op_desc->SetInput("Input", {input_}); op_desc->SetInput("Input", {input_});
op_desc->SetInput("W", {weight_}); if (padding_weights_) {
op_desc->SetInput("W", {weight_padding_});
} else {
op_desc->SetInput("W", {weight_});
}
if (bdims_.production() > 0) { if (bdims_.production() > 0) {
op_desc->SetInput("Bias", {bias_}); op_desc->SetInput("Bias", {bias_});
} }
op_desc->SetOutput("Out", {out_}); op_desc->SetOutput("Out", {out_});
op_desc->SetAttr<int>("in_num_col_dims", in_num_col_dims_); op_desc->SetAttr<int>("in_num_col_dims", in_num_col_dims_);
#ifdef LITE_WITH_X86
std::string activation_type = with_relu_ ? "relu" : "";
op_desc->SetAttr<std::string>("activation_type", activation_type);
op_desc->SetAttr<bool>("padding_weights", padding_weights_);
#endif
} }
void PrepareData() override { void PrepareData() override {
...@@ -155,22 +194,37 @@ class FcOPTest : public arena::TestCase { ...@@ -155,22 +194,37 @@ class FcOPTest : public arena::TestCase {
SetCommonTensor(input_, dims_, din.data()); SetCommonTensor(input_, dims_, din.data());
SetCommonTensor(weight_, wdims_, win.data()); SetCommonTensor(weight_, wdims_, win.data());
if (padding_weights_) {
std::vector<float> win_padding(wdims_padding_.production());
for (int64_t i = 0; i < wdims_[0]; ++i) {
memcpy(&(win_padding[i * wdims_padding_[1]]),
&(win[i * wdims_[1]]),
wdims_[1] * sizeof(float));
}
SetCommonTensor(weight_padding_, wdims_padding_, win_padding.data());
}
if (flag_bias) { if (flag_bias) {
SetCommonTensor(bias_, bdims_, bin.data()); SetCommonTensor(bias_, bdims_, bin.data());
} }
} }
}; };
void test_fc(Place place, float abs_error) { void TestFCMain(Place place,
float abs_error,
bool with_relu = false,
bool padding = false) {
for (auto& m : {1, 3, 16}) { for (auto& m : {1, 3, 16}) {
for (auto& n : {1, 4, 16, 128, 256, 1024}) { for (auto& n : {1, 4, 16, 128, 256, 1024}) {
for (auto& k : {1, 16, 128, 1024}) { for (auto& k : {1, 16, 128, 1024}) {
for (auto& bflag : {false, true}) { for (auto& bflag : {false, true}) {
if (!bflag && with_relu) {
continue;
}
DDim dim_in{{m, k}}; DDim dim_in{{m, k}};
DDim wdim{{k, n}}; DDim wdim{{k, n}};
DDim bdim{{bflag ? n : 0}}; DDim bdim{{bflag ? n : 0}};
std::unique_ptr<arena::TestCase> tester( std::unique_ptr<arena::TestCase> tester(new FcOPTest(
new FcOPTest(place, "def", dim_in, wdim, bdim, 1)); place, "def", dim_in, wdim, bdim, 1, with_relu, padding));
#ifdef LITE_WITH_ARM #ifdef LITE_WITH_ARM
if (place == TARGET(kARM)) { if (place == TARGET(kARM)) {
auto& ctx = tester->context()->As<ARMContext>(); auto& ctx = tester->context()->As<ARMContext>();
...@@ -195,13 +249,25 @@ TEST(FcOP, precision) { ...@@ -195,13 +249,25 @@ TEST(FcOP, precision) {
#if defined(LITE_WITH_NPU) #if defined(LITE_WITH_NPU)
place = TARGET(kNPU); place = TARGET(kNPU);
abs_error = 2e-1; // Using fp16 in NPU abs_error = 2e-1; // Using fp16 in NPU
#elif defined(LITE_WITH_X86)
place = TARGET(kX86);
abs_error = 1e-4;
#elif defined(LITE_WITH_ARM) #elif defined(LITE_WITH_ARM)
place = TARGET(kARM); place = TARGET(kARM);
#else #else
return; return;
#endif #endif
test_fc(place, abs_error); TestFCMain(place, abs_error);
} }
#ifdef LITE_WITH_X86
TEST(FcOP, padding_and_parallel) {
Place place(TARGET(kX86));
float abs_error = 1e-4;
x86::SetNumThreads(4);
TestFCMain(place, abs_error, true, true);
}
#endif
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册