未验证 提交 3766c9c2 编写于 作者: W wuyefeilin 提交者: GitHub

update train.py and benchmark (#374)

* update train.py

* update benchmark
上级 e99b0ac6
...@@ -96,7 +96,7 @@ def parse_args(): ...@@ -96,7 +96,7 @@ def parse_args():
dest='save_interval_iters', dest='save_interval_iters',
help='The interval iters for save a model snapshot', help='The interval iters for save a model snapshot',
type=int, type=int,
default=5) default=1000)
parser.add_argument( parser.add_argument(
'--save_dir', '--save_dir',
dest='save_dir', dest='save_dir',
......
...@@ -96,7 +96,7 @@ def parse_args(): ...@@ -96,7 +96,7 @@ def parse_args():
dest='save_interval_iters', dest='save_interval_iters',
help='The interval iters for save a model snapshot', help='The interval iters for save a model snapshot',
type=int, type=int,
default=5) default=1000)
parser.add_argument( parser.add_argument(
'--save_dir', '--save_dir',
dest='save_dir', dest='save_dir',
......
...@@ -17,7 +17,8 @@ import os ...@@ -17,7 +17,8 @@ import os
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.io import DataLoader from paddle.fluid.io import DataLoader
from paddle.incubate.hapi.distributed import DistributedBatchSampler # from paddle.incubate.hapi.distributed import DistributedBatchSampler
from paddle.io import DistributedBatchSampler
import dygraph.utils.logger as logger import dygraph.utils.logger as logger
from dygraph.utils import load_pretrained_model from dygraph.utils import load_pretrained_model
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册