我想实现如下代码 使用静态图 找不到api
Created by: lxk767363331
我想实现如下代码 使用静态图 找不到api import transforms.transforms as extended_transforms from PIL import Image import torch from torchvision import transforms import numpy as np from scipy.ndimage.morphology import distance_transform_edt path="Blowhole0.png" spath="Blowhole_edge.png" def mask_to_onehot(mask, num_classes): _mask = [mask == (i + 1) for i in range(num_classes)] return np.array(_mask).astype(np.uint8)
def onehot_to_binary_edges(mask, radius, num_classes): if radius < 0: return mask # We need to pad the borders for boundary conditions mask_pad = np.pad(mask, ((0, 0), (1, 1), (1, 1)), mode='constant', constant_values=0) edgemap = np.zeros(mask.shape[1:]) for i in range(num_classes): dist = distance_transform_edt(mask_pad[i, :]) + distance_transform_edt(1.0 - mask_pad[i, :]) dist = dist[1:-1, 1:-1] dist[dist > radius] = 0 edgemap += dist edgemap = np.expand_dims(edgemap, axis=0) edgemap = (edgemap > 0).astype(np.uint8) return edgemap
mask = Image.open(path) mask = torch.from_numpy(np.array(mask, dtype=np.int32)).long() _mask = mask.numpy() _edgemap = mask_to_onehot(_mask, 3) _edgemap = onehot_to_binary_edges(_edgemap, 1, 3) edgemap = torch.from_numpy(_edgemap).float() unloader = transforms.ToPILImage() image = edgemap.cpu().clone() # clone the tensor image = image.squeeze(0) # remove the fake batch dimension image = unloader(image) image.save(spath)