提交 b9ce0794 编写于 作者: qnqinan's avatar qnqinan 提交者: jameswu2014

update io file to surport int8 tensor feed and fetch fixed#1602 (#1603)

* update concat and split kernel and related files in FPGA v2(v3) track

* update

* update

* update kernel and related files in FPGA v2 track

* update

* update

* update kernel and related files for static quantization in FPGA v2 track

* update

* update feed and fetch kernel in FPGA v2 track

* update io file
上级 ae204098
...@@ -32,7 +32,7 @@ void format_image(framework::Tensor *image_tensor) { ...@@ -32,7 +32,7 @@ void format_image(framework::Tensor *image_tensor) {
int8_t *p_data = external_ptr == nullptr ? data_ptr : external_ptr; int8_t *p_data = external_ptr == nullptr ? data_ptr : external_ptr;
image::format_image<int8_t>(&p_data, channel, height, width); image::format_image<int8_t>(&p_data, channel, height, width);
if (p_data != data_ptr && external_ptr == nullptr) { if (p_data != data_ptr) {
image_tensor->reset_data_ptr(p_data); image_tensor->reset_data_ptr(p_data);
} }
} }
...@@ -43,7 +43,6 @@ void format_ofm(framework::Tensor *ofm_tensor) { ...@@ -43,7 +43,6 @@ void format_ofm(framework::Tensor *ofm_tensor) {
} else { } else {
format_int8_ofm(ofm_tensor); format_int8_ofm(ofm_tensor);
} }
format_int8_ofm(ofm_tensor);
} }
void format_int8_ofm(framework::Tensor *ofm_tensor) { void format_int8_ofm(framework::Tensor *ofm_tensor) {
......
...@@ -131,9 +131,12 @@ void ConvertTensors(const framework::Tensor &src, PaddleTensor *des) { ...@@ -131,9 +131,12 @@ void ConvertTensors(const framework::Tensor &src, PaddleTensor *des) {
if (src.type() == type_id<float>()) { if (src.type() == type_id<float>()) {
des->data.Reset(const_cast<float *>(src.data<float>()), des->data.Reset(const_cast<float *>(src.data<float>()),
num * sizeof(float)); num * sizeof(float));
} else { } else if (src.type() == type_id<half>()) {
des->data.Reset(const_cast<int16_t *>(src.data<int16_t>()), des->data.Reset(const_cast<int16_t *>(src.data<int16_t>()),
num * sizeof(int16_t)); num * sizeof(int16_t));
} else {
des->data.Reset(const_cast<int8_t *>(src.data<int8_t>()),
num * sizeof(int8_t));
} }
} }
...@@ -143,7 +146,11 @@ void PaddleMobilePredictor<Device, T>::FeedPaddleTensors( ...@@ -143,7 +146,11 @@ void PaddleMobilePredictor<Device, T>::FeedPaddleTensors(
auto num = inputs.size(); auto num = inputs.size();
std::vector<framework::Tensor> tensors(num, framework::Tensor()); std::vector<framework::Tensor> tensors(num, framework::Tensor());
for (int i = 0; i < num; i++) { for (int i = 0; i < num; i++) {
tensors[i].init(type_id<float>().hash_code()); if (inputs[i].dtypeid == type_id<int8_t>().hash_code()) {
tensors[i].init(type_id<int8_t>().hash_code());
} else {
tensors[i].init(type_id<float>().hash_code());
}
ConvertPaddleTensors(inputs[i], &tensors[i]); ConvertPaddleTensors(inputs[i], &tensors[i]);
} }
paddle_mobile_->FeedTensorData(tensors); paddle_mobile_->FeedTensorData(tensors);
......
...@@ -44,6 +44,7 @@ void FeedKernel<FPGA, float>::Compute(const FeedParam<FPGA> &param) { ...@@ -44,6 +44,7 @@ void FeedKernel<FPGA, float>::Compute(const FeedParam<FPGA> &param) {
} }
fpga::format_image(input); fpga::format_image(input);
output->ShareDataWith(*input); output->ShareDataWith(*input);
input->external_data = nullptr;
} }
template class FeedKernel<FPGA, float>; template class FeedKernel<FPGA, float>;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册