提交 cf027d49 编写于 作者: T tensor-tang

use arm kernel

上级 926ab88e
......@@ -6,7 +6,7 @@ message(STATUS "compile with lite ARM kernels")
cc_library(fc_compute_arm SRCS fc_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(relu_compute_arm SRCS relu_compute.cc DEPS ${lite_kernel_deps})
cc_library(mul_compute_arm SRCS mul_compute.cc DEPS ${lite_kernel_deps} eigen3)
cc_library(mul_compute_arm SRCS mul_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(scale_compute_arm SRCS scale_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(softmax_compute_arm SRCS softmax_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(conv_compute_arm SRCS conv_compute.cc DEPS ${lite_kernel_deps} math_arm)
......@@ -19,6 +19,7 @@ lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_
lite_cc_test(test_conv_compute_arm SRCS conv_compute_test.cc DEPS conv_compute_arm)
lite_cc_test(test_elementwise_add_compute_arm SRCS elementwise_add_compute_test.cc DEPS elementwise_add_compute_arm)
lite_cc_test(test_pool_compute_arm SRCS pool_compute_test.cc DEPS pool_compute_arm)
lite_cc_test(test_mul_compute_arm SRCS mul_compute_test.cc DEPS mul_compute_arm)
set(arm_kernels
fc_compute_arm
......
......@@ -63,10 +63,6 @@ void FcCompute::Run() {
}
}
TargetType FcCompute::target() const { return TARGET(kARM); }
PrecisionType FcCompute::precision() const { return PRECISION(kFloat); }
} // namespace arm
} // namespace kernels
} // namespace lite
......
......@@ -29,9 +29,6 @@ class FcCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
void Run() override;
TargetType target() const override;
PrecisionType precision() const override;
virtual ~FcCompute() = default;
};
......
......@@ -12,57 +12,53 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <Eigen/Core>
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/kernels/arm/mul_compute.h"
#include "paddle/fluid/lite/arm/math/funcs.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/types.h"
#include "paddle/fluid/lite/core/type_system.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
template <typename T>
void mul_compute_eigen(const T* x, int x_h, int x_w, const T* y, int y_h,
int y_w, T* out) {
using matrix_t =
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
void MulCompute::PrepareForRun() {
// TODO(TJ): transpose x or y if necessary
}
Eigen::Map<const matrix_t> X(x, x_h, x_w);
Eigen::Map<const matrix_t> Y(y, y_h, y_w);
Eigen::Map<matrix_t> Out(out, x_h, y_w);
void MulCompute::Run() {
auto& param = Param<param_t>();
Out = X * Y;
}
const auto* x_data = param.x->data<float>();
const auto* y_data = param.y->data<float>();
auto* o_data = param.output->mutable_data<float>();
class MulCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::MulParam;
int x_h = static_cast<int>(
param.x->dims().Slice(0, param.x_num_col_dims).production());
int x_w =
static_cast<int>(param.x->dims()
.Slice(param.x_num_col_dims, param.x->dims().size())
.production());
int y_h = static_cast<int>(
param.y->dims().Slice(0, param.y_num_col_dims).production());
int y_w =
static_cast<int>(param.y->dims()
.Slice(param.y_num_col_dims, param.y->dims().size())
.production());
void Run() override {
auto& param = Param<operators::MulParam>();
core::dim2 x_shape(
{static_cast<int>(
param.x->dims().Slice(0, param.x_num_col_dims).production()),
static_cast<int>(
param.x->dims()
.Slice(param.x_num_col_dims, param.x->dims().size())
.production())});
core::dim2 y_shape(
{static_cast<int>(
param.y->dims().Slice(0, param.y_num_col_dims).production()),
static_cast<int>(
param.y->dims()
.Slice(param.y_num_col_dims, param.y->dims().size())
.production())});
CHECK_EQ(x_w, y_h) << "x_w must be equal with y_h";
if (y_w == 1 || x_h == 1) {
lite::arm::math::sgemv(x_data, y_data, o_data, false, x_h, x_w, false,
nullptr, false);
mul_compute_eigen(param.x->data<float>(), x_shape.x, x_shape.y, //
param.y->data<float>(), y_shape.x, y_shape.y, //
param.output->mutable_data<float>());
}
} else {
constexpr bool is_tranposed_y = false;
auto& ctx = this->ctx_->template As<ARMContext>();
virtual ~MulCompute() = default;
};
lite::arm::math::sgemm_prepack(x_data, y_data, nullptr, o_data, x_h, y_w,
x_w, false, false, is_tranposed_y, &ctx);
}
}
} // namespace arm
} // namespace kernels
......
......@@ -22,44 +22,13 @@ namespace lite {
namespace kernels {
namespace arm {
template <typename T>
void mul_compute_eigen(const T* x, int x_h, int x_w, const T* y, int y_h,
int y_w, T* out) {
using matrix_t =
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
Eigen::Map<const matrix_t> X(x, x_h, x_w);
Eigen::Map<const matrix_t> Y(y, y_h, y_w);
Eigen::Map<matrix_t> Out(out, x_h, y_w);
Out = X * Y;
}
class MulCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::MulParam;
void Run() override {
auto& param = Param<operators::MulParam>();
core::dim2 x_shape(
{static_cast<int>(
param.x->dims().Slice(0, param.x_num_col_dims).production()),
static_cast<int>(
param.x->dims()
.Slice(param.x_num_col_dims, param.x->dims().size())
.production())});
core::dim2 y_shape(
{static_cast<int>(
param.y->dims().Slice(0, param.y_num_col_dims).production()),
static_cast<int>(
param.y->dims()
.Slice(param.y_num_col_dims, param.y->dims().size())
.production())});
void PrepareForRun() override;
mul_compute_eigen(param.x->data<float>(), x_shape.x, x_shape.y, //
param.y->data<float>(), y_shape.x, y_shape.y, //
param.output->mutable_data<float>());
}
void Run() override;
virtual ~MulCompute() = default;
};
......@@ -68,10 +37,3 @@ class MulCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(mul, kARM, kFloat, kNCHW,
paddle::lite::kernels::arm::MulCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
......@@ -12,31 +12,33 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/arm/mul_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "paddle/fluid/lite/arm/math/funcs.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/kernels/arm/fc_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
TEST(fc_arm, retrive_op) {
auto fc =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>("fc");
ASSERT_FALSE(fc.empty());
ASSERT_TRUE(fc.front());
TEST(mul_arm, retrive_op) {
auto mul =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>("mul");
ASSERT_FALSE(mul.empty());
ASSERT_TRUE(mul.front());
}
TEST(fc_arm, init) {
FcCompute fc;
ASSERT_EQ(fc.precision(), PRECISION(kFloat));
ASSERT_EQ(fc.target(), TARGET(kARM));
TEST(mul_arm, init) {
FcCompute mul;
ASSERT_EQ(mul.precision(), PRECISION(kFloat));
ASSERT_EQ(mul.target(), TARGET(kARM));
}
TEST(fc_arm, compare_test) {
TEST(mul_arm, compare_test) {
lite::Tensor x, w, b, out, ref;
constexpr int batch_size = 2;
x.Resize({batch_size, 3});
......@@ -65,8 +67,8 @@ TEST(fc_arm, compare_test) {
w_data, 3, 4, //
b_data, ref_data);
// fc compute kernel
FcCompute fc;
// mul compute kernel
FcCompute mul;
operators::FcParam param;
param.in_num_col_dims = 1;
......@@ -79,9 +81,9 @@ TEST(fc_arm, compare_test) {
DeviceInfo::Init();
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<ARMContext>();
fc.SetParam(param);
fc.SetContext(std::move(ctx));
fc.Run();
mul.SetParam(param);
mul.SetContext(std::move(ctx));
mul.Run();
VLOG(3) << "output vs ref";
for (int i = 0; i < out.dims().product(); i++) {
......@@ -93,8 +95,8 @@ TEST(fc_arm, compare_test) {
}
}
TEST(fc_arm, num_col_dims) {
FcCompute fc;
TEST(mul_arm, num_col_dims) {
FcCompute mul;
operators::FcParam param;
lite::Tensor x;
......@@ -136,9 +138,9 @@ TEST(fc_arm, num_col_dims) {
ctx->As<ARMContext>();
DeviceInfo::Init();
fc.SetParam(param);
fc.SetContext(std::move(ctx));
fc.Run();
mul.SetParam(param);
mul.SetContext(std::move(ctx));
mul.Run();
}
} // namespace arm
......@@ -146,4 +148,4 @@ TEST(fc_arm, num_col_dims) {
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(fc, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(mul, kARM, kFloat, kNCHW, def);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册