未验证 提交 a2a5c8b0 编写于 作者: Y Yanzhan Yang 提交者: GitHub

refine paddle_inference_api.h test=develop (#2048)

上级 eb42f9ee
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "common/type_define.h"
#include "cstring" #include "cstring"
#include "io/paddle_inference_api.h" #include "io/paddle_inference_api.h"
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "common/enforce.h" #include "common/enforce.h"
#include "common/type_define.h"
#include "framework/tensor.h" #include "framework/tensor.h"
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
#include <fpga/common/fpga_common.h> #include <fpga/common/fpga_common.h>
...@@ -35,7 +36,9 @@ PaddleMobilePredictor<Device, T>::PaddleMobilePredictor( ...@@ -35,7 +36,9 @@ PaddleMobilePredictor<Device, T>::PaddleMobilePredictor(
template <typename Device, typename T> template <typename Device, typename T>
bool PaddleMobilePredictor<Device, T>::Init(const PaddleMobileConfig &config) { bool PaddleMobilePredictor<Device, T>::Init(const PaddleMobileConfig &config) {
paddle_mobile_.reset(new PaddleMobile<Device, T>()); PaddleMobileConfigInternal configInternal;
configInternal.load_when_predict = config.load_when_predict;
paddle_mobile_.reset(new PaddleMobile<Device, T>(configInternal));
#ifdef PADDLE_MOBILE_CL #ifdef PADDLE_MOBILE_CL
paddle_mobile_->SetCLPath(config.cl_path); paddle_mobile_->SetCLPath(config.cl_path);
#endif #endif
...@@ -135,14 +138,14 @@ bool PaddleMobilePredictor<Device, T>::Run( ...@@ -135,14 +138,14 @@ bool PaddleMobilePredictor<Device, T>::Run(
void ConvertPaddleTensors(const PaddleTensor &src, framework::Tensor *des) { void ConvertPaddleTensors(const PaddleTensor &src, framework::Tensor *des) {
des->Resize(framework::make_ddim(src.shape)); des->Resize(framework::make_ddim(src.shape));
des->external_data = src.data.data(); des->external_data = src.data.data();
des->set_type(src.dtypeid); des->set_type(static_cast<kTypeId_t>(static_cast<int>(src.dtypeid)));
des->layout = des->layout =
src.layout == LAYOUT_HWC ? framework::LAYOUT_HWC : framework::LAYOUT_CHW; src.layout == LAYOUT_HWC ? framework::LAYOUT_HWC : framework::LAYOUT_CHW;
} }
void ConvertTensors(const framework::Tensor &src, PaddleTensor *des) { void ConvertTensors(const framework::Tensor &src, PaddleTensor *des) {
des->shape = framework::vectorize2int(src.dims()); des->shape = framework::vectorize2int(src.dims());
des->dtypeid = src.type(); des->dtypeid = static_cast<PaddlekTypeId_t>(static_cast<int>(src.type()));
des->layout = src.layout == framework::LAYOUT_HWC ? LAYOUT_HWC : LAYOUT_CHW; des->layout = src.layout == framework::LAYOUT_HWC ? LAYOUT_HWC : LAYOUT_CHW;
auto num = src.numel(); auto num = src.numel();
...@@ -164,7 +167,8 @@ void PaddleMobilePredictor<Device, T>::FeedPaddleTensors( ...@@ -164,7 +167,8 @@ void PaddleMobilePredictor<Device, T>::FeedPaddleTensors(
auto num = inputs.size(); auto num = inputs.size();
std::vector<framework::Tensor> tensors(num, framework::Tensor()); std::vector<framework::Tensor> tensors(num, framework::Tensor());
for (int i = 0; i < num; i++) { for (int i = 0; i < num; i++) {
if (inputs[i].dtypeid == type_id<int8_t>().hash_code()) { if (static_cast<kTypeId_t>(static_cast<int>(inputs[i].dtypeid)) ==
type_id<int8_t>().hash_code()) {
tensors[i].init(type_id<int8_t>().hash_code()); tensors[i].init(type_id<int8_t>().hash_code());
} else { } else {
tensors[i].init(type_id<float>().hash_code()); tensors[i].init(type_id<float>().hash_code());
......
...@@ -25,7 +25,6 @@ limitations under the License. */ ...@@ -25,7 +25,6 @@ limitations under the License. */
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "common/type_define.h"
namespace paddle_mobile { namespace paddle_mobile {
...@@ -86,6 +85,56 @@ class PaddleBuf { ...@@ -86,6 +85,56 @@ class PaddleBuf {
bool memory_owned_{true}; bool memory_owned_{true};
}; };
typedef enum {
paddle_void = 0,
paddle_float,
paddle_int,
paddle_uint16_t,
paddle_double,
paddle_int64_t,
paddle_size_t,
paddle_int16_t,
paddle_int8_t,
paddle_uint8_t,
paddle_bool,
paddle_string,
paddle_floats = 100,
paddle_ints,
paddle_int64_ts,
paddle_size_ts,
paddle_bools,
paddle_strings,
paddle_const_float = 200,
paddle_const_int,
paddle_block = 300,
paddle_tensor,
paddle_lod_tensor,
paddle_blocks,
paddle_tensors,
paddle_lod_tensors,
paddle_p_block = 400,
paddle_p_tensor,
paddle_p_lod_tensor,
paddle_p_blocks,
paddle_p_tensors,
paddle_p_lod_tensors,
paddle_scopes = 500,
paddle_selected_rows,
paddle_dim0 = 600,
paddle_dim1,
paddle_dim2,
paddle_dim3,
paddle_dim4,
paddle_dim5,
paddle_dim6,
paddle_dim7,
paddle_dim8,
paddle_dim9,
#ifdef PADDLE_MOBILE_CL
paddle_cl_image,
#endif
} PaddlekTypeId_t;
struct PaddleTensor { struct PaddleTensor {
PaddleTensor() = default; PaddleTensor() = default;
std::string name; // variable name. std::string name; // variable name.
...@@ -93,7 +142,7 @@ struct PaddleTensor { ...@@ -93,7 +142,7 @@ struct PaddleTensor {
std::vector<int> lod; std::vector<int> lod;
PaddleBuf data; // blob of data. PaddleBuf data; // blob of data.
PaddleDType dtype; PaddleDType dtype;
kTypeId_t dtypeid; PaddlekTypeId_t dtypeid;
LayoutType layout; LayoutType layout;
}; };
...@@ -166,6 +215,7 @@ struct PaddleMobileConfig : public PaddlePredictor::Config { ...@@ -166,6 +215,7 @@ struct PaddleMobileConfig : public PaddlePredictor::Config {
bool quantification = false; bool quantification = false;
bool lod_mode = false; bool lod_mode = false;
int thread_num = 1; int thread_num = 1;
bool load_when_predict = false;
std::string cl_path; std::string cl_path;
struct PaddleModelMemoryPack memory_pack; struct PaddleModelMemoryPack memory_pack;
}; };
......
...@@ -61,7 +61,7 @@ build_for_android() { ...@@ -61,7 +61,7 @@ build_for_android() {
elif [ "${PLATFORM}" = "arm-v8a" ]; then elif [ "${PLATFORM}" = "arm-v8a" ]; then
ABI="arm64-v8a" ABI="arm64-v8a"
ARM_PLATFORM="V8" ARM_PLATFORM="V8"
CXX_FLAGS="-march=armv8-a -pie -fPIE -w -Wno-error=format-security -llog" CXX_FLAGS="-march=armv8-a -pie -fPIE -w -Wno-error=format-security -llog -fuse-ld=gold"
else else
echo "unknown platform!" echo "unknown platform!"
exit -1 exit -1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册