提交 cf027d49 编写于 作者: T tensor-tang

use arm kernel

上级 926ab88e
...@@ -6,7 +6,7 @@ message(STATUS "compile with lite ARM kernels") ...@@ -6,7 +6,7 @@ message(STATUS "compile with lite ARM kernels")
cc_library(fc_compute_arm SRCS fc_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(fc_compute_arm SRCS fc_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(relu_compute_arm SRCS relu_compute.cc DEPS ${lite_kernel_deps}) cc_library(relu_compute_arm SRCS relu_compute.cc DEPS ${lite_kernel_deps})
cc_library(mul_compute_arm SRCS mul_compute.cc DEPS ${lite_kernel_deps} eigen3) cc_library(mul_compute_arm SRCS mul_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(scale_compute_arm SRCS scale_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(scale_compute_arm SRCS scale_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(softmax_compute_arm SRCS softmax_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(softmax_compute_arm SRCS softmax_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(conv_compute_arm SRCS conv_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(conv_compute_arm SRCS conv_compute.cc DEPS ${lite_kernel_deps} math_arm)
...@@ -19,6 +19,7 @@ lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_ ...@@ -19,6 +19,7 @@ lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_
lite_cc_test(test_conv_compute_arm SRCS conv_compute_test.cc DEPS conv_compute_arm) lite_cc_test(test_conv_compute_arm SRCS conv_compute_test.cc DEPS conv_compute_arm)
lite_cc_test(test_elementwise_add_compute_arm SRCS elementwise_add_compute_test.cc DEPS elementwise_add_compute_arm) lite_cc_test(test_elementwise_add_compute_arm SRCS elementwise_add_compute_test.cc DEPS elementwise_add_compute_arm)
lite_cc_test(test_pool_compute_arm SRCS pool_compute_test.cc DEPS pool_compute_arm) lite_cc_test(test_pool_compute_arm SRCS pool_compute_test.cc DEPS pool_compute_arm)
lite_cc_test(test_mul_compute_arm SRCS mul_compute_test.cc DEPS mul_compute_arm)
set(arm_kernels set(arm_kernels
fc_compute_arm fc_compute_arm
......
...@@ -63,10 +63,6 @@ void FcCompute::Run() { ...@@ -63,10 +63,6 @@ void FcCompute::Run() {
} }
} }
TargetType FcCompute::target() const { return TARGET(kARM); }
PrecisionType FcCompute::precision() const { return PRECISION(kFloat); }
} // namespace arm } // namespace arm
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
......
...@@ -29,9 +29,6 @@ class FcCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> { ...@@ -29,9 +29,6 @@ class FcCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
void Run() override; void Run() override;
TargetType target() const override;
PrecisionType precision() const override;
virtual ~FcCompute() = default; virtual ~FcCompute() = default;
}; };
......
...@@ -12,57 +12,53 @@ ...@@ -12,57 +12,53 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <Eigen/Core> #include "paddle/fluid/lite/kernels/arm/mul_compute.h"
#include "paddle/fluid/lite/core/kernel.h" #include "paddle/fluid/lite/arm/math/funcs.h"
#include "paddle/fluid/lite/core/op_registry.h" #include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/types.h" #include "paddle/fluid/lite/core/type_system.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace arm {
template <typename T> void MulCompute::PrepareForRun() {
void mul_compute_eigen(const T* x, int x_h, int x_w, const T* y, int y_h, // TODO(TJ): transpose x or y if necessary
int y_w, T* out) {
using matrix_t =
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
Eigen::Map<const matrix_t> X(x, x_h, x_w);
Eigen::Map<const matrix_t> Y(y, y_h, y_w);
Eigen::Map<matrix_t> Out(out, x_h, y_w);
Out = X * Y;
} }
class MulCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> { void MulCompute::Run() {
public: auto& param = Param<param_t>();
using param_t = operators::MulParam;
void Run() override { const auto* x_data = param.x->data<float>();
auto& param = Param<operators::MulParam>(); const auto* y_data = param.y->data<float>();
core::dim2 x_shape( auto* o_data = param.output->mutable_data<float>();
{static_cast<int>(
param.x->dims().Slice(0, param.x_num_col_dims).production()), int x_h = static_cast<int>(
static_cast<int>( param.x->dims().Slice(0, param.x_num_col_dims).production());
param.x->dims() int x_w =
static_cast<int>(param.x->dims()
.Slice(param.x_num_col_dims, param.x->dims().size()) .Slice(param.x_num_col_dims, param.x->dims().size())
.production())}); .production());
core::dim2 y_shape( int y_h = static_cast<int>(
{static_cast<int>( param.y->dims().Slice(0, param.y_num_col_dims).production());
param.y->dims().Slice(0, param.y_num_col_dims).production()), int y_w =
static_cast<int>( static_cast<int>(param.y->dims()
param.y->dims()
.Slice(param.y_num_col_dims, param.y->dims().size()) .Slice(param.y_num_col_dims, param.y->dims().size())
.production())}); .production());
mul_compute_eigen(param.x->data<float>(), x_shape.x, x_shape.y, // CHECK_EQ(x_w, y_h) << "x_w must be equal with y_h";
param.y->data<float>(), y_shape.x, y_shape.y, // if (y_w == 1 || x_h == 1) {
param.output->mutable_data<float>()); lite::arm::math::sgemv(x_data, y_data, o_data, false, x_h, x_w, false,
} nullptr, false);
virtual ~MulCompute() = default; } else {
}; constexpr bool is_tranposed_y = false;
auto& ctx = this->ctx_->template As<ARMContext>();
lite::arm::math::sgemm_prepack(x_data, y_data, nullptr, o_data, x_h, y_w,
x_w, false, false, is_tranposed_y, &ctx);
}
}
} // namespace arm } // namespace arm
} // namespace kernels } // namespace kernels
......
...@@ -22,44 +22,13 @@ namespace lite { ...@@ -22,44 +22,13 @@ namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace arm {
template <typename T>
void mul_compute_eigen(const T* x, int x_h, int x_w, const T* y, int y_h,
int y_w, T* out) {
using matrix_t =
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
Eigen::Map<const matrix_t> X(x, x_h, x_w);
Eigen::Map<const matrix_t> Y(y, y_h, y_w);
Eigen::Map<matrix_t> Out(out, x_h, y_w);
Out = X * Y;
}
class MulCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> { class MulCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public: public:
using param_t = operators::MulParam; using param_t = operators::MulParam;
void Run() override { void PrepareForRun() override;
auto& param = Param<operators::MulParam>();
core::dim2 x_shape(
{static_cast<int>(
param.x->dims().Slice(0, param.x_num_col_dims).production()),
static_cast<int>(
param.x->dims()
.Slice(param.x_num_col_dims, param.x->dims().size())
.production())});
core::dim2 y_shape(
{static_cast<int>(
param.y->dims().Slice(0, param.y_num_col_dims).production()),
static_cast<int>(
param.y->dims()
.Slice(param.y_num_col_dims, param.y->dims().size())
.production())});
mul_compute_eigen(param.x->data<float>(), x_shape.x, x_shape.y, // void Run() override;
param.y->data<float>(), y_shape.x, y_shape.y, //
param.output->mutable_data<float>());
}
virtual ~MulCompute() = default; virtual ~MulCompute() = default;
}; };
...@@ -68,10 +37,3 @@ class MulCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> { ...@@ -68,10 +37,3 @@ class MulCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL(mul, kARM, kFloat, kNCHW,
paddle::lite::kernels::arm::MulCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
...@@ -12,31 +12,33 @@ ...@@ -12,31 +12,33 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/kernels/arm/mul_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/lite/arm/math/funcs.h" #include "paddle/fluid/lite/arm/math/funcs.h"
#include "paddle/fluid/lite/core/op_registry.h" #include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/kernels/arm/fc_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace arm {
TEST(fc_arm, retrive_op) { TEST(mul_arm, retrive_op) {
auto fc = auto mul =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>("fc"); KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>("mul");
ASSERT_FALSE(fc.empty()); ASSERT_FALSE(mul.empty());
ASSERT_TRUE(fc.front()); ASSERT_TRUE(mul.front());
} }
TEST(fc_arm, init) { TEST(mul_arm, init) {
FcCompute fc; FcCompute mul;
ASSERT_EQ(fc.precision(), PRECISION(kFloat)); ASSERT_EQ(mul.precision(), PRECISION(kFloat));
ASSERT_EQ(fc.target(), TARGET(kARM)); ASSERT_EQ(mul.target(), TARGET(kARM));
} }
TEST(fc_arm, compare_test) { TEST(mul_arm, compare_test) {
lite::Tensor x, w, b, out, ref; lite::Tensor x, w, b, out, ref;
constexpr int batch_size = 2; constexpr int batch_size = 2;
x.Resize({batch_size, 3}); x.Resize({batch_size, 3});
...@@ -65,8 +67,8 @@ TEST(fc_arm, compare_test) { ...@@ -65,8 +67,8 @@ TEST(fc_arm, compare_test) {
w_data, 3, 4, // w_data, 3, 4, //
b_data, ref_data); b_data, ref_data);
// fc compute kernel // mul compute kernel
FcCompute fc; FcCompute mul;
operators::FcParam param; operators::FcParam param;
param.in_num_col_dims = 1; param.in_num_col_dims = 1;
...@@ -79,9 +81,9 @@ TEST(fc_arm, compare_test) { ...@@ -79,9 +81,9 @@ TEST(fc_arm, compare_test) {
DeviceInfo::Init(); DeviceInfo::Init();
std::unique_ptr<KernelContext> ctx(new KernelContext); std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<ARMContext>(); ctx->As<ARMContext>();
fc.SetParam(param); mul.SetParam(param);
fc.SetContext(std::move(ctx)); mul.SetContext(std::move(ctx));
fc.Run(); mul.Run();
VLOG(3) << "output vs ref"; VLOG(3) << "output vs ref";
for (int i = 0; i < out.dims().product(); i++) { for (int i = 0; i < out.dims().product(); i++) {
...@@ -93,8 +95,8 @@ TEST(fc_arm, compare_test) { ...@@ -93,8 +95,8 @@ TEST(fc_arm, compare_test) {
} }
} }
TEST(fc_arm, num_col_dims) { TEST(mul_arm, num_col_dims) {
FcCompute fc; FcCompute mul;
operators::FcParam param; operators::FcParam param;
lite::Tensor x; lite::Tensor x;
...@@ -136,9 +138,9 @@ TEST(fc_arm, num_col_dims) { ...@@ -136,9 +138,9 @@ TEST(fc_arm, num_col_dims) {
ctx->As<ARMContext>(); ctx->As<ARMContext>();
DeviceInfo::Init(); DeviceInfo::Init();
fc.SetParam(param); mul.SetParam(param);
fc.SetContext(std::move(ctx)); mul.SetContext(std::move(ctx));
fc.Run(); mul.Run();
} }
} // namespace arm } // namespace arm
...@@ -146,4 +148,4 @@ TEST(fc_arm, num_col_dims) { ...@@ -146,4 +148,4 @@ TEST(fc_arm, num_col_dims) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
USE_LITE_KERNEL(fc, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(mul, kARM, kFloat, kNCHW, def);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册