提交 6e5a7c6b 编写于 作者: T tensor-tang

refine fc prepare and test

上级 5f833603
......@@ -23,10 +23,6 @@ namespace kernels {
namespace arm {
void FcCompute::PrepareForRun() {
// TODO(TJ): transpose weight
}
void FcCompute::Run() {
auto& param = this->Param<operators::FcParam>();
auto x_dims = param.input->dims();
auto w_dims = param.w->dims();
......@@ -35,29 +31,52 @@ void FcCompute::Run() {
CHECK_EQ(w_dims.size(), 2UL);
CHECK_EQ(param.output->dims().size(), 2UL);
m_ = x_dims.Slice(0, param.in_num_col_dims).production();
k_ = x_dims.Slice(param.in_num_col_dims, x_dims.size()).production();
n_ = w_dims[1];
CHECK_EQ(k_, static_cast<int>(w_dims[0]));
if (m_ == 1) {
if (!transed_weight_) {
transed_weight_ = new Tensor;
}
transed_weight_->Resize({n_, k_});
const auto* w_data = param.w->data<float>();
auto* t_data = transed_weight_->mutable_data<float>();
int i = 0;
for (int nn = 0; nn < n_; ++nn) {
for (int kk = 0; kk < k_; ++kk) {
t_data[i++] = w_data[kk * n_ + nn];
}
}
}
}
void FcCompute::Run() {
auto& param = this->Param<operators::FcParam>();
const auto* i_data = param.input->data<float>();
const auto* w_data = param.w->data<float>();
const auto* b_data = param.bias ? param.bias->data<float>() : nullptr;
auto* o_data = param.output->mutable_data<float>();
int x_h = x_dims.Slice(0, param.in_num_col_dims).production();
int x_w = x_dims.Slice(param.in_num_col_dims, x_dims.size()).production();
int n = w_dims[1];
CHECK_EQ(x_w, static_cast<int>(w_dims[0]));
auto& ctx = this->ctx_->template As<ARMContext>();
if (x_h > 1) {
if (m_ > 1) {
float* packed_in = static_cast<float*>(ctx.workspace_data<float>()) +
ctx.l2_cache_size() / sizeof(float);
lite::arm::math::prepackA(packed_in, i_data, x_w, 0, x_h, 0, x_w, false,
&ctx);
lite::arm::math::sgemm_prepack(packed_in, w_data, b_data, o_data, x_h, n,
x_w, false, false, false, &ctx);
lite::arm::math::prepackA(packed_in, i_data, k_, 0, m_, 0, k_, false, &ctx);
lite::arm::math::sgemm_prepack(packed_in, w_data, b_data, o_data, m_, n_,
k_, false, false, false, &ctx);
if (param.bias) {
CHECK_EQ(param.bias->numel(), n);
lite::arm::math::fill_bias_fc(o_data, b_data, x_h, n);
CHECK_EQ(param.bias->numel(), n_);
lite::arm::math::fill_bias_fc(o_data, b_data, m_, n_);
}
} else {
lite::arm::math::sgemv(w_data, i_data, o_data, false, n, x_w,
CHECK(transed_weight_);
const auto* t_data = transed_weight_->data<float>();
lite::arm::math::sgemv(t_data, i_data, o_data, false, n_, k_,
b_data != nullptr, b_data, false);
}
}
......
......@@ -29,7 +29,15 @@ class FcCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
void Run() override;
virtual ~FcCompute() = default;
~FcCompute() override {
if (transed_weight_) {
delete transed_weight_;
}
};
private:
lite::Tensor* transed_weight_{nullptr};
int m_, n_, k_;
};
} // namespace arm
......
......@@ -14,6 +14,11 @@
#include "paddle/fluid/lite/kernels/arm/fc_compute.h"
#include <gtest/gtest.h>
#include <algorithm>
#include <iostream>
#include <memory>
#include <random>
#include <utility>
#include <vector>
#include "paddle/fluid/lite/arm/math/funcs.h"
#include "paddle/fluid/lite/core/op_registry.h"
......@@ -23,6 +28,17 @@ namespace lite {
namespace kernels {
namespace arm {
template <typename T>
void FillData(T* a, const int n, const T lower = static_cast<T>(-2.f),
const T upper = static_cast<T>(2.f)) {
static unsigned int seed = 100;
std::mt19937 rng(seed++);
std::uniform_real_distribution<double> uniform_dist(0, 1);
for (int i = 0; i < n; ++i) {
a[i] = static_cast<T>(uniform_dist(rng) * (upper - lower) + lower);
}
}
TEST(fc_arm, retrive_op) {
auto fc =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>("fc");
......@@ -37,43 +53,46 @@ TEST(fc_arm, init) {
}
TEST(fc_arm, compare_test) {
using T = float;
for (int m : {1, 2, 3, 4}) {
for (int n : {1, 2, 3, 4}) {
for (int k : {1, 2, 3, 4}) {
for (bool with_bias : {true, false}) {
VLOG(3) << "m: " << m << ", n: " << n << ", k: " << k
<< (with_bias ? ", with bias" : "");
lite::Tensor x, w, b, out, ref;
constexpr int batch_size = 2;
x.Resize({batch_size, 3});
w.Resize({3, 4});
b.Resize({1, 4});
out.Resize({batch_size, 4});
ref.Resize({batch_size, 4});
auto x_data = x.mutable_data<float>();
auto w_data = w.mutable_data<float>();
auto b_data = b.mutable_data<float>();
auto out_data = out.mutable_data<float>();
auto ref_data = ref.mutable_data<float>();
x.Resize({m, k});
w.Resize({k, n});
b.Resize({1, n});
out.Resize({m, n});
ref.Resize({m, n});
for (int64_t i = 0; i < x.dims().product(); i++) {
x_data[i] = static_cast<float>(i);
}
for (int64_t i = 0; i < w.dims().product(); i++) {
w_data[i] = static_cast<float>(i);
}
for (int64_t i = 0; i < b.dims().product(); i++) {
b_data[i] = static_cast<float>(i);
}
auto* x_data = x.mutable_data<T>();
auto* w_data = w.mutable_data<T>();
auto* b_data = with_bias ? b.mutable_data<T>() : nullptr;
auto* out_data = out.mutable_data<T>();
auto* ref_data = ref.mutable_data<T>();
lite::arm::math::fc_compute_eigen(x_data, batch_size, 3, //
w_data, 3, 4, //
b_data, ref_data);
FillData<T>(x_data, x.dims().production());
FillData<T>(w_data, w.dims().production());
FillData<T>(out_data, out.dims().production(), 0, 0);
FillData<T>(ref_data, ref.dims().production(), 0, 0);
if (with_bias) {
FillData<T>(b_data, b.dims().production());
}
// fc compute kernel
FcCompute fc;
operators::FcParam param;
param.in_num_col_dims = 1;
param.input = &x;
param.w = &w;
param.bias = &b;
param.bias = with_bias ? &b : nullptr;
param.output = &out;
param.in_num_col_dims = 1;
param.in_mat_dims = x.dims();
DeviceInfo::Init();
......@@ -81,55 +100,53 @@ TEST(fc_arm, compare_test) {
ctx->As<ARMContext>();
fc.SetParam(param);
fc.SetContext(std::move(ctx));
fc.PrepareForRun();
fc.Run();
VLOG(3) << "output vs ref";
for (int i = 0; i < out.dims().product(); i++) {
VLOG(3) << out_data[i] << " vs " << ref_data[i];
lite::arm::math::fc_compute_eigen(x_data, m, k, w_data, k, n, b_data,
ref_data);
for (int i = 0; i < out.dims().production(); i++) {
EXPECT_NEAR(out_data[i], ref_data[i], 1e-3);
}
}
}
}
for (int i = 0; i < out.dims().product(); ++i) {
EXPECT_NEAR(out_data[i], ref_data[i], 1e-5);
}
}
TEST(fc_arm, num_col_dims) {
FcCompute fc;
operators::FcParam param;
using T = float;
lite::Tensor x;
lite::Tensor w;
lite::Tensor bias;
lite::Tensor output;
for (bool with_bias : {true, false}) {
lite::Tensor x, w, b, out, ref;
x.Resize({1, 2, 3});
w.Resize({3, 4});
bias.Resize({1, 4});
output.Resize({2, 4});
b.Resize({1, 4});
out.Resize({2, 4});
ref.Resize({2, 4});
auto* x_data = x.mutable_data<float>();
auto* w_data = w.mutable_data<float>();
auto* bias_data = bias.mutable_data<float>();
auto* output_data = output.mutable_data<float>();
auto* b_data = with_bias ? b.mutable_data<T>() : nullptr;
for (int64_t i = 0; i < x.dims().product(); i++) {
x_data[i] = static_cast<float>(i);
}
for (int64_t i = 0; i < w.dims().product(); i++) {
w_data[i] = static_cast<float>(i);
}
for (int64_t i = 0; i < bias.dims().product(); i++) {
bias_data[i] = static_cast<float>(i);
}
for (int64_t i = 0; i < output.dims().product(); i++) {
output_data[i] = static_cast<float>(i);
}
auto* out_data = out.mutable_data<T>();
auto* ref_data = ref.mutable_data<T>();
param.in_num_col_dims = 2;
FillData<T>(x_data, x.dims().production());
FillData<T>(w_data, w.dims().production());
FillData<T>(out_data, out.dims().production(), 0, 0);
FillData<T>(ref_data, ref.dims().production(), 0, 0);
if (with_bias) {
FillData<T>(b_data, b.dims().production());
}
FcCompute fc;
operators::FcParam param;
param.input = &x;
param.w = &w;
param.bias = &bias;
param.output = &output;
param.bias = with_bias ? &b : nullptr;
param.output = &out;
param.in_num_col_dims = 2;
param.in_mat_dims = x.dims();
std::unique_ptr<KernelContext> ctx(new KernelContext);
......@@ -138,7 +155,15 @@ TEST(fc_arm, num_col_dims) {
fc.SetParam(param);
fc.SetContext(std::move(ctx));
fc.PrepareForRun();
fc.Run();
lite::arm::math::fc_compute_eigen(x_data, 2, 3, w_data, 3, 4, b_data,
ref_data);
for (int i = 0; i < out.dims().production(); i++) {
EXPECT_NEAR(out_data[i], ref_data[i], 1e-3);
}
}
}
} // namespace arm
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册