From cf027d49dfda6f0ef82799b653017a3fda303fae Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Wed, 12 Jun 2019 12:59:13 +0000 Subject: [PATCH] use arm kernel --- paddle/fluid/lite/kernels/arm/CMakeLists.txt | 3 +- paddle/fluid/lite/kernels/arm/fc_compute.cc | 4 -- paddle/fluid/lite/kernels/arm/fc_compute.h | 3 - paddle/fluid/lite/kernels/arm/mul_compute.cc | 72 +++++++++---------- paddle/fluid/lite/kernels/arm/mul_compute.h | 42 +---------- .../lite/kernels/arm/mul_compute_test.cc | 46 ++++++------ 6 files changed, 62 insertions(+), 108 deletions(-) diff --git a/paddle/fluid/lite/kernels/arm/CMakeLists.txt b/paddle/fluid/lite/kernels/arm/CMakeLists.txt index c0fa480f094..0f87a2f3b3d 100644 --- a/paddle/fluid/lite/kernels/arm/CMakeLists.txt +++ b/paddle/fluid/lite/kernels/arm/CMakeLists.txt @@ -6,7 +6,7 @@ message(STATUS "compile with lite ARM kernels") cc_library(fc_compute_arm SRCS fc_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(relu_compute_arm SRCS relu_compute.cc DEPS ${lite_kernel_deps}) -cc_library(mul_compute_arm SRCS mul_compute.cc DEPS ${lite_kernel_deps} eigen3) +cc_library(mul_compute_arm SRCS mul_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(scale_compute_arm SRCS scale_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(softmax_compute_arm SRCS softmax_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(conv_compute_arm SRCS conv_compute.cc DEPS ${lite_kernel_deps} math_arm) @@ -19,6 +19,7 @@ lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_ lite_cc_test(test_conv_compute_arm SRCS conv_compute_test.cc DEPS conv_compute_arm) lite_cc_test(test_elementwise_add_compute_arm SRCS elementwise_add_compute_test.cc DEPS elementwise_add_compute_arm) lite_cc_test(test_pool_compute_arm SRCS pool_compute_test.cc DEPS pool_compute_arm) +lite_cc_test(test_mul_compute_arm SRCS mul_compute_test.cc DEPS mul_compute_arm) set(arm_kernels fc_compute_arm diff --git a/paddle/fluid/lite/kernels/arm/fc_compute.cc b/paddle/fluid/lite/kernels/arm/fc_compute.cc index dcf3e4d81e6..5bf9faab9f1 100644 --- a/paddle/fluid/lite/kernels/arm/fc_compute.cc +++ b/paddle/fluid/lite/kernels/arm/fc_compute.cc @@ -63,10 +63,6 @@ void FcCompute::Run() { } } -TargetType FcCompute::target() const { return TARGET(kARM); } - -PrecisionType FcCompute::precision() const { return PRECISION(kFloat); } - } // namespace arm } // namespace kernels } // namespace lite diff --git a/paddle/fluid/lite/kernels/arm/fc_compute.h b/paddle/fluid/lite/kernels/arm/fc_compute.h index b72b24b4844..459d23194d8 100644 --- a/paddle/fluid/lite/kernels/arm/fc_compute.h +++ b/paddle/fluid/lite/kernels/arm/fc_compute.h @@ -29,9 +29,6 @@ class FcCompute : public KernelLite { void Run() override; - TargetType target() const override; - PrecisionType precision() const override; - virtual ~FcCompute() = default; }; diff --git a/paddle/fluid/lite/kernels/arm/mul_compute.cc b/paddle/fluid/lite/kernels/arm/mul_compute.cc index ff12b236031..4ca2c455e43 100644 --- a/paddle/fluid/lite/kernels/arm/mul_compute.cc +++ b/paddle/fluid/lite/kernels/arm/mul_compute.cc @@ -12,57 +12,53 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/kernels/arm/mul_compute.h" +#include "paddle/fluid/lite/arm/math/funcs.h" #include "paddle/fluid/lite/core/op_registry.h" -#include "paddle/fluid/lite/core/types.h" +#include "paddle/fluid/lite/core/type_system.h" namespace paddle { namespace lite { namespace kernels { namespace arm { -template -void mul_compute_eigen(const T* x, int x_h, int x_w, const T* y, int y_h, - int y_w, T* out) { - using matrix_t = - Eigen::Matrix; +void MulCompute::PrepareForRun() { + // TODO(TJ): transpose x or y if necessary +} - Eigen::Map X(x, x_h, x_w); - Eigen::Map Y(y, y_h, y_w); - Eigen::Map Out(out, x_h, y_w); +void MulCompute::Run() { + auto& param = Param(); - Out = X * Y; -} + const auto* x_data = param.x->data(); + const auto* y_data = param.y->data(); + auto* o_data = param.output->mutable_data(); -class MulCompute : public KernelLite { - public: - using param_t = operators::MulParam; + int x_h = static_cast( + param.x->dims().Slice(0, param.x_num_col_dims).production()); + int x_w = + static_cast(param.x->dims() + .Slice(param.x_num_col_dims, param.x->dims().size()) + .production()); + int y_h = static_cast( + param.y->dims().Slice(0, param.y_num_col_dims).production()); + int y_w = + static_cast(param.y->dims() + .Slice(param.y_num_col_dims, param.y->dims().size()) + .production()); - void Run() override { - auto& param = Param(); - core::dim2 x_shape( - {static_cast( - param.x->dims().Slice(0, param.x_num_col_dims).production()), - static_cast( - param.x->dims() - .Slice(param.x_num_col_dims, param.x->dims().size()) - .production())}); - core::dim2 y_shape( - {static_cast( - param.y->dims().Slice(0, param.y_num_col_dims).production()), - static_cast( - param.y->dims() - .Slice(param.y_num_col_dims, param.y->dims().size()) - .production())}); + CHECK_EQ(x_w, y_h) << "x_w must be equal with y_h"; + if (y_w == 1 || x_h == 1) { + lite::arm::math::sgemv(x_data, y_data, o_data, false, x_h, x_w, false, + nullptr, false); - mul_compute_eigen(param.x->data(), x_shape.x, x_shape.y, // - param.y->data(), y_shape.x, y_shape.y, // - param.output->mutable_data()); - } + } else { + constexpr bool is_tranposed_y = false; + auto& ctx = this->ctx_->template As(); - virtual ~MulCompute() = default; -}; + lite::arm::math::sgemm_prepack(x_data, y_data, nullptr, o_data, x_h, y_w, + x_w, false, false, is_tranposed_y, &ctx); + } +} } // namespace arm } // namespace kernels diff --git a/paddle/fluid/lite/kernels/arm/mul_compute.h b/paddle/fluid/lite/kernels/arm/mul_compute.h index 4d1abba94c2..c18995e5a5c 100644 --- a/paddle/fluid/lite/kernels/arm/mul_compute.h +++ b/paddle/fluid/lite/kernels/arm/mul_compute.h @@ -22,44 +22,13 @@ namespace lite { namespace kernels { namespace arm { -template -void mul_compute_eigen(const T* x, int x_h, int x_w, const T* y, int y_h, - int y_w, T* out) { - using matrix_t = - Eigen::Matrix; - - Eigen::Map X(x, x_h, x_w); - Eigen::Map Y(y, y_h, y_w); - Eigen::Map Out(out, x_h, y_w); - - Out = X * Y; -} - class MulCompute : public KernelLite { public: using param_t = operators::MulParam; - void Run() override { - auto& param = Param(); - core::dim2 x_shape( - {static_cast( - param.x->dims().Slice(0, param.x_num_col_dims).production()), - static_cast( - param.x->dims() - .Slice(param.x_num_col_dims, param.x->dims().size()) - .production())}); - core::dim2 y_shape( - {static_cast( - param.y->dims().Slice(0, param.y_num_col_dims).production()), - static_cast( - param.y->dims() - .Slice(param.y_num_col_dims, param.y->dims().size()) - .production())}); + void PrepareForRun() override; - mul_compute_eigen(param.x->data(), x_shape.x, x_shape.y, // - param.y->data(), y_shape.x, y_shape.y, // - param.output->mutable_data()); - } + void Run() override; virtual ~MulCompute() = default; }; @@ -68,10 +37,3 @@ class MulCompute : public KernelLite { } // namespace kernels } // namespace lite } // namespace paddle - -REGISTER_LITE_KERNEL(mul, kARM, kFloat, kNCHW, - paddle::lite::kernels::arm::MulCompute, def) - .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) - .Finalize(); diff --git a/paddle/fluid/lite/kernels/arm/mul_compute_test.cc b/paddle/fluid/lite/kernels/arm/mul_compute_test.cc index ee7c1b655fa..cef99b17607 100644 --- a/paddle/fluid/lite/kernels/arm/mul_compute_test.cc +++ b/paddle/fluid/lite/kernels/arm/mul_compute_test.cc @@ -12,31 +12,33 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/lite/kernels/arm/mul_compute.h" #include +#include +#include #include #include "paddle/fluid/lite/arm/math/funcs.h" #include "paddle/fluid/lite/core/op_registry.h" -#include "paddle/fluid/lite/kernels/arm/fc_compute.h" namespace paddle { namespace lite { namespace kernels { namespace arm { -TEST(fc_arm, retrive_op) { - auto fc = - KernelRegistry::Global().Create("fc"); - ASSERT_FALSE(fc.empty()); - ASSERT_TRUE(fc.front()); +TEST(mul_arm, retrive_op) { + auto mul = + KernelRegistry::Global().Create("mul"); + ASSERT_FALSE(mul.empty()); + ASSERT_TRUE(mul.front()); } -TEST(fc_arm, init) { - FcCompute fc; - ASSERT_EQ(fc.precision(), PRECISION(kFloat)); - ASSERT_EQ(fc.target(), TARGET(kARM)); +TEST(mul_arm, init) { + FcCompute mul; + ASSERT_EQ(mul.precision(), PRECISION(kFloat)); + ASSERT_EQ(mul.target(), TARGET(kARM)); } -TEST(fc_arm, compare_test) { +TEST(mul_arm, compare_test) { lite::Tensor x, w, b, out, ref; constexpr int batch_size = 2; x.Resize({batch_size, 3}); @@ -65,8 +67,8 @@ TEST(fc_arm, compare_test) { w_data, 3, 4, // b_data, ref_data); - // fc compute kernel - FcCompute fc; + // mul compute kernel + FcCompute mul; operators::FcParam param; param.in_num_col_dims = 1; @@ -79,9 +81,9 @@ TEST(fc_arm, compare_test) { DeviceInfo::Init(); std::unique_ptr ctx(new KernelContext); ctx->As(); - fc.SetParam(param); - fc.SetContext(std::move(ctx)); - fc.Run(); + mul.SetParam(param); + mul.SetContext(std::move(ctx)); + mul.Run(); VLOG(3) << "output vs ref"; for (int i = 0; i < out.dims().product(); i++) { @@ -93,8 +95,8 @@ TEST(fc_arm, compare_test) { } } -TEST(fc_arm, num_col_dims) { - FcCompute fc; +TEST(mul_arm, num_col_dims) { + FcCompute mul; operators::FcParam param; lite::Tensor x; @@ -136,9 +138,9 @@ TEST(fc_arm, num_col_dims) { ctx->As(); DeviceInfo::Init(); - fc.SetParam(param); - fc.SetContext(std::move(ctx)); - fc.Run(); + mul.SetParam(param); + mul.SetContext(std::move(ctx)); + mul.Run(); } } // namespace arm @@ -146,4 +148,4 @@ TEST(fc_arm, num_col_dims) { } // namespace lite } // namespace paddle -USE_LITE_KERNEL(fc, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(mul, kARM, kFloat, kNCHW, def); -- GitLab