未验证 提交 5a2b578e 编写于 作者: Z zhoujun 提交者: GitHub

Merge branch 'dygraph' into fix_doc

...@@ -50,6 +50,11 @@ int main(int argc, char **argv) { ...@@ -50,6 +50,11 @@ int main(int argc, char **argv) {
cv::Mat srcimg = cv::imread(img_path, cv::IMREAD_COLOR); cv::Mat srcimg = cv::imread(img_path, cv::IMREAD_COLOR);
if (!srcimg.data) {
std::cerr << "[ERROR] image read failed! image path: " << img_path << "\n";
exit(1);
}
DBDetector det(config.det_model_dir, config.use_gpu, config.gpu_id, DBDetector det(config.det_model_dir, config.use_gpu, config.gpu_id,
config.gpu_mem, config.cpu_math_library_num_threads, config.gpu_mem, config.cpu_math_library_num_threads,
config.use_mkldnn, config.max_side_len, config.det_db_thresh, config.use_mkldnn, config.max_side_len, config.det_db_thresh,
......
...@@ -89,8 +89,10 @@ def main(config, device, logger, vdl_writer): ...@@ -89,8 +89,10 @@ def main(config, device, logger, vdl_writer):
# load pretrain model # load pretrain model
pre_best_model_dict = init_model(config, model, logger, optimizer) pre_best_model_dict = init_model(config, model, logger, optimizer)
logger.info('train dataloader has {} iters, valid dataloader has {} iters'. logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
format(len(train_dataloader), len(valid_dataloader))) if valid_dataloader is not None:
logger.info('valid dataloader has {} iters'.format(
len(valid_dataloader)))
# start train # start train
program.train(config, train_dataloader, valid_dataloader, device, model, program.train(config, train_dataloader, valid_dataloader, device, model,
loss_class, optimizer, lr_scheduler, post_process_class, loss_class, optimizer, lr_scheduler, post_process_class,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册