未验证 提交 763c5f06 编写于 作者: Z zhouzj 提交者: GitHub

Fix bug of seg demo (#1449)

上级 253c09b5
Global: Global:
reader_config: configs/dataset/cityscapes_1024x512_scale1.0.yml reader_config: configs/deeplabv3/deeplabv3_reader.yml
model_dir: ./RES-paddle2-Deeplabv3-ResNet50 model_dir: ./RES-paddle2-Deeplabv3-ResNet50
model_filename: model model_filename: model
params_filename: params params_filename: params
...@@ -11,15 +11,16 @@ Distillation: ...@@ -11,15 +11,16 @@ Distillation:
- conv2d_123.tmp_1 - conv2d_123.tmp_1
Quantization: Quantization:
onnx_format: True
quantize_op_types: quantize_op_types:
- conv2d - conv2d
- depthwise_conv2d - depthwise_conv2d
TrainConfig: TrainConfig:
epochs: 10 epochs: 1
eval_iter: 360 eval_iter: 360
learning_rate: 0.0001 learning_rate: 0.0001
optimizer_builder: optimizer_builder:
optimizer: optimizer:
type: SGD type: SGD
weight_decay: 0.0005 weight_decay: 0.0005
\ No newline at end of file
batch_size: 4
train_dataset:
type: Cityscapes
dataset_root: data/cityscapes
transforms:
- type: ResizeStepScaling
min_scale_factor: 0.5
max_scale_factor: 2.0
scale_step_size: 0.25
- type: RandomPaddingCrop
crop_size: [1024, 512]
- type: RandomHorizontalFlip
- type: RandomDistort
brightness_range: 0.5
contrast_range: 0.5
saturation_range: 0.5
- type: Normalize
mode: train
val_dataset:
type: Cityscapes
dataset_root: data/cityscapes
transforms:
- type: Normalize
mode: val
...@@ -137,6 +137,7 @@ def main(args): ...@@ -137,6 +137,7 @@ def main(args):
# step1: load dataset config and create dataloader # step1: load dataset config and create dataloader
data_cfg = PaddleSegDataConfig(config['reader_config']) data_cfg = PaddleSegDataConfig(config['reader_config'])
train_dataset = data_cfg.train_dataset train_dataset = data_cfg.train_dataset
global eval_dataset
eval_dataset = data_cfg.val_dataset eval_dataset = data_cfg.val_dataset
batch_sampler = paddle.io.DistributedBatchSampler( batch_sampler = paddle.io.DistributedBatchSampler(
train_dataset, train_dataset,
...@@ -163,7 +164,7 @@ def main(args): ...@@ -163,7 +164,7 @@ def main(args):
save_dir=args.save_dir, save_dir=args.save_dir,
config=all_config, config=all_config,
train_dataloader=train_dataloader, train_dataloader=train_dataloader,
eval_callback=eval_function if nranks > 1 and rank_id != 0 else None, eval_callback=eval_function if rank_id == 0 else None,
deploy_hardware=config.get('deploy_hardware') or None, deploy_hardware=config.get('deploy_hardware') or None,
input_shapes=config.get('input_shapes', None)) input_shapes=config.get('input_shapes', None))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册