diff --git a/fluid/PaddleCV/image_classification/dist_train/dist_train.py b/fluid/PaddleCV/image_classification/dist_train/dist_train.py index 75213420ab7bbfc845e869645127ed49177d98b9..fd9877ab06771f77ff2909ca72b2e85ad8968aff 100644 --- a/fluid/PaddleCV/image_classification/dist_train/dist_train.py +++ b/fluid/PaddleCV/image_classification/dist_train/dist_train.py @@ -101,7 +101,7 @@ def prepare_reader(is_train, pyreader, args, pass_id=0): batch_size=bs)) def prepare_visreader(is_train, pyreader, args): - import datareader.example.imagenet_demo as imagenet + import visreader.example.imagenet_demo as imagenet def _parse_kv(r): """ parse kv data from sequence file for imagenet """ @@ -367,8 +367,9 @@ def train_parallel(args): test_ret = test_single(startup_exe, test_prog, args, test_pyreader,test_fetch_list) print("Pass: %d, Test Loss %s, test acc1: %s, test acc5: %s\n" % (pass_id, test_ret[0], test_ret[1], test_ret[2])) - - startup_exe.close() + # TODO(Yancey1989): need to fix on + if args.update_method == "pserver": + startup_exe.close() print("total train time: ", time.time() - over_all_start)