未验证 提交 6e30a1d7 编写于 作者: W wangna11BD 提交者: GitHub

fix tipc (#635)

* fix tipc

* fix tipc
上级 9a31bdf9
===========================train_params===========================
model_name:GPEN
python:python3.7
gpu_list:0
##
auto_cast:null
total_iters:lite_train_lite_infer=10
output_dir:./output/
snapshot_config.interval:lite_train_lite_infer=10
pretrained_model:null
train_model_name:gpen*/*checkpoint.pdparams
train_infer_img_dir:null
null:null
##
trainer:norm_train
norm_train:tools/main.py -c configs/gpen_256_ffhq.yaml --seed 100 -o log_config.interval=1
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:null
null:null
##
===========================infer_params===========================
--output_dir:./output/
load:null
norm_export:tools/export_model.py -c configs/gpen_256_ffhq.yaml --inputs_size=1,3,256,256 --model_name inference --load
quant_export:null
fpgm_export:null
distill_export:null
export1:null
export2:null
inference_dir:inference
train_model:./inference/gpen/gpenmodel_g_ema
infer_export:null
infer_quant:False
inference:tools/inference.py --model_type GPEN --seed 100 -c configs/gpen_256_ffhq.yaml --output_path test_tipc/output/ -o dataset.test.amount=5
--device:gpu
null:null
null:null
null:null
null:null
null:null
--model_path:
null:null
null:null
--benchmark:True
null:null
===========================train_params===========================
model_name:prenet
python:python3.7
gpu_list:0
##
auto_cast:null
total_iters:lite_train_lite_infer=10|lite_train_whole_infer=10|whole_train_whole_infer=200
output_dir:./output/
dataset.train.batch_size:lite_train_lite_infer=1|whole_train_whole_infer=1
pretrained_model:null
train_model_name:prenet*/*checkpoint.pdparams
train_infer_img_dir:./data/prenet/test
null:null
##
trainer:norm_train
norm_train:tools/main.py -c configs/prenet.yaml --seed 123 -o dataset.train.num_workers=0 log_config.interval=1 snapshot_config.interval=5
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:null
null:null
##
===========================infer_params===========================
--output_dir:./output/
load:null
norm_export:tools/export_model.py -c configs/prenet.yaml --inputs_size="-1,3,-1,-1" --model_name inference --load
quant_export:null
fpgm_export:null
distill_export:null
export1:null
export2:null
inference_dir:inference
train_model:./inference/prenet/prenet_generator
infer_export:null
infer_quant:False
inference:tools/inference.py --model_type prenet -c configs/prenet.yaml --seed 123 --output_path test_tipc/output/
--device:gpu
null:null
null:null
null:null
null:null
null:null
--model_path:
null:null
null:null
--benchmark:True
null:null
===========================train_benchmark_params==========================
batch_size:2|4
fp_items:fp32
total_iters:50
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[6,3,180,320]}]
......@@ -15,7 +15,8 @@ Linux端基础训练预测功能测试的主程序为`test_train_inference_pytho
| FOMM |FOMM | 生成 | 支持 | 多机多卡 | | |
| BasicVSR |BasicVSR | 超分 | 支持 | 多机多卡 | | |
|PP-MSVSR|PP-MSVSR | 超分|
|SinGAN|SinGAN | 生成|支持|
|edvr|edvr | 超分|支持|
|esrgan|esrgan | 超分|支持|
- 预测相关:预测功能汇总如下,
......
......@@ -54,9 +54,6 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
rm -rf ./data/ffhq*
wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/ffhq.tar --no-check-certificate
cd ./data/ && tar xf ffhq.tar && cd ../ ;;
GPEN)
rm -rf ./data/ffhq*
wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/ffhq.tar --no-check-certificate
cd ./data/ && tar xf ffhq.tar && cd ../ ;;
FOMM)
rm -rf ./data/fom_lite*
......@@ -110,10 +107,6 @@ elif [ ${MODE} = "lite_train_whole_infer" ];then
rm -rf ./data/ffhq*
wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/ffhq.tar --no-check-certificate
cd ./data/ && tar xf ffhq.tar && cd ../
elif [ ${model_name} == "GPEN" ]; then
rm -rf ./data/ffhq*
wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/ffhq.tar --no-check-certificate
cd ./data/ && tar xf ffhq.tar && cd ../
elif [ ${model_name} == "basicvsr" ]; then
rm -rf ./data/reds*
wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/reds_lite.tar --no-check-certificate
......
......@@ -19,7 +19,7 @@ from ppgan.metrics import build_metric
MODEL_CLASSES = ["pix2pix", "cyclegan", "wav2lip", "esrgan", \
"edvr", "fom", "stylegan2", "basicvsr", "msvsr", "singan","prenet","GPEN"]
"edvr", "fom", "stylegan2", "basicvsr", "msvsr", "singan"]
def parse_args():
......@@ -317,47 +317,6 @@ def main():
metric_file = os.path.join(args.output_path, "singan/metric.txt")
for metric in metrics.values():
metric.update(prediction, data['A'])
elif model_type == "prenet":
lq = data['lq'].numpy()
gt = data['gt'].numpy()
input_handles[0].copy_from_cpu(lq)
predictor.run()
prediction = output_handle.copy_to_cpu()
prediction = paddle.to_tensor(prediction)
gt = paddle.to_tensor(gt)
image_numpy = tensor2img(prediction, min_max)
gt_img = tensor2img(gt, min_max)
save_image(
image_numpy,
os.path.join(args.output_path, "prenet/{}.png".format(i)))
metric_file = os.path.join(args.output_path, "prenet/metric.txt")
for metric in metrics.values():
metric.update(image_numpy, gt_img)
elif model_type == "GPEN":
lq = data[0].numpy()
input_handles[0].copy_from_cpu(lq)
predictor.run()
prediction = output_handle.copy_to_cpu()
target = data[1].numpy()
metric_file = os.path.join(args.output_path, model_type,
"metric.txt")
for metric in metrics.values():
metric.update(prediction, target)
lq = paddle.to_tensor(lq)
target = paddle.to_tensor(target)
prediction = paddle.to_tensor(prediction)
lq = lq.transpose([0, 2, 3, 1])
target = target.transpose([0, 2, 3, 1])
prediction = prediction.transpose([0, 2, 3, 1])
sample_result = paddle.concat((lq[0], prediction[0], target[0]), 1)
sample = cv2.cvtColor((sample_result.numpy() + 1) / 2 * 255,
cv2.COLOR_RGB2BGR)
file_name = os.path.join(args.output_path, model_type,
"{}.png".format(i))
cv2.imwrite(file_name, sample)
if metrics:
log_file = open(metric_file, 'a')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册