提交 cacb362e 编写于 作者: xiebaiyuan's avatar xiebaiyuan

Merge remote-tracking branch 'upstream/develop' into develop

......@@ -104,7 +104,7 @@ int fpga_invalidate(void *address, size_t size) {
}
half fp32_2_fp16(float fp32_num) {
unsigned long tmp = *(unsigned long *)(&fp32_num);
unsigned long tmp = *(unsigned long *)(&fp32_num); // NOLINT
half t = ((tmp & 0x007fffff) >> 13) | ((tmp & 0x80000000) >> 16) |
(((tmp & 0x7f800000) >> 13) - (112 << 10));
if (tmp & 0x1000) {
......@@ -120,7 +120,7 @@ float fp16_2_fp32(half fp16_num) {
int tmp = 0;
float fp32_num;
tmp = s << 16 | exp << 23 | frac << 13;
fp32_num = *(float *)&tmp;
fp32_num = *(float *)&tmp; // NOLINT
return fp32_num;
}
......@@ -347,6 +347,20 @@ void format_filter(framework::Tensor *filter_tensor, float max_value,
filter_tensor->reset_data_ptr(new_data);
}
void format_fc_filter(framework::Tensor *filter_tensor, float max_value) {
filter_tensor->scale[0] = float(max_value / 127.0); // NOLINT
filter_tensor->scale[1] = float(127.0 / max_value); // NOLINT
auto dims = filter_tensor->dims();
auto num = dims[0], channel = dims[1], height = dims[2], width = dims[3];
auto data_ptr = filter_tensor->data<float>();
size_t memory_size = num * channel * height * width * sizeof(float);
auto new_data = (float *)fpga_malloc(memory_size); // NOLINT
fpga_copy(new_data, data_ptr, memory_size);
filter::format_fc_filter(&new_data, num, channel, height, width, 1,
max_value);
filter_tensor->reset_data_ptr(new_data);
}
void format_bias_scale_array(float **bias_scale_array,
int element_num_per_division, int num) {
bias_scale::format_bias_scale_array(bias_scale_array,
......
......@@ -109,8 +109,8 @@ struct PoolingArgs {
struct EWAddArgs {
bool relu_enabled;
half const0; // output0 = const0 x input0 + const1 x input1;
half const1;
uint32_t const0; // output0 = const0 x input0 + const1 x input1;
uint32_t const1;
struct ImageInputArgs image0;
struct ImageInputArgs image1;
struct ImageOutputArgs output;
......@@ -214,6 +214,7 @@ int get_aligned_filter_element_num(int chw);
int get_aligned_filter_num(int num);
void format_filter(framework::Tensor* filter_tensor, float max_value,
int group_num);
void format_fc_filter(framework::Tensor* filter_tensor, float max_value);
void format_bias_scale_array(float** bias_scale_array,
int element_num_per_division, int num);
void format_concat_output(framework::Tensor* out, int height, int width,
......
......@@ -225,6 +225,45 @@ void format_filter(float **data_in, int num, int channel, int height, int width,
num_after_alignment * sizeof(char));
}
void convert_fc_filter(char **data_in, int num, int chw) {
char *tmp = *data_in;
char *data_tmp = (char *)fpga_malloc(chw * num * sizeof(char)); // NOLINT
for (int n = 0; n < num; n++) {
for (int c = 0; c < chw; c++) {
data_tmp[n * chw + c] = (*data_in)[num * c + n];
}
}
*data_in = data_tmp;
fpga_free(tmp);
}
void format_fc_filter(float **data_in, int num, int channel, int height,
int width, int group_num, float max) {
int data_size = channel * height * width * num;
int chw = channel * height * width;
int division_capacity = calc_division_capacity(chw);
int num_per_div_before_alignment =
calc_num_per_div(num, group_num, division_capacity);
int num_per_div_after_alignment =
align_to_x(num_per_div_before_alignment, FILTER_NUM_ALIGNMENT);
int div_num =
(num + num_per_div_before_alignment - 1) / num_per_div_before_alignment;
int num_after_alignment = num_per_div_after_alignment * div_num;
quantize(data_in, data_size, max);
char **quantize_data = (char **)data_in; // NOLINT
convert_fc_filter(quantize_data, num, chw);
align_element(quantize_data, num, chw);
align_num(quantize_data, num_per_div_before_alignment, num, chw);
reorder(quantize_data, num_after_alignment, chw);
interleave(quantize_data, num_after_alignment, chw);
fpga_flush(*quantize_data, align_to_x(chw, FILTER_ELEMENT_ALIGNMENT) *
num_after_alignment * sizeof(char));
}
} // namespace filter
} // namespace fpga
} // namespace paddle_mobile
......@@ -25,7 +25,7 @@ int calc_division_capacity(int chw);
int calc_split_num(int num, int division_capacity);
int calc_division_number(int num, int group_num, int division_capacity);
int calc_num_per_div(int num, int group_num, int division_capacity);
void convert_to_hwc(float** data_in, int num, int channel, int height,
void convert_to_hwc(char** data_in, int num, int channel, int height,
int width);
float find_max(float* data_in, int data_size);
void quantize(float** data_in, int data_size, float max);
......@@ -36,6 +36,11 @@ void reorder(float** data_in, int num_after_alignment, int chw);
void interleave(float** data_in, int num_after_alignment, int chw);
void format_filter(float** data_in, int num, int channel, int height, int width,
int group_num, float max);
void convert_fc_filter(char** data_in, int num, int chw);
void format_fc_filter(float** data_in, int num, int channel, int height,
int width, int group_num, float max);
} // namespace filter
} // namespace fpga
} // namespace paddle_mobile
......@@ -101,6 +101,11 @@ bool PaddleMobilePredictor<Dtype, P>::Run(
return true;
}
template <typename Dtype, Precision P>
PaddleMobilePredictor<Dtype, P>::~PaddleMobilePredictor() {
paddle_mobile_->Clear();
}
// A factory to help create difference predictor.
template <>
std::unique_ptr<PaddlePredictor>
......
......@@ -32,7 +32,7 @@ namespace paddle_mobile {
template <typename Dtype = CPU, Precision P = Precision::FP32>
class PaddleMobilePredictor : public PaddlePredictor {
public:
PaddleMobilePredictor() {}
PaddleMobilePredictor() = delete;
explicit PaddleMobilePredictor(const PaddleMobileConfig& config);
......@@ -40,7 +40,7 @@ class PaddleMobilePredictor : public PaddlePredictor {
std::vector<PaddleTensor>* output_data,
int batch_size = -1) override;
~PaddleMobilePredictor() override{};
~PaddleMobilePredictor() override;
private:
std::unique_ptr<PaddleMobile<Dtype, P>> paddle_mobile_;
......
......@@ -87,7 +87,6 @@ enum class PaddleEngineKind {
class PaddlePredictor {
public:
struct Config;
PaddlePredictor() = default;
PaddlePredictor(const PaddlePredictor&) = delete;
PaddlePredictor& operator=(const PaddlePredictor&) = delete;
......@@ -107,6 +106,9 @@ class PaddlePredictor {
struct Config {
std::string model_dir; // path to the model directory.
};
protected:
PaddlePredictor() = default;
};
struct PaddleMobileConfig : public PaddlePredictor::Config {
......
......@@ -46,7 +46,7 @@ bool FusionFcReluKernel<FPGA, float>::Init(FusionFcReluParam<FPGA> *param) {
filter->Resize(framework::make_ddim({num, filter_channel, height, width}));
float max_value = fpga::filter_find_max(filter);
fpga::format_filter(filter, max_value, 1);
fpga::format_fc_filter(filter, max_value);
int element_num_per_div = fpga::get_filter_num_per_div(filter, 1);
fpga::format_bias_scale_array(&bs_ptr, element_num_per_div, channel);
......
......@@ -47,7 +47,7 @@ bool FusionFcKernel<FPGA, float>::Init(FusionFcParam<FPGA> *param) {
filter->Resize(framework::make_ddim({num, filter_channel, height, width}));
float max_value = fpga::filter_find_max(filter);
fpga::format_filter(filter, max_value, 1);
fpga::format_fc_filter(filter, max_value);
int element_num_per_div = fpga::get_filter_num_per_div(filter, 1);
fpga::format_bias_scale_array(&bs_ptr, element_num_per_div, channel);
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef MUL_OP
#include "operators/kernel/mul_kernel.h"
namespace paddle_mobile {
namespace operators {
template <>
bool MulKernel<FPGA, float>::Init(MulParam<FPGA> *param) {
bool relu_enabled = false;
auto input_x = const_cast<LoDTensor *>(param->InputX());
auto filter = const_cast<LoDTensor *>(param->InputY());
auto out = param->Out();
PADDLE_MOBILE_ENFORCE(input_x->dims()[1] == filter->dims()[0],
"Image channel should be equal to weight number");
int channel = (uint32_t)out->dims()[1];
auto bs_ptr =
(float *)fpga::fpga_malloc(2 * channel * sizeof(float)); // NOLINT
for (int i = 0; i < channel; i++) {
bs_ptr[i + channel] = 1;
bs_ptr[i] = 0;
}
int num = (uint32_t)filter->dims()[1];
int chw = (uint32_t)filter->dims()[0];
PADDLE_MOBILE_ENFORCE(
chw == input_x->numel(),
"Filter element num should be equal to IFM element num");
int height = (uint32_t)input_x->dims()[2];
int width = (uint32_t)input_x->dims()[3];
int filter_channel = chw / height / width;
filter->Resize(framework::make_ddim({num, filter_channel, height, width}));
float max_value = fpga::filter_find_max(filter);
fpga::format_fc_filter(filter, max_value);
int element_num_per_div = fpga::get_filter_num_per_div(filter, 1);
fpga::format_bias_scale_array(&bs_ptr, element_num_per_div, channel);
fpga::format_fp16_ofm(out);
fpga::WrapperConvArgs conv_arg = {0};
fpga::fill_conv_arg(&conv_arg, input_x, out, filter, relu_enabled, 1, 1, 1, 0,
0, bs_ptr);
param->SetFpgaArgs(conv_arg);
return true;
}
template <>
void MulKernel<FPGA, float>::Compute(const MulParam<FPGA> &param) const {
fpga::ComputeFpgaConv(param.FpgaArgs());
}
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -61,5 +61,7 @@ REGISTER_OPERATOR_CPU(mul, ops::MulOp);
#ifdef PADDLE_MOBILE_MALI_GPU
REGISTER_OPERATOR_MALI_GPU(mul, ops::MulOp);
#endif
#ifdef PADDLE_MOBILE_FPGA
REGISTER_OPERATOR_FPGA(mul, ops::MulOp);
#endif
#endif
......@@ -441,6 +441,15 @@ class MulParam : OpParam {
GType *out_;
int x_num_col_dims_;
int y_num_col_dims_;
#ifdef PADDLE_MOBILE_FPGA
private:
fpga::WrapperConvArgs fpga_conv_args;
public:
const fpga::WrapperConvArgs &FpgaArgs() const { return fpga_conv_args; }
void SetFpgaArgs(const fpga::WrapperConvArgs &args) { fpga_conv_args = args; }
#endif
};
#endif
......
......@@ -18,8 +18,9 @@ static const char *g_resnet_combine = "../models/resnet50";
int main() {
DLOG << paddle_mobile::fpga::open_device();
paddle_mobile::PaddleMobile<paddle_mobile::FPGA> paddle_mobile;
if (paddle_mobile.Load(std::string(g_resnet_combine) + "/model",
std::string(g_resnet_combine) + "/params", true)) {
// if (paddle_mobile.Load(std::string(g_resnet_combine) + "/model",
// std::string(g_resnet_combine) + "/params", true)) {
if (paddle_mobile.Load(std::string(g_resnet_combine), true)) {
std::vector<int64_t> dims{1, 3, 224, 224};
Tensor input_tensor;
SetupTensor<float>(&input_tensor, {1, 3, 224, 224}, static_cast<float>(0),
......
......@@ -46,7 +46,12 @@ int main() {
tensor_out.dtype = PaddleDType::FLOAT32;
std::vector<PaddleTensor> outputs(1, tensor_out);
assert(predictor->Run(paddle_tensor_feeds, &outputs));
std::cout << " before predict " << std::endl;
predictor->Run(paddle_tensor_feeds, &outputs);
std::cout << " after predict " << std::endl;
// assert();
float* data_o = static_cast<float*>(outputs[0].data.data());
for (size_t j = 0; j < outputs[0].data.length() / sizeof(float); ++j) {
......
......@@ -52,8 +52,8 @@ int main() {
#else
auto time3 = time();
paddle_mobile.FeedData(input_tensor);
paddle_mobile.Predict_To(10);
paddle_mobile.Predict_From(10);
paddle_mobile.Predict_To(-1);
/*paddle_mobile.Predict_From(10);
auto tensor_ptr = paddle_mobile.FetchResult(9);
std::cout << "Tensor element number for op[9]: " << tensor_ptr->numel()
<< std::endl;
......@@ -63,7 +63,7 @@ int main() {
auto time4 = time();
std::cout << "predict cost :" << time_diff(time3, time4) << "ms"
<< std::endl;
<< std::endl;*/
#endif
}
return 0;
......
......@@ -121,6 +121,7 @@ if (CON GREATER -1)
set(FUSION_CONVBNRELU_OP ON)
set(FUSION_CONVBN_OP ON)
set(FUSION_CONVADD_OP ON)
set(MUL_OP ON)
set(FOUND_MATCH ON)
endif()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册