提交 502e70b0 编写于 作者: X xiaoting 提交者: lujun

remove run_ce and fix typo (#704)

* remove run_ce and fix typo

* remove run_ce for md
上级 35d1b92e
...@@ -311,7 +311,7 @@ train_reader = paddle.batch( ...@@ -311,7 +311,7 @@ train_reader = paddle.batch(
```python ```python
if use_gpu: if use_gpu:
exe = fluid.Executor(fluid.CUDAPlace(0)) exe = fluid.Executor(fluid.CUDAPlace(0))
else else:
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
...@@ -390,7 +390,7 @@ for pass_id in range(epoch): ...@@ -390,7 +390,7 @@ for pass_id in range(epoch):
fetch_list={dg_loss})[0][0] fetch_list={dg_loss})[0][0]
losses[1].append(dg_loss_n) losses[1].append(dg_loss_n)
t_time += (time.time() - s_time) t_time += (time.time() - s_time)
if batch_id % 10 == 0 and not run_ce: if batch_id % 10 == 0 :
if not os.path.exists(output): if not os.path.exists(output):
os.makedirs(output) os.makedirs(output)
# 每轮的生成结果 # 每轮的生成结果
......
...@@ -41,7 +41,6 @@ add_arg('batch_size', int, 128, "Minibatch size.") ...@@ -41,7 +41,6 @@ add_arg('batch_size', int, 128, "Minibatch size.")
add_arg('epoch', int, 20, "The number of epoched to be trained.") add_arg('epoch', int, 20, "The number of epoched to be trained.")
add_arg('output', str, "./output_dcgan", "The directory the model and the test result to be saved to.") add_arg('output', str, "./output_dcgan", "The directory the model and the test result to be saved to.")
add_arg('use_gpu', bool, True, "Whether to use GPU to train.") add_arg('use_gpu', bool, True, "Whether to use GPU to train.")
add_arg('run_ce', bool, False, "Whether to run for model ce.")
# yapf: enable # yapf: enable
...@@ -52,9 +51,6 @@ def loss(x, label): ...@@ -52,9 +51,6 @@ def loss(x, label):
def train(args): def train(args):
if args.run_ce:
np.random.seed(10)
fluid.default_startup_program().random_seed = 90
d_program = fluid.Program() d_program = fluid.Program()
dg_program = fluid.Program() dg_program = fluid.Program()
...@@ -92,13 +88,9 @@ def train(args): ...@@ -92,13 +88,9 @@ def train(args):
exe = fluid.Executor(fluid.CUDAPlace(0)) exe = fluid.Executor(fluid.CUDAPlace(0))
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
if args.run_ce: train_reader = paddle.batch(
train_reader = paddle.batch( paddle.reader.shuffle(paddle.dataset.mnist.train(), buf_size=60000),
paddle.dataset.mnist.train(), batch_size=args.batch_size) batch_size=args.batch_size)
else:
train_reader = paddle.batch(
paddle.reader.shuffle(paddle.dataset.mnist.train(), buf_size=60000),
batch_size=args.batch_size)
NUM_TRAIN_TIMES_OF_DG = 2 NUM_TRAIN_TIMES_OF_DG = 2
const_n = np.random.uniform( const_n = np.random.uniform(
...@@ -155,7 +147,7 @@ def train(args): ...@@ -155,7 +147,7 @@ def train(args):
fetch_list={dg_loss})[0][0] fetch_list={dg_loss})[0][0]
losses[1].append(dg_loss_n) losses[1].append(dg_loss_n)
t_time += (time.time() - s_time) t_time += (time.time() - s_time)
if batch_id % 10 == 0 and not args.run_ce: if batch_id % 10 == 0:
if not os.path.exists(args.output): if not os.path.exists(args.output):
os.makedirs(args.output) os.makedirs(args.output)
# generate image each batch # generate image each batch
...@@ -174,10 +166,6 @@ def train(args): ...@@ -174,10 +166,6 @@ def train(args):
batch_id), batch_id),
bbox_inches='tight') bbox_inches='tight')
plt.close(fig) plt.close(fig)
if args.run_ce:
print("kpis,dcgan_d_train_cost,{}".format(np.mean(losses[0])))
print("kpis,dcgan_g_train_cost,{}".format(np.mean(losses[1])))
print("kpis,dcgan_duration,{}".format(t_time / args.epoch))
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -353,7 +353,7 @@ train_reader = paddle.batch( ...@@ -353,7 +353,7 @@ train_reader = paddle.batch(
```python ```python
if use_gpu: if use_gpu:
exe = fluid.Executor(fluid.CUDAPlace(0)) exe = fluid.Executor(fluid.CUDAPlace(0))
else else:
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
...@@ -432,7 +432,7 @@ for pass_id in range(epoch): ...@@ -432,7 +432,7 @@ for pass_id in range(epoch):
fetch_list={dg_loss})[0][0] fetch_list={dg_loss})[0][0]
losses[1].append(dg_loss_n) losses[1].append(dg_loss_n)
t_time += (time.time() - s_time) t_time += (time.time() - s_time)
if batch_id % 10 == 0 and not run_ce: if batch_id % 10 == 0 :
if not os.path.exists(output): if not os.path.exists(output):
os.makedirs(output) os.makedirs(output)
# 每轮的生成结果 # 每轮的生成结果
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册