未验证 提交 adfd9c8d 编写于 作者: Z zhengya01 提交者: GitHub

Merge pull request #20 from PaddlePaddle/develop

update
...@@ -20,7 +20,7 @@ cudaid=${deeplabv3plus_m:=0,1,2,3} # use 0,1,2,3 card as default ...@@ -20,7 +20,7 @@ cudaid=${deeplabv3plus_m:=0,1,2,3} # use 0,1,2,3 card as default
export CUDA_VISIBLE_DEVICES=$cudaid export CUDA_VISIBLE_DEVICES=$cudaid
FLAGS_benchmark=true python train.py \ FLAGS_benchmark=true python train.py \
--batch_size=2 \ --batch_size=8 \
--train_crop_size=769 \ --train_crop_size=769 \
--total_step=50 \ --total_step=50 \
--save_weights_path=output4 \ --save_weights_path=output4 \
......
...@@ -13,6 +13,8 @@ class GATrainer(): ...@@ -13,6 +13,8 @@ class GATrainer():
self.program = fluid.default_main_program().clone() self.program = fluid.default_main_program().clone()
with fluid.program_guard(self.program): with fluid.program_guard(self.program):
self.fake_B = build_generator_resnet_9blocks(input_A, name="g_A") self.fake_B = build_generator_resnet_9blocks(input_A, name="g_A")
#FIXME set persistable explicitly to pass CE
self.fake_B.persistable = True
self.fake_A = build_generator_resnet_9blocks(input_B, name="g_B") self.fake_A = build_generator_resnet_9blocks(input_B, name="g_B")
self.cyc_A = build_generator_resnet_9blocks(self.fake_B, "g_B") self.cyc_A = build_generator_resnet_9blocks(self.fake_B, "g_B")
self.cyc_B = build_generator_resnet_9blocks(self.fake_A, "g_A") self.cyc_B = build_generator_resnet_9blocks(self.fake_A, "g_A")
...@@ -58,6 +60,8 @@ class GBTrainer(): ...@@ -58,6 +60,8 @@ class GBTrainer():
with fluid.program_guard(self.program): with fluid.program_guard(self.program):
self.fake_B = build_generator_resnet_9blocks(input_A, name="g_A") self.fake_B = build_generator_resnet_9blocks(input_A, name="g_A")
self.fake_A = build_generator_resnet_9blocks(input_B, name="g_B") self.fake_A = build_generator_resnet_9blocks(input_B, name="g_B")
#FIXME set persistable explicitly to pass CE
self.fake_A.persistable = True
self.cyc_A = build_generator_resnet_9blocks(self.fake_B, "g_B") self.cyc_A = build_generator_resnet_9blocks(self.fake_B, "g_B")
self.cyc_B = build_generator_resnet_9blocks(self.fake_A, "g_A") self.cyc_B = build_generator_resnet_9blocks(self.fake_A, "g_A")
self.infer_program = self.program.clone() self.infer_program = self.program.clone()
......
...@@ -207,10 +207,14 @@ def validation(inference_program, avg_cost, s_probs, e_probs, match, feed_order, ...@@ -207,10 +207,14 @@ def validation(inference_program, avg_cost, s_probs, e_probs, match, feed_order,
""" """
""" """
build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = False
build_strategy.memory_optimize = False
parallel_executor = fluid.ParallelExecutor( parallel_executor = fluid.ParallelExecutor(
main_program=inference_program, main_program=inference_program,
use_cuda=bool(args.use_gpu), use_cuda=bool(args.use_gpu),
loss_name=avg_cost.name) loss_name=avg_cost.name,
build_strategy=build_strategy)
print_para(inference_program, parallel_executor, logger, args) print_para(inference_program, parallel_executor, logger, args)
# Use test set as validation each pass # Use test set as validation each pass
...@@ -523,7 +527,7 @@ def evaluate(logger, args): ...@@ -523,7 +527,7 @@ def evaluate(logger, args):
inference_program = main_program.clone(for_test=True) inference_program = main_program.clone(for_test=True)
eval_loss, bleu_rouge = validation( eval_loss, bleu_rouge = validation(
inference_program, avg_cost, s_probs, e_probs, match, inference_program, avg_cost, s_probs, e_probs, match,
feed_order, place, dev_count, vocab, brc_data, logger, args) feed_order, place, dev_count, vocab, brc_data, logger, args)
logger.info('Dev eval loss {}'.format(eval_loss)) logger.info('Dev eval loss {}'.format(eval_loss))
logger.info('Dev eval result: {}'.format(bleu_rouge)) logger.info('Dev eval result: {}'.format(bleu_rouge))
......
...@@ -29,7 +29,9 @@ Tagspace模型学习文本及标签的embedding表示,应用于工业级的标 ...@@ -29,7 +29,9 @@ Tagspace模型学习文本及标签的embedding表示,应用于工业级的标
## 数据下载及预处理 ## 数据下载及预处理
[ag news dataset](https://github.com/mhjabreel/CharCNN/tree/master/data/ag_news_csv) 数据地址: [ag news dataset](https://github.com/mhjabreel/CharCNN/tree/master/data/)
备份数据地址:[ag news dataset](https://paddle-tagspace.bj.bcebos.com/data.tar)
数据格式如下 数据格式如下
...@@ -37,7 +39,7 @@ Tagspace模型学习文本及标签的embedding表示,应用于工业级的标 ...@@ -37,7 +39,7 @@ Tagspace模型学习文本及标签的embedding表示,应用于工业级的标
"3","Wall St. Bears Claw Back Into the Black (Reuters)","Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again." "3","Wall St. Bears Claw Back Into the Black (Reuters)","Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again."
``` ```
将文本数据转为paddle数据,先将数据放到训练数据目录和测试数据目录 备份数据解压后,将文本数据转为paddle数据,先将数据放到训练数据目录和测试数据目录
``` ```
mv train.csv raw_big_train_data mv train.csv raw_big_train_data
mv test.csv raw_big_test_data mv test.csv raw_big_test_data
...@@ -59,7 +61,7 @@ CUDA_VISIBLE_DEVICES=0 python train.py --use_cuda 1 ...@@ -59,7 +61,7 @@ CUDA_VISIBLE_DEVICES=0 python train.py --use_cuda 1
``` ```
CPU 环境 CPU 环境
``` ```
python train.py python train.py
``` ```
全量数据单机单卡训练 全量数据单机单卡训练
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册