提交 c2be5da3 编写于 作者: 开心的小妮's avatar 开心的小妮

[LITE][CI] Enable android armv7 CI, Fix mul, fc bug

上级 57e66daa
...@@ -13,10 +13,10 @@ ...@@ -13,10 +13,10 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/kernels/arm/fc_compute.h" #include "paddle/fluid/lite/kernels/arm/fc_compute.h"
#include <vector>
#include "paddle/fluid/lite/arm/math/funcs.h" #include "paddle/fluid/lite/arm/math/funcs.h"
#include "paddle/fluid/lite/core/op_registry.h" #include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/type_system.h" #include "paddle/fluid/lite/core/type_system.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace kernels {
...@@ -53,6 +53,12 @@ void FcCompute::PrepareForRun() { ...@@ -53,6 +53,12 @@ void FcCompute::PrepareForRun() {
} }
} }
} }
if (m_ > 1) {
int hblock = lite::arm::math::get_hblock(ctx.arch());
int m_round = hblock * ((m_ + hblock - 1) / hblock);
ctx.ExtendWorkspace(DDimLite(std::vector<int64_t>({m_round * k_})));
}
} }
void FcCompute::Run() { void FcCompute::Run() {
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/kernels/arm/mul_compute.h" #include "paddle/fluid/lite/kernels/arm/mul_compute.h"
#include <vector>
#include "paddle/fluid/lite/arm/math/funcs.h" #include "paddle/fluid/lite/arm/math/funcs.h"
#include "paddle/fluid/lite/core/op_registry.h" #include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/type_system.h" #include "paddle/fluid/lite/core/type_system.h"
...@@ -33,7 +34,7 @@ void MulCompute::Run() { ...@@ -33,7 +34,7 @@ void MulCompute::Run() {
const auto* y_data = param.y->data<float>(); const auto* y_data = param.y->data<float>();
auto* o_data = param.output->mutable_data<float>(); auto* o_data = param.output->mutable_data<float>();
int m = static_cast<int>( m_ = static_cast<int>(
param.x->dims().Slice(0, param.x_num_col_dims).production()); param.x->dims().Slice(0, param.x_num_col_dims).production());
int x_w = int x_w =
static_cast<int>(param.x->dims() static_cast<int>(param.x->dims()
...@@ -41,26 +42,29 @@ void MulCompute::Run() { ...@@ -41,26 +42,29 @@ void MulCompute::Run() {
.production()); .production());
int y_h = static_cast<int>( int y_h = static_cast<int>(
param.y->dims().Slice(0, param.y_num_col_dims).production()); param.y->dims().Slice(0, param.y_num_col_dims).production());
int n = n_ = static_cast<int>(param.y->dims()
static_cast<int>(param.y->dims() .Slice(param.y_num_col_dims, param.y->dims().size())
.Slice(param.y_num_col_dims, param.y->dims().size()) .production());
.production());
CHECK_EQ(x_w, y_h) << "x_w must be equal with y_h"; CHECK_EQ(x_w, y_h) << "x_w must be equal with y_h";
auto k = x_w; k_ = x_w;
if (n == 1) {
lite::arm::math::sgemv(x_data, y_data, o_data, false, m, k, false, nullptr, if (n_ == 1) {
false); lite::arm::math::sgemv(x_data, y_data, o_data, false, m_, k_, false,
nullptr, false);
} else { } else {
constexpr bool is_tranposed_y = false; constexpr bool is_tranposed_y = false;
auto& ctx = this->ctx_->template As<ARMContext>(); auto& ctx = this->ctx_->template As<ARMContext>();
int hblock = lite::arm::math::get_hblock(ctx.arch());
int m_round = hblock * ((m_ + hblock - 1) / hblock);
ctx.ExtendWorkspace(DDimLite(std::vector<int64_t>({m_round * k_})));
float* packed_x = static_cast<float*>(ctx.workspace_data<float>()) + float* packed_x = static_cast<float*>(ctx.workspace_data<float>()) +
ctx.l2_cache_size() / sizeof(float); ctx.l2_cache_size() / sizeof(float);
lite::arm::math::prepackA(packed_x, x_data, k, 0, m, 0, k, false, &ctx); lite::arm::math::prepackA(packed_x, x_data, k_, 0, m_, 0, k_, false, &ctx);
lite::arm::math::sgemm_prepack(packed_x, y_data, nullptr, o_data, m, n, k, lite::arm::math::sgemm_prepack(packed_x, y_data, nullptr, o_data, m_, n_,
false, false, is_tranposed_y, &ctx); k_, false, false, is_tranposed_y, &ctx);
} }
} }
......
...@@ -31,6 +31,9 @@ class MulCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> { ...@@ -31,6 +31,9 @@ class MulCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
void Run() override; void Run() override;
virtual ~MulCompute() = default; virtual ~MulCompute() = default;
private:
int m_, n_, k_;
}; };
} // namespace arm } // namespace arm
......
...@@ -246,14 +246,7 @@ function test_arm { ...@@ -246,14 +246,7 @@ function test_arm {
echo "android do not need armv7hf" echo "android do not need armv7hf"
return 0 return 0
fi fi
# TODO(yuanshuai): enable armv7 on android
if [[ ${abi} == "armv7" ]]; then
echo "skip android v7 test yet"
return 0
fi
echo "test file: ${TESTS_FILE}" echo "test file: ${TESTS_FILE}"
for _test in $(cat $TESTS_FILE); do for _test in $(cat $TESTS_FILE); do
test_arm_android $_test $port test_arm_android $_test $port
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册