提交 c2be5da3 编写于 作者: 开心的小妮's avatar 开心的小妮

[LITE][CI] Enable android armv7 CI, Fix mul, fc bug

上级 57e66daa
......@@ -13,10 +13,10 @@
// limitations under the License.
#include "paddle/fluid/lite/kernels/arm/fc_compute.h"
#include <vector>
#include "paddle/fluid/lite/arm/math/funcs.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/type_system.h"
namespace paddle {
namespace lite {
namespace kernels {
......@@ -53,6 +53,12 @@ void FcCompute::PrepareForRun() {
}
}
}
if (m_ > 1) {
int hblock = lite::arm::math::get_hblock(ctx.arch());
int m_round = hblock * ((m_ + hblock - 1) / hblock);
ctx.ExtendWorkspace(DDimLite(std::vector<int64_t>({m_round * k_})));
}
}
void FcCompute::Run() {
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/lite/kernels/arm/mul_compute.h"
#include <vector>
#include "paddle/fluid/lite/arm/math/funcs.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/type_system.h"
......@@ -33,7 +34,7 @@ void MulCompute::Run() {
const auto* y_data = param.y->data<float>();
auto* o_data = param.output->mutable_data<float>();
int m = static_cast<int>(
m_ = static_cast<int>(
param.x->dims().Slice(0, param.x_num_col_dims).production());
int x_w =
static_cast<int>(param.x->dims()
......@@ -41,26 +42,29 @@ void MulCompute::Run() {
.production());
int y_h = static_cast<int>(
param.y->dims().Slice(0, param.y_num_col_dims).production());
int n =
static_cast<int>(param.y->dims()
.Slice(param.y_num_col_dims, param.y->dims().size())
.production());
n_ = static_cast<int>(param.y->dims()
.Slice(param.y_num_col_dims, param.y->dims().size())
.production());
CHECK_EQ(x_w, y_h) << "x_w must be equal with y_h";
auto k = x_w;
if (n == 1) {
lite::arm::math::sgemv(x_data, y_data, o_data, false, m, k, false, nullptr,
false);
k_ = x_w;
if (n_ == 1) {
lite::arm::math::sgemv(x_data, y_data, o_data, false, m_, k_, false,
nullptr, false);
} else {
constexpr bool is_tranposed_y = false;
auto& ctx = this->ctx_->template As<ARMContext>();
int hblock = lite::arm::math::get_hblock(ctx.arch());
int m_round = hblock * ((m_ + hblock - 1) / hblock);
ctx.ExtendWorkspace(DDimLite(std::vector<int64_t>({m_round * k_})));
float* packed_x = static_cast<float*>(ctx.workspace_data<float>()) +
ctx.l2_cache_size() / sizeof(float);
lite::arm::math::prepackA(packed_x, x_data, k, 0, m, 0, k, false, &ctx);
lite::arm::math::sgemm_prepack(packed_x, y_data, nullptr, o_data, m, n, k,
false, false, is_tranposed_y, &ctx);
lite::arm::math::prepackA(packed_x, x_data, k_, 0, m_, 0, k_, false, &ctx);
lite::arm::math::sgemm_prepack(packed_x, y_data, nullptr, o_data, m_, n_,
k_, false, false, is_tranposed_y, &ctx);
}
}
......
......@@ -31,6 +31,9 @@ class MulCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
void Run() override;
virtual ~MulCompute() = default;
private:
int m_, n_, k_;
};
} // namespace arm
......
......@@ -246,14 +246,7 @@ function test_arm {
echo "android do not need armv7hf"
return 0
fi
# TODO(yuanshuai): enable armv7 on android
if [[ ${abi} == "armv7" ]]; then
echo "skip android v7 test yet"
return 0
fi
echo "test file: ${TESTS_FILE}"
for _test in $(cat $TESTS_FILE); do
test_arm_android $_test $port
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册