未验证 提交 a2a5c8b0 编写于 作者: Y Yanzhan Yang 提交者: GitHub

refine paddle_inference_api.h test=develop (#2048)

上级 eb42f9ee
......@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "common/type_define.h"
#include "cstring"
#include "io/paddle_inference_api.h"
......
......@@ -18,6 +18,7 @@
#include <utility>
#include <vector>
#include "common/enforce.h"
#include "common/type_define.h"
#include "framework/tensor.h"
#ifdef PADDLE_MOBILE_FPGA
#include <fpga/common/fpga_common.h>
......@@ -35,7 +36,9 @@ PaddleMobilePredictor<Device, T>::PaddleMobilePredictor(
template <typename Device, typename T>
bool PaddleMobilePredictor<Device, T>::Init(const PaddleMobileConfig &config) {
paddle_mobile_.reset(new PaddleMobile<Device, T>());
PaddleMobileConfigInternal configInternal;
configInternal.load_when_predict = config.load_when_predict;
paddle_mobile_.reset(new PaddleMobile<Device, T>(configInternal));
#ifdef PADDLE_MOBILE_CL
paddle_mobile_->SetCLPath(config.cl_path);
#endif
......@@ -135,14 +138,14 @@ bool PaddleMobilePredictor<Device, T>::Run(
void ConvertPaddleTensors(const PaddleTensor &src, framework::Tensor *des) {
des->Resize(framework::make_ddim(src.shape));
des->external_data = src.data.data();
des->set_type(src.dtypeid);
des->set_type(static_cast<kTypeId_t>(static_cast<int>(src.dtypeid)));
des->layout =
src.layout == LAYOUT_HWC ? framework::LAYOUT_HWC : framework::LAYOUT_CHW;
}
void ConvertTensors(const framework::Tensor &src, PaddleTensor *des) {
des->shape = framework::vectorize2int(src.dims());
des->dtypeid = src.type();
des->dtypeid = static_cast<PaddlekTypeId_t>(static_cast<int>(src.type()));
des->layout = src.layout == framework::LAYOUT_HWC ? LAYOUT_HWC : LAYOUT_CHW;
auto num = src.numel();
......@@ -164,7 +167,8 @@ void PaddleMobilePredictor<Device, T>::FeedPaddleTensors(
auto num = inputs.size();
std::vector<framework::Tensor> tensors(num, framework::Tensor());
for (int i = 0; i < num; i++) {
if (inputs[i].dtypeid == type_id<int8_t>().hash_code()) {
if (static_cast<kTypeId_t>(static_cast<int>(inputs[i].dtypeid)) ==
type_id<int8_t>().hash_code()) {
tensors[i].init(type_id<int8_t>().hash_code());
} else {
tensors[i].init(type_id<float>().hash_code());
......
......@@ -25,7 +25,6 @@ limitations under the License. */
#include <memory>
#include <string>
#include <vector>
#include "common/type_define.h"
namespace paddle_mobile {
......@@ -86,6 +85,56 @@ class PaddleBuf {
bool memory_owned_{true};
};
typedef enum {
paddle_void = 0,
paddle_float,
paddle_int,
paddle_uint16_t,
paddle_double,
paddle_int64_t,
paddle_size_t,
paddle_int16_t,
paddle_int8_t,
paddle_uint8_t,
paddle_bool,
paddle_string,
paddle_floats = 100,
paddle_ints,
paddle_int64_ts,
paddle_size_ts,
paddle_bools,
paddle_strings,
paddle_const_float = 200,
paddle_const_int,
paddle_block = 300,
paddle_tensor,
paddle_lod_tensor,
paddle_blocks,
paddle_tensors,
paddle_lod_tensors,
paddle_p_block = 400,
paddle_p_tensor,
paddle_p_lod_tensor,
paddle_p_blocks,
paddle_p_tensors,
paddle_p_lod_tensors,
paddle_scopes = 500,
paddle_selected_rows,
paddle_dim0 = 600,
paddle_dim1,
paddle_dim2,
paddle_dim3,
paddle_dim4,
paddle_dim5,
paddle_dim6,
paddle_dim7,
paddle_dim8,
paddle_dim9,
#ifdef PADDLE_MOBILE_CL
paddle_cl_image,
#endif
} PaddlekTypeId_t;
struct PaddleTensor {
PaddleTensor() = default;
std::string name; // variable name.
......@@ -93,7 +142,7 @@ struct PaddleTensor {
std::vector<int> lod;
PaddleBuf data; // blob of data.
PaddleDType dtype;
kTypeId_t dtypeid;
PaddlekTypeId_t dtypeid;
LayoutType layout;
};
......@@ -166,6 +215,7 @@ struct PaddleMobileConfig : public PaddlePredictor::Config {
bool quantification = false;
bool lod_mode = false;
int thread_num = 1;
bool load_when_predict = false;
std::string cl_path;
struct PaddleModelMemoryPack memory_pack;
};
......
......@@ -61,7 +61,7 @@ build_for_android() {
elif [ "${PLATFORM}" = "arm-v8a" ]; then
ABI="arm64-v8a"
ARM_PLATFORM="V8"
CXX_FLAGS="-march=armv8-a -pie -fPIE -w -Wno-error=format-security -llog"
CXX_FLAGS="-march=armv8-a -pie -fPIE -w -Wno-error=format-security -llog -fuse-ld=gold"
else
echo "unknown platform!"
exit -1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册