提交 7c0bbbc5 编写于 作者: xiebaiyuan's avatar xiebaiyuan

Merge remote-tracking branch 'upstream/develop' into develop

...@@ -26,22 +26,6 @@ Paddle-Moible是PaddlePaddle组织下的项目,是一个致力于嵌入式平 ...@@ -26,22 +26,6 @@ Paddle-Moible是PaddlePaddle组织下的项目,是一个致力于嵌入式平
- **ARM CPU** - **ARM CPU**
|mobilenet arm v7|1线程|2线程|4线程|
|------------|----|-----|-----|
|麒麟960(ms)|110.586|63.285|38.215|
|||||
|mobilenetssd arm v7|1线程|2线程|4线程|
|麒麟960(ms)|220.248|128.473|79.334|
|||||
|googlenet(v1) arm v7|1线程|2线程|4线程|
|麒麟960(ms)|341.965|228.724|161.531|
|||||
|squeezenet arm v7|1线程|2线程|4线程|
|麒麟960(ms)|84.080|55.641|37.182|
|||||
|yolo arm v7|1线程|2线程|4线程|
|麒麟960(ms)|129.445|80.627|50.936|
arm cpu是paddle-mobile的主要支持方向,cpu的通用性一直是其优势。嵌入式深度学习,需要大量的cpu汇编实现。我们正在紧锣密鼓的编码,为的是能充分硬件的每一点加速能力。 arm cpu是paddle-mobile的主要支持方向,cpu的通用性一直是其优势。嵌入式深度学习,需要大量的cpu汇编实现。我们正在紧锣密鼓的编码,为的是能充分硬件的每一点加速能力。
arm cpu的优化工作还在进行中,现在使用了常规的cpu优化。在arm a73上paddle-mobile arm-v7现在单核运行一次mobilenet1.0是110+ms,显然这不是我们的最终目标,我们正在用大量的汇编改写,后续性能仍会有巨大提升空间, 目前只支持armv7, 未来我们也会支持armv8。 arm cpu的优化工作还在进行中,现在使用了常规的cpu优化。在arm a73上paddle-mobile arm-v7现在单核运行一次mobilenet1.0是110+ms,显然这不是我们的最终目标,我们正在用大量的汇编改写,后续性能仍会有巨大提升空间, 目前只支持armv7, 未来我们也会支持armv8。
......
...@@ -27,7 +27,12 @@ limitations under the License. */ ...@@ -27,7 +27,12 @@ limitations under the License. */
#include <cstdio> #include <cstdio>
#include <cstring> #include <cstring>
#include "fpga/api/fpga_api.h" #include "api.h"
#define FPGA_TEST_MODE
#ifdef FPGA_TEST_MODE
#include "common/log.h"
#endif
namespace paddle_mobile { namespace paddle_mobile {
namespace fpga { namespace fpga {
...@@ -36,7 +41,11 @@ static int fd = -1; ...@@ -36,7 +41,11 @@ static int fd = -1;
static const char *device_path = "/dev/fpgadrv0"; static const char *device_path = "/dev/fpgadrv0";
static inline int do_ioctl(int req, const void *arg) { static inline int do_ioctl(int req, const void *arg) {
#ifdef PADDLE_MOBILE_OS_LINUX
return ioctl(req, (unsigned int64_t)arg); return ioctl(req, (unsigned int64_t)arg);
#else
return -1;
#endif
} }
int open_device() { int open_device() {
...@@ -48,26 +57,110 @@ int open_device() { ...@@ -48,26 +57,110 @@ int open_device() {
// memory management; // memory management;
void *fpga_malloc(size_t size) { void *fpga_malloc(size_t size) {
#ifdef PADDLE_MOBILE_OS_LINUX
return reinterpret_cast<void *>( return reinterpret_cast<void *>(
mmap64(NULL, size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0)); mmap64(NULL, size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0));
#else
return malloc(size);
#endif
} }
void fpga_free(void *ptr) { munmap(ptr, 0); } void fpga_free(void *ptr) {
#ifdef PADDLE_MOBILE_OS_LINUX
munmap(ptr, 0);
#else
free(ptr);
#endif
}
void fpga_copy(void *dest, const void *src, size_t num) { void fpga_copy(void *dest, const void *src, size_t num) {
memcpy(dest, src, num); memcpy(dest, src, num);
} }
int ComputeFpgaConv(const struct ConvArgs &args) { int ComputeFpgaConv(const struct ConvArgs &args) {
#ifdef FPGA_TEST_MODE
DLOG << " relu_enabled:" << args.relu_enabled
<< " sb_address:" << args.sb_address
<< " filter_address:" << args.filter_address
<< " filter_num:" << args.filter_num
<< " group_num:" << args.group_num;
DLOG << " image_address:" << args.image.address
<< " image_scale_address:" << args.image.scale_address
<< " image_channels:" << args.image.channels
<< " image_height:" << args.image.height
<< " image_width:" << args.image.width
<< " pad_height:" << args.image.pad_height
<< " pad_width:" << args.image.pad_width;
DLOG << " kernel_height:" << args.kernel.height
<< " kernel_width:" << args.kernel.width
<< " stride_h:" << args.kernel.stride_h
<< " stride_w:" << args.kernel.stride_w;
DLOG << " out_address:" << args.output.address
<< " out_scale_address:" << args.output.scale_address;
#endif
return do_ioctl(IOCTL_CONFIG_CONV, &args); return do_ioctl(IOCTL_CONFIG_CONV, &args);
} }
int ComputeFpgaPool(const struct PoolingArgs &args) { int ComputeFpgaPool(const struct PoolingArgs &args) {
#ifdef FPGA_TEST_MODE
DLOG << " image_address:" << args.image.address
<< " image_scale_address:" << args.image.scale_address
<< " image_channels:" << args.image.channels
<< " image_height:" << args.image.height
<< " image_width:" << args.image.width
<< " pad_height:" << args.image.pad_height
<< " pad_width:" << args.image.pad_width;
DLOG << " kernel_height:" << args.kernel.height
<< " kernel_width:" << args.kernel.width
<< " stride_h:" << args.kernel.stride_h
<< " stride_w:" << args.kernel.stride_w;
DLOG << " out_address:" << args.output.address
<< " out_scale_address:" << args.output.scale_address;
#endif
return do_ioctl(IOCTL_CONFIG_POOLING, &args); return do_ioctl(IOCTL_CONFIG_POOLING, &args);
} }
int ComputeFpgaEWAdd(const struct EWAddArgs &args) { int ComputeFpgaEWAdd(const struct EWAddArgs &args) {
#ifdef FPGA_TEST_MODE
DLOG << " relu_enabled:" << args.relu_enabled << " const0:" << args.const0
<< " const1:" << args.const1;
DLOG << " image0_address:" << args.image0.address
<< " image0_scale_address:" << args.image0.scale_address
<< " image0_channels:" << args.image0.channels
<< " image0_height:" << args.image0.height
<< " image0_width:" << args.image0.width
<< " pad0_height:" << args.image0.pad_height
<< " pad0_width:" << args.image0.pad_width;
DLOG << " image1_address:" << args.image1.address
<< " image1_scale_address:" << args.image1.scale_address
<< " image1_channels:" << args.image1.channels
<< " image1_height:" << args.image1.height
<< " image1_width:" << args.image1.width
<< " pad1_height:" << args.image1.pad_height
<< " pad_width:" << args.image1.pad_width;
DLOG << " out_address:" << args.output.address
<< " out_scale_address:" << args.output.scale_address;
#endif
return do_ioctl(IOCTL_CONFIG_EW, &args); return do_ioctl(IOCTL_CONFIG_EW, &args);
} }
int PerformBypass(const struct BypassArgs &args) { int PerformBypass(const struct BypassArgs &args) {
#ifdef FPGA_TEST_MODE
DLOG << " layout_type:" << args.layout_type
<< " convert_type:" << args.convert_type;
DLOG << " image_address:" << args.image.address
<< " image_scale_address:" << args.image.scale_address
<< " image_channels:" << args.image.channels
<< " image_height:" << args.image.height
<< " image_width:" << args.image.width
<< " pad_height:" << args.image.pad_height
<< " pad_width:" << args.image.pad_width;
DLOG << " out_address:" << args.output.address
<< " out_scale_address:" << args.output.scale_address;
#endif
return do_ioctl(IOCTL_CONFIG_BYPASS, &args); return do_ioctl(IOCTL_CONFIG_BYPASS, &args);
} }
......
...@@ -12,22 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,22 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "fpga/fpga_quantilization.h" #include "fpga/quantization.h"
#include <algorithm> #include <algorithm>
namespace paddle_mobile { namespace paddle_mobile {
namespace fpga { namespace fpga {
template <typename Dtype> template <typename Dtype>
static void chw_to_hwc(Dtype* data_in, Dtype* data_out, int num, int channel, static void chw_to_hwc(Dtype* data_in, Dtype* data_out, int64_t num,
int height, int width) { int64_t channel, int64_t height, int64_t width) {
int offset_height = 0;
for (int n = 0; n < num; n++) { for (int n = 0; n < num; n++) {
int amount_per_row = width * channel; int64_t amount_per_row = width * channel;
for (int c = 0; c < channel; c++) { for (int c = 0; c < channel; c++) {
for (int h = 0; h < height; h++) { for (int h = 0; h < height; h++) {
int offset_height = h * amount_per_row; int64_t offset_height = h * amount_per_row;
for (int w = 0; w < width; w++) { for (int w = 0; w < width; w++) {
*(data_out + offset_height + w * channel + c) = *(data_in++); *(data_out + offset_height + w * channel + c) = *(data_in++);
} }
...@@ -38,53 +36,56 @@ static void chw_to_hwc(Dtype* data_in, Dtype* data_out, int num, int channel, ...@@ -38,53 +36,56 @@ static void chw_to_hwc(Dtype* data_in, Dtype* data_out, int num, int channel,
} }
template <typename Dtype> template <typename Dtype>
static Dtype find_max(Dtype* data, int num) { static Dtype find_max(Dtype* data, int64_t num) {
Dtype max = 0; Dtype max = 0;
for (int i = 0; i < num; ++i) { for (int i = 0; i < num; ++i) {
max = std::max(max, data[i]); Dtype value = data[i];
Dtype abs = value > 0 ? value : -value;
max = std::max(max, abs);
} }
return max; return max;
} }
// template <typename Dtype> // template <typename Dtype>
void quantify_filter(framework::Tensor* filter) { void quantize_filter(framework::Tensor* filter) {
DLOG << "quantilize_filter........"; DLOG << "quantilize_filter........" << filter->dims();
float scale = 0; float scale = 0;
float fix_range = static_cast<float>((1 << (8 - 1)) - 1); auto fix_range = static_cast<float>(std::pow(2, 8 - 1) - 1);
const int batch_size = filter->dims()[0];
const int channel = filter->dims()[1];
const int height = filter->dims()[2];
const int width = filter->dims()[3];
int8_t* int_data = nullptr; auto* tmp_data = new int8_t[filter->numel()];
int8_t* tmp_data = new int8_t[filter->numel()];
// 32bit filter -> 8bit filter; // 32bit filter -> 8bit filter;
if (filter->type() == typeid(float)) { if (filter->type() == typeid(float)) {
float* float_data = filter->data<float>(); auto* float_data = filter->data<float>();
float max = find_max<float>(float_data, filter->numel()); auto max = find_max<float>(float_data, filter->numel());
scale = (max / fix_range); scale = (fix_range / max);
DLOG << "scale:" << scale;
for (int i = 0; i < filter->numel(); ++i) { for (int i = 0; i < filter->numel(); ++i) {
tmp_data[i] = (int8_t)float_data[i] * scale; tmp_data[i] = (int8_t)(float_data[i] * scale);
} }
int_data = filter->mutable_data<int8_t>();
} else { } else {
int8_t max = find_max<int8_t>(filter->data<int8_t>(), filter->numel()); auto max = find_max<int8_t>(filter->data<int8_t>(), filter->numel());
scale = (max / fix_range); scale = (fix_range / max);
std::memcpy(tmp_data, filter->data<int8_t>(), (size_t)filter->numel());
for (int i = 0; i < filter->numel(); ++i) {
tmp_data[i] = filter->data<int8_t>()[i];
} }
int_data = filter->mutable_data<int8_t>();
if (filter->dims().size() == 4) {
const auto batch_size = filter->dims()[0];
const auto channel = filter->dims()[1];
const auto height = filter->dims()[2];
const auto width = filter->dims()[3];
chw_to_hwc<int8_t>(tmp_data, filter->mutable_data<int8_t>(), batch_size,
channel, height, width);
} else if (filter->dims().size() == 2) {
std::memcpy(filter->mutable_data<int8_t>(), tmp_data,
(size_t)filter->numel());
} }
// NCHW -> NHWC;
chw_to_hwc<int8_t>(tmp_data, int_data, batch_size, channel, height, width);
delete tmp_data; delete tmp_data;
*(filter->fpga_args().scale_pointer()) = scale; filter->SetFpgaScale(scale);
} }
} // namespace fpga } // namespace fpga
......
...@@ -21,11 +21,10 @@ namespace paddle_mobile { ...@@ -21,11 +21,10 @@ namespace paddle_mobile {
namespace fpga { namespace fpga {
template <typename Dtype> template <typename Dtype>
static void chw_to_hwc(Dtype* data_in, Dtype* data_out, int num, int channel, static void chw_to_hwc(Dtype* data_in, Dtype* data_out, int64_t num,
int height, int width); int64_t channel, int64_t height, int64_t width);
// template <typename Dtype> void quantize_filter(framework::Tensor* filter);
void quantify_filter(framework::Tensor* filter);
} // namespace fpga } // namespace fpga
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -64,7 +64,8 @@ struct SizeOfTypeFunctor<HEAD, TAIL...> { ...@@ -64,7 +64,8 @@ struct SizeOfTypeFunctor<HEAD, TAIL...> {
}; };
static inline size_t SizeOfType(std::type_index type) { static inline size_t SizeOfType(std::type_index type) {
SizeOfTypeFunctor<int, half, float, double, int16_t, int64_t, bool, size_t> SizeOfTypeFunctor<int8_t, int, half, float, double, int16_t, int64_t, bool,
size_t>
functor; functor;
size_t size = functor(type); size_t size = functor(type);
...@@ -115,8 +116,8 @@ class Tensor { ...@@ -115,8 +116,8 @@ class Tensor {
PADDLE_MOBILE_ENFORCE( PADDLE_MOBILE_ENFORCE(
(std::is_same<T, void>::value || (std::is_same<T, void>::value ||
holder_->type().hash_code() == typeid(T).hash_code()), holder_->type().hash_code() == typeid(T).hash_code()),
"Tensor holds the wrong type, it holds %s", "Tensor holds the wrong type, it holds %s ,requested:%s",
this->holder_->type().name()); this->holder_->type().name(), typeid(T).name());
return reinterpret_cast<const T *>( return reinterpret_cast<const T *>(
reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_); reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_);
...@@ -255,14 +256,26 @@ class Tensor { ...@@ -255,14 +256,26 @@ class Tensor {
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
struct FPGAArgs { struct FPGAArgs {
float scale; friend class Tensor;
inline float *scale_pointer() { return &scale; } inline float *scale_pointer() { return scale_; }
inline float scale() { return *scale_; }
private:
float *scale_;
}; };
struct FPGAArgs fpga_args() const { struct FPGAArgs fpga_args() const {
return fpgaArgs_; FPGAArgs args;
args.scale_ = scale.get();
return args;
} }
void SetFpgaScale(float s) { *(scale.get()) = s; }
private:
std::shared_ptr<float> scale = std::make_shared<float>(0);
#endif #endif
private: private:
...@@ -331,10 +344,6 @@ class Tensor { ...@@ -331,10 +344,6 @@ class Tensor {
* begins. * begins.
*/ */
size_t offset_; size_t offset_;
#ifdef PADDLE_MOBILE_FPGA
FPGAArgs fpgaArgs_;
#endif
}; };
#ifdef PADDLE_MOBILE_DEBUG #ifdef PADDLE_MOBILE_DEBUG
...@@ -342,9 +351,12 @@ inline Print &operator<<(Print &printer, const Tensor &tensor) { ...@@ -342,9 +351,12 @@ inline Print &operator<<(Print &printer, const Tensor &tensor) {
printer << " dims: " << tensor.dims() << "\n"; printer << " dims: " << tensor.dims() << "\n";
int stride = tensor.numel() / 20; int stride = tensor.numel() / 20;
stride = stride > 0 ? stride : 1; stride = stride > 0 ? stride : 1;
#ifndef PADDLE_MOBILE_FPGA
for (int i = 0; i < tensor.numel(); i += stride) { for (int i = 0; i < tensor.numel(); i += stride) {
printer << tensor.data<float>()[i] << " "; printer << tensor.data<float>()[i] << " ";
} }
#endif
return printer; return printer;
} }
......
...@@ -18,7 +18,7 @@ limitations under the License. */ ...@@ -18,7 +18,7 @@ limitations under the License. */
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
#include "fpga/api/fpga_api.h" #include "fpga/api.h"
#endif #endif
...@@ -26,7 +26,7 @@ namespace paddle_mobile { ...@@ -26,7 +26,7 @@ namespace paddle_mobile {
namespace memory { namespace memory {
const int MALLOC_ALIGN = 64; const int MALLOC_ALIGN = 64;
#ifdef PADDLE_MOBILE_FPGA__VV #ifdef PADDLE_MOBILE_FPGA
namespace fpga = paddle_mobile::fpga; namespace fpga = paddle_mobile::fpga;
void Copy(void *dst, const void *src, size_t num) { void Copy(void *dst, const void *src, size_t num) {
......
...@@ -38,10 +38,15 @@ class FeedOp : public framework::OperatorBase<DeviceType> { ...@@ -38,10 +38,15 @@ class FeedOp : public framework::OperatorBase<DeviceType> {
} }
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
void RunImpl() const { fpga::PerformBypass(param_.FpgaArgs()); }
void Init() { void Init() {
Tensor *output = param_.Out();
output->mutable_data<half>();
}
void RunImpl() const {
const Tensor *input = param_.InputX(); const Tensor *input = param_.InputX();
auto input_ptr = (const_cast<Tensor *>(input))->mutable_data<float>(); auto input_ptr = input->data<float>();
Tensor *output = param_.Out(); Tensor *output = param_.Out();
auto output_ptr = output->mutable_data<half>(); auto output_ptr = output->mutable_data<half>();
fpga::BypassArgs args; fpga::BypassArgs args;
...@@ -52,12 +57,12 @@ class FeedOp : public framework::OperatorBase<DeviceType> { ...@@ -52,12 +57,12 @@ class FeedOp : public framework::OperatorBase<DeviceType> {
args.image.height = input->dims()[2]; args.image.height = input->dims()[2];
args.image.width = input->dims()[3]; args.image.width = input->dims()[3];
args.output.address = output_ptr; args.output.address = output_ptr;
param_.SetFpgaArgs(args); fpga::PerformBypass(args);
} }
#else #else
void RunImpl() const { param_.Out()->ShareDataWith(*param_.InputX()); }
void Init() {} void Init() {}
void RunImpl() const { param_.Out()->ShareDataWith(*param_.InputX()); }
#endif #endif
protected: protected:
......
...@@ -15,8 +15,8 @@ limitations under the License. */ ...@@ -15,8 +15,8 @@ limitations under the License. */
#ifdef FUSION_CONVADDBN_OP #ifdef FUSION_CONVADDBN_OP
#include "operators/kernel/conv_add_bn_kernel.h" #include "operators/kernel/conv_add_bn_kernel.h"
#include "fpga/api/fpga_api.h" #include "fpga/api.h"
#include "fpga/fpga_quantilization.h" #include "fpga/quantization.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
...@@ -37,11 +37,11 @@ bool ConvAddBNKernel<FPGA, float>::Init(FusionConvAddBNParam *param) { ...@@ -37,11 +37,11 @@ bool ConvAddBNKernel<FPGA, float>::Init(FusionConvAddBNParam *param) {
auto bn_scale_ptr = param->InputScale()->data<float>(); auto bn_scale_ptr = param->InputScale()->data<float>();
auto bn_bias_ptr = param->InputBias()->data<float>(); auto bn_bias_ptr = param->InputBias()->data<float>();
const float epsilon = param->Epsilon(); const float epsilon = param->Epsilon();
PADDLE_MOBILE_ENFORCE(input->dims()[1] == bias->dims()[0] && PADDLE_MOBILE_ENFORCE(out->dims()[1] == bias->dims()[0] &&
bias->dims()[0] == param->InputBias()->dims()[0], bias->dims()[0] == param->InputBias()->dims()[0],
"Image channel should be equal to bias number"); "Output channel should be equal to bias number");
const int channel = input->dims()[1]; const int channel = out->dims()[1];
float *bs_ptr = float *bs_ptr =
reinterpret_cast<float *>(fpga::fpga_malloc(2 * channel * sizeof(float))); reinterpret_cast<float *>(fpga::fpga_malloc(2 * channel * sizeof(float)));
Tensor *new_scale = new Tensor(); Tensor *new_scale = new Tensor();
...@@ -60,8 +60,8 @@ bool ConvAddBNKernel<FPGA, float>::Init(FusionConvAddBNParam *param) { ...@@ -60,8 +60,8 @@ bool ConvAddBNKernel<FPGA, float>::Init(FusionConvAddBNParam *param) {
param->SetNewScale(new_scale); param->SetNewScale(new_scale);
param->SetNewBias(new_bias); param->SetNewBias(new_bias);
fpga::quantify_filter(filter); fpga::quantize_filter(filter);
auto filter_ptr = filter->data<float>(); auto filter_ptr = filter->data<int8_t>();
fpga::ConvArgs convArgs; fpga::ConvArgs convArgs;
convArgs.relu_enabled = relu_enabled; convArgs.relu_enabled = relu_enabled;
......
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
#ifdef FUSION_CONVADDBNRELU_OP #ifdef FUSION_CONVADDBNRELU_OP
#include "operators/kernel/conv_add_bn_relu_kernel.h" #include "operators/kernel/conv_add_bn_relu_kernel.h"
#include "fpga/fpga_quantilization.h" #include "fpga/quantization.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
...@@ -35,11 +35,11 @@ bool ConvAddBNReluKernel<FPGA, float>::Init(FusionConvAddBNReluParam *param) { ...@@ -35,11 +35,11 @@ bool ConvAddBNReluKernel<FPGA, float>::Init(FusionConvAddBNReluParam *param) {
auto bn_scale_ptr = param->InputScale()->data<float>(); auto bn_scale_ptr = param->InputScale()->data<float>();
auto bn_bias_ptr = param->InputBias()->data<float>(); auto bn_bias_ptr = param->InputBias()->data<float>();
const float epsilon = param->Epsilon(); const float epsilon = param->Epsilon();
PADDLE_MOBILE_ENFORCE(input->dims()[1] == bias->dims()[0] && PADDLE_MOBILE_ENFORCE(out->dims()[1] == bias->dims()[0] &&
bias->dims()[0] == param->InputBias()->dims()[0], bias->dims()[0] == param->InputBias()->dims()[0],
"Image channel should be equal to bias number"); "Output channel should be equal to bias number");
const int channel = input->dims()[1]; const int channel = out->dims()[1];
float *bs_ptr = (float *)fpga::fpga_malloc(2 * channel * sizeof(float)); float *bs_ptr = (float *)fpga::fpga_malloc(2 * channel * sizeof(float));
Tensor *new_scale = new Tensor(); Tensor *new_scale = new Tensor();
Tensor *new_bias = new Tensor(); Tensor *new_bias = new Tensor();
...@@ -56,8 +56,8 @@ bool ConvAddBNReluKernel<FPGA, float>::Init(FusionConvAddBNReluParam *param) { ...@@ -56,8 +56,8 @@ bool ConvAddBNReluKernel<FPGA, float>::Init(FusionConvAddBNReluParam *param) {
} }
param->SetNewScale(new_scale); param->SetNewScale(new_scale);
param->SetNewBias(new_bias); param->SetNewBias(new_bias);
fpga::quantify_filter(filter); fpga::quantize_filter(filter);
auto filter_ptr = filter->data<float>(); auto filter_ptr = filter->data<int8_t>();
fpga::ConvArgs convArgs; fpga::ConvArgs convArgs;
convArgs.relu_enabled = relu_enabled; convArgs.relu_enabled = relu_enabled;
......
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
#ifdef FUSION_CONVADDRELU_OP #ifdef FUSION_CONVADDRELU_OP
#include "operators/kernel/conv_add_relu_kernel.h" #include "operators/kernel/conv_add_relu_kernel.h"
#include "fpga/fpga_quantilization.h" #include "fpga/quantization.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
...@@ -31,17 +31,17 @@ bool ConvAddReluKernel<FPGA, float>::Init(FusionConvAddReluParam *param) { ...@@ -31,17 +31,17 @@ bool ConvAddReluKernel<FPGA, float>::Init(FusionConvAddReluParam *param) {
Tensor *out = param->Output(); Tensor *out = param->Output();
auto out_ptr = out->mutable_data<half>(); auto out_ptr = out->mutable_data<half>();
PADDLE_MOBILE_ENFORCE(input->dims()[1] == bias->dims()[0], PADDLE_MOBILE_ENFORCE(out->dims()[1] == bias->dims()[0],
"Image channel should be equal to bias number"); "Output channel should be equal to bias number");
int channel = input->dims()[1]; int channel = out->dims()[1];
float *bs_ptr = (float *)fpga::fpga_malloc(2 * channel * sizeof(float)); float *bs_ptr = (float *)fpga::fpga_malloc(2 * channel * sizeof(float));
for (int i = 0; i < channel; i++) { for (int i = 0; i < channel; i++) {
bs_ptr[i * 2] = 1; bs_ptr[i * 2] = 1;
bs_ptr[i * 2 + 1] = bias_ptr[i]; bs_ptr[i * 2 + 1] = bias_ptr[i];
} }
fpga::quantify_filter(filter); fpga::quantize_filter(filter);
auto filter_ptr = filter->data<float>(); auto filter_ptr = filter->data<int8_t>();
fpga::ConvArgs convArgs; fpga::ConvArgs convArgs;
convArgs.relu_enabled = relu_enabled; convArgs.relu_enabled = relu_enabled;
......
...@@ -15,8 +15,8 @@ limitations under the License. */ ...@@ -15,8 +15,8 @@ limitations under the License. */
#ifdef FUSION_CONVBN_OP #ifdef FUSION_CONVBN_OP
#include "operators/kernel/conv_bn_kernel.h" #include "operators/kernel/conv_bn_kernel.h"
#include "fpga/api/fpga_api.h" #include "fpga/api.h"
#include "fpga/fpga_quantilization.h" #include "fpga/quantization.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
...@@ -35,10 +35,10 @@ bool ConvBNKernel<FPGA, float>::Init(FusionConvBNParam *param) { ...@@ -35,10 +35,10 @@ bool ConvBNKernel<FPGA, float>::Init(FusionConvBNParam *param) {
auto bn_scale_ptr = param->InputScale()->data<float>(); auto bn_scale_ptr = param->InputScale()->data<float>();
auto bn_bias_ptr = param->InputBias()->data<float>(); auto bn_bias_ptr = param->InputBias()->data<float>();
const float epsilon = param->Epsilon(); const float epsilon = param->Epsilon();
PADDLE_MOBILE_ENFORCE(input->dims()[1] == param->InputBias()->dims()[0], PADDLE_MOBILE_ENFORCE(out->dims()[1] == param->InputBias()->dims()[0],
"Image channel should be equal to bias number"); "Output channel should be equal to bias number");
const int channel = input->dims()[1]; const int channel = out->dims()[1];
float *bs_ptr = float *bs_ptr =
reinterpret_cast<float *>(fpga::fpga_malloc(2 * channel * sizeof(float))); reinterpret_cast<float *>(fpga::fpga_malloc(2 * channel * sizeof(float)));
Tensor *new_scale = new Tensor(); Tensor *new_scale = new Tensor();
...@@ -55,8 +55,8 @@ bool ConvBNKernel<FPGA, float>::Init(FusionConvBNParam *param) { ...@@ -55,8 +55,8 @@ bool ConvBNKernel<FPGA, float>::Init(FusionConvBNParam *param) {
} }
param->SetNewScale(new_scale); param->SetNewScale(new_scale);
param->SetNewBias(new_bias); param->SetNewBias(new_bias);
fpga::quantify_filter(filter); fpga::quantize_filter(filter);
auto filter_ptr = filter->data<float>(); auto filter_ptr = filter->data<int8_t>();
fpga::ConvArgs convArgs; fpga::ConvArgs convArgs;
convArgs.relu_enabled = relu_enabled; convArgs.relu_enabled = relu_enabled;
......
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
#ifdef FUSION_CONVBNRELU_OP #ifdef FUSION_CONVBNRELU_OP
#include "operators/kernel/conv_bn_relu_kernel.h" #include "operators/kernel/conv_bn_relu_kernel.h"
#include "fpga/fpga_quantilization.h" #include "fpga/quantization.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
...@@ -33,10 +33,10 @@ bool ConvBNReluKernel<FPGA, float>::Init(FusionConvBNReluParam *param) { ...@@ -33,10 +33,10 @@ bool ConvBNReluKernel<FPGA, float>::Init(FusionConvBNReluParam *param) {
auto bn_scale_ptr = param->InputScale()->data<float>(); auto bn_scale_ptr = param->InputScale()->data<float>();
auto bn_bias_ptr = param->InputBias()->data<float>(); auto bn_bias_ptr = param->InputBias()->data<float>();
const float epsilon = param->Epsilon(); const float epsilon = param->Epsilon();
PADDLE_MOBILE_ENFORCE(input->dims()[1] == param->InputBias()->dims()[0], PADDLE_MOBILE_ENFORCE(out->dims()[1] == param->InputBias()->dims()[0],
"Image channel should be equal to bias number"); "Output channel should be equal to bias number");
const int channel = input->dims()[1]; const int channel = out->dims()[1];
float *bs_ptr = (float *)fpga::fpga_malloc(2 * channel * sizeof(float)); float *bs_ptr = (float *)fpga::fpga_malloc(2 * channel * sizeof(float));
Tensor *new_scale = new Tensor(); Tensor *new_scale = new Tensor();
Tensor *new_bias = new Tensor(); Tensor *new_bias = new Tensor();
...@@ -52,8 +52,8 @@ bool ConvBNReluKernel<FPGA, float>::Init(FusionConvBNReluParam *param) { ...@@ -52,8 +52,8 @@ bool ConvBNReluKernel<FPGA, float>::Init(FusionConvBNReluParam *param) {
} }
param->SetNewScale(new_scale); param->SetNewScale(new_scale);
param->SetNewBias(new_bias); param->SetNewBias(new_bias);
fpga::quantify_filter(filter); fpga::quantize_filter(filter);
auto filter_ptr = filter->data<float>(); auto filter_ptr = filter->data<int8_t>();
fpga::ConvArgs convArgs; fpga::ConvArgs convArgs;
convArgs.relu_enabled = relu_enabled; convArgs.relu_enabled = relu_enabled;
......
...@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef FUSION_FCRELU_OP #ifdef FUSION_FCRELU_OP
#include "operators/kernel/fc_relu_kernel.h" #include "operators/kernel/fc_relu_kernel.h"
#include "fpga/api/fpga_api.h"
#include "fpga/api.h"
#include "fpga/quantization.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
...@@ -23,8 +25,7 @@ bool FusionFcReluKernel<FPGA, float>::Init(FusionFcReluParam *param) { ...@@ -23,8 +25,7 @@ bool FusionFcReluKernel<FPGA, float>::Init(FusionFcReluParam *param) {
bool relu_enabled = true; bool relu_enabled = true;
const Tensor *input_x = param->InputX(); const Tensor *input_x = param->InputX();
auto input_x_ptr = input_x->data<half>(); auto input_x_ptr = input_x->data<half>();
const Tensor *input_y = param->InputY(); Tensor *input_y = param->InputY();
auto input_y_ptr = input_y->data<float>();
const Tensor *input_z = param->InputZ(); const Tensor *input_z = param->InputZ();
auto input_z_ptr = input_z->data<float>(); auto input_z_ptr = input_z->data<float>();
Tensor *out = param->Out(); Tensor *out = param->Out();
...@@ -32,13 +33,16 @@ bool FusionFcReluKernel<FPGA, float>::Init(FusionFcReluParam *param) { ...@@ -32,13 +33,16 @@ bool FusionFcReluKernel<FPGA, float>::Init(FusionFcReluParam *param) {
PADDLE_MOBILE_ENFORCE(input_x->dims()[1] == input_y->dims()[0], PADDLE_MOBILE_ENFORCE(input_x->dims()[1] == input_y->dims()[0],
"Image channel should be equal to weight number"); "Image channel should be equal to weight number");
int channel = input_x->dims()[1]; int channel = out->dims()[1];
float *bs_ptr = (float *)fpga::fpga_malloc(2 * channel * sizeof(float)); float *bs_ptr = (float *)fpga::fpga_malloc(2 * channel * sizeof(float));
for (int i = 0; i < channel; i++) { for (int i = 0; i < channel; i++) {
bs_ptr[i * 2] = 1; bs_ptr[i * 2] = 1;
bs_ptr[i * 2 + 1] = input_z_ptr[i]; bs_ptr[i * 2 + 1] = input_z_ptr[i];
} }
fpga::quantize_filter(input_y);
auto input_y_ptr = input_y->data<int8_t>();
fpga::ConvArgs convArgs; fpga::ConvArgs convArgs;
convArgs.relu_enabled = relu_enabled; convArgs.relu_enabled = relu_enabled;
convArgs.filter_address = (void *)input_y_ptr; convArgs.filter_address = (void *)input_y_ptr;
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#ifdef FUSION_FC_OP #ifdef FUSION_FC_OP
#include "operators/kernel/fusion_fc_kernel.h" #include "operators/kernel/fusion_fc_kernel.h"
#include "fpga/quantization.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
...@@ -23,8 +24,7 @@ bool FusionFcKernel<FPGA, float>::Init(FusionFcParam *param) { ...@@ -23,8 +24,7 @@ bool FusionFcKernel<FPGA, float>::Init(FusionFcParam *param) {
bool relu_enabled = false; bool relu_enabled = false;
const Tensor *input_x = param->InputX(); const Tensor *input_x = param->InputX();
auto input_x_ptr = input_x->data<half>(); auto input_x_ptr = input_x->data<half>();
const Tensor *input_y = param->InputY(); Tensor *input_y = param->InputY();
auto input_y_ptr = input_y->data<float>();
const Tensor *input_z = param->InputZ(); const Tensor *input_z = param->InputZ();
auto input_z_ptr = input_z->data<float>(); auto input_z_ptr = input_z->data<float>();
Tensor *out = param->Out(); Tensor *out = param->Out();
...@@ -32,13 +32,16 @@ bool FusionFcKernel<FPGA, float>::Init(FusionFcParam *param) { ...@@ -32,13 +32,16 @@ bool FusionFcKernel<FPGA, float>::Init(FusionFcParam *param) {
PADDLE_MOBILE_ENFORCE(input_x->dims()[1] == input_y->dims()[0], PADDLE_MOBILE_ENFORCE(input_x->dims()[1] == input_y->dims()[0],
"Image channel should be equal to weight number"); "Image channel should be equal to weight number");
int channel = input_x->dims()[1]; int channel = out->dims()[1];
float *bs_ptr = (float *)fpga::fpga_malloc(2 * channel * sizeof(float)); float *bs_ptr = (float *)fpga::fpga_malloc(2 * channel * sizeof(float));
for (int i = 0; i < channel; i++) { for (int i = 0; i < channel; i++) {
bs_ptr[i * 2] = 1; bs_ptr[i * 2] = 1;
bs_ptr[i * 2 + 1] = input_z_ptr[i]; bs_ptr[i * 2 + 1] = input_z_ptr[i];
} }
fpga::quantize_filter(input_y);
auto input_y_ptr = input_y->data<int8_t>();
fpga::ConvArgs convArgs; fpga::ConvArgs convArgs;
convArgs.relu_enabled = relu_enabled; convArgs.relu_enabled = relu_enabled;
convArgs.filter_address = (void *)input_y_ptr; convArgs.filter_address = (void *)input_y_ptr;
...@@ -55,11 +58,9 @@ bool FusionFcKernel<FPGA, float>::Init(FusionFcParam *param) { ...@@ -55,11 +58,9 @@ bool FusionFcKernel<FPGA, float>::Init(FusionFcParam *param) {
convArgs.image.width = input_x->dims()[3]; convArgs.image.width = input_x->dims()[3];
convArgs.image.pad_height = 0; convArgs.image.pad_height = 0;
convArgs.image.pad_width = 0; convArgs.image.pad_width = 0;
convArgs.image.scale_address = convArgs.image.scale_address = input_x->fpga_args().scale_pointer();
input_x->fpga_args().scale_pointer(); // fc input has scale attribute??
convArgs.output.address = (void *)out_ptr; convArgs.output.address = (void *)out_ptr;
convArgs.output.scale_address = convArgs.output.scale_address = out->fpga_args().scale_pointer();
out->fpga_args().scale_pointer(); // fc output has scale attribute??
param->SetFpgaArgs(convArgs); param->SetFpgaArgs(convArgs);
return true; return true;
} }
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef SOFTMAX_OP
#include "../softmax_kernel.h"
#include "../central-arm-func/softmax_arm_func.h"
#include "common/types.h"
#include "fpga/api.h"
#include "operators/math/softmax.h"
namespace paddle_mobile {
namespace operators {
template <>
bool SoftmaxKernel<FPGA, float>::Init(SoftmaxParam *param) {
const Tensor *input = param->InputX();
if (input->type() == typeid(half)) {
auto input_ptr = input->data<half>();
auto output_ptr = param->Out();
fpga::BypassArgs args;
args.convert_type = fpga::DATA_FP16_TO_FP32;
args.layout_type = fpga::LAYOUT_HWC_TO_CHW;
args.image.address = (void *)(input_ptr);
args.image.height = input->dims()[0];
args.image.width = input->dims()[1];
args.image.channels = 1;
args.output.address = output_ptr;
param->SetFpgaArgs(args);
}
return true;
}
template <>
void SoftmaxKernel<FPGA, float>::Compute(const SoftmaxParam &param) const {
// SoftmaxCompute<float>(param);
}
template class SoftmaxKernel<FPGA, float>;
} // namespace operators
} // namespace paddle_mobile
#endif
...@@ -23,7 +23,7 @@ limitations under the License. */ ...@@ -23,7 +23,7 @@ limitations under the License. */
#include "framework/tensor.h" #include "framework/tensor.h"
#include "framework/variable.h" #include "framework/variable.h"
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
#include "fpga/api/fpga_api.h" #include "fpga/api.h"
#endif #endif
namespace paddle_mobile { namespace paddle_mobile {
...@@ -585,6 +585,21 @@ class SoftmaxParam : public OpParam { ...@@ -585,6 +585,21 @@ class SoftmaxParam : public OpParam {
private: private:
Tensor *input_x_; Tensor *input_x_;
Tensor *out_; Tensor *out_;
#ifdef PADDLE_MOBILE_FPGA
private:
std::shared_ptr<Tensor> float_input_x_;
fpga::BypassArgs fpga_bypass_args;
public:
Tensor *FloatInput() {
return float_input_x_ == nullptr ? input_x_ : float_input_x_.get();
}
void SetFloatInput(Tensor *input) { float_input_x_.reset(input); }
const fpga::BypassArgs &FpgaArgs() const { return fpga_bypass_args; }
void SetFpgaArgs(const fpga::BypassArgs &args) { fpga_bypass_args = args; }
#endif
}; };
#endif #endif
...@@ -670,16 +685,6 @@ class FeedParam : public OpParam { ...@@ -670,16 +685,6 @@ class FeedParam : public OpParam {
Tensor *input_x_; Tensor *input_x_;
Tensor *out_; Tensor *out_;
int batch_size; int batch_size;
#ifdef PADDLE_MOBILE_FPGA
private:
fpga::BypassArgs fpga_bypass_args;
public:
const fpga::BypassArgs &FpgaArgs() const { return fpga_bypass_args; }
void SetFpgaArgs(const fpga::BypassArgs &args) { fpga_bypass_args = args; }
#endif
}; };
class FetchParam : public OpParam { class FetchParam : public OpParam {
...@@ -1143,7 +1148,6 @@ class FusionConvBNParam : public OpParam { ...@@ -1143,7 +1148,6 @@ class FusionConvBNParam : public OpParam {
FusionConvBNParam(const VariableNameMap &inputs, FusionConvBNParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs, const VariableNameMap &outputs, const AttributeMap &attrs,
const Scope &scope) { const Scope &scope) {
axis_ = GetAttr<int>("axis", attrs);
filter_ = FilterFrom<LoDTensor>(inputs, scope); filter_ = FilterFrom<LoDTensor>(inputs, scope);
input_ = InputFrom<LoDTensor>(inputs, scope); input_ = InputFrom<LoDTensor>(inputs, scope);
output_y_ = OutputYFrom<LoDTensor>(outputs, scope); output_y_ = OutputYFrom<LoDTensor>(outputs, scope);
...@@ -1160,8 +1164,6 @@ class FusionConvBNParam : public OpParam { ...@@ -1160,8 +1164,6 @@ class FusionConvBNParam : public OpParam {
// is_test_ = GetAttr<bool>("is_test", attrs); // is_test_ = GetAttr<bool>("is_test", attrs);
} }
const int &Axis() const { return axis_; }
const Tensor *Input() const { return input_; } const Tensor *Input() const { return input_; }
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
...@@ -1202,7 +1204,6 @@ class FusionConvBNParam : public OpParam { ...@@ -1202,7 +1204,6 @@ class FusionConvBNParam : public OpParam {
const Tensor *NewBias() const { return new_bias_; } const Tensor *NewBias() const { return new_bias_; }
protected: protected:
int axis_;
Tensor *input_; Tensor *input_;
Tensor *output_y_; Tensor *output_y_;
Tensor *filter_; Tensor *filter_;
......
...@@ -34,6 +34,7 @@ REGISTER_OPERATOR_CPU(softmax, ops::SoftmaxOp); ...@@ -34,6 +34,7 @@ REGISTER_OPERATOR_CPU(softmax, ops::SoftmaxOp);
REGISTER_OPERATOR_MALI_GPU(softmax, ops::SoftmaxOp); REGISTER_OPERATOR_MALI_GPU(softmax, ops::SoftmaxOp);
#endif #endif
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
REGISTER_OPERATOR_FPGA(softmax, ops::SoftmaxOp);
#endif #endif
#endif #endif
...@@ -55,6 +55,7 @@ USE_OP_CPU(softmax); ...@@ -55,6 +55,7 @@ USE_OP_CPU(softmax);
USE_OP_MALI_GPU(softmax); USE_OP_MALI_GPU(softmax);
#endif #endif
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
USE_OP_FPGA(softmax);
#endif #endif
#endif #endif
...@@ -27,10 +27,14 @@ elseif("resnet" IN_LIST NET) ...@@ -27,10 +27,14 @@ elseif("resnet" IN_LIST NET)
ADD_EXECUTABLE(test-resnet net/test_resnet.cpp test_helper.h test_include.h executor_for_test.h) ADD_EXECUTABLE(test-resnet net/test_resnet.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-resnet paddle-mobile) target_link_libraries(test-resnet paddle-mobile)
elseif("FPGAnets" IN_LIST NET) elseif("FPGAnets" IN_LIST NET)
# ADD_EXECUTABLE(test-resnet net/test_resnet.cpp test_helper.h test_include.h executor_for_test.h) ADD_EXECUTABLE(test-resnet net/test_resnet.cpp test_helper.h test_include.h executor_for_test.h)
# target_link_libraries(test-resnet paddle-mobile) target_link_libraries(test-resnet paddle-mobile)
ADD_EXECUTABLE(test-tensor-quant fpga/test_tensor_quant.cpp test_helper.h test_include.h executor_for_test.h) ADD_EXECUTABLE(test-tensor-quant fpga/test_tensor_quant.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-tensor-quant paddle-mobile) target_link_libraries(test-tensor-quant paddle-mobile)
ADD_EXECUTABLE(test-fpga-concat-op fpga/test_concat_op.cpp test_helper.h test_include.h)
target_link_libraries(test-fpga-concat-op paddle-mobile)
elseif("mobilenetssd" IN_LIST NET) elseif("mobilenetssd" IN_LIST NET)
# gen test # gen test
ADD_EXECUTABLE(test-mobilenetssd net/test_mobilenet+ssd.cpp test_helper.h test_include.h executor_for_test.h) ADD_EXECUTABLE(test-mobilenetssd net/test_mobilenet+ssd.cpp test_helper.h test_include.h executor_for_test.h)
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "../test_include.h"
#include "operators/concat_op.h"
int main() {
paddle_mobile::Loader<paddle_mobile::FPGA> loader;
auto program = loader.Load(g_googlenet);
PADDLE_MOBILE_ENFORCE(program.originProgram != nullptr,
"program file read fail");
Executor4Test<paddle_mobile::FPGA,
paddle_mobile::operators::ConcatOp<paddle_mobile::FPGA, float>>
executor(program, "concat");
// 1. input_tensors;
vector<Tensor> input_tensors;
Tensor input1;
auto input1_data = CreateInput<float>(&input1, {4, 10, 2, 2}, 0, 1);
input_tensors.push_back(input1);
Tensor input2;
auto input2_data = CreateInput<float>(&input2, {4, 20, 2, 2}, 0, 1);
input_tensors.push_back(input2);
Tensor input3;
auto input3_data = CreateInput<float>(&input3, {4, 30, 2, 2}, 0, 1);
input_tensors.push_back(input3);
Tensor input4;
auto input4_data = CreateInput<float>(&input4, {4, 40, 2, 2}, 0, 1);
input_tensors.push_back(input4);
// 2. input_names
vector<string> input_names({
"conv2d_3.tmp_1",
"conv2d_5.tmp_1",
"conv2d_7.tmp_1",
"conv2d_8.tmp_1",
});
// 3. output_names
vector<string> output_names({"concat_0.tmp_0"});
// 4. out_dims;
vector<DDim> out_ddims;
auto out_ddim = paddle_mobile::framework::make_ddim({3, 100, 2, 2});
out_ddims.push_back(out_ddim);
auto output = executor.Predict<LoDTensor>(input_tensors, input_names,
output_names, out_ddims);
auto output0_data = output[0]->data<float>();
// 5. test one example.
int input_n = 1;
int input_c = 2;
int input_h = 0;
int input_w = 1;
int stride0 = input3.numel() / input3.dims()[0];
int stride1 = input3.numel() / input3.dims()[0] / input3.dims()[1];
int stride2 = input3.dims()[3];
/// inputx1 (4,10,2,2),
/// inputx2 (4,20,2,2),
/// inputx3 (4,30,2,2),
/// inputx4 (4,40,2,2),
/// axis = 1
/// output (4,100,2,2)
int input_index =
input_n * stride0 + input_c * stride1 + input_h * stride2 + input_w;
int output_index = input_n * 100 * 2 * 2 +
(input_c + input1.dims()[1] + input2.dims()[1]) * 2 * 2 +
input_h * 2 + input_w;
DLOG << " input3 [1, 2,0,1] = " << input3_data[input_index];
DLOG << " output [1,32,0,1] = " << output0_data[output_index];
return 0;
}
...@@ -17,7 +17,13 @@ limitations under the License. */ ...@@ -17,7 +17,13 @@ limitations under the License. */
#include "../test_include.h" #include "../test_include.h"
int main() { int main() {
#ifdef PADDLE_MOBILE_FPGA
paddle_mobile::PaddleMobile<paddle_mobile::FPGA> paddle_mobile;
#endif
#ifdef PADDLE_MOBILE_CPU
paddle_mobile::PaddleMobile<paddle_mobile::CPU> paddle_mobile; paddle_mobile::PaddleMobile<paddle_mobile::CPU> paddle_mobile;
#endif
paddle_mobile.SetThreadNum(4); paddle_mobile.SetThreadNum(4);
auto time1 = time(); auto time1 = time();
if (paddle_mobile.Load(g_resnet, true)) { if (paddle_mobile.Load(g_resnet, true)) {
......
...@@ -86,6 +86,8 @@ if ("resnet" IN_LIST NET) ...@@ -86,6 +86,8 @@ if ("resnet" IN_LIST NET)
set(RELU_OP ON) set(RELU_OP ON)
set(ELEMENTWISEADD_OP ON) set(ELEMENTWISEADD_OP ON)
set(POOL_OP ON) set(POOL_OP ON)
set(BATCHNORM_OP ON)
set(MUL_OP ON)
set(RESHAPE_OP ON) set(RESHAPE_OP ON)
set(SOFTMAX_OP ON) set(SOFTMAX_OP ON)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册