diff --git a/paddle/fluid/lite/kernels/arm/fc_compute.cc b/paddle/fluid/lite/kernels/arm/fc_compute.cc index 2e6f46a0e07e422bb118834214fee3fc43ae1d61..24619ed9261c8a79b62061a821d12482b332f299 100644 --- a/paddle/fluid/lite/kernels/arm/fc_compute.cc +++ b/paddle/fluid/lite/kernels/arm/fc_compute.cc @@ -13,10 +13,10 @@ // limitations under the License. #include "paddle/fluid/lite/kernels/arm/fc_compute.h" +#include #include "paddle/fluid/lite/arm/math/funcs.h" #include "paddle/fluid/lite/core/op_registry.h" #include "paddle/fluid/lite/core/type_system.h" - namespace paddle { namespace lite { namespace kernels { @@ -53,6 +53,12 @@ void FcCompute::PrepareForRun() { } } } + + if (m_ > 1) { + int hblock = lite::arm::math::get_hblock(ctx.arch()); + int m_round = hblock * ((m_ + hblock - 1) / hblock); + ctx.ExtendWorkspace(DDimLite(std::vector({m_round * k_}))); + } } void FcCompute::Run() { diff --git a/paddle/fluid/lite/kernels/arm/mul_compute.cc b/paddle/fluid/lite/kernels/arm/mul_compute.cc index 57c28e63bbf3bfdacf861d60ba2ab25436b61b42..c721e8046e735d3b7b4d963e98687f824a2fb6ba 100644 --- a/paddle/fluid/lite/kernels/arm/mul_compute.cc +++ b/paddle/fluid/lite/kernels/arm/mul_compute.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/lite/kernels/arm/mul_compute.h" +#include #include "paddle/fluid/lite/arm/math/funcs.h" #include "paddle/fluid/lite/core/op_registry.h" #include "paddle/fluid/lite/core/type_system.h" @@ -33,7 +34,7 @@ void MulCompute::Run() { const auto* y_data = param.y->data(); auto* o_data = param.output->mutable_data(); - int m = static_cast( + m_ = static_cast( param.x->dims().Slice(0, param.x_num_col_dims).production()); int x_w = static_cast(param.x->dims() @@ -41,26 +42,29 @@ void MulCompute::Run() { .production()); int y_h = static_cast( param.y->dims().Slice(0, param.y_num_col_dims).production()); - int n = - static_cast(param.y->dims() - .Slice(param.y_num_col_dims, param.y->dims().size()) - .production()); + n_ = static_cast(param.y->dims() + .Slice(param.y_num_col_dims, param.y->dims().size()) + .production()); CHECK_EQ(x_w, y_h) << "x_w must be equal with y_h"; - auto k = x_w; - if (n == 1) { - lite::arm::math::sgemv(x_data, y_data, o_data, false, m, k, false, nullptr, - false); + k_ = x_w; + + if (n_ == 1) { + lite::arm::math::sgemv(x_data, y_data, o_data, false, m_, k_, false, + nullptr, false); } else { constexpr bool is_tranposed_y = false; auto& ctx = this->ctx_->template As(); + int hblock = lite::arm::math::get_hblock(ctx.arch()); + int m_round = hblock * ((m_ + hblock - 1) / hblock); + ctx.ExtendWorkspace(DDimLite(std::vector({m_round * k_}))); float* packed_x = static_cast(ctx.workspace_data()) + ctx.l2_cache_size() / sizeof(float); - lite::arm::math::prepackA(packed_x, x_data, k, 0, m, 0, k, false, &ctx); - lite::arm::math::sgemm_prepack(packed_x, y_data, nullptr, o_data, m, n, k, - false, false, is_tranposed_y, &ctx); + lite::arm::math::prepackA(packed_x, x_data, k_, 0, m_, 0, k_, false, &ctx); + lite::arm::math::sgemm_prepack(packed_x, y_data, nullptr, o_data, m_, n_, + k_, false, false, is_tranposed_y, &ctx); } } diff --git a/paddle/fluid/lite/kernels/arm/mul_compute.h b/paddle/fluid/lite/kernels/arm/mul_compute.h index c18995e5a5c3cceb749465382b284c0a52c188a4..64c8f813d4e34384b3e6b79eac8aa879bffaeac4 100644 --- a/paddle/fluid/lite/kernels/arm/mul_compute.h +++ b/paddle/fluid/lite/kernels/arm/mul_compute.h @@ -31,6 +31,9 @@ class MulCompute : public KernelLite { void Run() override; virtual ~MulCompute() = default; + + private: + int m_, n_, k_; }; } // namespace arm diff --git a/paddle/fluid/lite/tools/build.sh b/paddle/fluid/lite/tools/build.sh index a88e7786c012e15a5b8b8aaf6ed66c453f3dbe48..49dff53bfa50b8a9c121bc8705b097c6dd011afe 100755 --- a/paddle/fluid/lite/tools/build.sh +++ b/paddle/fluid/lite/tools/build.sh @@ -246,14 +246,7 @@ function test_arm { echo "android do not need armv7hf" return 0 fi - - # TODO(yuanshuai): enable armv7 on android - if [[ ${abi} == "armv7" ]]; then - echo "skip android v7 test yet" - return 0 - fi - - + echo "test file: ${TESTS_FILE}" for _test in $(cat $TESTS_FILE); do test_arm_android $_test $port