提交 926ab88e 编写于 作者: T tensor-tang

init mul

上级 66fafb56
...@@ -22,6 +22,10 @@ namespace lite { ...@@ -22,6 +22,10 @@ namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace arm {
void FcCompute::PrepareForRun() {
// TODO(TJ): transpose weight
}
void FcCompute::Run() { void FcCompute::Run() {
auto& param = this->Param<operators::FcParam>(); auto& param = this->Param<operators::FcParam>();
auto x_dims = param.input->dims(); auto x_dims = param.input->dims();
...@@ -54,9 +58,8 @@ void FcCompute::Run() { ...@@ -54,9 +58,8 @@ void FcCompute::Run() {
lite::arm::math::fill_bias_fc(o_data, b_data, x_h, n); lite::arm::math::fill_bias_fc(o_data, b_data, x_h, n);
} }
} else { } else {
// use sgemmv lite::arm::math::sgemv(w_data, i_data, o_data, false, n, x_w,
// sgemv((const float*)weights, (const float*)din, (float*)dout, b_data != nullptr, b_data, false);
// false, n, x_w, _param->_flag_bias, (float*)bias, false);
} }
} }
......
...@@ -25,6 +25,8 @@ class FcCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> { ...@@ -25,6 +25,8 @@ class FcCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public: public:
using param_t = operators::FcParam; using param_t = operators::FcParam;
void PrepareForRun() override;
void Run() override; void Run() override;
TargetType target() const override; TargetType target() const override;
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/types.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
template <typename T>
void mul_compute_eigen(const T* x, int x_h, int x_w, const T* y, int y_h,
int y_w, T* out) {
using matrix_t =
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
Eigen::Map<const matrix_t> X(x, x_h, x_w);
Eigen::Map<const matrix_t> Y(y, y_h, y_w);
Eigen::Map<matrix_t> Out(out, x_h, y_w);
Out = X * Y;
}
class MulCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::MulParam;
void Run() override {
auto& param = Param<operators::MulParam>();
core::dim2 x_shape(
{static_cast<int>(
param.x->dims().Slice(0, param.x_num_col_dims).production()),
static_cast<int>(
param.x->dims()
.Slice(param.x_num_col_dims, param.x->dims().size())
.production())});
core::dim2 y_shape(
{static_cast<int>(
param.y->dims().Slice(0, param.y_num_col_dims).production()),
static_cast<int>(
param.y->dims()
.Slice(param.y_num_col_dims, param.y->dims().size())
.production())});
mul_compute_eigen(param.x->data<float>(), x_shape.x, x_shape.y, //
param.y->data<float>(), y_shape.x, y_shape.y, //
param.output->mutable_data<float>());
}
virtual ~MulCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(mul, kARM, kFloat, kNCHW,
paddle::lite::kernels::arm::MulCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/lite/arm/math/funcs.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/kernels/arm/fc_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
TEST(fc_arm, retrive_op) {
auto fc =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>("fc");
ASSERT_FALSE(fc.empty());
ASSERT_TRUE(fc.front());
}
TEST(fc_arm, init) {
FcCompute fc;
ASSERT_EQ(fc.precision(), PRECISION(kFloat));
ASSERT_EQ(fc.target(), TARGET(kARM));
}
TEST(fc_arm, compare_test) {
lite::Tensor x, w, b, out, ref;
constexpr int batch_size = 2;
x.Resize({batch_size, 3});
w.Resize({3, 4});
b.Resize({1, 4});
out.Resize({batch_size, 4});
ref.Resize({batch_size, 4});
auto x_data = x.mutable_data<float>();
auto w_data = w.mutable_data<float>();
auto b_data = b.mutable_data<float>();
auto out_data = out.mutable_data<float>();
auto ref_data = ref.mutable_data<float>();
for (int64_t i = 0; i < x.dims().product(); i++) {
x_data[i] = static_cast<float>(i);
}
for (int64_t i = 0; i < w.dims().product(); i++) {
w_data[i] = static_cast<float>(i);
}
for (int64_t i = 0; i < b.dims().product(); i++) {
b_data[i] = static_cast<float>(i);
}
lite::arm::math::fc_compute_eigen(x_data, batch_size, 3, //
w_data, 3, 4, //
b_data, ref_data);
// fc compute kernel
FcCompute fc;
operators::FcParam param;
param.in_num_col_dims = 1;
param.input = &x;
param.w = &w;
param.bias = &b;
param.output = &out;
param.in_mat_dims = x.dims();
DeviceInfo::Init();
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<ARMContext>();
fc.SetParam(param);
fc.SetContext(std::move(ctx));
fc.Run();
VLOG(3) << "output vs ref";
for (int i = 0; i < out.dims().product(); i++) {
VLOG(3) << out_data[i] << " vs " << ref_data[i];
}
for (int i = 0; i < out.dims().product(); ++i) {
EXPECT_NEAR(out_data[i], ref_data[i], 1e-5);
}
}
TEST(fc_arm, num_col_dims) {
FcCompute fc;
operators::FcParam param;
lite::Tensor x;
lite::Tensor w;
lite::Tensor bias;
lite::Tensor output;
x.Resize({1, 2, 3});
w.Resize({3, 4});
bias.Resize({1, 4});
output.Resize({2, 4});
auto* x_data = x.mutable_data<float>();
auto* w_data = w.mutable_data<float>();
auto* bias_data = bias.mutable_data<float>();
auto* output_data = output.mutable_data<float>();
for (int64_t i = 0; i < x.dims().product(); i++) {
x_data[i] = static_cast<float>(i);
}
for (int64_t i = 0; i < w.dims().product(); i++) {
w_data[i] = static_cast<float>(i);
}
for (int64_t i = 0; i < bias.dims().product(); i++) {
bias_data[i] = static_cast<float>(i);
}
for (int64_t i = 0; i < output.dims().product(); i++) {
output_data[i] = static_cast<float>(i);
}
param.in_num_col_dims = 2;
param.input = &x;
param.w = &w;
param.bias = &bias;
param.output = &output;
param.in_mat_dims = x.dims();
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<ARMContext>();
DeviceInfo::Init();
fc.SetParam(param);
fc.SetContext(std::move(ctx));
fc.Run();
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(fc, kARM, kFloat, kNCHW, def);
...@@ -57,6 +57,7 @@ struct FcParam { ...@@ -57,6 +57,7 @@ struct FcParam {
lite::Tensor* output{}; lite::Tensor* output{};
lite::DDim in_mat_dims; lite::DDim in_mat_dims;
int in_num_col_dims{1}; int in_num_col_dims{1};
bool weight_transposed{false};
}; };
struct ReluParam { struct ReluParam {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册