提交 502e70b0 编写于 作者: X xiaoting 提交者: lujun

remove run_ce and fix typo (#704)

* remove run_ce and fix typo

* remove run_ce for md
上级 35d1b92e
......@@ -311,7 +311,7 @@ train_reader = paddle.batch(
```python
if use_gpu:
exe = fluid.Executor(fluid.CUDAPlace(0))
else
else:
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
......@@ -390,7 +390,7 @@ for pass_id in range(epoch):
fetch_list={dg_loss})[0][0]
losses[1].append(dg_loss_n)
t_time += (time.time() - s_time)
if batch_id % 10 == 0 and not run_ce:
if batch_id % 10 == 0 :
if not os.path.exists(output):
os.makedirs(output)
# 每轮的生成结果
......
......@@ -41,7 +41,6 @@ add_arg('batch_size', int, 128, "Minibatch size.")
add_arg('epoch', int, 20, "The number of epoched to be trained.")
add_arg('output', str, "./output_dcgan", "The directory the model and the test result to be saved to.")
add_arg('use_gpu', bool, True, "Whether to use GPU to train.")
add_arg('run_ce', bool, False, "Whether to run for model ce.")
# yapf: enable
......@@ -52,9 +51,6 @@ def loss(x, label):
def train(args):
if args.run_ce:
np.random.seed(10)
fluid.default_startup_program().random_seed = 90
d_program = fluid.Program()
dg_program = fluid.Program()
......@@ -92,13 +88,9 @@ def train(args):
exe = fluid.Executor(fluid.CUDAPlace(0))
exe.run(fluid.default_startup_program())
if args.run_ce:
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=args.batch_size)
else:
train_reader = paddle.batch(
paddle.reader.shuffle(paddle.dataset.mnist.train(), buf_size=60000),
batch_size=args.batch_size)
train_reader = paddle.batch(
paddle.reader.shuffle(paddle.dataset.mnist.train(), buf_size=60000),
batch_size=args.batch_size)
NUM_TRAIN_TIMES_OF_DG = 2
const_n = np.random.uniform(
......@@ -155,7 +147,7 @@ def train(args):
fetch_list={dg_loss})[0][0]
losses[1].append(dg_loss_n)
t_time += (time.time() - s_time)
if batch_id % 10 == 0 and not args.run_ce:
if batch_id % 10 == 0:
if not os.path.exists(args.output):
os.makedirs(args.output)
# generate image each batch
......@@ -174,10 +166,6 @@ def train(args):
batch_id),
bbox_inches='tight')
plt.close(fig)
if args.run_ce:
print("kpis,dcgan_d_train_cost,{}".format(np.mean(losses[0])))
print("kpis,dcgan_g_train_cost,{}".format(np.mean(losses[1])))
print("kpis,dcgan_duration,{}".format(t_time / args.epoch))
if __name__ == "__main__":
......
......@@ -353,7 +353,7 @@ train_reader = paddle.batch(
```python
if use_gpu:
exe = fluid.Executor(fluid.CUDAPlace(0))
else
else:
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
......@@ -432,7 +432,7 @@ for pass_id in range(epoch):
fetch_list={dg_loss})[0][0]
losses[1].append(dg_loss_n)
t_time += (time.time() - s_time)
if batch_id % 10 == 0 and not run_ce:
if batch_id % 10 == 0 :
if not os.path.exists(output):
os.makedirs(output)
# 每轮的生成结果
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册