未验证 提交 adfd9c8d 编写于 作者: Z zhengya01 提交者: GitHub

Merge pull request #20 from PaddlePaddle/develop

update
......@@ -20,7 +20,7 @@ cudaid=${deeplabv3plus_m:=0,1,2,3} # use 0,1,2,3 card as default
export CUDA_VISIBLE_DEVICES=$cudaid
FLAGS_benchmark=true python train.py \
--batch_size=2 \
--batch_size=8 \
--train_crop_size=769 \
--total_step=50 \
--save_weights_path=output4 \
......
......@@ -13,6 +13,8 @@ class GATrainer():
self.program = fluid.default_main_program().clone()
with fluid.program_guard(self.program):
self.fake_B = build_generator_resnet_9blocks(input_A, name="g_A")
#FIXME set persistable explicitly to pass CE
self.fake_B.persistable = True
self.fake_A = build_generator_resnet_9blocks(input_B, name="g_B")
self.cyc_A = build_generator_resnet_9blocks(self.fake_B, "g_B")
self.cyc_B = build_generator_resnet_9blocks(self.fake_A, "g_A")
......@@ -58,6 +60,8 @@ class GBTrainer():
with fluid.program_guard(self.program):
self.fake_B = build_generator_resnet_9blocks(input_A, name="g_A")
self.fake_A = build_generator_resnet_9blocks(input_B, name="g_B")
#FIXME set persistable explicitly to pass CE
self.fake_A.persistable = True
self.cyc_A = build_generator_resnet_9blocks(self.fake_B, "g_B")
self.cyc_B = build_generator_resnet_9blocks(self.fake_A, "g_A")
self.infer_program = self.program.clone()
......
......@@ -207,10 +207,14 @@ def validation(inference_program, avg_cost, s_probs, e_probs, match, feed_order,
"""
"""
build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = False
build_strategy.memory_optimize = False
parallel_executor = fluid.ParallelExecutor(
main_program=inference_program,
use_cuda=bool(args.use_gpu),
loss_name=avg_cost.name)
loss_name=avg_cost.name,
build_strategy=build_strategy)
print_para(inference_program, parallel_executor, logger, args)
# Use test set as validation each pass
......@@ -523,7 +527,7 @@ def evaluate(logger, args):
inference_program = main_program.clone(for_test=True)
eval_loss, bleu_rouge = validation(
inference_program, avg_cost, s_probs, e_probs, match,
inference_program, avg_cost, s_probs, e_probs, match,
feed_order, place, dev_count, vocab, brc_data, logger, args)
logger.info('Dev eval loss {}'.format(eval_loss))
logger.info('Dev eval result: {}'.format(bleu_rouge))
......
......@@ -29,7 +29,9 @@ Tagspace模型学习文本及标签的embedding表示,应用于工业级的标
## 数据下载及预处理
[ag news dataset](https://github.com/mhjabreel/CharCNN/tree/master/data/ag_news_csv)
数据地址: [ag news dataset](https://github.com/mhjabreel/CharCNN/tree/master/data/)
备份数据地址:[ag news dataset](https://paddle-tagspace.bj.bcebos.com/data.tar)
数据格式如下
......@@ -37,7 +39,7 @@ Tagspace模型学习文本及标签的embedding表示,应用于工业级的标
"3","Wall St. Bears Claw Back Into the Black (Reuters)","Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again."
```
将文本数据转为paddle数据,先将数据放到训练数据目录和测试数据目录
备份数据解压后,将文本数据转为paddle数据,先将数据放到训练数据目录和测试数据目录
```
mv train.csv raw_big_train_data
mv test.csv raw_big_test_data
......@@ -59,7 +61,7 @@ CUDA_VISIBLE_DEVICES=0 python train.py --use_cuda 1
```
CPU 环境
```
python train.py
python train.py
```
全量数据单机单卡训练
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册