提交 0c1bde02 编写于 作者: C chonwhite

merged from pm_sync

...@@ -22,6 +22,9 @@ if (WITH_PADDLE_MOBILE) ...@@ -22,6 +22,9 @@ if (WITH_PADDLE_MOBILE)
return() return()
endif(WITH_PADDLE_MOBILE) endif(WITH_PADDLE_MOBILE)
# set(CMAKE_BUILD_TYPE DEBUG)
set(PADDLE_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}) set(PADDLE_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR})
set(PADDLE_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}) set(PADDLE_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR})
set(CMAKE_CXX_STANDARD 11) set(CMAKE_CXX_STANDARD 11)
......
...@@ -223,6 +223,25 @@ if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND WITH_TESTING) ...@@ -223,6 +223,25 @@ if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND WITH_TESTING)
CL_DEPS ${opencl_kernels} CL_DEPS ${opencl_kernels}
FPGA_DEPS ${fpga_kernels}) FPGA_DEPS ${fpga_kernels})
lite_cc_test(test_ssd_fpga SRCS test_ssd_fpga.cc
DEPS ${lite_model_test_DEPS}
CL_DEPS ${opencl_kernels}
FPGA_DEPS ${fpga_kernels})
lite_cc_test(test_inceptionv3_fpga SRCS inceptionv3_test_fpga.cc
DEPS ${lite_model_test_DEPS}
CL_DEPS ${opencl_kernels}
FPGA_DEPS ${fpga_kernels})
lite_cc_test(test_inceptionv4 SRCS inceptionv4_test.cc
DEPS ${lite_model_test_DEPS}
CL_DEPS ${opencl_kernels}
ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl
--model_dir=${LITE_MODEL_DIR}/inception_v4 SERIAL)
add_dependencies(test_inceptionv4 extern_lite_download_inception_v4_simple_tar_gz)
lite_cc_test(test_ocr_attention_fpga SRCS ocr_attention_test_fpga.cc
DEPS ${lite_model_test_DEPS})
# lite_cc_test(model_run_test_image SRCS model_run_test_image.cc # lite_cc_test(model_run_test_image SRCS model_run_test_image.cc
# DEPS ${lite_model_test_DEPS} # DEPS ${lite_model_test_DEPS}
# CL_DEPS ${opencl_kernels} # CL_DEPS ${opencl_kernels}
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
#include "lite/api/cxx_api.h"
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/api/paddle_use_passes.h"
#include "lite/api/test_helper.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
#ifdef LITE_WITH_FPGA
TEST(ResNet50, test) {
lite::Predictor predictor;
std::vector<Place> valid_places({
Place{TARGET(kFPGA), PRECISION(kFP16), DATALAYOUT(kNHWC)},
Place{TARGET(kHost), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kFloat)},
});
predictor.Build("",
FLAGS_model_dir + "/model",
FLAGS_model_dir + "/params",
valid_places);
auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224})));
auto* data = input_tensor->mutable_data<float>();
auto item_size = input_tensor->dims().production();
for (int i = 0; i < item_size; i++) {
data[i] = 1;
}
for (int i = 0; i < FLAGS_warmup; ++i) {
predictor.Run();
}
auto start = GetCurrentUS();
for (int i = 0; i < 2; ++i) {
predictor.Run();
}
LOG(INFO) << "================== Speed Report ===================";
}
#endif
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
#include "lite/api/cxx_api.h"
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/api/paddle_use_passes.h"
#include "lite/api/test_helper.h"
#include "lite/core/op_registry.h"
DEFINE_string(input_file, "", "input_file");
namespace paddle {
namespace lite {
void read_from_file(const std::string& path, float* data, int num) {
std::ifstream file_stream;
file_stream.open(path);
if (!file_stream) {
exit(-1);
return;
}
for (int i = 0; i < num; ++i) {
float value = 0;
file_stream >> value;
data[i] = value;
}
}
void chw_to_hwc(float* src, float* dst, int channel, int height, int width) {
int amount_per_row = width * channel;
int index = 0;
for (int c = 0; c < channel; c++) {
for (int h = 0; h < height; h++) {
int offset_height = h * amount_per_row;
for (int w = 0; w < width; w++) {
int dst_index = offset_height + w * channel + c;
dst[dst_index] = src[index];
index = index + 1;
}
}
}
}
void TestModel(const std::vector<Place>& valid_places,
const Place& preferred_place,
bool use_npu = false) {
DeviceInfo::Init();
DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads);
lite::Predictor predictor;
// predictor.Build(FLAGS_model_dir, "", "", preferred_place, valid_places);
predictor.Build("", "attention/model", "attention/params", valid_places);
auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 1, 100, 200})));
auto* data = input_tensor->mutable_data<float>();
auto item_size = input_tensor->dims().production();
for (int i = 0; i < item_size; i++) {
data[i] = 1;
}
read_from_file(FLAGS_input_file, data, 100 * 200);
//=============================================
auto* init_ids = predictor.GetInput(1);
init_ids->Resize(DDim(std::vector<DDim::value_type>({1, 1})));
auto* data_ids = init_ids->mutable_data<float>();
auto ids_size = init_ids->dims().production();
for (int i = 0; i < ids_size; i++) {
data_ids[i] = 0;
}
auto lod_ids = init_ids->mutable_lod();
std::vector<std::vector<uint64_t>> lod_i{{0, 1}, {0, 1}};
*lod_ids = lod_i;
//=============================================
auto* init_scores = predictor.GetInput(2);
init_scores->Resize(DDim(std::vector<DDim::value_type>({1, 1})));
auto* data_scores = init_scores->mutable_data<float>();
auto scores_size = input_tensor->dims().production();
for (int i = 0; i < scores_size; i++) {
data_scores[i] = 0;
}
auto lod_scores = init_scores->mutable_lod();
std::vector<std::vector<uint64_t>> lod_s{{0, 1}, {0, 1}};
*lod_scores = lod_s;
//=============================================
auto* position_encoding = predictor.GetInput(3);
position_encoding->Resize(
DDim(std::vector<DDim::value_type>({1, 33, 10, 23})));
auto* position_encoding_data = position_encoding->mutable_data<float>();
float* temp_data = position_encoding_data;
for (int i = 0; i < position_encoding->dims().production(); ++i) {
temp_data[i] = 0;
}
int index = 0;
for (int i = 0; i < 10; i++) {
for (int row = 0; row < 10; row++) {
for (int col = 0; col < 23; col++) {
if (i == row) {
temp_data[index] = 1.0f;
} else {
temp_data[index] = 0.0f;
}
index++;
}
}
}
for (int i = 0; i < 23; i++) {
for (int row = 0; row < 10; row++) {
for (int col = 0; col < 23; col++) {
if (i == col) {
temp_data[index] = 1.0f;
} else {
temp_data[index] = 0.0f;
}
index++;
}
}
}
// chw_to_hwc(temp_data, position_encoding_data, 33, 10, 23);
// delete[] temp_data;
// read_from_file("position_encoding.data", position_encoding_data, 33 * 10 *
// 23);
auto start = GetCurrentUS();
for (int i = 0; i < 2; ++i) {
predictor.Run();
}
std::cout << "================== Speed Report ===================";
std::cout << "Model: " << FLAGS_model_dir << ", threads num " << FLAGS_threads
<< ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats
<< ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0
<< " ms in average.";
auto* out = predictor.GetOutput(0);
std::string file = "plate_data/" + FLAGS_input_file.substr(9);
std::cout << "file:::" << file << std::endl;
std::ofstream ofs;
ofs.open(file);
for (int i = 0; i < out->dims().production(); i++) {
float value = out->data<float>()[i];
ofs << value << std::endl;
}
ofs.close();
}
TEST(OcrAttention, test_arm) {
std::vector<Place> valid_places({
Place{TARGET(kFPGA), PRECISION(kFP16), DATALAYOUT(kNHWC)},
Place{TARGET(kHost), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kFloat)},
});
TestModel(valid_places, Place{TARGET(kARM), PRECISION(kFloat)});
}
} // namespace lite
} // namespace paddle
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include <fstream>
#include <iostream>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
...@@ -22,7 +24,7 @@ ...@@ -22,7 +24,7 @@
namespace paddle { namespace paddle {
namespace lite { namespace lite {
#define FPGA_PRINT_TENSOR // #define FPGA_PRINT_TENSOR
class Debugger { class Debugger {
public: public:
...@@ -37,25 +39,42 @@ class Debugger { ...@@ -37,25 +39,42 @@ class Debugger {
} }
} }
void tick(std::string key) {
float value = 0;
if (tick_tock_map.count(key) > 0) {
value += tick_tock_map[key] = value;
}
}
void tock(std::string key) {}
void setEnable(bool en) { enabled_ = en; }
private: private:
bool enabled_ = false;
std::unordered_map<std::string, bool> op_config; std::unordered_map<std::string, bool> op_config;
std::unordered_map<std::string, float> tick_tock_map;
Debugger() { Debugger() {
op_config["concat"] = true; op_config["concat"] = true;
op_config["pooling"] = true; op_config["pooling"] = true;
op_config["conv"] = true; op_config["conv"] = true;
op_config["dropout"] = true;
op_config["dwconv"] = true; op_config["dwconv"] = true;
op_config["ew_add"] = true; op_config["ew_add"] = true;
op_config["ew_mul"] = true;
op_config["crop"] = true; op_config["crop"] = true;
op_config["feed"] = true; op_config["feed"] = true;
op_config["mul"] = true;
op_config["fetch"] = true; op_config["fetch"] = true;
op_config["fc"] = true;
op_config["mul"] = true;
op_config["boxes"] = true; op_config["boxes"] = true;
op_config["scores"] = true; op_config["scores"] = true;
op_config["nms"] = true; op_config["nms"] = true;
op_config["pb_boxes"] = true; op_config["pb_boxes"] = true;
op_config["pb_variances"] = true; op_config["pb_variances"] = true;
// op_config["fc"] = true;
op_config["softmax"] = true; op_config["softmax"] = true;
op_config["split"] = true;
} }
}; };
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <stdio.h> #include <stdio.h>
#include "lite/backends/fpga/KD/llapi/filter.h" #include "lite/backends/fpga/KD/llapi/filter.h"
#include "lite/backends/fpga/KD/llapi/zynqmp_api.h" #include "lite/backends/fpga/KD/llapi/zynqmp_api.h"
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <cl_common.h> #include "io.hpp"
__kernel void sigmoid(__global const CL_DTYPE* x_data, const int count, __global CL_DTYPE* out_data) { namespace paddle {
const int index = get_global_id(0); namespace zynqmp {
if (index < count) {
out_data[index] = 1 / (1 + exp(-x_data[index]));
} } // namespace zynqmp
} } // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <stdio.h>
namespace paddle {
namespace zynqmp {
class FpgaIO {
public:
static FpgaIO& get_instance() {
static FpgaIO s_instance;
return s_instance;
}
void allocData(size_t s) { data_ = new float[s]; }
float* getData() { return data_; }
private:
float* data_ = nullptr;
FpgaIO();
};
} // namespace zynqmp
} // namespace paddle
文件模式从 100644 更改为 100755
...@@ -240,8 +240,10 @@ int8_t* format_filter(float* data_in, ...@@ -240,8 +240,10 @@ int8_t* format_filter(float* data_in,
for (int n = 0; n < num; n++) { for (int n = 0; n < num; n++) {
float* filter_start = data_in + n * chw; float* filter_start = data_in + n * chw;
int8_t* quantized_start = quantized_data + n * chw; int8_t* quantized_start = quantized_data + n * chw;
quantize(filter_start, quantized_start, chw, max); // float f_max = find_max(filter_start, chw);
filter_max.push_back(1); float f_max = max;
quantize(filter_start, quantized_start, chw, f_max);
filter_max.push_back(f_max);
} }
int8_t* hwc_data = int8_t* hwc_data =
......
...@@ -204,7 +204,7 @@ int get_device_info(const struct DeviceInfo &args) { ...@@ -204,7 +204,7 @@ int get_device_info(const struct DeviceInfo &args) {
int perform_bypass(const struct BypassArgs &args) { int perform_bypass(const struct BypassArgs &args) {
int ret = -1; int ret = -1;
int size = args.image.channels * args.image.width * args.image.height; int size = args.image.channels * args.image.width * args.image.height;
int max_size = 1 << 21; int max_size = 1 << 20;
float times = 1.0 * size / max_size; float times = 1.0 * size / max_size;
int count = static_cast<int>(times); int count = static_cast<int>(times);
......
...@@ -83,26 +83,34 @@ struct ConvParam : PEParam { ...@@ -83,26 +83,34 @@ struct ConvParam : PEParam {
std::vector<int> kernelSize; std::vector<int> kernelSize;
std::vector<int> dilations; std::vector<int> dilations;
Tensor* scale() { return scale_; } Tensor* scale() { return &scale_; }
Tensor* bias() { return bias_; } Tensor* bias() { return &bias_; }
std::vector<BasicConvParam*>& splitParams() { return splitParams_; } std::vector<BasicConvParam*>& splitParams() { return splitParams_; }
~ConvParam() {
for (int i = 0; i < splitParams_.size(); i++) {
BasicConvParam* basic_param = splitParams_[i];
delete basic_param;
}
splitParams_.clear();
}
protected: protected:
std::vector<BasicConvParam*> splitParams_; std::vector<BasicConvParam*> splitParams_;
Tensor* scale_ = new Tensor(); Tensor scale_;
Tensor* bias_ = new Tensor(); Tensor bias_;
}; };
struct DepthwiseConvParam : ConvParam { struct DepthwiseConvParam : ConvParam {
public: public:
Tensor* quantizedFilter() { return quantizedFilter_; } Tensor* quantizedFilter() { return &quantizedFilter_; }
DWconvArgs args; DWconvArgs args;
protected: protected:
Tensor* quantizedFilter_ = new Tensor(); Tensor quantizedFilter_;
}; };
enum PoolingType : int { enum PoolingType : int {
...@@ -142,7 +150,7 @@ struct ElementwiseAddParam : PEParam { ...@@ -142,7 +150,7 @@ struct ElementwiseAddParam : PEParam {
struct ElementwiseMulParam : PEParam { struct ElementwiseMulParam : PEParam {
public: public:
Tensor* input_x; Tensor* input_x = nullptr;
Tensor* input_y = nullptr; Tensor* input_y = nullptr;
Tensor* output = nullptr; Tensor* output = nullptr;
}; };
...@@ -154,13 +162,13 @@ struct FullyConnectedParam : PEParam { ...@@ -154,13 +162,13 @@ struct FullyConnectedParam : PEParam {
Tensor* bias = nullptr; Tensor* bias = nullptr;
Tensor* output = nullptr; Tensor* output = nullptr;
Tensor* quantizedFilter() { return quantizedFilter_; } Tensor* quantizedFilter() { return &quantizedFilter_; }
Tensor* biasScale() { return biasScale_; } Tensor* biasScale() { return &biasScale_; }
protected: protected:
Tensor* quantizedFilter_ = new Tensor(); Tensor quantizedFilter_;
Tensor* biasScale_ = new Tensor(); Tensor biasScale_;
}; };
struct SoftmaxParam : PEParam { struct SoftmaxParam : PEParam {
...@@ -193,10 +201,10 @@ struct NormParam : PEParam { ...@@ -193,10 +201,10 @@ struct NormParam : PEParam {
}; };
struct PriorBoxParam : PEParam { struct PriorBoxParam : PEParam {
Tensor* input; Tensor* input = nullptr;
Tensor* image; Tensor* image = nullptr;
Tensor* outputBoxes; Tensor* outputBoxes = nullptr;
Tensor* outputVariances; Tensor* outputVariances = nullptr;
std::vector<float> minSizes; std::vector<float> minSizes;
std::vector<float> maxSizes; std::vector<float> maxSizes;
...@@ -212,10 +220,10 @@ struct PriorBoxParam : PEParam { ...@@ -212,10 +220,10 @@ struct PriorBoxParam : PEParam {
}; };
struct YoloBoxParam : PEParam { struct YoloBoxParam : PEParam {
Tensor* input; Tensor* input = nullptr;
Tensor* imgSize; Tensor* imgSize = nullptr;
Tensor* outputBoxes; Tensor* outputBoxes = nullptr;
Tensor* outputScores; Tensor* outputScores = nullptr;
int downsampleRatio; int downsampleRatio;
std::vector<int> anchors; std::vector<int> anchors;
int classNum; int classNum;
...@@ -229,15 +237,15 @@ struct ScaleParam : PEParam { ...@@ -229,15 +237,15 @@ struct ScaleParam : PEParam {
Tensor* scale = nullptr; Tensor* scale = nullptr;
Tensor* bias = nullptr; Tensor* bias = nullptr;
Tensor* alignedScale() { return alignedScale_; } Tensor* alignedScale() { return &alignedScale_; }
Tensor* alignedBias() { return alignedBias_; } Tensor* alignedBias() { return &alignedBias_; }
ScaleArgs args = {0}; ScaleArgs args = {0};
protected: protected:
Tensor* alignedScale_ = new Tensor(); Tensor alignedScale_;
Tensor* alignedBias_ = new Tensor(); Tensor alignedBias_;
}; };
struct ResizeParam : PEParam { struct ResizeParam : PEParam {
......
...@@ -212,6 +212,8 @@ class ConvPE : public PE { ...@@ -212,6 +212,8 @@ class ConvPE : public PE {
ConvParam& param() { return param_; } ConvParam& param() { return param_; }
~ConvPE() {}
private: private:
bool use_cpu_ = false; bool use_cpu_ = false;
bool split_channel = false; bool split_channel = false;
......
...@@ -38,7 +38,7 @@ class FullyConnectedPE : public PE { ...@@ -38,7 +38,7 @@ class FullyConnectedPE : public PE {
Tensor* input = param_.input; Tensor* input = param_.input;
convParam_.input = param_.input; convParam_.input = param_.input;
convParam_.output = param_.output; convParam_.output = param_.output;
// convParam_.relu = param_.relu;
convParam_.activeParam.type = param_.activeParam.type; convParam_.activeParam.type = param_.activeParam.type;
convParam_.groups = 1; convParam_.groups = 1;
convParam_.strides = {1, 1}; convParam_.strides = {1, 1};
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include "lite/backends/fpga/KD/llapi/zynqmp_api.h"
#include "lite/backends/fpga/KD/pe.hpp" #include "lite/backends/fpga/KD/pe.hpp"
#include "lite/backends/fpga/KD/pe_params.hpp" #include "lite/backends/fpga/KD/pe_params.hpp"
...@@ -52,6 +53,12 @@ class OutputPE : public PE { ...@@ -52,6 +53,12 @@ class OutputPE : public PE {
memcpy(DLEngine::get_instance().out_data, memcpy(DLEngine::get_instance().out_data,
output->data<void>(), output->data<void>(),
output->shape().numel() * sizeof(float)); output->shape().numel() * sizeof(float));
fpga_reset();
auto max = fpga_get_memory_size_max();
std::cout << "PL ===== Max: ===== :: " << max << std::endl;
return true; return true;
} }
......
...@@ -241,10 +241,13 @@ void PriorBoxPE::compute_prior_box() { ...@@ -241,10 +241,13 @@ void PriorBoxPE::compute_prior_box() {
} }
boxes.flush(); boxes.flush();
boxes.syncToCPU(); // boxes.syncToCPU();
variances.flush(); variances.flush();
output_boxes->copyFrom(&boxes); output_boxes->copyFrom(&boxes);
output_variances->copyFrom(&variances); output_variances->copyFrom(&variances);
output_boxes->invalidate();
output_variances->invalidate();
} }
void PriorBoxPE::apply() {} void PriorBoxPE::apply() {}
...@@ -253,8 +256,9 @@ bool PriorBoxPE::dispatch() { ...@@ -253,8 +256,9 @@ bool PriorBoxPE::dispatch() {
if (cachedBoxes_ == nullptr) { if (cachedBoxes_ == nullptr) {
cachedBoxes_ = new Tensor(); cachedBoxes_ = new Tensor();
cachedVariances_ = new Tensor(); cachedVariances_ = new Tensor();
cachedBoxes_->mutableData<float>(FP32, param_.outputBoxes->shape()); cachedBoxes_->mutableData<float16>(FP16, param_.outputBoxes->shape());
cachedVariances_->mutableData<float>(FP32, param_.outputVariances->shape()); cachedVariances_->mutableData<float16>(FP16,
param_.outputVariances->shape());
cachedBoxes_->setDataLocation(CPU); cachedBoxes_->setDataLocation(CPU);
cachedVariances_->setDataLocation(CPU); cachedVariances_->setDataLocation(CPU);
compute_prior_box(); compute_prior_box();
......
...@@ -23,43 +23,27 @@ class ReluPE : public PE { ...@@ -23,43 +23,27 @@ class ReluPE : public PE {
public: public:
bool init() { bool init() {
Tensor* output = param_.output; Tensor* output = param_.output;
output->setAligned(true); output->setAligned(param_.input->aligned());
output->setDataLocation(Device); output->setDataLocation(CPU);
return true; return true;
} }
void apply() { void apply() {}
Tensor* src = param_.input;
args_.input_data_type = DATA_TYPE_FP16;
args_.output_data_type = DATA_TYPE_FP16;
args_.input_layout_type = LAYOUT_HWC;
args_.output_layout_type = LAYOUT_HWC;
args_.image = {.address = src->data<void>(),
.scale_address = src->scale(),
.channels = (uint32_t)src->shape().channel(),
.width = (uint32_t)src->shape().width(),
.height = (uint32_t)src->shape().height(),
.pad_width = 0u,
.pad_height = 0u};
args_.output = {
.address = param_.output->data<void>(),
.scale_address = param_.output->scale(),
};
inplace_.relu_enable = false;
inplace_.power_enable = false;
inplace_.normalize_enable = false;
}
bool dispatch() { bool dispatch() {
inplace_.relu_enable = true; param_.input->invalidate();
config_inplace(inplace_); int16_t* input_data = param_.input->data<int16_t>();
param_.input->syncToDevice(); float16* out_data = param_.output->data<float16>();
param_.output->copyFrom(param_.input); for (int i = 0; i < param_.input->shape().alignedElementCount(); i++) {
param_.output->invalidate(); int16_t v = param_.input->data<float16>()[i];
inplace_.relu_enable = false; if (v > 0) {
config_inplace(inplace_); out_data[i] = input_data[i];
} else {
out_data[i] = zero;
}
}
param_.output->copyScaleFrom(param_.input);
param_.output->flush();
return true; return true;
} }
...@@ -67,8 +51,7 @@ class ReluPE : public PE { ...@@ -67,8 +51,7 @@ class ReluPE : public PE {
private: private:
InputParam param_; InputParam param_;
BypassArgs args_; float16 zero = float_to_half(0.0f);
InplaceArgs inplace_;
}; };
} // namespace zynqmp } // namespace zynqmp
......
...@@ -36,6 +36,7 @@ class ScalePE : public PE { ...@@ -36,6 +36,7 @@ class ScalePE : public PE {
} }
inline int lcm(int a, int b) { return a * b / gcd(a, b); } inline int lcm(int a, int b) { return a * b / gcd(a, b); }
bool init() { bool init() {
Tensor* output = param_.output; Tensor* output = param_.output;
output->setAligned(true); output->setAligned(true);
......
...@@ -103,12 +103,14 @@ class Tensor { ...@@ -103,12 +103,14 @@ class Tensor {
return reinterpret_cast<Dtype*>(ptr); return reinterpret_cast<Dtype*>(ptr);
} }
void releaseData() {
released = true;
placeHolder_.reset();
}
template <typename Dtype> template <typename Dtype>
Dtype* mutableData(DataType dataType, const Shape& shape) { Dtype* mutableData(DataType dataType, const Shape& shape) {
if (this->shape_ != nullptr) { this->shape_.reset(new Shape(shape));
delete shape_;
}
this->shape_ = new Shape(shape);
this->dataType_ = dataType; this->dataType_ = dataType;
return mutableData<Dtype>(); return mutableData<Dtype>();
} }
...@@ -138,7 +140,7 @@ class Tensor { ...@@ -138,7 +140,7 @@ class Tensor {
DataType dataType() { return this->dataType_; } DataType dataType() { return this->dataType_; }
Shape& shape() { return *shape_; } Shape& shape() { return *(shape_.get()); }
bool aligned() { return this->aligned_; } bool aligned() { return this->aligned_; }
...@@ -247,15 +249,12 @@ class Tensor { ...@@ -247,15 +249,12 @@ class Tensor {
void shareDataWith(Tensor* src) { shareDataWith(src, src->shape()); } void shareDataWith(Tensor* src) { shareDataWith(src, src->shape()); }
void shareDataWith(Tensor* src, const Shape& shape, int offset = 0) { void shareDataWith(Tensor* src, const Shape& shape, int offset = 0) {
if (shape_ != nullptr) {
delete shape_;
}
this->placeHolder_ = src->placeHolder_; this->placeHolder_ = src->placeHolder_;
this->dataType_ = src->dataType_; this->dataType_ = src->dataType_;
this->aligned_ = src->aligned_; this->aligned_ = src->aligned_;
this->dateLocation_ = src->dateLocation_; this->dateLocation_ = src->dateLocation_;
this->offset = offset; this->offset = offset;
shape_ = new Shape(const_cast<Shape&>(shape)); shape_.reset(new Shape(shape));
} }
void copyFrom(Tensor* src) { void copyFrom(Tensor* src) {
...@@ -284,7 +283,6 @@ class Tensor { ...@@ -284,7 +283,6 @@ class Tensor {
.address = data<void>(), .scale_address = scale(), .address = data<void>(), .scale_address = scale(),
}; };
args.output = output; args.output = output;
src->syncToDevice();
size_t aligned_remainder = src->shape().numel() % 16; size_t aligned_remainder = src->shape().numel() % 16;
if (aligned_remainder > 0) { if (aligned_remainder > 0) {
size_t dtype_size = size_t dtype_size =
...@@ -294,12 +292,14 @@ class Tensor { ...@@ -294,12 +292,14 @@ class Tensor {
fpga_flush(dst, aligned_remainder * dtype_size); fpga_flush(dst, aligned_remainder * dtype_size);
} }
src->syncToDevice(); src->syncToDevice();
this->invalidate();
perform_bypass(args); perform_bypass(args);
this->invalidate(); this->invalidate();
} }
void flush() { void flush() {
if (released) {
return;
}
size_t memorySize = placeHolder_->memorySize(); size_t memorySize = placeHolder_->memorySize();
fpga_flush(placeHolder_->data(), memorySize); fpga_flush(placeHolder_->data(), memorySize);
} }
...@@ -380,7 +380,6 @@ class Tensor { ...@@ -380,7 +380,6 @@ class Tensor {
} }
void save_file_with_name(std::string path) { void save_file_with_name(std::string path) {
invalidate();
std::ofstream ofs; std::ofstream ofs;
ofs.open(path); ofs.open(path);
ofs << scale()[0] << " / " << scale()[1] << std::endl; ofs << scale()[0] << " / " << scale()[1] << std::endl;
...@@ -389,11 +388,17 @@ class Tensor { ...@@ -389,11 +388,17 @@ class Tensor {
float value = 0; float value = 0;
if (dataType_ == FP32) { if (dataType_ == FP32) {
value = data<float>()[i]; value = data<float>()[i];
} else if (dataType_ == FP16) { }
if (dataType_ == FP16) {
value = half_to_float(data<float16>()[i]); value = half_to_float(data<float16>()[i]);
} else { }
if (dataType_ == INT8) {
value = data<int8_t>()[i]; value = data<int8_t>()[i];
} }
if (dataType_ == INT32) {
value = data<int32_t>()[i];
}
ofs << value << std::endl; ofs << value << std::endl;
} }
ofs.close(); ofs.close();
...@@ -451,18 +456,12 @@ class Tensor { ...@@ -451,18 +456,12 @@ class Tensor {
return os; return os;
} }
~Tensor() {
if (shape_ != nullptr) {
delete shape_;
shape_ = nullptr;
}
}
private: private:
bool released = false;
int offset = 0; int offset = 0;
float mem_scale_factor_ = 1.0f; float mem_scale_factor_ = 1.0f;
std::shared_ptr<PlaceHolder> placeHolder_; std::shared_ptr<PlaceHolder> placeHolder_;
Shape* shape_ = nullptr; std::shared_ptr<Shape> shape_;
DataType dataType_ = FP32; DataType dataType_ = FP32;
bool aligned_ = false; bool aligned_ = false;
DataSyncStatus synchedStatus_ = Synched; DataSyncStatus synchedStatus_ = Synched;
......
...@@ -69,7 +69,7 @@ std::string DDimLite::repr() const { ...@@ -69,7 +69,7 @@ std::string DDimLite::repr() const {
} }
void TensorLite::ShareDataWith(const TensorLite &other) { void TensorLite::ShareDataWith(const TensorLite &other) {
buffer_ = other.buffer_; buffer_ = other.buffer_; // TODO(chonwhite) delete buffer;
dims_ = other.dims_; dims_ = other.dims_;
zynq_tensor_ = other.zynq_tensor_; zynq_tensor_ = other.zynq_tensor_;
target_ = other.target_; target_ = other.target_;
...@@ -79,10 +79,8 @@ void TensorLite::ShareDataWith(const TensorLite &other) { ...@@ -79,10 +79,8 @@ void TensorLite::ShareDataWith(const TensorLite &other) {
} }
void *TensorLite::mutable_data(size_t memory_size) { void *TensorLite::mutable_data(size_t memory_size) {
memory_size_ = memory_size; memory_size_ = memory_size; // TODO(chonwhite) delete buffer;
buffer_->ResetLazy(target_, memory_size_); buffer_->ResetLazy(target_, memory_size_);
// throw -1;
std::cout << memory_size << std::endl;
return buffer_->data(); return buffer_->data();
} }
...@@ -95,13 +93,20 @@ void TensorLite::CopyDataFrom(const TensorLite &other) { ...@@ -95,13 +93,20 @@ void TensorLite::CopyDataFrom(const TensorLite &other) {
dims_ = other.dims_; dims_ = other.dims_;
target_ = other.target_; target_ = other.target_;
lod_ = other.lod_; lod_ = other.lod_;
auto dt = zynq_tensor_->dataType();
auto shape = other.zynq_tensor_->shape(); if (zynq_tensor_.get() == nullptr) {
zynq_tensor_.reset(new zynqmp::Tensor());
}
auto dt = zynq_tensor_->dataType();
Resize(other.dims()); Resize(other.dims());
auto shape = other.zynq_tensor_->shape();
zynq_tensor_->mutableData<void>(zynq_tensor_->dataType(), shape); zynq_tensor_->mutableData<void>(zynq_tensor_->dataType(), shape);
this->ZynqTensor()->copyFrom(other.ZynqTensor());
// this->ZynqTensor()->copyFrom(other.ZynqTensor());
memcpy(this->ZynqTensor()->data<void>(),
other.ZynqTensor()->data<void>(),
other.ZynqTensor()->shape().numel() * sizeof(float));
} }
} // namespace lite } // namespace lite
......
...@@ -81,6 +81,8 @@ class DDimLite { ...@@ -81,6 +81,8 @@ class DDimLite {
return !(a == b); return !(a == b);
} }
~DDimLite() {}
private: private:
std::vector<value_type> data_; std::vector<value_type> data_;
}; };
...@@ -142,7 +144,9 @@ class TensorLite { ...@@ -142,7 +144,9 @@ class TensorLite {
void *mutable_data(size_t memory_size); void *mutable_data(size_t memory_size);
void *mutable_data(TargetType target, size_t memory_size); void *mutable_data(TargetType target, size_t memory_size);
const void *raw_data() const { return buffer_->data(); } const void *raw_data() const {
return buffer_->data();
} // TODO(chonwhite) delete buffer;
size_t data_size() const { return this->dims().production(); } size_t data_size() const { return this->dims().production(); }
...@@ -150,17 +154,19 @@ class TensorLite { ...@@ -150,17 +154,19 @@ class TensorLite {
size_t offset() const { return offset_; } size_t offset() const { return offset_; }
bool IsInitialized() const { return buffer_->data(); } bool IsInitialized() const {
void clear() { return buffer_->data();
buffer_->Free(); } // TODO(chonwhite) delete buffer;
offset_ = 0;
}
// Other share data to this. // Other share data to this.
void ShareDataWith(const TensorLite &other); void ShareDataWith(const TensorLite &other);
void CopyDataFrom(const TensorLite &other); void CopyDataFrom(const TensorLite &other);
void clear() {
// zynq_tensor_->releaseData();
}
template <typename T> template <typename T>
TensorLite Slice(int64_t begin, int64_t end) const; TensorLite Slice(int64_t begin, int64_t end) const;
...@@ -169,7 +175,10 @@ class TensorLite { ...@@ -169,7 +175,10 @@ class TensorLite {
TargetType target() const { return target_; } TargetType target() const { return target_; }
zynqmp::Tensor *ZynqTensor() const { return zynq_tensor_; } // template <typename T>
// TensorLite Slice(int64_t begin, int64_t end) const;
zynqmp::Tensor *ZynqTensor() const { return zynq_tensor_.get(); }
friend std::ostream &operator<<(std::ostream &os, const TensorLite &tensor) { friend std::ostream &operator<<(std::ostream &os, const TensorLite &tensor) {
os << "Tensor:" << '\n'; os << "Tensor:" << '\n';
...@@ -198,12 +207,34 @@ class TensorLite { ...@@ -198,12 +207,34 @@ class TensorLite {
size_t memory_size_{}; size_t memory_size_{};
size_t offset_{0}; size_t offset_{0};
zynqmp::Tensor *zynq_tensor_ = new zynqmp::Tensor(); std::shared_ptr<zynqmp::Tensor> zynq_tensor_;
template <typename T> template <typename T>
void mutable_data_internal(); void mutable_data_internal();
}; };
template <typename T>
zynqmp::DataType get_date_type() {
zynqmp::DataType data_type = zynqmp::FP32;
if (typeid(T) == typeid(float)) {
data_type = zynqmp::FP32;
}
if (typeid(T) == typeid(zynqmp::float16)) {
data_type = zynqmp::FP16;
}
if (typeid(T) == typeid(int)) {
data_type = zynqmp::INT32;
}
if (typeid(T) == typeid(int32_t)) {
data_type = zynqmp::INT32;
}
if (typeid(T) == typeid(int8_t)) {
data_type = zynqmp::INT8;
}
return data_type;
}
template <typename T, typename R> template <typename T, typename R>
R *TensorLite::mutable_data() { R *TensorLite::mutable_data() {
std::vector<int> v; std::vector<int> v;
...@@ -229,14 +260,12 @@ R *TensorLite::mutable_data() { ...@@ -229,14 +260,12 @@ R *TensorLite::mutable_data() {
break; break;
} }
zynqmp::Shape input_shape(layout_type, v); zynqmp::Shape input_shape(layout_type, v);
zynqmp::DataType data_type = get_date_type<T>();
zynqmp::DataType data_type = zynqmp::FP32; if (zynq_tensor_.get() == nullptr) {
if (typeid(T) == typeid(float)) { zynq_tensor_.reset(new zynqmp::Tensor());
data_type = zynqmp::FP32;
}
if (typeid(T) == typeid(zynqmp::float16)) {
data_type = zynqmp::FP16;
} }
return zynq_tensor_->mutableData<R>(data_type, input_shape); return zynq_tensor_->mutableData<R>(data_type, input_shape);
} }
...@@ -276,6 +305,7 @@ TensorLite TensorLite::Slice(int64_t begin, int64_t end) const { ...@@ -276,6 +305,7 @@ TensorLite TensorLite::Slice(int64_t begin, int64_t end) const {
template <typename T> template <typename T>
void TensorLite::Slice(TensorLite &dst, int64_t begin, int64_t end) const { void TensorLite::Slice(TensorLite &dst, int64_t begin, int64_t end) const {
// TODO(chonwhite) delete this function;
CHECK_GE(begin, 0); CHECK_GE(begin, 0);
CHECK_LE(end, dims_[0]); CHECK_LE(end, dims_[0]);
CHECK_LT(begin, end); CHECK_LT(begin, end);
......
...@@ -25,7 +25,7 @@ namespace mir { ...@@ -25,7 +25,7 @@ namespace mir {
void ConvActivationFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { void ConvActivationFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
std::vector<std::string> act_types{"relu"}; std::vector<std::string> act_types{"relu"};
for (auto& place : graph->valid_places()) { for (auto& place : graph->valid_places()) {
if (place.target == TARGET(kCUDA)) { if (place.target == TARGET(kCUDA) || place.target == TARGET(kFPGA)) {
act_types.push_back("leaky_relu"); act_types.push_back("leaky_relu");
break; break;
} }
......
...@@ -103,6 +103,12 @@ void DeleteDynamicQuantOpFuser::InsertNewNode(SSAGraph* graph, ...@@ -103,6 +103,12 @@ void DeleteDynamicQuantOpFuser::InsertNewNode(SSAGraph* graph,
// obtain values, save values and relink node // obtain values, save values and relink node
int bit_length = quant_node->stmt()->op_info()->GetAttr<int>("bit_length"); int bit_length = quant_node->stmt()->op_info()->GetAttr<int>("bit_length");
int range = ((1 << (bit_length - 1)) - 1);
auto* scope = quant_node->stmt()->op()->scope();
auto* scale_tensor = scope->FindVar(output_scale_node->arg()->name)
->GetMutable<lite::Tensor>();
float scale_value = scale_tensor->data<float>()[0] / range;
auto outlinks = output_act_node->outlinks; auto outlinks = output_act_node->outlinks;
for (auto* quantized_node : outlinks) { for (auto* quantized_node : outlinks) {
auto* op_desc = quantized_node->stmt()->mutable_op_info(); auto* op_desc = quantized_node->stmt()->mutable_op_info();
...@@ -208,9 +214,11 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph, ...@@ -208,9 +214,11 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph,
for (int i = 0; i < weight_scale_size; i++) { for (int i = 0; i < weight_scale_size; i++) {
weight_scale.push_back(whole_weight_scale); weight_scale.push_back(whole_weight_scale);
} }
#ifndef LITE_WITH_FPGA #ifndef LITE_WITH_FPGA
op_desc.SetAttr("enable_int8", true); op_desc.SetAttr("enable_int8", true);
#endif #endif
if (quantized_op->stmt()->op_info()->HasAttr("input_scale")) { if (quantized_op->stmt()->op_info()->HasAttr("input_scale")) {
op_desc.SetAttr("input_scale", input_scale); op_desc.SetAttr("input_scale", input_scale);
} }
...@@ -689,13 +697,16 @@ void DynamicQuantDequantOpFuser::InsertNewNode(SSAGraph* graph, ...@@ -689,13 +697,16 @@ void DynamicQuantDequantOpFuser::InsertNewNode(SSAGraph* graph,
float* temp_data = temp_tensor.mutable_data<float>(); float* temp_data = temp_tensor.mutable_data<float>();
size_t weight_num = quantized_weight_t->data_size(); size_t weight_num = quantized_weight_t->data_size();
quantized_weight_t->set_persistable(true); quantized_weight_t->set_persistable(true);
std::cout << "DynamicQuantDequantOpFuser::InsertNewNode====================" std::cout << "DynamicQuantDequantOpFuser::InsertNewNode===================="
"========================================" "========================================"
<< std::endl; << std::endl;
#ifdef LITE_WITH_FPGA #ifdef LITE_WITH_FPGA
float* quantized_weight_data = quantized_weight_t->mutable_data<float>(); float* quantized_weight_data = quantized_weight_t->mutable_data<float>();
for (size_t i = 0; i < weight_num; i++) { for (size_t i = 0; i < weight_num; i++) {
quantized_weight_data[i] = temp_data[i] * whole_weight_scale; quantized_weight_data[i] = temp_data[i] * whole_weight_scale;
std::cout << whole_weight_scale << "," << temp_data[i] << "," std::cout << whole_weight_scale << "," << temp_data[i] << ","
<< quantized_weight_data[i] << std::endl; << quantized_weight_data[i] << std::endl;
} }
......
...@@ -86,6 +86,8 @@ class KernelPlaceCorrectPass : public DebugPass { ...@@ -86,6 +86,8 @@ class KernelPlaceCorrectPass : public DebugPass {
<< node_name; << node_name;
VLOG(4) << "-- input arg_name:" << arg_name << " " VLOG(4) << "-- input arg_name:" << arg_name << " "
<< "-- node name:" << node_name; << "-- node name:" << node_name;
auto type = inst.picked_kernel().GetInputDeclType(arg_name);
if (!x_in->AsArg().type) { if (!x_in->AsArg().type) {
need_correct_place &= false; need_correct_place &= false;
} else { } else {
...@@ -107,6 +109,8 @@ class KernelPlaceCorrectPass : public DebugPass { ...@@ -107,6 +109,8 @@ class KernelPlaceCorrectPass : public DebugPass {
<< node_name << " in Inst " << node_name << " in Inst "
<< inst.op_type(); << inst.op_type();
VLOG(4) << "-- output arg_name " << arg_name; VLOG(4) << "-- output arg_name " << arg_name;
auto type = inst.picked_kernel().GetOutputDeclType(arg_name);
if (!x_out->AsArg().type) { if (!x_out->AsArg().type) {
need_correct_place &= false; need_correct_place &= false;
} else { } else {
......
文件模式从 100644 更改为 100755
...@@ -139,7 +139,14 @@ void RuntimeProgram::Run() { ...@@ -139,7 +139,14 @@ void RuntimeProgram::Run() {
for (auto& inst : instructions_) { for (auto& inst : instructions_) {
#ifndef LITE_WITH_FPGA #ifndef LITE_WITH_FPGA
if (inst.is_feed_fetch_op()) continue; if (inst.is_feed_fetch_op()) continue;
#endif
std::string op_type = inst.op()->op_info()->Type(); std::string op_type = inst.op()->op_info()->Type();
VLOG(4) << ">> Running kernel: " << inst.op()->op_info()->Repr()
<< " on Target " << TargetToStr(inst.kernel()->target());
#ifndef LITE_WITH_FPGA
if (op_type == "feed" || op_type == "fetch") continue;
#endif #endif
inst.Run(); inst.Run();
#ifdef LITE_WITH_PROFILE #ifdef LITE_WITH_PROFILE
......
...@@ -46,7 +46,7 @@ class Tensor { ...@@ -46,7 +46,7 @@ class Tensor {
*/ */
class PaddlePredictor { class PaddlePredictor {
public: public:
void Init(); void Init() {}
std::unique_ptr<Tensor> GetTensor(const std::string &id) const; std::unique_ptr<Tensor> GetTensor(const std::string &id) const;
std::unique_ptr<Tensor> GetMutableTensor(const std::string &id); std::unique_ptr<Tensor> GetMutableTensor(const std::string &id);
......
...@@ -59,6 +59,7 @@ void SequencePoolCompute::Run() { ...@@ -59,6 +59,7 @@ void SequencePoolCompute::Run() {
for (int i = 0; i <= batch_size; i++) { for (int i = 0; i <= batch_size; i++) {
offset_new[i] = i; offset_new[i] = i;
} }
(output->mutable_lod())->clear();
(output->mutable_lod())->push_back(offset_new); (output->mutable_lod())->push_back(offset_new);
} }
......
...@@ -5,28 +5,32 @@ endif() ...@@ -5,28 +5,32 @@ endif()
set(fpga_deps fpga_target_wrapper kernel_fpga) set(fpga_deps fpga_target_wrapper kernel_fpga)
# add_kernel(activation_compute_fpga FPGA basic SRCS activation_compute.cc DEPS ${fpga_deps}) add_kernel(activation_compute_fpga FPGA basic SRCS activation_compute.cc DEPS ${fpga_deps})
# add_kernel(box_coder_compute_fpga FPGA basic SRCS box_coder_compute.cc DEPS ${fpga_deps}) # add_kernel(box_coder_compute_fpga FPGA basic SRCS box_coder_compute.cc DEPS ${fpga_deps})
# add_kernel(concat_compute_fpga FPGA basic SRCS concat_compute.cc DEPS ${fpga_deps})
add_kernel(concat_compute_fpga FPGA basic SRCS concat_compute.cc DEPS ${fpga_deps})
add_kernel(conv_compute_fpga FPGA basic SRCS conv_compute.cc DEPS ${fpga_deps}) add_kernel(conv_compute_fpga FPGA basic SRCS conv_compute.cc DEPS ${fpga_deps})
# add_kernel(density_prior_box_compute_fpga FPGA basic SRCS density_prior_box_compute.cc DEPS ${fpga_deps}) # add_kernel(density_prior_box_compute_fpga FPGA basic SRCS density_prior_box_compute.cc DEPS ${fpga_deps})
add_kernel(dropout_compute_fpga FPGA basic SRCS dropout_compute.cc DEPS ${fpga_deps}) add_kernel(dropout_compute_fpga FPGA basic SRCS dropout_compute.cc DEPS ${fpga_deps})
add_kernel(elementwise_compute_fpga FPGA basic SRCS elementwise_compute.cc DEPS ${fpga_deps}) add_kernel(elementwise_compute_fpga FPGA basic SRCS elementwise_compute.cc DEPS ${fpga_deps})
# add_kernel(feed_compute_fpga FPGA basic SRCS fc_compute.cc DEPS ${fpga_deps})
add_kernel(fc_compute_fpga FPGA basic SRCS fc_compute.cc DEPS ${fpga_deps}) add_kernel(fc_compute_fpga FPGA basic SRCS fc_compute.cc DEPS ${fpga_deps})
add_kernel(gru_compute_fpga FPGA extra SRCS gru_compute.cc DEPS ${fpga_deps}) add_kernel(gru_compute_fpga FPGA extra SRCS gru_compute.cc DEPS ${fpga_deps})
# add_kernel(mul_compute_fpga FPGA basic SRCS mul_compute.cc DEPS ${fpga_deps}) # add_kernel(mul_compute_fpga FPGA basic SRCS mul_compute.cc DEPS ${fpga_deps})
add_kernel(multiclass_nms_compute_fpga FPGA basic SRCS multiclass_nms_compute.cc DEPS ${fpga_deps}) add_kernel(multiclass_nms_compute_fpga FPGA basic SRCS multiclass_nms_compute.cc DEPS ${fpga_deps})
add_kernel(norm_compute_fpga FPGA basic SRCS norm_compute.cc DEPS ${fpga_deps}) add_kernel(norm_compute_fpga FPGA basic SRCS norm_compute.cc DEPS ${fpga_deps})
# add_kernel(im2sequence_compute_fpga FPGA basic SRCS im2sequence_compute.cc DEPS ${fpga_deps}) # add_kernel(im2sequence_compute_fpga FPGA basic SRCS im2sequence_compute.cc DEPS ${fpga_deps})
add_kernel(pooling_compute_fpga FPGA basic SRCS pooling_compute.cc DEPS ${fpga_deps}) add_kernel(pooling_compute_fpga FPGA basic SRCS pooling_compute.cc DEPS ${fpga_deps})
add_kernel(prior_box_compute_fpga FPGA basic SRCS prior_box_compute.cc DEPS ${fpga_deps}) add_kernel(prior_box_compute_fpga FPGA basic SRCS prior_box_compute.cc DEPS ${fpga_deps})
# add_kernel(reshape_compute_fpga FPGA basic SRCS reshape_compute.cc DEPS ${fpga_deps} reshape_op) add_kernel(reshape_compute_fpga FPGA basic SRCS reshape_compute.cc DEPS ${fpga_deps} reshape_op)
# add_kernel(sequence_pool_compute_fpga FPGA basic SRCS sequence_pool_compute.cc DEPS ${fpga_deps}) # add_kernel(sequence_pool_compute_fpga FPGA basic SRCS sequence_pool_compute.cc DEPS ${fpga_deps})
add_kernel(scale_compute_fpga FPGA basic SRCS scale_compute.cc DEPS ${fpga_deps}) add_kernel(scale_compute_fpga FPGA basic SRCS scale_compute.cc DEPS ${fpga_deps})
# add_kernel(softmax_compute_fpga FPGA basic SRCS softmax_compute.cc DEPS ${fpga_deps}) add_kernel(softmax_compute_fpga FPGA basic SRCS softmax_compute.cc DEPS ${fpga_deps})
# add_kernel(transpose_compute_fpga FPGA basic SRCS transpose_compute.cc DEPS ${fpga_deps}) add_kernel(split_compute_fpga FPGA basic SRCS split_compute.cc DEPS ${fpga_deps})
add_kernel(transpose_compute_fpga FPGA basic SRCS transpose_compute.cc DEPS ${fpga_deps})
add_kernel(io_copy_compute_fpga FPGA basic SRCS io_copy_compute.cc DEPS ${fpga_deps}) add_kernel(io_copy_compute_fpga FPGA basic SRCS io_copy_compute.cc DEPS ${fpga_deps})
add_kernel(calib_compute_fpga FPGA basic SRCS calib_compute.cc DEPS ${fpga_deps}) add_kernel(calib_compute_fpga FPGA basic SRCS calib_compute.cc DEPS ${fpga_deps})
......
...@@ -25,10 +25,10 @@ using float16 = zynqmp::float16; ...@@ -25,10 +25,10 @@ using float16 = zynqmp::float16;
void ReluCompute::PrepareForRun() { void ReluCompute::PrepareForRun() {
auto& param = this->Param<param_t>(); auto& param = this->Param<param_t>();
auto output_data = param.Out->mutable_data<float16>(); auto output_data = param.Out->mutable_data<float16>();
zynqmp::InputParam& input_param = pe_.param(); zynqmp::InputParam& relu_param = pe_.param();
input_param.input = param.X->ZynqTensor(); relu_param.input = param.X->ZynqTensor();
input_param.output = param.Out->ZynqTensor(); relu_param.output = param.Out->ZynqTensor();
pe_.init(); pe_.init();
pe_.apply(); pe_.apply();
} }
......
...@@ -72,12 +72,6 @@ void ConvCompute::PrepareForRun() { ...@@ -72,12 +72,6 @@ void ConvCompute::PrepareForRun() {
conv_param.activeParam.type = zynqmp::TYPE_RELU; conv_param.activeParam.type = zynqmp::TYPE_RELU;
} }
// conv_param.filter->saveToFile("conv_filter_", true);
// if (param.bias != nullptr) {
// std::cout << "param.bias != nullptr" << std::endl;
// conv_param.bias()->saveToFile("conv_bias_", true);
// }
conv_pe_.init(); conv_pe_.init();
conv_pe_.apply(); conv_pe_.apply();
} }
......
...@@ -80,21 +80,21 @@ void ElementwiseMulCompute::PrepareForRun() { ...@@ -80,21 +80,21 @@ void ElementwiseMulCompute::PrepareForRun() {
scale_param.activeParam.type = zynqmp::TYPE_NONE; scale_param.activeParam.type = zynqmp::TYPE_NONE;
int channel = scale_param.input->shape().channel(); int channel = scale_param.input->shape().channel();
zynqmp::Tensor* scale = new zynqmp::Tensor(); scale_param.scale = &scale_;
zynqmp::Tensor* bias = new zynqmp::Tensor(); scale_param.bias = &bias_;
scale_param.scale = scale;
scale_param.bias = bias;
zynqmp::Shape shape(zynqmp::N, {channel}); zynqmp::Shape shape(zynqmp::N, {channel});
float* scale_data = scale->mutableData<float>(zynqmp::FP32, shape); zynqmp::float16* scale_data =
float* bias_data = bias->mutableData<float>(zynqmp::FP32, shape); scale_.mutableData<zynqmp::float16>(zynqmp::FP16, shape);
zynqmp::float16* bias_data =
bias_.mutableData<zynqmp::float16>(zynqmp::FP16, shape);
float scale_value = param.Y->data<float>()[0]; float scale_value = param.Y->data<float>()[0];
for (int i = 0; i < channel; ++i) { for (int i = 0; i < channel; i++) {
if (param.Y->dims().production() != 1) { if (param.Y->dims().production() != 1) {
scale_value = param.Y->ZynqTensor()->data<float>()[i]; scale_value = param.Y->ZynqTensor()->data<float>()[i];
} }
scale_data[i] = scale_value; scale_data[i] = zynqmp::float_to_half(scale_value);
bias_data[i] = 0; bias_data[i] = zero_;
} }
pe_.init(); pe_.init();
...@@ -102,6 +102,10 @@ void ElementwiseMulCompute::PrepareForRun() { ...@@ -102,6 +102,10 @@ void ElementwiseMulCompute::PrepareForRun() {
} }
void ElementwiseMulCompute::Run() { void ElementwiseMulCompute::Run() {
auto& param = Param<operators::ElementwiseParam>();
param.Y->ZynqTensor()->flush();
scale_.copyFrom(param.Y->ZynqTensor());
scale_.invalidate();
pe_.dispatch(); pe_.dispatch();
#ifdef FPGA_PRINT_TENSOR #ifdef FPGA_PRINT_TENSOR
zynqmp::ScaleParam& scale_param = pe_.param(); zynqmp::ScaleParam& scale_param = pe_.param();
......
...@@ -61,6 +61,9 @@ class ElementwiseMulCompute ...@@ -61,6 +61,9 @@ class ElementwiseMulCompute
private: private:
zynqmp::ScalePE pe_; zynqmp::ScalePE pe_;
zynqmp::Tensor scale_;
zynqmp::Tensor bias_;
zynqmp::float16 zero_ = zynqmp::float_to_half(0.0f);
}; };
} // namespace fpga } // namespace fpga
......
...@@ -93,21 +93,6 @@ void elementwise_compute_ref(const operators::ElementwiseParam& param, ...@@ -93,21 +93,6 @@ void elementwise_compute_ref(const operators::ElementwiseParam& param,
} }
// do elementwise add/sub/max... // do elementwise add/sub/max...
if (elt_type == "add") { if (elt_type == "add") {
// for (int i = 0; i < batch; ++i) {
// for (int j = 0; j < channels; ++j) {
// int offset = (i * channels + j) * num;
// const dtype* din_ptr = x_data + offset;
// const dtype diny_data = y_data[j];
// dtype* dout_ptr = out_data + offset;
// for (int k = 0; k < num; ++k) {
// *dout_ptr =
// zynqmp::float_to_half(sum(zynqmp::half_to_float(*din_ptr),
// zynqmp::half_to_float(diny_data)));
// dout_ptr++;
// din_ptr++;
// }
// }
// }
int count = x_dims[0] * x_dims[1] * x_dims[2] * x_dims[3]; int count = x_dims[0] * x_dims[1] * x_dims[2] * x_dims[3];
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
out_data[i] = zynqmp::float_to_half(sum( out_data[i] = zynqmp::float_to_half(sum(
...@@ -229,75 +214,6 @@ TEST(fusion_elementwise_add_activation_fpga, retrive_op) { ...@@ -229,75 +214,6 @@ TEST(fusion_elementwise_add_activation_fpga, retrive_op) {
ASSERT_TRUE(fusion_elementwise_add_activation.front()); ASSERT_TRUE(fusion_elementwise_add_activation.front());
} }
// TEST(fusion_elementwise_add_activation_fpga, init) {
// ElementwiseAddActivationCompute fusion_elementwise_add_activation;
// ASSERT_EQ(fusion_elementwise_add_activation.precision(), PRECISION(kFP16));
// ASSERT_EQ(fusion_elementwise_add_activation.target(), TARGET(kFPGA));
// }
// TEST(fusion_elementwise_add_activation_fpga, compute) {
// ElementwiseAddActivationCompute fusion_elementwise_add_activation;
// operators::FusionElementwiseActivationParam param;
// lite::Tensor x, y, output, output_ref;
// for (auto act_type : {"relu"}) {
// for (auto n : {1}) {
// for (auto c : {8}) {
// for (auto h : {8}) {
// for (auto w : {8}) {
// for (auto axis : {0}) {
// for (auto yd : {std::vector<int64_t>({n, c, h, w})}) {
// auto x_dim = DDim(std::vector<int64_t>({n, c, h, w}));
// auto y_dim = DDim(yd);
// int axis_t = axis < 0 ? x_dim.size() - y_dim.size() : axis;
// if (axis_t + y_dim.size() > 4) continue;
// bool flag = false;
// for (int i = 0; i < y_dim.size(); i++) {
// if (x_dim[i + axis_t] != y_dim[i]) flag = true;
// }
// if (flag) continue;
// x.Resize(x_dim);
// y.Resize(y_dim);
// output.Resize(x_dim);
// output_ref.Resize(x_dim);
// auto* x_data = x.mutable_data<float16>(TARGET(kFPGA));
// auto* y_data = y.mutable_data<float16>(TARGET(kFPGA));
// auto* output_data =
// output.mutable_data<float16>(TARGET(kFPGA));
// auto* output_ref_data =
// output_ref.mutable_data<float16>(TARGET(kFPGA));
// for (int i = 0; i < x_dim.production(); i++) {
// float sign = i % 3 == 0 ? -1.0f : 1.0f;
// x_data[i] = zynqmp::float_to_half(i * sign);
// }
// for (int i = 0; i < y_dim.production(); i++) {
// float sign = i % 2 == 0 ? 0.5f : -0.5f;
// y_data[i] = zynqmp::float_to_half(i * sign);
// }
// param.X = &x;
// param.Y = &y;
// param.axis = axis;
// param.Out = &output;
// param.act_type = act_type;
// fusion_elementwise_add_activation.SetParam(param);
// fusion_elementwise_add_activation.PrepareForRun();
// fusion_elementwise_add_activation.Run();
// param.Out = &output_ref;
// elementwise_compute_ref<float16>(param, "add", act_type);
// for (int i = 0; i < output.dims().production(); i++) {
// EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5);
// }
// }
// }
// }
// }
// }
// }
// }
// }
} // namespace fpga } // namespace fpga
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
......
...@@ -40,8 +40,8 @@ void FeedCompute::PrepareForRun() { ...@@ -40,8 +40,8 @@ void FeedCompute::PrepareForRun() {
void FeedCompute::Run() { void FeedCompute::Run() {
auto& param = this->Param<param_t>(); auto& param = this->Param<param_t>();
Tensor& x = param.feed_list->at(param.col); Tensor& x = param.feed_list->at(param.col);
pe_.param().input = x.ZynqTensor();
pe_.dispatch(); pe_.dispatch();
auto out_lod = param.out->mutable_lod(); auto out_lod = param.out->mutable_lod();
*out_lod = x.lod(); *out_lod = x.lod();
......
...@@ -55,6 +55,7 @@ void FetchCompute::Run() { ...@@ -55,6 +55,7 @@ void FetchCompute::Run() {
#ifdef FPGA_PRINT_TENSOR #ifdef FPGA_PRINT_TENSOR
zynqmp::OutputParam& fetch_param = pe_.param(); zynqmp::OutputParam& fetch_param = pe_.param();
Debugger::get_instance().registerOutput("fetch", fetch_param.output); Debugger::get_instance().registerOutput("fetch", fetch_param.output);
Debugger::get_instance().setEnable(true);
#endif #endif
} }
......
...@@ -45,21 +45,32 @@ class IoCopyHostToFpgaCompute ...@@ -45,21 +45,32 @@ class IoCopyHostToFpgaCompute
auto& param = Param<operators::IoCopyParam>(); auto& param = Param<operators::IoCopyParam>();
CHECK(param.x->target() == TARGET(kHost) || CHECK(param.x->target() == TARGET(kHost) ||
param.x->target() == TARGET(kFPGA)); param.x->target() == TARGET(kFPGA));
param.y->mutable_data<float16>(); param.x->ZynqTensor()->flush();
if (param.x->ZynqTensor()->aligned() &&
param.x->ZynqTensor()->shape().shouldAlign()) { if (param.x->ZynqTensor()->dataType() == zynqmp::INT32) {
zynqmp::Tensor tempTensor; param.y->mutable_data<int>();
tempTensor.mutableData<float16>(zynqmp::FP16,
param.x->ZynqTensor()->shape());
tempTensor.copyFrom(param.x->ZynqTensor());
tempTensor.setAligned(true);
tempTensor.unalignImage();
param.y->ZynqTensor()->copyFrom(&tempTensor);
} else {
param.y->ZynqTensor()->copyFrom(param.x->ZynqTensor()); param.y->ZynqTensor()->copyFrom(param.x->ZynqTensor());
return;
} }
param.y->ZynqTensor()->invalidate();
param.y->ZynqTensor()->copyScaleFrom(param.x->ZynqTensor()); if (param.x->ZynqTensor()->dataType() == zynqmp::FP32) {
param.y->mutable_data<float16>();
if (param.x->ZynqTensor()->aligned() &&
param.x->ZynqTensor()->shape().shouldAlign()) {
zynqmp::Tensor tempTensor;
tempTensor.mutableData<float16>(zynqmp::FP16,
param.x->ZynqTensor()->shape());
tempTensor.copyFrom(param.x->ZynqTensor());
tempTensor.setAligned(true);
tempTensor.unalignImage();
param.y->ZynqTensor()->copyFrom(&tempTensor);
} else {
param.y->ZynqTensor()->copyFrom(param.x->ZynqTensor());
}
param.y->ZynqTensor()->invalidate();
param.y->ZynqTensor()->copyScaleFrom(param.x->ZynqTensor());
}
auto out_lod = param.y->mutable_lod(); auto out_lod = param.y->mutable_lod();
*out_lod = param.x->lod(); *out_lod = param.x->lod();
} }
......
...@@ -80,7 +80,8 @@ void mul(MulCompute* k) { ...@@ -80,7 +80,8 @@ void mul(MulCompute* k) {
} }
void MulCompute::Run() { void MulCompute::Run() {
pe_.dispatch(); // pe_.dispatch();
mul(this);
#ifdef FPGA_PRINT_TENSOR #ifdef FPGA_PRINT_TENSOR
zynqmp::FullyConnectedParam& fc_param = pe_.param(); zynqmp::FullyConnectedParam& fc_param = pe_.param();
Debugger::get_instance().registerOutput("mul", fc_param.output); Debugger::get_instance().registerOutput("mul", fc_param.output);
......
...@@ -318,14 +318,29 @@ void MultiClassOutput(const Tensor& scores, ...@@ -318,14 +318,29 @@ void MultiClassOutput(const Tensor& scores,
void MulticlassNmsCompute::Run() { void MulticlassNmsCompute::Run() {
auto& param = Param<operators::MulticlassNmsParam>(); auto& param = Param<operators::MulticlassNmsParam>();
auto* boxes = param.bboxes; auto* boxes_in = param.bboxes;
auto* scores = param.scores; auto* scores_in = param.scores;
auto* outs = param.out; auto* outs = param.out;
outs->mutable_data<float>(); outs->mutable_data<float>();
auto score_dims = scores->dims(); auto score_dims = boxes_in->dims();
auto score_size = score_dims.size(); auto score_size = score_dims.size();
Tensor boxes_float;
Tensor scores_float;
boxes_float.Resize(boxes_in->dims());
scores_float.Resize(scores_in->dims());
boxes_float.mutable_data<float>();
scores_float.mutable_data<float>();
boxes_float.ZynqTensor()->copyFrom(boxes_in->ZynqTensor());
scores_float.ZynqTensor()->copyFrom(scores_in->ZynqTensor());
Tensor* boxes = &boxes_float;
Tensor* scores = &scores_float;
auto box_dims = boxes->dims(); auto box_dims = boxes->dims();
int64_t box_dim = boxes->dims()[2]; int64_t box_dim = boxes->dims()[2];
...@@ -383,6 +398,7 @@ void MulticlassNmsCompute::Run() { ...@@ -383,6 +398,7 @@ void MulticlassNmsCompute::Run() {
MultiClassOutput<float>( MultiClassOutput<float>(
scores_slice, boxes_slice, all_indices[i], score_dims.size(), &out); scores_slice, boxes_slice, all_indices[i], score_dims.size(), &out);
outs->ZynqTensor()->copyFrom(out.ZynqTensor()); outs->ZynqTensor()->copyFrom(out.ZynqTensor());
out.ZynqTensor()->saveToFile("nms_oo", true);
} }
outs->Resize({static_cast<int64_t>(e - s), out_dim}); outs->Resize({static_cast<int64_t>(e - s), out_dim});
} }
...@@ -402,16 +418,16 @@ void MulticlassNmsCompute::Run() { ...@@ -402,16 +418,16 @@ void MulticlassNmsCompute::Run() {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL(multiclass_nms, // REGISTER_LITE_KERNEL(multiclass_nms,
kFPGA, // kFPGA,
kFP16, // kFP16,
kNHWC, // kNHWC,
paddle::lite::kernels::fpga::MulticlassNmsCompute, // paddle::lite::kernels::fpga::MulticlassNmsCompute,
def) // def)
.BindInput("BBoxes", {LiteType::GetTensorTy(TARGET(kHost))}) // .BindInput("BBoxes", {LiteType::GetTensorTy(TARGET(kHost))})
.BindInput("Scores", {LiteType::GetTensorTy(TARGET(kHost))}) // .BindInput("Scores", {LiteType::GetTensorTy(TARGET(kHost))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))}) // .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))})
.Finalize(); // .Finalize();
REGISTER_LITE_KERNEL(multiclass_nms, REGISTER_LITE_KERNEL(multiclass_nms,
kFPGA, kFPGA,
...@@ -427,5 +443,8 @@ REGISTER_LITE_KERNEL(multiclass_nms, ...@@ -427,5 +443,8 @@ REGISTER_LITE_KERNEL(multiclass_nms,
{LiteType::GetTensorTy(TARGET(kFPGA), {LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16), PRECISION(kFP16),
DATALAYOUT(kNHWC))}) DATALAYOUT(kNHWC))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))}) .BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.Finalize(); .Finalize();
...@@ -131,3 +131,27 @@ REGISTER_LITE_KERNEL(prior_box, ...@@ -131,3 +131,27 @@ REGISTER_LITE_KERNEL(prior_box,
.BindOutput("Boxes", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Boxes", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Variances", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Variances", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
// REGISTER_LITE_KERNEL(prior_box,
// kFPGA,
// kFP16,
// kNHWC,
// paddle::lite::kernels::fpga::PriorBoxCompute,
// def)
// .BindInput("Input",
// {LiteType::GetTensorTy(TARGET(kFPGA),
// PRECISION(kFP16),
// DATALAYOUT(kNHWC))})
// .BindInput("Image",
// {LiteType::GetTensorTy(TARGET(kFPGA),
// PRECISION(kFP16),
// DATALAYOUT(kNHWC))})
// .BindOutput("Boxes",
// {LiteType::GetTensorTy(TARGET(kFPGA),
// PRECISION(kFP16),
// DATALAYOUT(kNHWC))})
// .BindOutput("Variances",
// {LiteType::GetTensorTy(TARGET(kFPGA),
// PRECISION(kFP16),
// DATALAYOUT(kNHWC))})
// .Finalize();
...@@ -38,18 +38,15 @@ void ReshapeCompute::Run() { ...@@ -38,18 +38,15 @@ void ReshapeCompute::Run() {
auto* actual_shape_data = actual_shape->data<int>(); auto* actual_shape_data = actual_shape->data<int>();
auto shape = std::vector<int>( auto shape = std::vector<int>(
actual_shape_data, actual_shape_data + actual_shape_dims.production()); actual_shape_data, actual_shape_data + actual_shape_dims.production());
output_dims = lite::operators::ValidateShape(shape, x_dims); // output_dims = lite::operators::ValidateShape(shape, x_dims); //TODO
output->Resize(output_dims); output->Resize(output_dims);
} }
if (inplace) { // if (inplace) {
output->ShareDataWith(*x); // output->ShareDataWith(*x);
} else { // } else {
output->CopyDataFrom(*x); // output->CopyDataFrom(*x);
} // }
output->ZynqTensor()->copyFrom(x->ZynqTensor());
param.x->ZynqTensor()->saveToFile("reshape_in", true);
output->ZynqTensor()->saveToFile("reshape_out", true);
output->Resize(output_dims); output->Resize(output_dims);
} }
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/fpga/split_compute.h"
#include <vector>
#include "lite/backends/arm/math/funcs.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace fpga {
void SplitCompute::PrepareForRun() {
auto& param = Param<operators::SplitParam>();
zynqmp::SplitParam& split_param = pe_.param();
split_param.input = param.x->ZynqTensor();
auto& dout = param.output;
for (int i = 0; i < dout.size(); i++) {
dout[i]->mutable_data<zynqmp::float16>();
split_param.outputs.push_back(dout[i]->ZynqTensor());
}
pe_.init();
pe_.apply();
}
void SplitCompute::Run() {
zynqmp::SplitParam& split_param = pe_.param();
pe_.dispatch();
#ifdef FPGA_PRINT_TENSOR
auto& dout = param.output;
for (int i = 0; i < dout.size(); i++) {
Debugger::get_instance().registerOutput("split", split_param.outputs[0]);
}
#endif
}
} // namespace fpga
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
split, kFPGA, kFP16, kNHWC, paddle::lite::kernels::fpga::SplitCompute, def)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.BindInput("AxisTensor",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("SectionsTensorList",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/backends/fpga/KD/float16.hpp"
#include "lite/backends/fpga/KD/pes/split_pe.hpp"
namespace paddle {
namespace lite {
namespace kernels {
namespace fpga {
class SplitCompute
: public KernelLite<TARGET(kFPGA), PRECISION(kFP16), DATALAYOUT(kNHWC)> {
public:
void PrepareForRun() override;
void Run() override;
virtual ~SplitCompute() = default;
private:
zynqmp::SplitPE pe_;
};
} // namespace fpga
} // namespace kernels
} // namespace lite
} // namespace paddle
...@@ -81,7 +81,17 @@ void transposeCompute(operators::TransposeParam param) { ...@@ -81,7 +81,17 @@ void transposeCompute(operators::TransposeParam param) {
} }
// Transpose // Transpose
void TransposeCompute::Run() { auto& param = this->Param<param_t>(); } void TransposeCompute::Run() {
auto& param = this->Param<param_t>();
param.output->mutable_data<zynqmp::float16>();
param.x->ZynqTensor()->invalidate();
param.x->ZynqTensor()->unalignImage();
if (param.x->dims().size() != 4) {
transposeCompute(param);
} else {
param.output->ZynqTensor()->copyFrom(param.x->ZynqTensor());
}
}
// Transpose2 // Transpose2
void Transpose2Compute::Run() { void Transpose2Compute::Run() {
......
...@@ -4,6 +4,7 @@ add_kernel(feed_compute_host Host basic SRCS feed_compute.cc DEPS ${lite_kernel_ ...@@ -4,6 +4,7 @@ add_kernel(feed_compute_host Host basic SRCS feed_compute.cc DEPS ${lite_kernel_
add_kernel(fetch_compute_host Host basic SRCS fetch_compute.cc DEPS ${lite_kernel_deps}) add_kernel(fetch_compute_host Host basic SRCS fetch_compute.cc DEPS ${lite_kernel_deps})
add_kernel(reshape_compute_host Host basic SRCS reshape_compute.cc DEPS ${lite_kernel_deps} reshape_op) add_kernel(reshape_compute_host Host basic SRCS reshape_compute.cc DEPS ${lite_kernel_deps} reshape_op)
add_kernel(multiclass_nms_compute_host Host basic SRCS multiclass_nms_compute.cc DEPS ${lite_kernel_deps}) add_kernel(multiclass_nms_compute_host Host basic SRCS multiclass_nms_compute.cc DEPS ${lite_kernel_deps})
add_kernel(one_hot_compute_host Host extra SRCS one_hot_compute.cc DEPS ${lite_kernel_deps}) add_kernel(one_hot_compute_host Host extra SRCS one_hot_compute.cc DEPS ${lite_kernel_deps})
#lite_cc_test(test_reshape_compute_host SRCS reshape_compute_test.cc DEPS reshape_compute_host any) #lite_cc_test(test_reshape_compute_host SRCS reshape_compute_test.cc DEPS reshape_compute_host any)
......
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册