未验证 提交 85f23656 编写于 作者: S smilejames 提交者: GitHub

Merge branch 'develop' into develop

......@@ -76,3 +76,4 @@ demo/ios/PaddleMobileDemo/PaddleMobileDemo/googlenet_combine/
demo/ios/PaddleMobileDemo/PaddleMobileDemo/*.jpg
demo/ios/PaddleMobileDemo/PaddleMobileDemo/PaddleMobile/*.a
*.xcuserstate
/tools/quantification/quantify
......@@ -29,15 +29,15 @@ limitations under the License. */
#include "fpga/api/fpga_api.h"
namespace paddle {
namespace mobile {
namespace paddle_mobile {
namespace fpga {
namespace api {
static int fd = -1;
static const char *device_path = "/dev/fpgadrv0";
static inline int do_ioctl(int req, void *arg) { return ioctl(req, arg); }
static inline int do_ioctl(int req, void *arg) {
return ioctl(req, (unsigned int64_t)arg);
}
int open_device() {
if (fd == -1) {
......@@ -48,8 +48,8 @@ int open_device() {
// memory management;
void *fpga_malloc(size_t size) {
return reinterpret_cast<(void *)> mmap64(NULL, size, PROT_READ | PROT_WRITE,
MAP_SHARED, fd, 0);
return reinterpret_cast<void *>(
mmap64(NULL, size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0));
}
void fpga_free(void *ptr) { munmap(ptr, 0); }
......@@ -58,11 +58,13 @@ void fpga_copy(void *dest, const void *src, size_t num) {
memcpy(dest, src, num);
}
int ComputeFpgaConv(struct FpgaConvArgs) {}
int ComputeFpgaPool(struct FpgaPoolArgs) {}
int ComputeFpgaEWAdd(struct FpgaEWAddArgs) {}
int ComputeFpgaConv(const struct ConvArgs &args) { return do_ioctl(21, &args); }
int ComputeFpgaPool(const struct PoolingArgs &args) {
return do_ioctl(22, &args);
}
int ComputeFpgaEWAdd(const struct EWAddArgs &args) {
return do_ioctl(23, &args);
}
} // namespace api
} // namespace fpga
} // namespace mobile
} // namespace paddle
} // namespace paddle_mobile
......@@ -31,90 +31,150 @@ void* fpga_malloc(size_t size);
void fpga_free(void* ptr);
void fpga_copy(void* dst, const void* src, size_t num);
struct FpgaVersionArgs {
void* buf;
enum DataConvertType {
DATA_NO_CONVERT = 0,
DATA_FP32_TO_FP16 = 1,
DATA_FP16_TO_FP32 = 2,
};
struct MemoryToPhysicalArgs {
const void* src;
uint64_t physical;
enum LayoutConvertType {
LAYOUT_NO_CONVERT = 0,
LAYOUT_CHW_TO_HWC = 1,
LAYOUT_HWC_TO_CHW = 2,
};
struct VersionArgs {
void* buffer;
};
struct MemoryCopyArgs {
void* src;
void* dst;
void* dest;
size_t size;
};
struct FpgaQuantArgs {
float scale;
};
struct FpgaBNArgs {
bool enabled = false;
void* bias_addr;
void* scale_addr;
struct BNArgs {
bool enabled;
void* bias_address;
void* scale_address;
};
struct FpgaKernelArgs {
/**
Conv and Pooling kernel
*/
struct KernelArgs {
uint32_t width;
uint32_t height;
uint32_t stride_h;
uint32_t stride_w;
uint32_t stride_h;
};
struct FpgaImageArgs {
uint32_t width;
uint32_t height;
struct ImageInputArgs {
void* address; // input featuremap virtual address
float* scale_address; // input scale address;
uint32_t channels;
uint32_t pad_h;
uint32_t pad_w;
uint32_t width; // featuremap width
uint32_t height;
uint32_t pad_width; // padding width;
uint32_t pad_height;
};
struct FpgaConvArgs {
struct ImageOutputArgs {
void* address; // output result address;
float* scale_address; // output scale address;
};
struct ConvArgs {
bool relu_enabled;
struct FpgaBNArgs BNargs;
void* image_addr;
void* filter_addr;
void* bias_addr;
void* output_addr;
float quant_scale;
struct FpgaImageArgs image;
void* bias_address;
void* filter_address;
uint32_t filter_num;
uint32_t group_num;
struct FpgaKernelArgs kernel;
void* sb_address; // scale and bias are interlaced;
struct KernelArgs kernel;
struct ImageInputArgs image; // input image;
struct ImageOutputArgs output;
};
struct FpgaPoolArgs {
void* image_addr;
void* output_addr;
struct FpgaImageArgs image;
struct FpgaKernelArgs kernel;
struct PoolingArgs {
struct KernelArgs kernel;
struct ImageInputArgs image; // input image;
struct ImageOutputArgs output;
};
struct FpgaEWAddArgs {
// elementwise add arguments
struct EWAddArgs {
bool relu_enabled;
void* image0_addr;
void* image1_addr;
void* result_addr;
uint32_t const0;
uint32_t const1;
uint32_t data_len; // aligned element count
float const0; // output0 = const0 x input0 + const1 x input1;
float const1;
struct ImageInputArgs image0;
struct ImageInputArgs image1;
struct ImageOutputArgs output;
};
int ComputeFpgaConv(struct FpgaConvArgs args);
int ComputeFpgaPool(struct FpgaPoolArgs args);
int ComputeFpgaEWAdd(struct FpgaEWAddArgs args);
struct BypassArgs {
enum DataConvertType convert_type;
struct ImageInputArgs image;
struct ImageOutputArgs output;
};
struct FpgaRegWriteArgs {
uint64_t address; //
uint64_t value;
};
struct FpgaRegReadArgs {
uint64_t address;
uint64_t value;
};
#define IOCTL_FPGA_MAGIC 'FPGA'
#define IOCTL_VERSION _IOW(IOCTL_FPGA_MAGIC, 01, struct VersionArgs)
#define IOCTL_SEPARATOR_0 10
#define IOCTL_FPGA_MAGIC 'CNN'
#define IOCTL_VERSION _IOW(IOCTL_FPGA_MAGIC, 1, struct FpgaVersionArgs)
#define IOCTL_GET_QUANT _IOW(IOCTL_FPGA_MAGIC, 2, struct FpgaQuantArgs)
#define IOCTL_SET_QUANT _IOW(IOCTL_FPGA_MAGIC, 3, struct FpgaQuantArgs)
#define IOCTL_MEM_COPY _IOW(IOCTL_FPGA_MAGIC, 11, struct MemoryCopyArgs)
#define IOCTL_CONFIG_CONV _IOW(IOCTL_FPGA_MAGIC, 21, struct FpgaConvArgs)
#define IOCTL_CONFIG_POOLING _IOW(IOCTL_FPGA_MAGIC, 22, struct FpgaPoolArgs)
#define IOCTL_CONFIG_EW _IOW(IOCTL_FPGA_MAGIC, 23, struct FpgaEWAddArgs)
#define IOCTL_SEPARATOR_1 20
#define IOCTL_CONFIG_CONV _IOW(IOCTL_FPGA_MAGIC, 21, struct ConvArgs)
#define IOCTL_CONFIG_POOLING _IOW(IOCTL_FPGA_MAGIC, 22, struct PoolingArgs)
#define IOCTL_CONFIG_EW _IOW(IOCTL_FPGA_MAGIC, 23, struct EWAddArgs)
#define IOCTL_FPGA_REG_READ _IOW(IOCTL_FPGA_MAGIC, 28, struct FpgaRegReadArgs)
#define IOCTL_FPGA_REG_WRITE _IOW(IOCTL_FPGA_MAGIC, 29, struct FpgaRegWriteArgs)
enum FPGA_ERR_TYPE {
ERR_IOCTL_CMD = -1,
ERR_TIMEOUT = -2,
ERR_COMPLETION_TIMEOUT = -3,
ERR_INVALID_FPGA_ADDR = -4,
ERR_NOMEM = -5,
ERR_NO_RESERVE_MEM = -6,
ERR_COPY_FROM_USER = -7,
ERR_COPY_TO_USER = -8,
ERR_DEL_TIMER = -9,
ERR_ENABLE_MSI = -10,
ERR_REGISTER_IRQ = -11,
ERR_PCIE_REGISTER = -12,
ERR_PCIE_PROBE = -13,
ERR_REGISTER_BLOCK = -14,
ERR_ALLOC_GENDISK = -15,
ERR_INIT_QUEUE = -16,
ERR_WAIT = -17,
ERR_ECC_ERROR = -31,
ERR_FPGA_FAIL_STOP = -64,
ERR_FPGA_DEBUG_STOP = -113,
DEV_TMP_UNAVAILABLE = -128
};
//============================== API =============================
int ComputeFpgaConv(const struct ConvArgs& args);
int ComputeFpgaPool(const struct PoolingArgs& args);
int ComputeFpgaEWAdd(const struct EWAddArgs& args);
} // namespace fpga
} // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "common/types.h"
#include "framework/lod_tensor.h"
#include "framework/operator.h"
#include "framework/scope.h"
#include "framework/tensor.h"
namespace paddle_mobile {
bool is_conv(std::string type) {
if (type.compare(G_OP_TYPE_CONV) == 0) {
return true;
}
if (type.compare(G_OP_TYPE_FUSION_CONV_ADD) == 0) {
return true;
}
if (type.compare(G_OP_TYPE_FUSION_CONV_ADD_RELU) == 0) {
return true;
}
if (type.compare(G_OP_TYPE_FUSION_CONV_BN_RELU) == 0) {
return true;
}
if (type.compare(G_OP_TYPE_FUSION_CONV_ADD_BN) == 0) {
return true;
}
return false;
}
template <typename Dtype>
void quantilize_op(std::shared_ptr<framework::OperatorBase<Dtype>> op,
std::shared_ptr<framework::Scope> scope) {
if (!is_conv(op.get()->Type())) {
return;
}
framework::Tensor* filter = nullptr;
auto var_vec = op.get()->Inputs().at("Filter");
if (!var_vec.empty()) {
auto var = scope.get()->FindVar(var_vec[0]);
filter = var->template GetMutable<framework::LoDTensor>();
}
float scale = 0;
// 32bit filter -> 8bit filter;
if (filter->type() == typeid(float)) {
framework::Tensor* originalFilter = filter;
framework::Tensor* quantFilter = new framework::Tensor();
float* floatData = originalFilter->data<float>();
int8_t* intData = quantFilter->mutable_data<int8_t>();
}
}
} // namespace paddle_mobile
......@@ -253,6 +253,18 @@ class Tensor {
"Tensor's dims_ is out of bound. ");
}
#ifdef PADDLE_MOBILE_FPGA
struct FPGAArgs {
float scale;
inline float *scale_pointer() { return &scale; }
};
struct FPGAArgs &fpga_args() {
return fpgaArgs_;
}
#endif
private:
/**
* @note Placeholder hides type T, so it doesn't appear as a
......@@ -319,6 +331,10 @@ class Tensor {
* begins.
*/
size_t offset_;
#ifdef PADDLE_MOBILE_FPGA
FPGAArgs fpgaArgs_;
#endif
};
#ifdef PADDLE_MOBILE_DEBUG
......
......@@ -32,6 +32,10 @@ limitations under the License. */
#include "common/threadpool.h"
#endif
#ifdef PADDLE_MOBILE_FPGA
#include "fpga/fpga_quantilization.h"
#endif
namespace paddle_mobile {
using framework::Variable;
......@@ -96,6 +100,11 @@ Executor<Dtype, P>::Executor(const framework::Program<Dtype> p, int batch_size,
for (const auto &op : ops) {
op->Init();
}
#ifdef PADDLE_MOBILE_FPGA
for (const auto &op : ops) {
quantilize_op(op, program_.scope);
}
#endif
}
template <typename Dtype, Precision P>
......@@ -420,6 +429,6 @@ std::vector<typename Executor<Dtype, P>::Ptype> Executor<Dtype, P>::Predict(
template class Executor<CPU, Precision::FP32>;
template class Executor<GPU_MALI, Precision::FP32>;
template class Executor<FPGA, Precision::FP16>;
template class Executor<FPGA, Precision::FP32>;
} // namespace paddle_mobile
......@@ -56,7 +56,8 @@ template <typename Dtype, Precision P>
const framework::Program<Dtype, P> Loader<Dtype, P>::Load(
const std::string &model_path, const std::string &para_path, bool optimize,
bool quantification) {
auto program = this->LoadProgram(model_path, optimize);
auto program = this->LoadProgram(model_path, optimize, quantification);
program.para_path = para_path;
program.combined = true;
program.quantification = quantification;
......
......@@ -61,6 +61,15 @@ JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_load(JNIEnv *env,
optimize);
}
JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_loadQualified(
JNIEnv *env, jclass thiz, jstring modelPath) {
ANDROIDLOGI("loadQualified invoked");
bool optimize = true;
bool qualified = true;
return getPaddleMobileInstance()->Load(jstring2cppstring(env, modelPath),
optimize, qualified);
}
JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_loadCombined(
JNIEnv *env, jclass thiz, jstring modelPath, jstring paramPath) {
ANDROIDLOGI("loadCombined invoked");
......@@ -70,6 +79,16 @@ JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_loadCombined(
optimize);
}
JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_loadCombinedQualified(
JNIEnv *env, jclass thiz, jstring modelPath, jstring paramPath) {
ANDROIDLOGI("loadCombinedQualified invoked");
bool optimize = true;
bool qualified = true;
return getPaddleMobileInstance()->Load(jstring2cppstring(env, modelPath),
jstring2cppstring(env, paramPath),
optimize, qualified);
}
JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictImage(
JNIEnv *env, jclass thiz, jfloatArray buf, jintArray ddims) {
ANDROIDLOGI("predictImage invoked");
......
......@@ -27,12 +27,24 @@ namespace jni {
JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_load(JNIEnv *env,
jclass thiz,
jstring modelPath);
/**
* load separated qualified model for android
*/
JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_loadQualified(
JNIEnv *env, jclass thiz, jstring modelPath);
/**
* load combined model for android
*/
JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_loadCombined(
JNIEnv *env, jclass thiz, jstring modelPath, jstring paramPath);
/**
* load combined qualified model for android
*/
JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_loadCombinedQualified(
JNIEnv *env, jclass thiz, jstring modelPath, jstring paramPath);
/**
* object detection for anroid
*/
......
......@@ -27,17 +27,17 @@ namespace memory {
const int MALLOC_ALIGN = 64;
#ifdef PADDLE_MOBILE_FPGA
namespace api = paddle::mobile::fpga::api;
namespace fpga = paddle_mobile::fpga;
void Copy(void *dst, const void *src, size_t num) {
std::memcpy(dst, src, num);
}
void *Alloc(size_t size) { return api::malloc(size); }
void *Alloc(size_t size) { return fpga::fpga_malloc(size); }
void Free(void *ptr) {
if (ptr) {
api::fpga_free(ptr);
fpga::fpga_free(ptr);
}
}
......
......@@ -29,7 +29,7 @@ class FeedOp : public framework::OperatorBase<DeviceType> {
std::shared_ptr<framework::Scope> scope)
: framework::OperatorBase<DeviceType>(type, inputs, outputs, attrs,
scope),
param_(inputs, outputs, attrs, *scope) {}
param_(inputs, outputs, attrs, scope.get()) {}
void RunImpl() const { param_.Out()->ShareDataWith(*param_.InputX()); }
void Init() {}
......
......@@ -14,8 +14,6 @@ limitations under the License. */
#ifdef DROPOUT_OP
#pragma once
#include "operators/kernel/dropout_kernel.h"
#include <operators/math/transform.h>
......
......@@ -17,7 +17,7 @@ limitations under the License. */
#include "framework/operator.h"
#include "operators/op_param.h"
#pragma once;
#pragma once
namespace paddle_mobile {
namespace operators {
......
......@@ -39,7 +39,7 @@ void ConcatKernel<FPGA, half>::Compute(const ConcatParam &param) const {
for (int i = 0; i < inputs.size(); ++i) {
auto input = inputs[i];
auto channels = input[3];
auto channels = input->dims()[3];
out_offset += channels;
auto src = input->data<half>();
for (int j = 0; j < pixels; ++j) {
......
......@@ -20,13 +20,11 @@ limitations under the License. */
#include "operators/math/vol2col.h"
#include "operators/op_param.h"
#pragma once;
#pragma once
namespace paddle_mobile {
namespace operators {
using namespace framework;
template <typename DeviceType, typename T>
class Im2SequenceKernel
: public framework::OpKernelBase<DeviceType, Im2SequenceParam> {
......
......@@ -14,8 +14,6 @@ limitations under the License. */
#ifdef FUSION_FC_OP
#pragma once
#include "operators/kernel/fusion_fc_kernel.h"
namespace paddle_mobile {
......
......@@ -15,7 +15,7 @@ limitations under the License. */
#include "framework/operator.h"
#include "operators/op_param.h"
#pragma once;
#pragma once
namespace paddle_mobile {
namespace operators {
......
......@@ -15,7 +15,7 @@ limitations under the License. */
#include "framework/operator.h"
#include "operators/op_param.h"
#pragma once;
#pragma once
namespace paddle_mobile {
namespace operators {
......
......@@ -15,7 +15,7 @@ limitations under the License. */
#include "framework/operator.h"
#include "operators/op_param.h"
#pragma once;
#pragma once
namespace paddle_mobile {
namespace operators {
......
......@@ -107,20 +107,22 @@ void PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda,
*buffer++ = *a3++;
}
}
int i = m - m_tail;
a0 = &A(i, 0);
if (m_tail != 0) {
a0 = &A(m - m_tail, 0);
a1 = a0 + lda;
a2 = a0 + 2 * lda;
a3 = a0 + 3 * lda;
if (m_tail != 0) {
if (m_tail <= 3) {
a3 = zero;
}
if (m_tail <= 2) {
a2 = zero;
}
if (m_tail <= 1) {
switch (m_tail) {
case 1:
a1 = zero;
case 2:
a2 = zero;
case 3:
a3 = zero;
break;
default:
break;
}
for (int j = 0; j < k; ++j) {
*buffer++ = *a0++;
......@@ -150,28 +152,89 @@ void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda,
*buffer++ = *a5++;
}
}
int i = m - m_tail;
a0 = &A(i, 0);
if (m_tail != 0) {
a0 = &A(m - m_tail, 0);
a1 = a0 + lda;
a2 = a0 + 2 * lda;
a3 = a0 + 3 * lda;
a4 = a0 + 4 * lda;
a5 = a0 + 5 * lda;
if (m_tail != 0) {
if (m_tail <= 5) {
switch (m_tail) {
case 1:
a1 = zero;
case 2:
a2 = zero;
case 3:
a3 = zero;
case 4:
a4 = zero;
case 5:
a5 = zero;
break;
default:
break;
}
if (m_tail <= 4) {
a4 = zero;
for (int j = 0; j < k; ++j) {
*buffer++ = *a0++;
*buffer++ = *a1++;
*buffer++ = *a2++;
*buffer++ = *a3++;
*buffer++ = *a4++;
*buffer++ = *a5++;
}
if (m_tail <= 3) {
a3 = zero;
}
if (m_tail <= 2) {
a2 = zero;
}
void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda,
float *buffer) {
const float *a0, *a1, *a2, *a3, *a4, *a5, *a6, *a7;
for (int i = 0; i < m - m_tail; i += MR) {
a0 = A + i * lda;
a1 = A + (i + 1) * lda;
a2 = A + (i + 2) * lda;
a3 = A + (i + 3) * lda;
a4 = A + (i + 4) * lda;
a5 = A + (i + 5) * lda;
a6 = A + (i + 6) * lda;
a7 = A + (i + 7) * lda;
for (int j = 0; j < k; ++j) {
*buffer++ = *a0++;
*buffer++ = *a1++;
*buffer++ = *a2++;
*buffer++ = *a3++;
*buffer++ = *a4++;
*buffer++ = *a5++;
*buffer++ = *a6++;
*buffer++ = *a7++;
}
}
if (m_tail <= 1) {
if (m_tail != 0) {
a0 = &A(m - m_tail, 0);
a1 = a0 + lda;
a2 = a0 + 2 * lda;
a3 = a0 + 3 * lda;
a4 = a0 + 4 * lda;
a5 = a0 + 5 * lda;
a6 = a0 + 6 * lda;
a7 = a0 + 7 * lda;
switch (m_tail) {
case 1:
a1 = zero;
case 2:
a2 = zero;
case 3:
a3 = zero;
case 4:
a4 = zero;
case 5:
a5 = zero;
case 6:
a6 = zero;
case 7:
a7 = zero;
break;
default:
break;
}
for (int j = 0; j < k; ++j) {
*buffer++ = *a0++;
......@@ -180,6 +243,8 @@ void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda,
*buffer++ = *a3++;
*buffer++ = *a4++;
*buffer++ = *a5++;
*buffer++ = *a6++;
*buffer++ = *a7++;
}
}
}
......@@ -234,15 +299,78 @@ void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb,
}
}
#if __aarch64__
void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer) {
const float *b0;
for (int j = 0; j < n - n_tail; j += NR) {
for (int i = 0; i < k; ++i) {
b0 = &B(i, j);
asm volatile(
"prfm pldl2keep, [%[b0], #64] \n\t"
"ld1 {v0.4s, v1.4s, v2.4s}, [%[b0]] \n\t"
"st1 {v0.4s, v1.4s, v2.4s}, [%[buffer]], #48 \n\t"
: [buffer] "+r"(buffer)
: [b0] "r"(b0)
: "memory", "v0", "v1", "v2");
}
}
if (n_tail != 0) {
for (int i = 0; i < k; ++i) {
b0 = &B(i, n - n_tail);
for (int j = n - n_tail; j < n; ++j) {
*buffer++ = *b0++;
}
for (int j = n; j < n + (NR - n_tail); ++j) {
*buffer++ = 0;
}
}
}
}
void PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer) {
const float *b0;
for (int j = 0; j < n - n_tail; j += NR) {
for (int i = 0; i < k; ++i) {
b0 = &B(i, j);
asm volatile(
"prfm pldl2keep, [%[b0], #64] \n\t"
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b0]] \n\t"
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[buffer]], #64 \n\t"
: [buffer] "+r"(buffer)
: [b0] "r"(b0)
: "memory", "v0", "v1", "v2", "v3");
}
}
if (n_tail != 0) {
for (int i = 0; i < k; ++i) {
b0 = &B(i, n - n_tail);
for (int j = n - n_tail; j < n; ++j) {
*buffer++ = *b0++;
}
for (int j = n; j < n + (NR - n_tail); ++j) {
*buffer++ = 0;
}
}
}
}
#endif // __aarch64__
// 分块矩阵乘法
void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b,
float beta, float *c, float *C, int ldc, bool relu) {
#pragma omp parallel for
for (int j = 0; j < nc; j += NR) {
for (int i = 0; i < mc; i += MR) {
#if __aarch64__
// AddDot8x12(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
AddDot6x16(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
#else
// AddDot4x4(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
// AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
#endif
}
}
......@@ -271,9 +399,14 @@ void InnerKernelWithBn(int mc, int nc, float alpha, const float *a,
#pragma omp parallel for
for (int j = 0; j < nc; j += NR) {
for (int i = 0; i < mc; i += MR) {
#if __aarch64__
// AddDot8x12(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
AddDot6x16(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
#else
// AddDot4x4(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
// AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
#endif
}
}
......@@ -1956,10 +2089,20 @@ void Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
int mc, nc;
for (int j = 0; j < n; j += NC) {
nc = s_min(n - j, NC);
#if __aarch64__
// PackMatrixB_12c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
PackMatrixB_16c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
#else
PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
#endif
for (int i = 0; i < m; i += MC) {
mc = s_min(m - i, MC);
#if __aarch64__
PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
// PackMatrixA_8r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
#else
PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
#endif
InnerKernel(mc, nc, alpha, packedA, packedB, beta, packedC, &C(i, j), ldc,
relu);
}
......@@ -2009,10 +2152,20 @@ void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda,
int mc, nc;
for (int j = 0; j < n; j += NC) {
nc = s_min(n - j, NC);
#if __aarch64__
// PackMatrixB_12c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
PackMatrixB_16c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
#else
PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
#endif
for (int i = 0; i < m; i += MC) {
mc = s_min(m - i, MC);
#if __aarch64__
PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
// PackMatrixA_8r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
#else
PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
#endif
InnerKernelWithBn(mc, nc, alpha, packedA, packedB, beta, packedC,
&C(i, j), ldc, relu, new_scale + i, new_bias + i);
}
......@@ -2239,6 +2392,192 @@ void AddDot6x8(int k, const float *a, const float *b, float *c, int ldc) {
#endif // __ARM_NEON
}
#if __aarch64__
void AddDot8x12(int k, const float *a, const float *b, float *c, int ldc) {
const float *a_ptr, *b_ptr;
a_ptr = a;
b_ptr = b;
int kc1 = k;
int step = 4 * ldc;
asm volatile(
"dup v5.4s, wzr \n\t"
"dup v6.4s, wzr \n\t"
"dup v7.4s, wzr \n\t"
"dup v8.4s, wzr \n\t"
"dup v9.4s, wzr \n\t"
"dup v10.4s, wzr \n\t"
"dup v11.4s, wzr \n\t"
"dup v12.4s, wzr \n\t"
"dup v13.4s, wzr \n\t"
"dup v14.4s, wzr \n\t"
"dup v15.4s, wzr \n\t"
"dup v16.4s, wzr \n\t"
"dup v17.4s, wzr \n\t"
"dup v18.4s, wzr \n\t"
"dup v19.4s, wzr \n\t"
"dup v20.4s, wzr \n\t"
"dup v21.4s, wzr \n\t"
"dup v22.4s, wzr \n\t"
"dup v23.4s, wzr \n\t"
"dup v24.4s, wzr \n\t"
"dup v25.4s, wzr \n\t"
"dup v26.4s, wzr \n\t"
"dup v27.4s, wzr \n\t"
"dup v28.4s, wzr \n\t"
"subs %[kc1], %[kc1], #1 \n\t"
"blt end_kc1_%= \n\t"
"loop_kc1_%=: \n\t"
"prfm pldl1keep, [%[a_ptr], #32] \n\t"
"prfm pldl1keep, [%[b_ptr], #48] \n\t"
"ld1 {v0.4s, v1.4s}, [%[a_ptr]], #32 \n\t"
"ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], #48 \n\t"
"fmla v5.4s, v2.4s, v0.s[0] \n\t"
"fmla v6.4s, v3.4s, v0.s[0] \n\t"
"fmla v7.4s, v4.4s, v0.s[0] \n\t"
"fmla v8.4s, v2.4s, v0.s[1] \n\t"
"fmla v9.4s, v3.4s, v0.s[1] \n\t"
"fmla v10.4s, v4.4s, v0.s[1] \n\t"
"fmla v11.4s, v2.4s, v0.s[2] \n\t"
"fmla v12.4s, v3.4s, v0.s[2] \n\t"
"fmla v13.4s, v4.4s, v0.s[2] \n\t"
"fmla v14.4s, v2.4s, v0.s[3] \n\t"
"fmla v15.4s, v3.4s, v0.s[3] \n\t"
"fmla v16.4s, v4.4s, v0.s[3] \n\t"
"fmla v17.4s, v2.4s, v1.s[0] \n\t"
"fmla v18.4s, v3.4s, v1.s[0] \n\t"
"fmla v19.4s, v4.4s, v1.s[0] \n\t"
"fmla v20.4s, v2.4s, v1.s[1] \n\t"
"fmla v21.4s, v3.4s, v1.s[1] \n\t"
"fmla v22.4s, v4.4s, v1.s[1] \n\t"
"fmla v23.4s, v2.4s, v1.s[2] \n\t"
"fmla v24.4s, v3.4s, v1.s[2] \n\t"
"fmla v25.4s, v4.4s, v1.s[2] \n\t"
"fmla v26.4s, v2.4s, v1.s[3] \n\t"
"fmla v27.4s, v3.4s, v1.s[3] \n\t"
"fmla v28.4s, v4.4s, v1.s[3] \n\t"
"subs %[kc1], %[kc1], #1 \n\t"
"bge loop_kc1_%= \n\t"
"end_kc1_%=: \n\t"
"st1 {v5.4s, v6.4s, v7.4s}, [%[c]], %[step] \n\t"
"st1 {v8.4s, v9.4s, v10.4s}, [%[c]], %[step] \n\t"
"st1 {v11.4s, v12.4s, v13.4s}, [%[c]], %[step] \n\t"
"st1 {v14.4s, v15.4s, v16.4s}, [%[c]], %[step] \n\t"
"st1 {v17.4s, v18.4s, v19.4s}, [%[c]], %[step] \n\t"
"st1 {v20.4s, v21.4s, v22.4s}, [%[c]], %[step] \n\t"
"st1 {v23.4s, v24.4s, v25.4s}, [%[c]], %[step] \n\t"
"st1 {v26.4s, v27.4s, v28.4s}, [%[c]], %[step] \n\t"
:
: [a_ptr] "r"(a_ptr), [b_ptr] "r"(b_ptr), [c] "r"(c), [kc1] "r"(kc1),
[step] "r"(step)
: "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28");
}
void AddDot6x16(int k, const float *a, const float *b, float *c, int ldc) {
const float *a_ptr, *b_ptr;
a_ptr = a;
b_ptr = b;
int kc1 = k;
int step = 4 * ldc;
int step1 = 4 * 6;
asm volatile(
"dup v6.4s, wzr \n\t"
"dup v7.4s, wzr \n\t"
"dup v8.4s, wzr \n\t"
"dup v9.4s, wzr \n\t"
"dup v10.4s, wzr \n\t"
"dup v11.4s, wzr \n\t"
"dup v12.4s, wzr \n\t"
"dup v13.4s, wzr \n\t"
"dup v14.4s, wzr \n\t"
"dup v15.4s, wzr \n\t"
"dup v16.4s, wzr \n\t"
"dup v17.4s, wzr \n\t"
"dup v18.4s, wzr \n\t"
"dup v19.4s, wzr \n\t"
"dup v20.4s, wzr \n\t"
"dup v21.4s, wzr \n\t"
"dup v22.4s, wzr \n\t"
"dup v23.4s, wzr \n\t"
"dup v24.4s, wzr \n\t"
"dup v25.4s, wzr \n\t"
"dup v26.4s, wzr \n\t"
"dup v27.4s, wzr \n\t"
"dup v28.4s, wzr \n\t"
"dup v29.4s, wzr \n\t"
"subs %[kc1], %[kc1], #1 \n\t"
"blt end_kc1_%= \n\t"
"loop_kc1_%=: \n\t"
"prfm pldl1keep, [%[a_ptr], #24] \n\t"
"prfm pldl1keep, [%[b_ptr], #64] \n\t"
"ld1 {v0.4s, v1.4s}, [%[a_ptr]], %[step1] \n\t"
"ld1 {v2.4s, v3.4s, v4.4s, v5.4s}, [%[b_ptr]], #64 \n\t"
"fmla v6.4s, v2.4s, v0.s[0] \n\t"
"fmla v7.4s, v3.4s, v0.s[0] \n\t"
"fmla v8.4s, v4.4s, v0.s[0] \n\t"
"fmla v9.4s, v5.4s, v0.s[0] \n\t"
"fmla v10.4s, v2.4s, v0.s[1] \n\t"
"fmla v11.4s, v3.4s, v0.s[1] \n\t"
"fmla v12.4s, v4.4s, v0.s[1] \n\t"
"fmla v13.4s, v5.4s, v0.s[1] \n\t"
"fmla v14.4s, v2.4s, v0.s[2] \n\t"
"fmla v15.4s, v3.4s, v0.s[2] \n\t"
"fmla v16.4s, v4.4s, v0.s[2] \n\t"
"fmla v17.4s, v5.4s, v0.s[2] \n\t"
"fmla v18.4s, v2.4s, v0.s[3] \n\t"
"fmla v19.4s, v3.4s, v0.s[3] \n\t"
"fmla v20.4s, v4.4s, v0.s[3] \n\t"
"fmla v21.4s, v5.4s, v0.s[3] \n\t"
"fmla v22.4s, v2.4s, v1.s[0] \n\t"
"fmla v23.4s, v3.4s, v1.s[0] \n\t"
"fmla v24.4s, v4.4s, v1.s[0] \n\t"
"fmla v25.4s, v5.4s, v1.s[0] \n\t"
"fmla v26.4s, v2.4s, v1.s[1] \n\t"
"fmla v27.4s, v3.4s, v1.s[1] \n\t"
"fmla v28.4s, v4.4s, v1.s[1] \n\t"
"fmla v29.4s, v5.4s, v1.s[1] \n\t"
"subs %[kc1], %[kc1], #1 \n\t"
"bge loop_kc1_%= \n\t"
"end_kc1_%=: \n\t"
"st1 {v6.4s, v7.4s, v8.4s, v9.4s}, [%[c]], %[step] \n\t"
"st1 {v10.4s, v11.4s, v12.4s, v13.4s}, [%[c]], %[step] \n\t"
"st1 {v14.4s, v15.4s, v16.4s, v17.4s}, [%[c]], %[step] \n\t"
"st1 {v18.4s, v19.4s, v20.4s, v21.4s}, [%[c]], %[step] \n\t"
"st1 {v22.4s, v23.4s, v24.4s, v25.4s}, [%[c]], %[step] \n\t"
"st1 {v26.4s, v27.4s, v28.4s, v29.4s}, [%[c]], %[step] \n\t"
:
: [a_ptr] "r"(a_ptr), [b_ptr] "r"(b_ptr), [c] "r"(c), [kc1] "r"(kc1),
[step] "r"(step), [step1] "r"(step1)
: "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29");
}
#endif // __aarch64__
} // namespace math
} // namespace operators
} // namespace paddle_mobile
......@@ -19,8 +19,13 @@ limitations under the License. */
#define B(i, j) B[(i)*ldb + (j)]
#define C(i, j) C[(i)*ldc + (j)]
#if __aarch64__
#define MR 6
#define NR 16
#else
#define MR 6
#define NR 8
#endif
#define s_min(i, j) ((i) < (j) ? (i) : (j))
......@@ -43,10 +48,16 @@ void PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
// 将 B 矩阵分块复制到连续内存(RowMajor)
void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
void PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
// 分块矩阵乘法
void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b,
......@@ -70,6 +81,8 @@ void VectorKernelWithBn(int m, int n, int k, float alpha, const float *A,
void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc);
void AddDot4x8(int k, const float *a, const float *b, float *c, int ldc);
void AddDot6x8(int k, const float *a, const float *b, float *c, int ldc);
void AddDot8x12(int k, const float *a, const float *b, float *c, int ldc);
void AddDot6x16(int k, const float *a, const float *b, float *c, int ldc);
// 分块矩阵乘法结果回写
// C = A * B
......@@ -114,10 +127,6 @@ void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *new_scale, float *new_bias);
// 64位 double 矩阵乘法
void dgemm(int m, int n, int k, float alpha, const double *A, int lda,
const double *B, int ldb, float beta, double *C, int ldc);
} // namespace math
} // namespace operators
} // namespace paddle_mobile
......@@ -262,11 +262,11 @@ class ElementwiseAddParam : OpParam {
#ifdef PADDLE_MOBILE_FPGA
private:
fpga::FpgaEWAddArgs fpga_EW_add_args;
fpga::EWAddArgs fpga_EW_add_args;
public:
const fpga::FpgaEWAddArgs &FpgaArgs() const { return fpga_EW_add_args; }
void SetFpgaArgs(const fpga::FpgaEWAddArgs &args) { fpga_EW_add_args = args; }
const fpga::EWAddArgs &FpgaArgs() const { return fpga_EW_add_args; }
void SetFpgaArgs(const fpga::EWAddArgs &args) { fpga_EW_add_args = args; }
#endif
};
......@@ -465,11 +465,11 @@ class PoolParam : public OpParam {
#ifdef PADDLE_MOBILE_FPGA
private:
fpga::FpgaPoolArgs fpga_pool_args;
fpga::PoolingArgs fpga_pool_args;
public:
const fpga::FpgaPoolArgs &FpgaArgs() const { return fpga_pool_args; }
void SetFpgaArgs(const fpga::FpgaPoolArgs &args) { fpga_pool_args = args; }
const fpga::PoolingArgs &FpgaArgs() const { return fpga_pool_args; }
void SetFpgaArgs(const fpga::PoolingArgs &args) { fpga_pool_args = args; }
#endif
};
#endif
......@@ -651,10 +651,10 @@ class MultiClassNMSParam : public OpParam {
class FeedParam : public OpParam {
public:
FeedParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, Scope const &scope) {
input_x_ = InputXFrom<LoDTensor>(inputs, scope);
out_ = OutFrom<LoDTensor>(outputs, scope);
auto var = scope.Var("batch_size");
const AttributeMap &attrs, Scope *scope) {
input_x_ = InputXFrom<LoDTensor>(inputs, *scope);
out_ = OutFrom<LoDTensor>(outputs, *scope);
auto var = scope->Var("batch_size");
batch_size = var->GetValue<int>();
}
const Tensor *InputX() const { return input_x_; }
......@@ -933,11 +933,11 @@ class FusionFcParam : public OpParam {
#ifdef PADDLE_MOBILE_FPGA
private:
fpga::FpgaConvArgs fpga_conv_args;
fpga::ConvArgs fpga_conv_args;
public:
const fpga::FpgaConvArgs &FpgaArgs() const { return fpga_conv_args; }
void SetFpgaArgs(const fpga::FpgaConvArgs &args) { fpga_conv_args = args; }
const fpga::ConvArgs &FpgaArgs() const { return fpga_conv_args; }
void SetFpgaArgs(const fpga::ConvArgs &args) { fpga_conv_args = args; }
#endif
};
......@@ -991,11 +991,11 @@ class FusionConvAddParam : public OpParam {
#ifdef PADDLE_MOBILE_FPGA
private:
fpga::FpgaConvArgs fpga_conv_args;
fpga::ConvArgs fpga_conv_args;
public:
const fpga::FpgaConvArgs &FpgaArgs() const { return fpga_conv_args; }
void SetFpgaArgs(const fpga::FpgaConvArgs &args) { fpga_conv_args = args; }
const fpga::ConvArgs &FpgaArgs() const { return fpga_conv_args; }
void SetFpgaArgs(const fpga::ConvArgs &args) { fpga_conv_args = args; }
#endif
};
......@@ -1096,11 +1096,11 @@ class FusionConvAddBNReluParam : public OpParam {
#ifdef PADDLE_MOBILE_FPGA
private:
fpga::FpgaConvArgs fpga_conv_args;
fpga::ConvArgs fpga_conv_args;
public:
const fpga::FpgaConvArgs &FpgaArgs() const { return fpga_conv_args; }
void SetFpgaArgs(const fpga::FpgaConvArgs &args) { fpga_conv_args = args; }
const fpga::ConvArgs &FpgaArgs() const { return fpga_conv_args; }
void SetFpgaArgs(const fpga::ConvArgs &args) { fpga_conv_args = args; }
#endif
};
#endif
......@@ -1190,11 +1190,11 @@ class FusionConvAddBNParam : public OpParam {
#ifdef PADDLE_MOBILE_FPGA
private:
fpga::FpgaConvArgs fpga_conv_args;
fpga::ConvArgs fpga_conv_args;
public:
const fpga::FpgaConvArgs &FpgaArgs() const { return fpga_conv_args; }
void SetFpgaArgs(const fpga::FpgaConvArgs &args) { fpga_conv_args = args; }
const fpga::ConvArgs &FpgaArgs() const { return fpga_conv_args; }
void SetFpgaArgs(const fpga::ConvArgs &args) { fpga_conv_args = args; }
#endif
};
#endif
......
......@@ -114,8 +114,12 @@ else ()
target_link_libraries(test-softmax paddle-mobile)
# gen test
ADD_EXECUTABLE(test-gemm common/test_gemm.cpp)
target_link_libraries(test-gemm paddle-mobile)
ADD_EXECUTABLE(test-gemm-accuracy common/test_gemm_accuracy.cpp)
target_link_libraries(test-gemm-accuracy paddle-mobile)
# gen test
ADD_EXECUTABLE(test-gemm-perf common/test_gemm_perf.cpp)
target_link_libraries(test-gemm-perf paddle-mobile)
# gen test
ADD_EXECUTABLE(test-enforce common/test_enforce.cpp)
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include "../test_helper.h"
#include "operators/math/gemm.h"
#include "operators/math/math_function.h"
#define a(i, j) a[(i)*lda + (j)]
#define b(i, j) b[(i)*ldb + (j)]
#define c1(i, j) c1[(i)*ldc + (j)]
#define m 1024
#define n 1024
#define k 1024
int main() {
Tensor aa, bb, cc, scale, bias;
auto aaptr = aa.mutable_data<float>({m, k});
auto bbptr = bb.mutable_data<float>({k, n});
auto ccptr = cc.mutable_data<float>({m, n});
auto scaleptr = scale.mutable_data<float>({m});
auto biasptr = bias.mutable_data<float>({m});
for (int i = 0; i < m * k; ++i) {
aaptr[i] = 2;
}
for (int i = 0; i < k * n; ++i) {
bbptr[i] = 2;
}
for (int i = 0; i < m * n; ++i) {
ccptr[i] = 2;
}
for (int i = 0; i < m; ++i) {
scaleptr[i] = 1;
biasptr[i] = 0;
}
auto time1 = time();
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul<float>(aa, false, bb, false,
static_cast<float>(1), &cc,
static_cast<float>(0), false);
// paddle_mobile::operators::math::matmulWithBn<float>(
// aa, false, bb, false, static_cast<float>(1), &cc,
// static_cast<float>(0), true, &scale, &bias, 0);
}
auto time2 = time();
std::cout << "gemm cost :" << time_diff(time1, time2) / 10 << "ms\n";
return 0;
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <fstream>
#include "../test_helper.h"
#include "../test_include.h"
int main() {
paddle_mobile::PaddleMobile<paddle_mobile::CPU> paddle_mobile;
bool optimize = false;
if (paddle_mobile.Load(g_googlenet, optimize)) {
auto time2 = time();
DLOG << "load cost: " << time_diff(time1, time1) << "ms";
std::vector<float> input;
std::vector<int64_t> dims{1, 3, 224, 224};
GetInput<float>(g_test_image_1x3x224x224, &input, dims);
auto time3 = time();
auto vec_result = paddle_mobile.Predict(input, dims);
auto time4 = time();
DLOG << "predict cost :" << time_diff(time3, time4) << "ms";
}
return 0;
}
......@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include "../test_helper.h"
#include "io/loader.h"
......@@ -20,12 +22,10 @@ int main() {
// ../../../test/models/googlenet
// ../../../test/models/mobilenet
// auto program = loader.Load(g_googlenet, true);
// auto program = loader.Load(g_mobilenet_ssd, true);
auto program = loader.Load(g_mobilenet_ssd, true);
// auto program = loader.Load(g_googlenet_combine + "/model",
// g_googlenet_combine +
// "/params", true);
auto program = loader.Load(std::string(g_ocr) + "/model",
std::string(g_ocr) + "/params", false);
// program.originProgram->Description("program desc: ");
return 0;
}
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <fstream>
#include <iostream>
#include "../test_helper.h"
#include "../test_include.h"
......@@ -23,15 +23,20 @@ int main() {
auto time1 = time();
if (paddle_mobile.Load(g_googlenet, optimize)) {
auto time2 = time();
DLOG << "load cost: " << time_diff(time1, time1) << "ms";
std::cout << "load cost :" << time_diff(time1, time2) << "ms" << std::endl;
std::vector<float> input;
std::vector<int64_t> dims{1, 3, 224, 224};
GetInput<float>(g_test_image_1x3x224x224, &input, dims);
// 预热一次
auto vec_result = paddle_mobile.Predict(input, dims);
auto time3 = time();
for (int i = 0; i < 10; ++i) {
auto vec_result = paddle_mobile.Predict(input, dims);
}
auto time4 = time();
DLOG << "predict cost :" << time_diff(time3, time4) << "ms";
std::cout << "predict cost :" << time_diff(time3, time4) / 10 << "ms"
<< std::endl;
}
return 0;
}
......@@ -32,10 +32,14 @@ int main() {
std::vector<int64_t> dims{1, 3, 300, 300};
GetInput<float>(g_hand, &input, dims);
// 预热一次
auto output = paddle_mobile.Predict(input, dims);
auto time3 = time();
for (int i = 0; i < 10; ++i) {
auto output = paddle_mobile.Predict(input, dims);
}
auto time4 = time();
std::cout << "predict cost :" << time_diff(time3, time4) << "ms"
std::cout << "predict cost :" << time_diff(time3, time4) / 10 << "ms"
<< std::endl;
}
return 0;
......
......@@ -26,19 +26,22 @@ int main() {
std::vector<float> input;
std::vector<int64_t> dims{1, 3, 224, 224};
GetInput<float>(g_test_image_1x3x224x224, &input, dims);
GetInput<float>(g_test_image_1x3x224x224_banana, &input, dims);
for (int i = 0; i < 10; ++i) {
auto time3 = time();
// 预热一次
auto vec_result = paddle_mobile.Predict(input, dims);
auto time4 = time();
std::vector<float>::iterator biggest =
std::max_element(std::begin(vec_result), std::end(vec_result));
std::cout << " Max element is " << *biggest << " at position "
<< std::distance(std::begin(vec_result), biggest) << std::endl;
std::cout << "predict cost :" << time_diff(time3, time4) << "ms"
<< std::endl;
auto time3 = time();
for (int i = 0; i < 10; ++i) {
auto vec_result = paddle_mobile.Predict(input, dims);
}
auto time4 = time();
std::cout << "predict cost :" << time_diff(time3, time4) / 10 << "ms"
<< std::endl;
}
return 0;
......
......@@ -24,6 +24,7 @@ limitations under the License. */
#include "framework/ddim.h"
#include "framework/tensor.h"
static const char *g_ocr = "../models/ocr";
static const char *g_mobilenet_ssd = "../models/mobilenet+ssd";
static const char *g_mobilenet_ssd_gesture = "../models/mobilenet+ssd_gesture";
static const char *g_squeezenet = "../models/squeezenet";
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include <cstdlib>
#include <ctime>
#include "../test_helper.h"
#include "common/log.h"
#include "memory/t_malloc.h"
#include "operators/math/gemm.h"
#define a(i, j) a[(i)*lda + (j)]
#define b(i, j) b[(i)*ldb + (j)]
#define c(i, j) c[(i)*ldc + (j)]
#define c1(i, j) c1[(i)*ldc + (j)]
void print_matirx(int m, int n, int ldc, float *c) {
for (int i = 0; i < m; ++i) {
std::cout << c(i, 0);
for (int j = 1; j < n; ++j) {
std::cout << " | " << c(i, j);
}
std::cout << std::endl;
}
std::cout << std::endl;
}
int do_sgemm(int m, int n, int k, bool relu, int t1, int t2, int pr) {
int lda = k;
int ldb = n;
int ldc = n;
float *a = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * m * k));
float *b = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * k * n));
float *c = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * m * n));
float *c1 = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * m * n));
float* scale = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * m));
float* bias = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * m));
srand(unsigned(time(0)));
for (int i = 0; i < m * k; ++i) {
a[i] = t1 + rand() % t2;
}
for (int i = 0; i < k * n; ++i) {
b[i] = t1 + rand() % t2;
}
for (int i = 0; i < m; ++i) {
scale[i] = t1 + rand() % t2;
}
for (int i = 0; i < m; ++i) {
bias[i] = t1 + rand() % t2;
}
for (int i = 0; i < m; ++i) {
for (int j = 0; j < n; ++j) {
float r = 0;
for (int p = 0; p < k; p++) {
r += a(i, p) * b(p, j);
}
r *= scale[i];
r += bias[i];
if (relu && (r < 0)) {
r = 0;
}
c1(i, j) = r;
}
}
paddle_mobile::operators::math::SgemmWithBn(m, n, k, 0.9, a, lda,
b, ldb, 0.3, c, ldc, relu, scale, bias);
int eq = 0;
int neq = 0;
for (int i = 0; i < m * n; ++i) {
if (static_cast<int>(c[i]) == static_cast<int>(c1[i])) {
++eq;
} else {
++neq;
}
}
if (pr > 0) {
std::cout << "A:" << std::endl;
print_matirx(m, k, lda, a);
std::cout << "B:" << std::endl;
print_matirx(k, n, ldb, b);
std::cout << "C:" << std::endl;
print_matirx(m, n, ldc, c);
std::cout << "C1:" << std::endl;
print_matirx(m, n, ldc, c1);
}
std::cout << "mnk=" << m << " " << n << " " << k <<
" relu=" << relu <<
" eq=" << eq << " neq=" << neq << std::endl;
paddle_mobile::memory::Free(a);
paddle_mobile::memory::Free(b);
paddle_mobile::memory::Free(c);
paddle_mobile::memory::Free(c1);
paddle_mobile::memory::Free(scale);
paddle_mobile::memory::Free(bias);
return 0;
}
int main() {
do_sgemm(9, 9, 9, true, 10, 10, 10);
do_sgemm(10, 6, 12, false, 10, 10, 0);
do_sgemm(512, 256, 384, false, 10, 10, 0);
do_sgemm(1366, 768, 256, false, 10, 10, 0);
do_sgemm(1255, 755, 333, false, 10, 10, 0);
do_sgemm(555, 777, 999, false, 10, 10, 0);
do_sgemm(10, 6, 12, true, -4, 10, 0);
do_sgemm(512, 256, 384, true, -4, 10, 0);
do_sgemm(1366, 768, 256, true, -4, 10, 0);
do_sgemm(1255, 755, 333, true, -4, 10, 0);
do_sgemm(555, 777, 999, true, -4, 10, 0);
return 0;
}
......@@ -40,8 +40,8 @@ build_for_android() {
fi
if [ -z "$PLATFORM" ]; then
# PLATFORM="arm-v7a" # Users could choose "arm-v8a" platform.
PLATFORM="arm-v8a"
PLATFORM="arm-v7a" # Users could choose "arm-v8a" platform.
# PLATFORM="arm-v8a"
fi
if [ "${PLATFORM}" = "arm-v7a" ]; then
......
......@@ -3,8 +3,8 @@
#include "src/enforce.h"
#include "src/var_desc.h"
#include "src/program_desc.h"
#include <cstring>
#include <cstdlib>
#include <string>
#include <cmath>
#include <iostream>
#include <utility>
......@@ -13,7 +13,7 @@
#include "src/protobuf-c.h"
#include <fstream>
#include <iostream>
#include <limits>
const size_t kSize64 = sizeof(uint64_t);
const size_t kSize32 = sizeof(uint32_t);
......@@ -68,60 +68,60 @@ std::shared_ptr<ProgramDesc> loadParams(const std::string &model_path) {
}
void LoadWithDump(const paddle_mobile::framework::VarDesc &var_desc, char *dataP, FILE *out_file) {
void LoadWithDump(const paddle_mobile::framework::VarDesc &var_desc, char **dataP, FILE *out_file) {
// 1. version
uint32_t version = *reinterpret_cast<uint32_t *>(dataP);
uint32_t version = *reinterpret_cast<uint32_t *>(*dataP);
// write version
fwrite(&version, kSize32, 1, out_file);
dataP += kSize32;
*dataP += kSize32;
// 2 Lod information
auto *lod_level_ptr = new uint64_t();
memcpy(lod_level_ptr, dataP, kSize64);
memcpy(lod_level_ptr, *dataP, kSize64);
uint64_t lod_level = 0;
// write lod Information
fwrite(&lod_level, kSize64, 1, out_file);
delete lod_level_ptr;
dataP += kSize64;
*dataP += kSize64;
for (uint64_t i = 0; i < lod_level; ++i) {
uint64_t size = *reinterpret_cast<uint64_t *>(dataP);
uint64_t size = *reinterpret_cast<uint64_t *>(*dataP);
// write lod size
fwrite(&size, kSize64, 1, out_file);
(dataP) += kSize64;
(*dataP) += kSize64;
std::vector<size_t> tmp(size / sizeof(size_t));
for (unsigned long &k : tmp) {
k = *reinterpret_cast<size_t *>(dataP);
(dataP) += sizeof(size_t);
k = *reinterpret_cast<size_t *>(*dataP);
(*dataP) += sizeof(size_t);
}
// write lod size vector
fwrite(&tmp, sizeof(size_t), tmp.size(), out_file);
}
// 3. tensor version
uint32_t tensor_version = *reinterpret_cast<uint32_t *>(dataP);
uint32_t tensor_version = *reinterpret_cast<uint32_t *>(*dataP);
// write tensor version
fwrite(&tensor_version, kSize32, 1, out_file);
(dataP) += kSize32;
(*dataP) += kSize32;
// 4. tensor desc
int32_t size = *reinterpret_cast<int32_t *>(dataP);
int32_t size = *reinterpret_cast<int32_t *>(*dataP);
// write tensor desc
fwrite(&size, sizeof(int32_t), 1, out_file);
(dataP) += sizeof(int32_t);
(*dataP) += sizeof(int32_t);
std::unique_ptr<char[]> buf(new char[size]);
for (int m = 0; m < size; ++m) {
buf.get()[m] = (dataP)[m];
buf.get()[m] = (*dataP)[m];
}
fwrite(buf.get(), sizeof(char), static_cast<size_t>(size), out_file);
(dataP) += (sizeof(char) * size);
(*dataP) += (sizeof(char) * size);
const paddle_mobile::framework::TensorDesc &desc = var_desc.Tensor_desc();
int memory_size = 1;
......@@ -158,9 +158,9 @@ void LoadWithDump(const paddle_mobile::framework::VarDesc &var_desc, char *dataP
memory = new char[tensorSize];
for (int n = 0; n < tensorSize; ++n) {
static_cast<char *>(memory)[n] = (dataP)[n];
static_cast<char *>(memory)[n] = (*dataP)[n];
}
dataP += tensorSize;
*dataP += tensorSize;
// for float 32
float min_value = std::numeric_limits<float>::max();
......@@ -194,7 +194,7 @@ quantificate_combined(const std::string &model_path, const std::string &param_pa
if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") {
continue;
}
LoadWithDump(*var_desc, data, out_file);
LoadWithDump(*var_desc, &data, out_file);
}
}
}
......@@ -220,7 +220,7 @@ void quantificate_seperated(const std::string model_dir, const std::string param
FILE *out_file = fopen(file_name.c_str(), "wb");
char *origin_data = Get_binary_data(model_dir + "/" + var_desc->Name());
char *data = origin_data;
LoadWithDump(*var_desc, data, out_file);
LoadWithDump(*var_desc, &data, out_file);
delete origin_data;
fclose(out_file);
}
......
......@@ -19,6 +19,7 @@ limitations under the License. */
#ifndef TOOLS_QUANTIFICATION_SRC_BLOCK_DESC_LOCAL_H_
#define TOOLS_QUANTIFICATION_SRC_BLOCK_DESC_LOCAL_H_
#include <memory>
#include <vector>
#include "src/var_desc.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册