diff --git a/src/ios_io/PaddleMobile.h b/src/ios_io/PaddleMobile.h index 091a892ab9be086c6842eb52ef6266f7fba9ce96..1844bbf037a50339ecfe27bb09043a5a26d84222 100644 --- a/src/ios_io/PaddleMobile.h +++ b/src/ios_io/PaddleMobile.h @@ -29,6 +29,11 @@ */ - (BOOL)load:(NSString *)modelPath andWeightsPath:(NSString *)weighsPath; +/* + 加载散开形式的模型, 需传入模型的目录 +*/ +- (BOOL)load:(NSString *)modelAndWeightPath; + /* 进行预测, means 和 scale 为训练模型时的预处理参数, 如训练时没有做这些预处理则直接使用 predict */ diff --git a/src/ios_io/PaddleMobile.mm b/src/ios_io/PaddleMobile.mm index f5ec2afb2a996ec4932d99ea93362e06ddf28a14..e3ed909394a1057302fb0f747b582b944c89cc65 100644 --- a/src/ios_io/PaddleMobile.mm +++ b/src/ios_io/PaddleMobile.mm @@ -62,6 +62,15 @@ static std::mutex shared_mutex; } } +- (BOOL)load:(NSString *)modelAndWeightPath{ + std::string model_path_str = std::string([modelAndWeightPath UTF8String]); + if (loaded_ = pam_->Load(model_path_str)) { + return YES; + } else { + return NO; + } +} + -(void)preprocess:(const UInt8 *)input output:(float *)output imageWidth:(int)imageWidth imageHeight:(int)imageHeight imageChannels:(int)imageChannels means:(NSArray *)means scale:(float)scale dim:(std::vector)dim{ if (means == nil) { means = @[@0, @0, @0];