未验证 提交 8dc06771 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

Update the U2Net and U2Netp

上级 13c0ceaf
......@@ -6,7 +6,8 @@
## 效果展示
![](https://ai-studio-static-online.cdn.bcebos.com/4d77bc3a05cf48bba6f67b797978f4cdf10f38288b9645d59393dd85cef58eff)
![](https://ai-studio-static-online.cdn.bcebos.com/d7839c7207024747b32e42e49f7881bd2554d8408ab44e669fb340b50d4e38de)
![](https://ai-studio-static-online.cdn.bcebos.com/865b7b6a262b4ce3bbba4a5c0d973d9eea428bc3e8af4f76a1cdab0c04e3dd33)
![](https://ai-studio-static-online.cdn.bcebos.com/11c9eba8de6d4316b672f10b285245061821f0a744e441f3b80c223881256ca0)
## API
```python
......
......@@ -33,13 +33,6 @@ class Processor():
def preprocess(self, imgs, batch_size=1, input_size=320):
input_datas = []
for image in imgs:
# h, w = image.shape[:2]
# if h > w:
# new_h, new_w = input_size*h/w,input_size
# else:
# new_h, new_w = input_size,input_size*w/h
# image = cv2.resize(image, (int(new_w), int(new_h)))
image = cv2.resize(image, (input_size, input_size))
tmpImg = np.zeros((image.shape[0],image.shape[1],3))
image = image/np.max(image)
......@@ -78,19 +71,25 @@ class Processor():
for i, image in enumerate(self.imgs):
# normalization
pred = 1.0 - outputs[i,0,:,:]
pred = outputs[i,0,:,:]
pred = self.normPRED(pred)
# convert torch tensor to numpy array
pred = pred.squeeze()
pred = (pred*255).astype(np.uint8)
h, w = image.shape[:2]
pred = cv2.resize(pred, (w, h))
mask = cv2.resize(pred, (w, h))
results.append(pred)
output_img = (image*mask[..., np.newaxis] + (1-mask[..., np.newaxis])*255).astype(np.uint8)
mask = (mask*255).astype(np.uint8)
if visualization:
cv2.imwrite(os.path.join(output_dir, 'result_%d.png' % i), pred)
cv2.imwrite(os.path.join(output_dir, 'result_mask_%d.png' % i), mask)
cv2.imwrite(os.path.join(output_dir, 'result_%d.png' % i), output_img)
results.append({
'mask': mask,
'front': output_img
})
return results
\ No newline at end of file
......@@ -7,7 +7,8 @@
## 效果展示
![](https://ai-studio-static-online.cdn.bcebos.com/4d77bc3a05cf48bba6f67b797978f4cdf10f38288b9645d59393dd85cef58eff)
![](https://ai-studio-static-online.cdn.bcebos.com/d7839c7207024747b32e42e49f7881bd2554d8408ab44e669fb340b50d4e38de)
![](https://ai-studio-static-online.cdn.bcebos.com/865b7b6a262b4ce3bbba4a5c0d973d9eea428bc3e8af4f76a1cdab0c04e3dd33)
![](https://ai-studio-static-online.cdn.bcebos.com/11c9eba8de6d4316b672f10b285245061821f0a744e441f3b80c223881256ca0)
## API
```python
......
......@@ -33,13 +33,6 @@ class Processor():
def preprocess(self, imgs, batch_size=1, input_size=320):
input_datas = []
for image in imgs:
# h, w = image.shape[:2]
# if h > w:
# new_h, new_w = input_size*h/w,input_size
# else:
# new_h, new_w = input_size,input_size*w/h
# image = cv2.resize(image, (int(new_w), int(new_h)))
image = cv2.resize(image, (input_size, input_size))
tmpImg = np.zeros((image.shape[0],image.shape[1],3))
image = image/np.max(image)
......@@ -78,19 +71,25 @@ class Processor():
for i, image in enumerate(self.imgs):
# normalization
pred = 1.0 - outputs[i,0,:,:]
pred = outputs[i,0,:,:]
pred = self.normPRED(pred)
# convert torch tensor to numpy array
pred = pred.squeeze()
pred = (pred*255).astype(np.uint8)
h, w = image.shape[:2]
pred = cv2.resize(pred, (w, h))
results.append(pred)
mask = cv2.resize(pred, (w, h))
output_img = (image*mask[..., np.newaxis] + (1-mask[..., np.newaxis])*255).astype(np.uint8)
mask = (mask*255).astype(np.uint8)
if visualization:
cv2.imwrite(os.path.join(output_dir, 'result_%d.png' % i), pred)
cv2.imwrite(os.path.join(output_dir, 'result_mask_%d.png' % i), mask)
cv2.imwrite(os.path.join(output_dir, 'result_%d.png' % i), output_img)
results.append({
'mask': mask,
'front': output_img
})
return results
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册