未验证 提交 77fe9936 编写于 作者: S suiyang 提交者: GitHub

Merge branch 'develop' into dev-latest

......@@ -83,7 +83,7 @@ Paddle-Mobile是PaddlePaddle组织下的项目,是一个致力于嵌入式平
- **FPGA**
FPGA实现正在进行中,是基于Xilinx的ZU5目标开发板。
目前已经支持 ZCU102 开发板。
- **灵活性**
......@@ -112,6 +112,7 @@ Paddle-Mobile是PaddlePaddle组织下的项目,是一个致力于嵌入式平
开发文档主要是关于编译、运行等问题。做为开发者,它可以和贡献文档共同结合使用。
* [iOS](https://github.com/PaddlePaddle/paddle-mobile/blob/develop/doc/development_ios.md)
* [Android](https://github.com/PaddlePaddle/paddle-mobile/blob/develop/doc/development_android.md)
* [FPGA](https://github.com/PaddlePaddle/paddle-mobile/blob/develop/doc/development_fpga.md)
### 贡献文档
- [贡献文档链接](https://github.com/PaddlePaddle/paddle-mobile/blob/develop/CONTRIBUTING.md)
......
......@@ -22,6 +22,7 @@ const char *G_OP_TYPE_BATCHNORM = "batch_norm";
const char *G_OP_TYPE_BOX_CODER = "box_coder";
const char *G_OP_TYPE_CONCAT = "concat";
const char *G_OP_TYPE_ELEMENTWISE_ADD = "elementwise_add";
const char *G_OP_TYPE_FILL_CONSTANT = "fill_constant";
const char *G_OP_TYPE_FUSION_CONV_ADD_RELU = "fusion_conv_add_relu";
const char *G_OP_TYPE_FUSION_CONV_ADD_PRELU = "fusion_conv_add_prelu";
const char *G_OP_TYPE_FUSION_CONV_ADD_ADD_PRELU = "fusion_conv_add_add_prelu";
......@@ -99,6 +100,7 @@ std::unordered_map<
{G_OP_TYPE_FC, {{"X", "Y", "Z"}, {"Out"}}},
{G_OP_TYPE_RESHAPE, {{"X"}, {"Out"}}},
{G_OP_TYPE_DEPTHWISE_CONV, {{"Input"}, {"Output"}}},
{G_OP_TYPE_FILL_CONSTANT, {{}, {"Out"}}},
{G_OP_TYPE_FUSION_CONV_ADD_RELU, {{"Input"}, {"Out"}}},
{G_OP_TYPE_FUSION_CONV_ADD_PRELU, {{"Input"}, {"Out"}}},
{G_OP_TYPE_FUSION_CONV_ADD_ADD_PRELU, {{"Input"}, {"Out"}}},
......
......@@ -22,7 +22,7 @@ limitations under the License. */
#include "fpga/filter.h"
#include "fpga/image.h"
#define FPGA_TEST_MODE
// #define PADDLE_MOBILE_OS_LINUX
#define PADDLE_MOBILE_OS_LINUX
namespace paddle_mobile {
namespace fpga {
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "framework/data_type.h"
#include <stdint.h>
#include <string>
#include <unordered_map>
namespace paddle_mobile {
namespace framework {
struct DataTypeMap {
std::unordered_map<std::type_index,
_PaddleMobile__Framework__Proto__VarType__Type>
cpp_to_proto_;
std::unordered_map<int, std::type_index> proto_to_cpp_;
std::unordered_map<int, std::string> proto_to_str_;
std::unordered_map<std::type_index, size_t> cpp_to_size_;
};
static DataTypeMap* InitDataTypeMap();
// C++11 removes the need for manual locking. Concurrent execution shall wait if
// a static local variable is already being initialized.
// https://stackoverflow.com/questions/11711920/how-to-implement-multithread-safe-singleton-in-c11-without-using-mutex
static DataTypeMap& gDataTypeMap() {
static DataTypeMap* g_data_type_map_ = InitDataTypeMap();
return *g_data_type_map_;
}
template <typename T>
static inline void RegisterType(
DataTypeMap* map, _PaddleMobile__Framework__Proto__VarType__Type proto_type,
const std::string& name) {
map->proto_to_cpp_.emplace(static_cast<int>(proto_type), typeid(T));
map->cpp_to_proto_.emplace(typeid(T), proto_type);
map->proto_to_str_.emplace(static_cast<int>(proto_type), name);
map->cpp_to_size_.emplace(typeid(T), sizeof(T));
}
static DataTypeMap* InitDataTypeMap() {
auto retv = new DataTypeMap();
#define RegType(cc_type, proto_type) \
RegisterType<cc_type>(retv, proto_type, #cc_type)
// NOTE: Add your customize type here.
// RegType(float16, PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__FP16);
RegType(float, PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__FP32);
RegType(double, PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__FP64);
RegType(int, PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__INT32);
RegType(int64_t, PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__INT64);
RegType(bool, PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__BOOL);
RegType(size_t, PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__SIZE_T);
RegType(int16_t, PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__INT16);
RegType(uint8_t, PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__UINT8);
RegType(int8_t, PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__INT8);
#undef RegType
return retv;
}
_PaddleMobile__Framework__Proto__VarType__Type ToDataType(
std::type_index type) {
auto it = gDataTypeMap().cpp_to_proto_.find(type);
if (it != gDataTypeMap().cpp_to_proto_.end()) {
return it->second;
}
PADDLE_MOBILE_THROW_EXCEPTION("Not support %s as tensor type", type.name());
}
std::type_index ToTypeIndex(
_PaddleMobile__Framework__Proto__VarType__Type type) {
auto it = gDataTypeMap().proto_to_cpp_.find(static_cast<int>(type));
if (it != gDataTypeMap().proto_to_cpp_.end()) {
return it->second;
}
PADDLE_MOBILE_THROW_EXCEPTION(
"Not support _PaddleMobile__Framework__Proto__VarType__Type(%d) as "
"tensor type",
static_cast<int>(type));
}
std::string DataTypeToString(
const _PaddleMobile__Framework__Proto__VarType__Type type) {
auto it = gDataTypeMap().proto_to_str_.find(static_cast<int>(type));
if (it != gDataTypeMap().proto_to_str_.end()) {
return it->second;
}
PADDLE_MOBILE_THROW_EXCEPTION(
"Not support _PaddleMobile__Framework__Proto__VarType__Type(%d) as "
"tensor type",
static_cast<int>(type));
}
} // namespace framework
} // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <typeindex>
#include "common/enforce.h"
#include "framework/framework.pb-c.h"
namespace paddle_mobile {
namespace framework {
extern _PaddleMobile__Framework__Proto__VarType__Type ToDataType(
std::type_index type);
extern std::type_index ToTypeIndex(
_PaddleMobile__Framework__Proto__VarType__Type type);
template <typename Visitor>
inline void VisitDataType(_PaddleMobile__Framework__Proto__VarType__Type type,
Visitor visitor) {
switch (type) {
// case PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__FP16:
// visitor.template apply<float16>();
// break;
case PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__FP32:
visitor.template apply<float>();
break;
case PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__FP64:
visitor.template apply<double>();
break;
case PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__INT32:
visitor.template apply<int>();
break;
case PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__INT64:
visitor.template apply<int64_t>();
break;
case PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__BOOL:
visitor.template apply<bool>();
break;
case PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__UINT8:
visitor.template apply<uint8_t>();
break;
case PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__INT16:
visitor.template apply<int16_t>();
break;
case PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__INT8:
visitor.template apply<int8_t>();
break;
default:
PADDLE_MOBILE_THROW_EXCEPTION("Not supported %d", type);
}
}
extern std::string DataTypeToString(
const _PaddleMobile__Framework__Proto__VarType__Type type);
inline std::ostream& operator<<(
std::ostream& out,
const _PaddleMobile__Framework__Proto__VarType__Type& type) {
out << DataTypeToString(type);
return out;
}
} // namespace framework
} // namespace paddle_mobile
......@@ -64,6 +64,9 @@ limitations under the License. */
// load requared ops
LOAD_OP(feed)
LOAD_OP(fetch)
#ifdef FILL_CONSTANT_OP
LOAD_OP(fill_constant)
#endif
#ifdef BATCHNORM_OP
LOAD_OP2(batch_norm, CPU, MALI_GPU);
#endif
......
......@@ -29,7 +29,14 @@ PaddleMobilePredictor<Dtype, P>::PaddleMobilePredictor(
template <typename Dtype, Precision P>
bool PaddleMobilePredictor<Dtype, P>::Init(const PaddleMobileConfig &config) {
paddle_mobile_.reset(new PaddleMobile<Dtype, P>());
if (!config.model_dir.empty()) {
if (config.memory_pack.from_memory) {
DLOG << "load from memory!";
paddle_mobile_->LoadCombinedMemory(config.memory_pack.model_size,
config.memory_pack.model_buf,
config.memory_pack.combined_params_size,
config.memory_pack.combined_params_buf);
} else if (!config.model_dir.empty()) {
paddle_mobile_->Load(config.model_dir, config.optimize,
config.quantification, config.batch_size);
} else if (!config.prog_file.empty() && !config.param_file.empty()) {
......
......@@ -111,6 +111,14 @@ class PaddlePredictor {
PaddlePredictor() = default;
};
struct PaddleModelMemoryPack {
bool from_memory = false;
size_t model_size = 0;
uint8_t* model_buf = nullptr;
size_t combined_params_size = 0;
uint8_t* combined_params_buf = nullptr;
};
struct PaddleMobileConfig : public PaddlePredictor::Config {
enum Precision { FP32 = 0 };
enum Device { kCPU = 0, kFPGA = 1, kGPU_MALI = 2 };
......@@ -124,6 +132,7 @@ struct PaddleMobileConfig : public PaddlePredictor::Config {
int thread_num = 1;
std::string prog_file;
std::string param_file;
struct PaddleModelMemoryPack memory_pack;
};
// A factory to help create different predictors.
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FILL_CONSTANT_OP
#include "operators/fill_constant_op.h"
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(fill_constant, ops::FillConstantOp);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
REGISTER_OPERATOR_MALI_GPU(fill_constant, ops::FillConstantOp);
#endif
#ifdef PADDLE_MOBILE_FPGA
REGISTER_OPERATOR_FPGA(fill_constant, ops::FillConstantOp);
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FILL_CONSTANT_OP
#pragma once
#include <string>
#include "framework/data_type.h"
#include "framework/operator.h"
#include "framework/selected_rows.h"
#include "operators/math/math_function.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using std::string;
template <typename DeviceType, typename T>
class FillConstantOp : public framework::OperatorBase<DeviceType> {
public:
FillConstantOp(const string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorBase<DeviceType>(type, inputs, outputs, attrs,
scope),
param_(inputs, outputs, attrs, *scope) {}
void RunImpl() const {
auto data_type =
static_cast<_PaddleMobile__Framework__Proto__VarType__Type>(
param_.DataDtype());
framework::Tensor *tensor = nullptr;
auto value = param_.Value();
auto *outvar = param_.OutVar();
if (outvar->template IsType<framework::LoDTensor>()) {
tensor = outvar->template GetMutable<framework::LoDTensor>();
} else if (outvar->template IsType<framework::SelectedRows>()) {
tensor = outvar->template GetMutable<framework::SelectedRows>()
->mutable_value();
} else {
PADDLE_MOBILE_THROW_EXCEPTION(
"fill constant op's output only"
"supports SelectedRows and LoDTensor");
}
tensor->Resize(framework::make_ddim(param_.Shape()));
tensor->mutable_data(framework::ToTypeIndex(data_type));
math::set_constant(tensor, value);
}
void Init() {}
void InferShape() const {
PADDLE_MOBILE_ENFORCE(
param_.Out() != nullptr,
"Output (Out) of fill_constant op should not be null.");
framework::DDim ddim = framework::make_ddim(param_.Shape());
param_.Out()->Resize(ddim);
}
protected:
FillConstantParam<DeviceType> param_;
};
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -15,12 +15,31 @@ limitations under the License. */
#include "operators/math/math_function.h"
#include <cstring>
#include <string>
#include "framework/data_type.h"
#include "framework/tensor.h"
#include "operators/math/gemm.h"
namespace paddle_mobile {
namespace operators {
namespace math {
struct TensorSetConstant {
TensorSetConstant(framework::Tensor *tensor, float value)
: tensor_(tensor), value_(value) {}
template <typename T>
void apply() const {
auto *begin = tensor_->mutable_data<T>();
std::fill(begin, begin + tensor_->numel(), static_cast<T>(value_));
}
framework::Tensor *tensor_;
float value_;
};
void set_constant(framework::Tensor *tensor, float value) {
framework::VisitDataType(framework::ToDataType(tensor->type()),
TensorSetConstant(tensor, value));
}
template <>
void matmul<float>(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, float alpha,
......
......@@ -22,6 +22,8 @@ namespace paddle_mobile {
namespace operators {
namespace math {
void set_constant(framework::Tensor *tensor, float value);
template <typename T>
void matmul(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, T alpha,
......
......@@ -1063,6 +1063,42 @@ class FetchParam : public OpParam {
RType *out_;
};
#ifdef FILL_CONSTANT_OP
template <typename Dtype>
class FillConstantParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
FillConstantParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs,
const Scope &scope) {
out_var_ = OutVarFrom(outputs, scope);
out_ = OutFrom<GType>(outputs, scope);
dtype_ = GetAttr<int>("dtype", attrs);
shape_ = GetAttr<vector<int>>("shape", attrs);
value_ = GetAttr<float>("value", attrs);
}
Variable *OutVar() const { return out_var_; }
RType *Out() const { return out_; }
const int &DataDtype() const { return dtype_; }
const vector<int> &Shape() const { return shape_; }
const float &Value() const { return value_; }
private:
Variable *out_var_;
RType *out_;
int dtype_;
vector<int> shape_;
float value_;
};
#endif
#ifdef TRANSPOSE_OP
template <typename Dtype>
class TransposeParam : public OpParam {
......
......@@ -12,6 +12,9 @@ if (CON GREATER -1)
ADD_EXECUTABLE(test-googlenet net/test_googlenet.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-googlenet paddle-mobile)
# gen test
ADD_EXECUTABLE(test-googlenet-quali net/test_googlenet_quali.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-googlenet-quali paddle-mobile)
set(FOUND_MATCH ON)
endif ()
......@@ -133,6 +136,10 @@ if (NOT FOUND_MATCH)
ADD_EXECUTABLE(test-googlenet net/test_googlenet.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-googlenet paddle-mobile)
# gen test
ADD_EXECUTABLE(test-googlenet-quali net/test_googlenet_quali.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-googlenet-quali paddle-mobile)
# gen test
ADD_EXECUTABLE(test-conv-op operators/test_cov_op.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-conv-op paddle-mobile)
......@@ -185,6 +192,10 @@ if (NOT FOUND_MATCH)
ADD_EXECUTABLE(test-polygon-box-transform-op operators/test_polygon_box_transform_op.cpp test_helper.h test_include.h)
target_link_libraries(test-polygon-box-transform-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-fill-constant-op operators/test_fill_constant_op.cpp test_helper.h test_include.h)
target_link_libraries(test-fill-constant-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-reshape-op operators/test_reshape_op.cpp test_helper.h test_include.h)
target_link_libraries(test-reshape-op paddle-mobile)
......@@ -225,6 +236,10 @@ if (NOT FOUND_MATCH)
ADD_EXECUTABLE(test-loadmemory framework/test_load_memory.cpp)
target_link_libraries(test-loadmemory paddle-mobile)
# gen test log
ADD_EXECUTABLE(test-loadmemory-inference framework/test_load_memory_inference_api.cpp)
target_link_libraries(test-loadmemory-inference paddle-mobile)
ADD_EXECUTABLE(test-inference-api framework/test_inference_api.cpp)
target_link_libraries(test-inference-api paddle-mobile)
......
......@@ -11,34 +11,107 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <fstream>
#include "../test_include.h"
static const char *g_resnet_combine = "../models/resnet50";
#include "fpga/api.h"
void readStream(std::string filename, float *buf) {
std::ifstream in;
in.open(filename, std::ios::in);
if (!in.is_open()) {
std::cout << "open File Failed." << std::endl;
return;
}
string strOne;
int i = 0;
while (!in.eof()) {
in >> buf[i];
i++;
}
in.close();
}
void convert_to_chw(int16_t **data_in, int channel, int height, int width,
int16_t *data_tmp) {
int64_t amount_per_side = width * height;
for (int h = 0; h < height; h++) {
for (int w = 0; w < width; w++) {
for (int c = 0; c < channel; c++) {
*(data_tmp + c * amount_per_side + width * h + w) = *((*data_in)++);
}
}
}
}
void dump(std::string filename, const Tensor input_tensor) {
auto dataptr = input_tensor.data<float>();
std::ofstream out(filename.c_str());
float result = 0;
for (int i = 0; i < input_tensor.numel(); ++i) {
result = paddle_mobile::fpga::fp16_2_fp32(dataptr[i]);
out << result << std::endl;
}
out.close();
}
void dump_stride(std::string filename, const Tensor input_tensor,
const int dumpnum) {
int c = (input_tensor.dims())[1];
int h = (input_tensor.dims())[2];
int w = (input_tensor.dims())[3];
auto data_ptr = input_tensor.data<float>();
int16_t *data_tmp = (int16_t *)malloc(c * h * w * sizeof(int16_t));
int16_t *data_ptr_16 = (int16_t *)data_ptr;
convert_to_chw(&data_ptr_16, c, h, w, data_tmp);
// const int16_t *dataptr = input_tensor.data<int16_t>();
std::ofstream out(filename.c_str());
float result = 0;
int stride = input_tensor.numel() / dumpnum;
stride = stride > 0 ? stride : 1;
for (int i = 0; i < input_tensor.numel(); i += stride) {
result = paddle_mobile::fpga::fp16_2_fp32(data_tmp[i]);
out << result << std::endl;
}
out.close();
free(data_tmp);
}
static const char *g_resnet50 = "../models/resnet50";
const std::string g_image_src_float = "../images/image_src_float";
int main() {
DLOG << paddle_mobile::fpga::open_device();
paddle_mobile::fpga::open_device();
paddle_mobile::PaddleMobile<paddle_mobile::FPGA> paddle_mobile;
// if (paddle_mobile.Load(std::string(g_resnet_combine) + "/model",
// std::string(g_resnet_combine) + "/params", true)) {
if (paddle_mobile.Load(std::string(g_resnet_combine), true)) {
std::vector<int64_t> dims{1, 3, 224, 224};
if (paddle_mobile.Load(std::string(g_resnet50), true)) {
Tensor input_tensor;
SetupTensor<float>(&input_tensor, {1, 3, 224, 224}, static_cast<float>(0),
static_cast<float>(1));
std::vector<float> input(input_tensor.data<float>(),
input_tensor.data<float>() + input_tensor.numel());
readStream(g_image_src_float,
input_tensor.mutable_data<float>({1, 3, 224, 224}));
paddle_mobile.FeedData(input_tensor);
for (int i = 0; i < 1000; i++) {
paddle_mobile.Predict_To(-1);
if (i % 100 == 0) std::cout << i << std::endl;
}
// paddle_mobile.Predict_From(73);
// paddle_mobile.Predict_From_To(72, 73);
/*for(int i = 0; i < 73; i++)
{
auto tensor_ptr = paddle_mobile.FetchResult(i);
std::string saveName = "resnet50_result_" + std::to_string (i);
paddle_mobile::fpga::fpga_invalidate((*tensor_ptr).data<float>(),
tensor_ptr->numel()); dump_stride(saveName, (*tensor_ptr), 20);
//dump(saveName, (*tensor_ptr));
}*/
DLOG << "Computation done";
/*std::shared_ptr<Tensor> output_tensor = paddle_mobile.FetchResult(73);
(*output_tensor).dump<float>("resnet50_result_73");
output_tensor = paddle_mobile.FetchResult(74);
(*output_tensor).dump<float>("resnet50_result_74");*/
std::shared_ptr<Tensor> output_tensor = paddle_mobile.FetchResult(74);
float max = 0;
auto data_ptr = output_tensor->data<float>();
int maximumIdx = 0;
for (int i = 0; i < (*output_tensor).numel(); i++) {
if (data_ptr[i] > max) {
maximumIdx = i;
max = data_ptr[i];
}
}
std::cout << "index : " << maximumIdx << ", value : " << max
<< std::endl;
std::cout << "Computation done" << std::endl;
return 0;
}
}
......@@ -58,9 +58,9 @@ int main() {
size_t sizeBuf = ReadBuffer(model_path.c_str(), &bufModel);
uint8_t *bufParams = nullptr;
DLOG << "sizeBuf: " << sizeBuf;
std::cout << "sizeBuf: " << sizeBuf << std::endl;
size_t sizeParams = ReadBuffer(params_path.c_str(), &bufParams);
DLOG << "sizeParams: " << sizeParams;
std::cout << "sizeParams: " << sizeParams << std::endl;
paddle_mobile.LoadCombinedMemory(sizeBuf, bufModel, sizeParams, bufParams);
return 0;
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include <iostream>
#include "../test_helper.h"
#include "io/paddle_inference_api.h"
static size_t ReadBuffer(const char *file_name, uint8_t **out) {
FILE *fp;
fp = fopen(file_name, "rb");
PADDLE_MOBILE_ENFORCE(fp != nullptr, " %s open failed !", file_name);
fseek(fp, 0, SEEK_END);
auto size = static_cast<size_t>(ftell(fp));
rewind(fp);
DLOG << "model size: " << size;
*out = reinterpret_cast<uint8_t *>(malloc(size));
size_t cur_len = 0;
size_t nread;
while ((nread = fread(*out + cur_len, 1, size - cur_len, fp)) != 0) {
cur_len += nread;
}
fclose(fp);
return cur_len;
}
static char *Get_binary_data(std::string filename) {
FILE *file = fopen(filename.c_str(), "rb");
PADDLE_MOBILE_ENFORCE(file != nullptr, "can't open file: %s ",
filename.c_str());
fseek(file, 0, SEEK_END);
int64_t size = ftell(file);
PADDLE_MOBILE_ENFORCE(size > 0, "size is too small");
rewind(file);
auto *data = new char[size];
size_t bytes_read = fread(data, 1, size, file);
PADDLE_MOBILE_ENFORCE(bytes_read == size,
"read binary file bytes do not match with fseek");
fclose(file);
return data;
}
paddle_mobile::PaddleMobileConfig GetConfig() {
paddle_mobile::PaddleMobileConfig config;
config.precision = paddle_mobile::PaddleMobileConfig::FP32;
config.device = paddle_mobile::PaddleMobileConfig::kCPU;
const std::shared_ptr<paddle_mobile::PaddleModelMemoryPack> &memory_pack =
std::make_shared<paddle_mobile::PaddleModelMemoryPack>();
auto model_path = std::string(g_genet_combine) + "/model";
auto params_path = std::string(g_genet_combine) + "/params";
memory_pack->model_size =
ReadBuffer(model_path.c_str(), &memory_pack->model_buf);
std::cout << "sizeBuf: " << memory_pack->model_size << std::endl;
memory_pack->combined_params_size =
ReadBuffer(params_path.c_str(), &memory_pack->combined_params_buf);
std::cout << "sizeParams: " << memory_pack->combined_params_size << std::endl;
memory_pack->from_memory = true;
config.memory_pack = *memory_pack;
config.thread_num = 4;
return config;
}
int main() {
paddle_mobile::PaddleMobileConfig config = GetConfig();
auto predictor = paddle_mobile::CreatePaddlePredictor<
paddle_mobile::PaddleMobileConfig,
paddle_mobile::PaddleEngineKind::kPaddleMobile>(config);
return 0;
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include "../test_helper.h"
#include "../test_include.h"
int main() {
#ifdef PADDLE_MOBILE_FPGA
paddle_mobile::PaddleMobile<paddle_mobile::FPGA> paddle_mobile;
#endif
#ifdef PADDLE_MOBILE_CPU
paddle_mobile::PaddleMobile<paddle_mobile::CPU> paddle_mobile;
#endif
paddle_mobile.SetThreadNum(4);
bool optimize = true;
bool quli = true;
auto time1 = time();
auto isok = paddle_mobile.Load(std::string(g_googlenet_quali) + "/model",
std::string(g_googlenet_quali) + "/params",
optimize, quli);
if (isok) {
auto time2 = time();
std::cout << "load cost :" << time_diff(time1, time2) << "ms" << std::endl;
std::vector<float> input;
std::vector<int64_t> dims{1, 3, 224, 224};
GetInput<float>(g_test_image_1x3x224x224, &input, dims);
// 预热十次
for (int i = 0; i < 10; ++i) {
auto vec_result = paddle_mobile.Predict(input, dims);
}
auto time3 = time();
for (int i = 0; i < 10; ++i) {
auto vec_result = paddle_mobile.Predict(input, dims);
}
auto time4 = time();
std::cout << "predict cost :" << time_diff(time3, time4) / 10 << "ms"
<< std::endl;
}
return 0;
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "../test_include.h"
#include "operators/fill_constant_op.h"
namespace paddle_mobile {
namespace framework {
template <typename Dtype>
class TestFillConstantOp {
public:
explicit TestFillConstantOp(const Program<Dtype> p) : program_(p) {
if (use_optimize_) {
to_predict_program_ = program_.optimizeProgram;
} else {
to_predict_program_ = program_.originProgram;
}
const std::vector<std::shared_ptr<BlockDesc>> blocks =
to_predict_program_->Blocks();
for (auto block_desc : blocks) {
std::vector<std::shared_ptr<OpDesc>> ops = block_desc->Ops();
for (auto op : ops) {
if (op->Type() == "fill_constant") {
DLOG << " attr size: " << op->GetAttrMap().size();
std::unordered_map<std::string, Attribute> attrs = op->GetAttrMap();
for (std::unordered_map<std::string, Attribute>::iterator it =
attrs.begin();
it != attrs.end(); ++it) {
DLOG << " " << it->first << " " << it->second;
}
DLOG << " inputs size: " << op->GetInputs().size();
DLOG << " outputs size: " << op->GetOutputs().size();
DLOG << " output is : " << op->Output("Out")[0];
output_var_name = op->Output("Out")[0];
std::shared_ptr<operators::FillConstantOp<Dtype, float>> op_ptr =
std::make_shared<operators::FillConstantOp<Dtype, float>>(
op->Type(), op->GetInputs(), op->GetOutputs(),
op->GetAttrMap(), program_.scope);
ops_of_block_[*block_desc.get()].push_back(op_ptr);
}
}
}
}
std::shared_ptr<Tensor> predict() {
auto scope = program_.scope;
Variable *output = scope->Var(output_var_name);
auto *output_tensor = output->GetMutable<LoDTensor>();
std::shared_ptr<Tensor> out_tensor = std::make_shared<LoDTensor>();
out_tensor.reset(output_tensor);
predict(0);
return out_tensor;
}
private:
const framework::Program<Dtype> program_;
std::shared_ptr<ProgramDesc> to_predict_program_;
std::map<framework::BlockDesc,
std::vector<std::shared_ptr<OperatorBase<Dtype>>>>
ops_of_block_;
bool use_optimize_ = false;
string output_var_name;
void predict(int block_id) {
std::shared_ptr<BlockDesc> to_predict_block =
to_predict_program_->Block(block_id);
for (int j = 0; j < ops_of_block_[*to_predict_block.get()].size(); ++j) {
auto op = ops_of_block_[*to_predict_block.get()][j];
op->Run();
}
}
};
template class TestFillConstantOp<CPU>;
} // namespace framework
} // namespace paddle_mobile
int main() {
DLOG << "----------**********----------";
DLOG << "begin to run FillConstant Test";
paddle_mobile::Loader<paddle_mobile::CPU> loader;
auto program = loader.Load(std::string(g_ocr) + "/model",
std::string(g_ocr) + "/params");
paddle_mobile::framework::TestFillConstantOp<paddle_mobile::CPU>
testFillConstantOp(program);
auto output = testFillConstantOp.predict();
auto *output_ptr = output->data<float>();
DLOG << "output : ";
for (int i = 0; i < output->numel(); ++i) {
DLOG << " index " << i << " : " << output_ptr[i];
}
return 0;
}
......@@ -34,6 +34,7 @@ static const char *g_googlenetv1_combined = "../models/googlenetv1_combine";
static const char *g_mobilenet_detect = "../models/mobilenet-detect";
static const char *g_squeezenet = "../models/squeezenet";
static const char *g_googlenet = "../models/googlenet";
static const char *g_googlenet_quali = "../models/googlenet_combine_quali";
static const char *g_mobilenet = "../models/mobilenet";
static const char *g_alexnet = "../models/alexnet";
static const char *g_inceptionv4 = "../models/inceptionv4";
......
......@@ -188,6 +188,7 @@ if(NOT FOUND_MATCH)
set(ELEMENTWISEADD_OP ON)
set(ELEMENTWISESUB_OP ON)
set(IM2SEQUENCE_OP ON)
set(FILL_CONSTANT_OP ON)
set(FUSION_CONVADD_OP ON)
set(FUSION_CONVADDPRELU_OP ON)
set(FUSION_CONVADDRELU_OP ON)
......@@ -233,6 +234,7 @@ endif()
# option(CONV_OP "" ON)
# option(DEPTHWISECONV_OP "" ON)
# option(ELEMENTWISEADD_OP "" ON)
# option(FILL_CONSTANT_OP "" ON)
# option(FUSION_CONVADD_OP "" ON)
# option(FUSION_CONVADDRELU_OP "" ON)
# option(FUSION_FC_OP "" ON)
......@@ -270,6 +272,9 @@ endif()
if (ELEMENTWISESUB_OP)
add_definitions(-DELEMENTWISESUB_OP)
endif()
if (FILL_CONSTANT_OP)
add_definitions(-DFILL_CONSTANT_OP)
endif()
if (FUSION_CONVADD_OP)
add_definitions(-DFUSION_CONVADD_OP)
endif()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册