提交 49f67ba7 编写于 作者: L liu zhengxi 提交者: GitHub

Add fc op on lite x86 platform (#2568)

上级 833aa8a7
......@@ -20,14 +20,13 @@ add_kernel(stack_compute_x86 X86 basic SRCS stack_compute.cc DEPS ${lite_kernel_
add_kernel(dropout_compute_x86 X86 basic SRCS dropout_compute.cc DEPS ${lite_kernel_deps})
add_kernel(transpose_compute_x86 X86 basic SRCS transpose_compute.cc DEPS ${lite_kernel_deps} math_function)
add_kernel(layer_norm_compute_x86 X86 basic SRCS layer_norm_compute.cc DEPS ${lite_kernel_deps} jit_kernel_helper)
# add_kernel(fc_compute_x86 X86 basic SRCS fc_compute.cc DEPS ${lite_kernel_deps})
add_kernel(fc_compute_x86 X86 basic SRCS fc_compute.cc DEPS ${lite_kernel_deps} jit_kernel_helper)
# lite_cc_library(batch_norm_compute_x86 SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps})
# lite_cc_library(uniform_random_compute_x86 SRCS uniform_random_compute.cc DEPS ${lite_kernel_deps} )
add_kernel(gru_compute_x86 X86 basic SRCS gru_compute.cc DEPS ${lite_kernel_deps} blas math_function sequence2batch gru_compute)
#add_kernel(gru_compute_x86 X86 basic SRCS gru_compute.cc DEPS ${lite_kernel_deps})
add_kernel(sequence_expand_as_compute_x86 X86 basic SRCS sequence_expand_as_compute.cc DEPS ${lite_kernel_deps})
# lite_cc_test(test_fc_compute_x86 SRCS fc_compute_test.cc DEPS fc_compute_x86)
# lite_cc_test(test_conv2d_compute_x86 SRCS conv_compute_test.cc DEPS conv_compute_x86)
add_kernel(gather_compute_x86 X86 basic SRCS gather_compute.cc DEPS ${lite_kernel_deps} fluid_data_type)
# lite_cc_test(test_scale_compute_x86 SRCS scale_compute_test.cc DEPS scale_compute_x86)
......@@ -100,3 +99,4 @@ lite_cc_test(test_sequence_concat_compute_x86 SRCS sequence_concat_compute_test.
lite_cc_test(test_var_conv_2d_compute_x86 SRCS var_conv_2d_compute_test.cc DEPS var_conv_2d_compute_x86)
#lite_cc_test(test_attention_padding_mask_compute_x86 SRCS attention_padding_mask_compute_test.cc DEPS attention_padding_mask_compute_x86)
lite_cc_test(test_sequence_arithmetic_compute_x86 SRCS sequence_arithmetic_compute_test.cc DEPS sequence_arithmetic_compute_x86)
lite_cc_test(test_fc_compute_x86 SRCS fc_compute_test.cc DEPS fc_compute_x86)
......@@ -13,66 +13,131 @@
// limitations under the License.
#pragma once
#include <Eigen/Core>
#include <vector>
#include "lite/backends/x86/jit/helper.h"
#include "lite/backends/x86/jit/kernel_base.h"
#include "lite/backends/x86/jit/kernels.h"
#include "lite/backends/x86/math/blas.h"
#include "lite/core/kernel.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
#include "lite/operators/fc_op.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T>
void fc_compute_eigen(const T* x,
int x_h,
int x_w, //
const T* w,
int w_h,
int w_w, //
const T* b, //
T* out) {
using matrix_t =
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
Eigen::Map<const matrix_t> X(x, x_h, x_w);
Eigen::Map<const matrix_t> W(w, w_h, w_w);
Eigen::Map<matrix_t> Out(out, x_h, w_w);
Out = X * W;
inline void FCOutputSize(const lite::DDim& in_dims,
const lite::DDim& w_dims,
std::vector<int64_t>& out_dims, // NOLINT
int in_num_col_dims,
bool padding_weights) {
auto w_dims1 = padding_weights ? w_dims[1] - 4 : w_dims[1];
if (b) {
Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, 1>> B(b, w_w);
Out = Out.array().rowwise() + B.transpose().array();
out_dims.reserve(static_cast<size_t>(in_num_col_dims + 1));
for (int i = 0; i < in_num_col_dims; ++i) {
out_dims.push_back(in_dims[i]);
}
out_dims.push_back(w_dims1);
}
template <typename T>
void fc_compute_naive(const T* x,
int x_h,
int x_w, //
const T* w,
int w_h,
int w_w, //
const T* b, //
T* out) {
CHECK_EQ(x_w, w_h);
// out shape: (x_h, w_w)
memset(out, 0, x_h * w_w * sizeof(T));
for (int i = 0; i < x_h; i++) {
for (int j = 0; j < w_w; j++) {
T tmp = static_cast<T>(0);
for (int k = 0; k < x_w; k++) {
tmp += x[i * x_w + k] * w[k * w_w + j];
template <lite::TargetType Target, typename T>
class FCFunctor {
public:
void operator()(const lite::X86Context& context,
const int M,
const int N,
const int K,
const T* X,
const T* W,
T* Y,
const T* B = nullptr,
bool relu = false,
bool padding_weights = false) {
auto blas = lite::x86::math::GetBlas<lite::TargetType::kX86, T>(context);
lite::Tensor Y1;
T* Y1_data = nullptr;
if (N % 128 == 0 && K % 128 == 0) {
const int NN = N + 4;
const int KK = K + 4;
lite::Tensor X1;
X1.Resize({M * KK});
Y1.Resize({M * (N + 4)});
T* X1_data = X1.mutable_data<T>();
Y1_data = Y1.mutable_data<T>();
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int i = 0; i < M; i++) {
memcpy(X1_data + i * KK, X + i * K, K * sizeof(X[0]));
}
lite::Tensor W1;
T* W1_data = nullptr;
if (!padding_weights) {
W1.Resize({(K + 4) * (N + 4)});
W1_data = W1.mutable_data<T>();
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int i = 0; i < K; i++) {
memcpy(W1_data + i * NN, W + i * N, N * sizeof(W[0]));
}
}
blas.GEMM(false,
false,
M,
N,
K,
static_cast<T>(1.0),
X1_data,
KK,
(padding_weights ? W : W1_data),
NN,
static_cast<T>(0.0),
Y1_data,
NN);
} else {
blas.MatMul(M, N, K, X, W, Y);
}
if (B == NULL) {
if (N % 128 == 0 && K % 128 == 0) {
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int i = 0; i < M; i++) {
memcpy(Y + i * N, Y1_data + i * (N + 4), N * sizeof(Y[0]));
}
}
return;
}
if (relu) {
auto compute =
paddle::lite::jit::KernelFuncs<paddle::lite::jit::VAddReluTuple<T>,
lite::fluid::CPUPlace>::Cache()
.At(N);
for (int i = 0; i < M; i++) {
T* dst = Y + i * N;
T* src = (N % 128 == 0 && K % 128 == 0) ? Y1_data + i * (N + 4) : dst;
compute(B, src, dst, N);
}
} else {
auto compute =
paddle::lite::jit::KernelFuncs<paddle::lite::jit::VAddTuple<T>,
lite::fluid::CPUPlace>::Cache()
.At(N);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int i = 0; i < M; i++) {
T* dst = Y + i * N;
T* src = (N % 128 == 0 && K % 128 == 0) ? Y1_data + i * (N + 4) : dst;
compute(B, src, dst, N);
}
out[i * w_w + j] = tmp + b[j];
}
}
}
};
template <typename T>
class FcCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
......@@ -81,20 +146,43 @@ class FcCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
void Run() override {
auto& param = *param_.get_mutable<param_t>();
CHECK_GE(param.input->dims().size(), 2UL);
CHECK_EQ(param.output->dims().size(), 2UL);
auto* input = param.input;
auto* w = param.w;
auto* bias = param.bias;
auto* output = param.output;
int in_num_col_dims = param.in_num_col_dims;
bool with_relu = (param.activation_type == "relu") ? true : false;
auto w_dims = w->dims();
bool padding_weights = param.padding_weights;
std::vector<int64_t> output_dims;
FCOutputSize(
input->dims(), w_dims, output_dims, in_num_col_dims, padding_weights);
output->Resize(output_dims);
output->set_lod(input->lod());
auto out_dims = output->dims();
auto w_dims0 = padding_weights ? w_dims[0] - 4 : w_dims[0];
auto w_dims1 = padding_weights ? w_dims[1] - 4 : w_dims[1];
int M = out_dims.production() / w_dims1;
const T* input_data = input->data<T>();
const T* w_data = w->data<T>();
T* output_data = output->mutable_data<T>();
fc_compute_eigen(
param.input->data<T>(), // x
param.input->dims().Slice(0, param.in_num_col_dims).production(),
param.input->dims()
.Slice(param.in_num_col_dims, param.input->dims().size())
.production(),
param.w->data<T>(), // w
param.w->dims()[0], // w_h
param.w->dims()[1], // w_w
param.bias->data<T>(), // b
param.output->mutable_data<T>());
auto& context = ctx_->As<X86Context>();
FCFunctor<lite::TargetType::kX86, T> fc;
fc(context,
M,
w_dims1,
w_dims0,
input_data,
w_data,
output_data,
bias ? bias->data<T>() : NULL,
with_relu,
padding_weights);
}
virtual ~FcCompute() = default;
......
......@@ -11,8 +11,11 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/fc_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
......@@ -43,7 +46,7 @@ TEST(fc_x86, run_test) {
w.Resize(lite::DDim(w_shape));
std::vector<int64_t> b_shape{1, 4};
b.Resize(lite::DDim(b_shape));
std::vector<int64_t> out_shape{1, 4};
std::vector<int64_t> out_shape{batch_size, 4};
out.Resize(lite::DDim(out_shape));
auto x_data = x.mutable_data<float>();
......@@ -55,16 +58,12 @@ TEST(fc_x86, run_test) {
x_data[i] = static_cast<float>(i);
}
for (int64_t i = 0; i < w.dims().production(); i++) {
w_data[i] = static_cast<float>(i);
w_data[i] = static_cast<float>(2);
}
for (int64_t i = 0; i < b.dims().production(); i++) {
b_data[i] = static_cast<float>(i);
b_data[i] = static_cast<float>(2);
}
/* lite::x86::math::fc_compute_eigen(x_data, batch_size, 3, //
w_data, 3, 4, //
b_data, ref_data); */
// FcCompute fc;
FcCompute<float> fc;
operators::FcParam param;
......@@ -75,21 +74,17 @@ TEST(fc_x86, run_test) {
param.bias = &b;
param.output = &out;
param.in_mat_dims = x.dims();
param.activation_type = "relu";
// std::unique_ptr<KernelContext> ctx(new KernelContext);
// ctx->As<X86Context>();
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
fc.SetParam(param);
// fc.SetContext(std::move(ctx));
fc.SetContext(std::move(ctx));
fc.Run();
VLOG(3) << "output vs ref";
std::vector<float> ref_data({8, 8, 8, 8, 26, 26, 26, 26});
for (int i = 0; i < out.dims().production(); i++) {
VLOG(3) << out_data[i];
EXPECT_NEAR(out_data[i], ref_data[i], 1e-5);
}
/* for (int i = 0; i < out.dims().production(); ++i) {
EXPECT_NEAR(out_data[i], ref_data[i], 1e-5);
}*/
}
} // namespace x86
......
......@@ -90,6 +90,9 @@ bool FcOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
if (op_desc.HasAttr("activation_type")) {
param_.activation_type = op_desc.GetAttr<std::string>("activation_type");
}
if (op_desc.HasAttr("padding_weights")) {
param_.activation_type = op_desc.GetAttr<bool>("padding_weights");
}
// For Int8
if (op_desc.HasAttr("enable_int8")) {
......
......@@ -86,6 +86,7 @@ struct FcParam {
lite::DDim in_mat_dims;
int in_num_col_dims{1};
std::string activation_type{""};
bool padding_weights{false};
// for int8
WITH_INT8_CONFIG
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册