提交 8c034956 编写于 作者: L lilong12 提交者: Yi Liu

fix the bug that the argument data_dir has no effect (#2415)

* modify single queue to a queue per process for data processing

* add transforms.py and rename torchvision_reader.py to reader.py

* add datasets.py

* remove pytorch apis

* fix some small bugs

* remove torch and torchvision from requirements.txt

* modify core.CUDAPlace to fluid.CUDAPlace

* bug fix: data_dir has no effect

* bug fix: data_dir has no effect

* bug fix: modify  in directory

* bug fix: fix a typo in README.md
上级 1230a298
......@@ -19,7 +19,7 @@ PaddlePaddle Fast ImageNet using the dynmiac batch size, dynamic image size, rec
|-train
`-validation
```
1. Install the requirements by `pip install -r requirement.txt`.
1. Install the requirements by `pip install -r requirements.txt`.
1. Launch the training job: `python train.py --data_dir /data/imagenet`
1. Learning curve, we launch the training job on V100 GPU card:
<p align="center">
......
......@@ -292,9 +292,9 @@ def refresh_program(args,
def prepare_reader(epoch_id, train_py_reader, train_bs, val_bs, trn_dir,
img_dim, min_scale, rect_val):
img_dim, min_scale, rect_val, args):
train_reader = reader.train(
traindir="/data/imagenet/%strain" % trn_dir,
traindir="%s/%strain" % (args.data_dir, trn_dir),
sz=img_dim,
min_scale=min_scale,
shuffle_seed=epoch_id + 1)
......@@ -303,7 +303,7 @@ def prepare_reader(epoch_id, train_py_reader, train_bs, val_bs, trn_dir,
train_reader, batch_size=train_bs))
test_reader = reader.test(
valdir="/data/imagenet/%svalidation" % trn_dir,
valdir="%s/%svalidation" % (args.data_dir, trn_dir),
bs=val_bs * DEVICE_NUM,
sz=img_dim,
rect_val=rect_val)
......@@ -324,7 +324,7 @@ def train_parallel(args):
## dynamic batch size, image size...
bs = 224
val_bs = 64
trn_dir = "sz/160/"
trn_dir = "160/"
img_dim = 128
min_scale = 0.08
rect_val = False
......@@ -341,7 +341,7 @@ def train_parallel(args):
need_update_start_prog=True)
elif epoch_id == 13: #13
bs = 96
trn_dir = "sz/352/"
trn_dir = "352/"
img_dim = 224
min_scale = 0.087
train_args, test_args, test_prog, exe, test_exe = refresh_program(
......@@ -384,7 +384,8 @@ def train_parallel(args):
trn_dir,
img_dim=img_dim,
min_scale=min_scale,
rect_val=rect_val)
rect_val=rect_val,
args=args)
train_py_reader.start() # start pyreader
batch_start_time = time.time()
while True:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册