提交 c1ed3365 编写于 作者: Y Yanjun Peng

fix batch repeat usage order

上级 5aafe2f0
......@@ -129,14 +129,14 @@ tar -zvxf cifar-10-binary.tar.gz
Shuffle data randomly to disorder the data sequence and read data in batches for model training:
```python
# apply repeat operations
cifar_ds = cifar_ds.repeat(repeat_num)
# apply shuffle operations
cifar_ds = cifar_ds.shuffle(buffer_size=10)
# apply batch operations
cifar_ds = cifar_ds.batch(batch_size=args_opt.batch_size, drop_remainder=True)
# apply repeat operations
cifar_ds = cifar_ds.repeat(repeat_num)
```
......
......@@ -145,15 +145,15 @@ def create_dataset(repeat_num=1, batch_size=32, rank_id=0, rank_size=1):
data_set = data_set.map(input_columns="label", operations=type_cast_op)
data_set = data_set.map(input_columns="image", operations=c_trans)
# apply repeat operations
data_set = data_set.repeat(repeat_num)
# apply shuffle operations
data_set = data_set.shuffle(buffer_size=10)
# apply batch operations
data_set = data_set.batch(batch_size=batch_size, drop_remainder=True)
# apply repeat operations
data_set = data_set.repeat(repeat_num)
return data_set
```
......
......@@ -131,14 +131,14 @@ tar -zvxf cifar-10-binary.tar.gz
最后通过数据混洗(shuffle)随机打乱数据的顺序,并按batch读取数据,进行模型训练:
```python
# apply repeat operations
cifar_ds = cifar_ds.repeat(repeat_num)
# apply shuffle operations
cifar_ds = cifar_ds.shuffle(buffer_size=10)
# apply batch operations
cifar_ds = cifar_ds.batch(batch_size=args_opt.batch_size, drop_remainder=True)
# apply repeat operations
cifar_ds = cifar_ds.repeat(repeat_num)
```
......
......@@ -144,15 +144,15 @@ def create_dataset(repeat_num=1, batch_size=32, rank_id=0, rank_size=1):
data_set = data_set.map(input_columns="label", operations=type_cast_op)
data_set = data_set.map(input_columns="image", operations=c_trans)
# apply repeat operations
data_set = data_set.repeat(repeat_num)
# apply shuffle operations
data_set = data_set.shuffle(buffer_size=10)
# apply batch operations
data_set = data_set.batch(batch_size=batch_size, drop_remainder=True)
# apply repeat operations
data_set = data_set.repeat(repeat_num)
return data_set
```
......
......@@ -69,15 +69,15 @@ def create_dataset(repeat_num=1, batch_size=32, rank_id=0, rank_size=1):
data_set = data_set.map(input_columns="label", operations=type_cast_op)
data_set = data_set.map(input_columns="image", operations=c_trans)
# apply repeat operations
data_set = data_set.repeat(repeat_num)
# apply shuffle operations
data_set = data_set.shuffle(buffer_size=10)
# apply batch operations
data_set = data_set.batch(batch_size=batch_size, drop_remainder=True)
# apply repeat operations
data_set = data_set.repeat(repeat_num)
return data_set
......
......@@ -91,15 +91,15 @@ def create_dataset(repeat_num=1, training=True):
cifar_ds = cifar_ds.map(input_columns="label", operations=type_cast_op)
cifar_ds = cifar_ds.map(input_columns="image", operations=c_trans)
# apply repeat operations
cifar_ds = cifar_ds.repeat(repeat_num)
# apply shuffle operations
cifar_ds = cifar_ds.shuffle(buffer_size=10)
# apply batch operations
cifar_ds = cifar_ds.batch(batch_size=args_opt.batch_size, drop_remainder=True)
# apply repeat operations
cifar_ds = cifar_ds.repeat(repeat_num)
return cifar_ds
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册