提交 433e325b 编写于 作者: W wuyefeilin 提交者: wuzewu

update loss_select.md and model_builder.py (#67)

* update model_builder.py

* update loss_select.md
上级 411b8fce
......@@ -55,13 +55,13 @@ python pdseg/check.py --cfg ./configs/deepglobe_road_extraction.yaml
* 训练
```shell
python pdseg/train.py --cfg ./configs/deepglobe_road_extraction.yaml --use_gpu SOLVER.LOSS ['dice_loss','bce_loss']
python pdseg/train.py --cfg ./configs/deepglobe_road_extraction.yaml --use_gpu SOLVER.LOSS "['dice_loss','bce_loss']"
```
* 评估
```
python pdseg/eval.py --cfg ./configs/deepglobe_road_extraction.yaml --use_gpu SOLVER.LOSS ['dice_loss','bce_loss']
python pdseg/eval.py --cfg ./configs/deepglobe_road_extraction.yaml --use_gpu SOLVER.LOSS "['dice_loss','bce_loss']"
```
......
......@@ -158,6 +158,8 @@ def build_model(main_prog, start_prog, phase=ModelPhase.TRAIN):
model_func = get_func("modeling." + model_name)
loss_type = cfg.SOLVER.LOSS
if not isinstance(loss_type, list):
loss_type = list(loss_type)
if class_num > 2 and (("dice_loss" in loss_type) or ("bce_loss" in loss_type)):
raise Exception("dice loss and bce loss is only applicable to binary classfication")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册