From 57ce5a1bb6f318e0b79a133253ab01dc95a88c42 Mon Sep 17 00:00:00 2001
From: FlyingQianMM <245467267@qq.com>
Date: Tue, 1 Sep 2020 06:39:04 +0000
Subject: [PATCH] modify some details for the reviews
---
docs/apis/models/semantic_segmentation.md | 13 ++++++++-----
.../multi-channel_remote_sensing/README.md | 6 +++---
examples/multi-channel_remote_sensing/README.md | 6 +++---
paddlex/cv/datasets/analysis.py | 4 ++--
paddlex/cv/datasets/imagenet.py | 4 ++++
paddlex/cv/datasets/seg_dataset.py | 4 ++--
paddlex/cv/datasets/voc.py | 4 ++++
paddlex/cv/models/deeplabv3p.py | 12 +++++++-----
paddlex/cv/models/fast_scnn.py | 12 +++++++-----
paddlex/cv/models/hrnet.py | 11 ++++++-----
paddlex/cv/models/unet.py | 11 ++++++-----
paddlex/cv/transforms/seg_transforms.py | 15 +++++++++------
12 files changed, 61 insertions(+), 41 deletions(-)
diff --git a/docs/apis/models/semantic_segmentation.md b/docs/apis/models/semantic_segmentation.md
index 82b758d..5841676 100755
--- a/docs/apis/models/semantic_segmentation.md
+++ b/docs/apis/models/semantic_segmentation.md
@@ -3,8 +3,7 @@
## paddlex.seg.DeepLabv3p
```python
-paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride=16, aspp_with_sep_conv=True, decoder_use_sep_conv=True, encoder_with_aspp=True, enable_decoder=True, use_bce_loss=False, use_dice_loss=False, class_weight=None, ignore_index=255, pooling_crop_size=None)
-
+paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride=16, aspp_with_sep_conv=True, decoder_use_sep_conv=True, encoder_with_aspp=True, enable_decoder=True, use_bce_loss=False, use_dice_loss=False, class_weight=None, ignore_index=255, pooling_crop_size=None, input_channel=3)
```
> 构建DeepLabv3p分割器。
@@ -23,6 +22,7 @@ paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride
> > - **class_weight** (list/str): 交叉熵损失函数各类损失的权重。当`class_weight`为list的时候,长度应为`num_classes`。当`class_weight`为str时, weight.lower()应为'dynamic',这时会根据每一轮各类像素的比重自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None是,各类的权重1,即平时使用的交叉熵损失函数。
> > - **ignore_index** (int): label上忽略的值,label为`ignore_index`的像素不参与损失函数的计算。默认255。
> > - **pooling_crop_size** (int):当backbone为`MobileNetV3_large_x1_0_ssld`时,需设置为训练过程中模型输入大小,格式为[W, H]。例如模型输入大小为[512, 512], 则`pooling_crop_size`应该设置为[512, 512]。在encoder模块中获取图像平均值时被用到,若为None,则直接求平均值;若为模型输入大小,则使用`avg_pool`算子得到平均值。默认值None。
+> > - **input_channel** (int): 输入图像通道数。默认值3。
### train
@@ -115,7 +115,7 @@ batch_predict(self, img_file_list, transforms=None, thread_num=2):
## paddlex.seg.UNet
```python
-paddlex.seg.UNet(num_classes=2, upsample_mode='bilinear', use_bce_loss=False, use_dice_loss=False, class_weight=None, ignore_index=255)
+paddlex.seg.UNet(num_classes=2, upsample_mode='bilinear', use_bce_loss=False, use_dice_loss=False, class_weight=None, ignore_index=255, input_channel=3)
```
> 构建UNet分割器。
@@ -128,6 +128,7 @@ paddlex.seg.UNet(num_classes=2, upsample_mode='bilinear', use_bce_loss=False, us
> > - **use_dice_loss** (bool): 是否使用dice loss作为网络的损失函数,只能用于两类分割,可与bce loss同时使用。当use_bce_loss和use_dice_loss都为False时,使用交叉熵损失函数。默认False。
> > - **class_weight** (list/str): 交叉熵损失函数各类损失的权重。当`class_weight`为list的时候,长度应为`num_classes`。当`class_weight`为str时, weight.lower()应为'dynamic',这时会根据每一轮各类像素的比重自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None是,各类的权重1,即平时使用的交叉熵损失函数。
> > - **ignore_index** (int): label上忽略的值,label为`ignore_index`的像素不参与损失函数的计算。默认255。
+> > - **input_channel** (int): 输入图像通道数。默认值3。
> - train 训练接口说明同 [DeepLabv3p模型train接口](#train)
> - evaluate 评估接口说明同 [DeepLabv3p模型evaluate接口](#evaluate)
@@ -137,7 +138,7 @@ paddlex.seg.UNet(num_classes=2, upsample_mode='bilinear', use_bce_loss=False, us
## paddlex.seg.HRNet
```python
-paddlex.seg.HRNet(num_classes=2, width=18, use_bce_loss=False, use_dice_loss=False, class_weight=None, ignore_index=255)
+paddlex.seg.HRNet(num_classes=2, width=18, use_bce_loss=False, use_dice_loss=False, class_weight=None, ignore_index=255, input_channel=3)
```
> 构建HRNet分割器。
@@ -150,6 +151,7 @@ paddlex.seg.HRNet(num_classes=2, width=18, use_bce_loss=False, use_dice_loss=Fal
> > - **use_dice_loss** (bool): 是否使用dice loss作为网络的损失函数,只能用于两类分割,可与bce loss同时使用。当use_bce_loss和use_dice_loss都为False时,使用交叉熵损失函数。默认False。
> > - **class_weight** (list|str): 交叉熵损失函数各类损失的权重。当`class_weight`为list的时候,长度应为`num_classes`。当`class_weight`为str时, weight.lower()应为'dynamic',这时会根据每一轮各类像素的比重自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None是,各类的权重1,即平时使用的交叉熵损失函数。
> > - **ignore_index** (int): label上忽略的值,label为`ignore_index`的像素不参与损失函数的计算。默认255。
+> > - **input_channel** (int): 输入图像通道数。默认值3。
> - train 训练接口说明同 [DeepLabv3p模型train接口](#train)
> - evaluate 评估接口说明同 [DeepLabv3p模型evaluate接口](#evaluate)
@@ -159,7 +161,7 @@ paddlex.seg.HRNet(num_classes=2, width=18, use_bce_loss=False, use_dice_loss=Fal
## paddlex.seg.FastSCNN
```python
-paddlex.seg.FastSCNN(num_classes=2, use_bce_loss=False, use_dice_loss=False, class_weight=None, ignore_index=255, multi_loss_weight=[1.0])
+paddlex.seg.FastSCNN(num_classes=2, use_bce_loss=False, use_dice_loss=False, class_weight=None, ignore_index=255, multi_loss_weight=[1.0], input_channel=3)
```
> 构建FastSCNN分割器。
@@ -172,6 +174,7 @@ paddlex.seg.FastSCNN(num_classes=2, use_bce_loss=False, use_dice_loss=False, cla
> > - **class_weight** (list/str): 交叉熵损失函数各类损失的权重。当`class_weight`为list的时候,长度应为`num_classes`。当`class_weight`为str时, weight.lower()应为'dynamic',这时会根据每一轮各类像素的比重自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None是,各类的权重1,即平时使用的交叉熵损失函数。
> > - **ignore_index** (int): label上忽略的值,label为`ignore_index`的像素不参与损失函数的计算。默认255。
> > - **multi_loss_weight** (list): 多分支上的loss权重。默认计算一个分支上的loss,即默认值为[1.0]。也支持计算两个分支或三个分支上的loss,权重按[fusion_branch_weight, higher_branch_weight, lower_branch_weight]排列,fusion_branch_weight为空间细节分支和全局上下文分支融合后的分支上的loss权重,higher_branch_weight为空间细节分支上的loss权重,lower_branch_weight为全局上下文分支上的loss权重,若higher_branch_weight和lower_branch_weight未设置则不会计算这两个分支上的loss。
+> > - **input_channel** (int): 输入图像通道数。默认值3。
> - train 训练接口说明同 [DeepLabv3p模型train接口](#train)
> - evaluate 评估接口说明同 [DeepLabv3p模型evaluate接口](#evaluate)
diff --git a/docs/examples/multi-channel_remote_sensing/README.md b/docs/examples/multi-channel_remote_sensing/README.md
index 146379d..1a46a61 100644
--- a/docs/examples/multi-channel_remote_sensing/README.md
+++ b/docs/examples/multi-channel_remote_sensing/README.md
@@ -69,7 +69,7 @@ cd ..
参考文档[数据分析](./analysis.md)对训练集进行统计分析,确定图像像素值的截断范围,并统计截断后的均值和方差。
## 模型训练
-本案例选择`UNet`语义分割模型完成云雪分割,运行以下步骤完成模型训练,模型的最优精度`miou`为`77.99%`。
+本案例选择`UNet`语义分割模型完成云雪分割,运行以下步骤完成模型训练,模型的最优精度`miou`为`78.38%`。
* 设置GPU卡号
```shell script
@@ -88,8 +88,8 @@ python train.py --data_dir dataset/remote_sensing_seg \
--lr 0.01 \
--clip_min_value 7172 6561 5777 5103 4291 4000 4000 4232 6934 7199 \
--clip_max_value 50000 50000 50000 50000 50000 40000 30000 18000 40000 36000 \
---mean 0.14311188522260637 0.14288498042151332 0.14812997807748615 0.16377211813814938 0.2737538363784552 0.2740934379398823 0.27749601919204 0.07767443032935262 0.5694699410349131 0.5549716085195542 \
---std 0.09101632762467489 0.09600705942721106 0.096193618606776 0.10371446736389771 0.10911951586604118 0.11043593115173281 0.12648042598739268 0.027746262217260665 0.06822348076384514 0.062377591186668725 \
+--mean 0.15163569 0.15142828 0.15574491 0.1716084 0.2799778 0.27652043 0.28195933 0.07853807 0.56333154 0.5477584 \
+--std 0.09301891 0.09818967 0.09831126 0.1057784 0.10842132 0.11062996 0.12791838 0.02637859 0.0675052 0.06168227 \
--num_epochs 500 \
--train_batch_size 3
```
diff --git a/examples/multi-channel_remote_sensing/README.md b/examples/multi-channel_remote_sensing/README.md
index d5e699a..8554e3d 100644
--- a/examples/multi-channel_remote_sensing/README.md
+++ b/examples/multi-channel_remote_sensing/README.md
@@ -84,7 +84,7 @@ cd ..
##
模型训练
-本案例选择`UNet`语义分割模型完成云雪分割,运行以下步骤完成模型训练,模型的最优精度`miou`为`77.99%`。
+本案例选择`UNet`语义分割模型完成云雪分割,运行以下步骤完成模型训练,模型的最优精度`miou`为`78.38%`。
* 设置GPU卡号
```shell script
@@ -103,8 +103,8 @@ python train.py --data_dir dataset/remote_sensing_seg \
--lr 0.01 \
--clip_min_value 7172 6561 5777 5103 4291 4000 4000 4232 6934 7199 \
--clip_max_value 50000 50000 50000 50000 50000 40000 30000 18000 40000 36000 \
---mean 0.14311188522260637 0.14288498042151332 0.14812997807748615 0.16377211813814938 0.2737538363784552 0.2740934379398823 0.27749601919204 0.07767443032935262 0.5694699410349131 0.5549716085195542 \
---std 0.09101632762467489 0.09600705942721106 0.096193618606776 0.10371446736389771 0.10911951586604118 0.11043593115173281 0.12648042598739268 0.027746262217260665 0.06822348076384514 0.062377591186668725 \
+--mean 0.15163569 0.15142828 0.15574491 0.1716084 0.2799778 0.27652043 0.28195933 0.07853807 0.56333154 0.5477584 \
+--std 0.09301891 0.09818967 0.09831126 0.1057784 0.10842132 0.11062996 0.12791838 0.02637859 0.0675052 0.06168227 \
--num_epochs 500 \
--train_batch_size 3
```
diff --git a/paddlex/cv/datasets/analysis.py b/paddlex/cv/datasets/analysis.py
index df9c75a..be1d58f 100644
--- a/paddlex/cv/datasets/analysis.py
+++ b/paddlex/cv/datasets/analysis.py
@@ -40,11 +40,11 @@ class Seg:
with open(file_list, encoding=get_encoding(file_list)) as f:
for line in f:
- if line.count(" ") > 1:
+ items = line.strip().split()
+ if len(items) > 2:
raise Exception(
"A space is defined as the separator, but it exists in image or label name {}."
.format(line))
- items = line.strip().split()
items[0] = path_normalization(items[0])
items[1] = path_normalization(items[1])
full_path_im = osp.join(data_dir, items[0])
diff --git a/paddlex/cv/datasets/imagenet.py b/paddlex/cv/datasets/imagenet.py
index ea93d58..9b0f2e7 100644
--- a/paddlex/cv/datasets/imagenet.py
+++ b/paddlex/cv/datasets/imagenet.py
@@ -67,6 +67,10 @@ class ImageNet(Dataset):
with open(file_list, encoding=get_encoding(file_list)) as f:
for line in f:
items = line.strip().split()
+ if len(items):
+ raise Exception(
+ "A space is defined as the separator, but it exists in image or label name {}."
+ .format(line))
items[0] = path_normalization(items[0])
if not is_pic(items[0]):
continue
diff --git a/paddlex/cv/datasets/seg_dataset.py b/paddlex/cv/datasets/seg_dataset.py
index cc80fc1..9a1c049 100644
--- a/paddlex/cv/datasets/seg_dataset.py
+++ b/paddlex/cv/datasets/seg_dataset.py
@@ -63,11 +63,11 @@ class SegDataset(Dataset):
self.labels.append(item)
with open(file_list, encoding=get_encoding(file_list)) as f:
for line in f:
- if line.count(" ") > 1:
+ items = line.strip().split()
+ if len(items) > 2:
raise Exception(
"A space is defined as the separator, but it exists in image or label name {}."
.format(line))
- items = line.strip().split()
items[0] = path_normalization(items[0])
items[1] = path_normalization(items[1])
full_path_im = osp.join(data_dir, items[0])
diff --git a/paddlex/cv/datasets/voc.py b/paddlex/cv/datasets/voc.py
index fae619b..8143501 100644
--- a/paddlex/cv/datasets/voc.py
+++ b/paddlex/cv/datasets/voc.py
@@ -91,6 +91,10 @@ class VOCDetection(Dataset):
line = fr.readline()
if not line:
break
+ if len(line.strip().split()) > 2:
+ raise Exception(
+ "A space is defined as the separator, but it exists in image or label name {}."
+ .format(line))
img_file, xml_file = [osp.join(data_dir, x) \
for x in line.strip().split()[:2]]
img_file = path_normalization(img_file)
diff --git a/paddlex/cv/models/deeplabv3p.py b/paddlex/cv/models/deeplabv3p.py
index f9c7629..9371859 100644
--- a/paddlex/cv/models/deeplabv3p.py
+++ b/paddlex/cv/models/deeplabv3p.py
@@ -54,6 +54,8 @@ class DeepLabv3p(BaseAPI):
pooling_crop_size (list): 当backbone为MobileNetV3_large_x1_0_ssld时,需设置为训练过程中模型输入大小, 格式为[W, H]。
在encoder模块中获取图像平均值时被用到,若为None,则直接求平均值;若为模型输入大小,则使用'pool'算子得到平均值。
默认值为None。
+ input_channel (int): 输入图像通道数。默认值3。
+
Raises:
ValueError: use_bce_loss或use_dice_loss为真且num_calsses > 2。
ValueError: backbone取值不在['Xception65', 'Xception41', 'MobileNetV2_x0.25',
@@ -65,7 +67,6 @@ class DeepLabv3p(BaseAPI):
def __init__(self,
num_classes=2,
- input_channel=3,
backbone='MobileNetV2_x1.0',
output_stride=16,
aspp_with_sep_conv=True,
@@ -76,7 +77,8 @@ class DeepLabv3p(BaseAPI):
use_dice_loss=False,
class_weight=None,
ignore_index=255,
- pooling_crop_size=None):
+ pooling_crop_size=None,
+ input_channel=3):
self.init_params = locals()
super(DeepLabv3p, self).__init__('segmenter')
# dice_loss或bce_loss只适用两类分割中
@@ -115,7 +117,6 @@ class DeepLabv3p(BaseAPI):
self.backbone = backbone
self.num_classes = num_classes
- self.input_channel = input_channel
self.use_bce_loss = use_bce_loss
self.use_dice_loss = use_dice_loss
self.class_weight = class_weight
@@ -151,6 +152,7 @@ class DeepLabv3p(BaseAPI):
if self.output_is_logits:
self.conv_filters = self.num_classes
self.backbone_lr_mult_list = [0.15, 0.35, 0.65, 0.85, 1]
+ self.input_channel = input_channel
def _get_backbone(self, backbone):
def mobilenetv2(backbone):
@@ -217,7 +219,6 @@ class DeepLabv3p(BaseAPI):
def build_net(self, mode='train'):
model = paddlex.cv.nets.segmentation.DeepLabv3p(
self.num_classes,
- input_channel=self.input_channel,
mode=mode,
backbone=self._get_backbone(self.backbone),
output_stride=self.output_stride,
@@ -239,7 +240,8 @@ class DeepLabv3p(BaseAPI):
add_image_level_feature=self.add_image_level_feature,
use_sum_merge=self.use_sum_merge,
conv_filters=self.conv_filters,
- output_is_logits=self.output_is_logits)
+ output_is_logits=self.output_is_logits,
+ input_channel=self.input_channel)
inputs = model.generate_inputs()
model_out = model.build_net(inputs)
outputs = OrderedDict()
diff --git a/paddlex/cv/models/fast_scnn.py b/paddlex/cv/models/fast_scnn.py
index 21003bf..c7ef53d 100644
--- a/paddlex/cv/models/fast_scnn.py
+++ b/paddlex/cv/models/fast_scnn.py
@@ -36,6 +36,8 @@ class FastSCNN(DeepLabv3p):
也支持计算两个分支或三个分支上的loss,权重按[fusion_branch_weight, higher_branch_weight, lower_branch_weight]排列,
fusion_branch_weight为空间细节分支和全局上下文分支融合后的分支上的loss权重,higher_branch_weight为空间细节分支上的loss权重,
lower_branch_weight为全局上下文分支上的loss权重,若higher_branch_weight和lower_branch_weight未设置则不会计算这两个分支上的loss。
+ input_channel (int): 输入图像通道数。默认值3。
+
Raises:
ValueError: use_bce_loss或use_dice_loss为真且num_calsses > 2。
@@ -48,12 +50,12 @@ class FastSCNN(DeepLabv3p):
def __init__(self,
num_classes=2,
- input_channel=3,
use_bce_loss=False,
use_dice_loss=False,
class_weight=None,
ignore_index=255,
- multi_loss_weight=[1.0]):
+ multi_loss_weight=[1.0],
+ input_channel=3):
self.init_params = locals()
super(DeepLabv3p, self).__init__('segmenter')
# dice_loss或bce_loss只适用两类分割中
@@ -87,7 +89,6 @@ class FastSCNN(DeepLabv3p):
)
self.num_classes = num_classes
- self.input_channel = input_channel
self.use_bce_loss = use_bce_loss
self.use_dice_loss = use_dice_loss
self.class_weight = class_weight
@@ -95,18 +96,19 @@ class FastSCNN(DeepLabv3p):
self.ignore_index = ignore_index
self.labels = None
self.fixed_input_shape = None
+ self.input_channel = input_channel
def build_net(self, mode='train'):
model = paddlex.cv.nets.segmentation.FastSCNN(
self.num_classes,
- input_channel=self.input_channel,
mode=mode,
use_bce_loss=self.use_bce_loss,
use_dice_loss=self.use_dice_loss,
class_weight=self.class_weight,
ignore_index=self.ignore_index,
multi_loss_weight=self.multi_loss_weight,
- fixed_input_shape=self.fixed_input_shape)
+ fixed_input_shape=self.fixed_input_shape,
+ input_channel=self.input_channel)
inputs = model.generate_inputs()
model_out = model.build_net(inputs)
outputs = OrderedDict()
diff --git a/paddlex/cv/models/hrnet.py b/paddlex/cv/models/hrnet.py
index cc4154a..80d8b5d 100644
--- a/paddlex/cv/models/hrnet.py
+++ b/paddlex/cv/models/hrnet.py
@@ -34,6 +34,7 @@ class HRNet(DeepLabv3p):
自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None是,各类的权重1,
即平时使用的交叉熵损失函数。
ignore_index (int): label上忽略的值,label为ignore_index的像素不参与损失函数的计算。默认255。
+ input_channel (int): 输入图像通道数。默认值3。
Raises:
ValueError: use_bce_loss或use_dice_loss为真且num_calsses > 2。
@@ -44,12 +45,12 @@ class HRNet(DeepLabv3p):
def __init__(self,
num_classes=2,
- input_channel=3,
width=18,
use_bce_loss=False,
use_dice_loss=False,
class_weight=None,
- ignore_index=255):
+ ignore_index=255,
+ input_channel=3):
self.init_params = locals()
super(DeepLabv3p, self).__init__('segmenter')
# dice_loss或bce_loss只适用两类分割中
@@ -73,7 +74,6 @@ class HRNet(DeepLabv3p):
'Expect class_weight is a list or string but receive {}'.
format(type(class_weight)))
self.num_classes = num_classes
- self.input_channel = input_channel
self.width = width
self.use_bce_loss = use_bce_loss
self.use_dice_loss = use_dice_loss
@@ -81,18 +81,19 @@ class HRNet(DeepLabv3p):
self.ignore_index = ignore_index
self.labels = None
self.fixed_input_shape = None
+ self.input_channel = input_channel
def build_net(self, mode='train'):
model = paddlex.cv.nets.segmentation.HRNet(
self.num_classes,
- input_channel=self.input_channel,
width=self.width,
mode=mode,
use_bce_loss=self.use_bce_loss,
use_dice_loss=self.use_dice_loss,
class_weight=self.class_weight,
ignore_index=self.ignore_index,
- fixed_input_shape=self.fixed_input_shape)
+ fixed_input_shape=self.fixed_input_shape,
+ input_channel=self.input_channel)
inputs = model.generate_inputs()
model_out = model.build_net(inputs)
outputs = OrderedDict()
diff --git a/paddlex/cv/models/unet.py b/paddlex/cv/models/unet.py
index c07879e..f84f048 100644
--- a/paddlex/cv/models/unet.py
+++ b/paddlex/cv/models/unet.py
@@ -33,6 +33,7 @@ class UNet(DeepLabv3p):
自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None是,各类的权重1,
即平时使用的交叉熵损失函数。
ignore_index (int): label上忽略的值,label为ignore_index的像素不参与损失函数的计算。默认255。
+ input_channel (int): 输入图像通道数。默认值3。
Raises:
ValueError: use_bce_loss或use_dice_loss为真且num_calsses > 2。
@@ -43,12 +44,12 @@ class UNet(DeepLabv3p):
def __init__(self,
num_classes=2,
- input_channel=3,
upsample_mode='bilinear',
use_bce_loss=False,
use_dice_loss=False,
class_weight=None,
- ignore_index=255):
+ ignore_index=255,
+ input_channel=3):
self.init_params = locals()
super(DeepLabv3p, self).__init__('segmenter')
# dice_loss或bce_loss只适用两类分割中
@@ -72,7 +73,6 @@ class UNet(DeepLabv3p):
'Expect class_weight is a list or string but receive {}'.
format(type(class_weight)))
self.num_classes = num_classes
- self.input_channel = input_channel
self.upsample_mode = upsample_mode
self.use_bce_loss = use_bce_loss
self.use_dice_loss = use_dice_loss
@@ -80,18 +80,19 @@ class UNet(DeepLabv3p):
self.ignore_index = ignore_index
self.labels = None
self.fixed_input_shape = None
+ self.input_channel = input_channel
def build_net(self, mode='train'):
model = paddlex.cv.nets.segmentation.UNet(
self.num_classes,
- input_channel=self.input_channel,
mode=mode,
upsample_mode=self.upsample_mode,
use_bce_loss=self.use_bce_loss,
use_dice_loss=self.use_dice_loss,
class_weight=self.class_weight,
ignore_index=self.ignore_index,
- fixed_input_shape=self.fixed_input_shape)
+ fixed_input_shape=self.fixed_input_shape,
+ input_channel=self.input_channel)
inputs = model.generate_inputs()
model_out = model.build_net(inputs)
outputs = OrderedDict()
diff --git a/paddlex/cv/transforms/seg_transforms.py b/paddlex/cv/transforms/seg_transforms.py
index 21fd5bd..c482930 100644
--- a/paddlex/cv/transforms/seg_transforms.py
+++ b/paddlex/cv/transforms/seg_transforms.py
@@ -21,6 +21,8 @@ import numpy as np
from PIL import Image
import cv2
import imghdr
+import six
+import sys
from collections import OrderedDict
import paddlex.utils.logging as logging
@@ -67,14 +69,15 @@ class Compose(SegTransform):
img_format = imghdr.what(img_path)
name, ext = osp.splitext(img_path)
if img_format == 'tiff' or ext == '.img':
- import gdal
- gdal.UseExceptions()
- gdal.PushErrorHandler('CPLQuietErrorHandler')
-
try:
- dataset = gdal.Open(img_path)
+ import gdal
except:
- logging.error(gdal.GetLastErrorMsg())
+ six.reraise(*sys.exc_info())
+ raise Exception(
+ "Please refer to https://github.com/PaddlePaddle/PaddleX/tree/develop/examples/multi-channel_remote_sensing/README.md to install gdal"
+ )
+
+ dataset = gdal.Open(img_path)
if dataset == None:
raise Exception('Can not open', img_path)
im_data = dataset.ReadAsArray()
--
GitLab