未验证 提交 23fa4be9 编写于 作者: J Jason 提交者: GitHub

Merge pull request #903 from wjj19950828/Add_onnx_tests

Fixed ToPILImage && rm SymbolicShapeInference
......@@ -18,7 +18,8 @@ from paddle.vision.transforms import functional as F
class ToPILImage(BaseTransform):
def __init__(self, mode=None, keys=None):
super(ToTensor, self).__init__(keys)
super(ToPILImage, self).__init__(keys)
self.mode = mode
def _apply_image(self, pic):
"""
......@@ -53,7 +54,7 @@ class ToPILImage(BaseTransform):
npimg = pic
if isinstance(pic, paddle.Tensor) and "float" in str(pic.numpy(
).dtype) and mode != 'F':
).dtype) and self.mode != 'F':
pic = pic.mul(255).byte()
if isinstance(pic, paddle.Tensor):
npimg = np.transpose(pic.numpy(), (1, 2, 0))
......@@ -74,40 +75,40 @@ class ToPILImage(BaseTransform):
expected_mode = 'I'
elif npimg.dtype == np.float32:
expected_mode = 'F'
if mode is not None and mode != expected_mode:
if self.mode is not None and self.mode != expected_mode:
raise ValueError(
"Incorrect mode ({}) supplied for input type {}. Should be {}"
.format(mode, np.dtype, expected_mode))
mode = expected_mode
.format(self.mode, np.dtype, expected_mode))
self.mode = expected_mode
elif npimg.shape[2] == 2:
permitted_2_channel_modes = ['LA']
if mode is not None and mode not in permitted_2_channel_modes:
if self.mode is not None and self.mode not in permitted_2_channel_modes:
raise ValueError("Only modes {} are supported for 2D inputs".
format(permitted_2_channel_modes))
if mode is None and npimg.dtype == np.uint8:
mode = 'LA'
if self.mode is None and npimg.dtype == np.uint8:
self.mode = 'LA'
elif npimg.shape[2] == 4:
permitted_4_channel_modes = ['RGBA', 'CMYK', 'RGBX']
if mode is not None and mode not in permitted_4_channel_modes:
if self.mode is not None and self.mode not in permitted_4_channel_modes:
raise ValueError("Only modes {} are supported for 4D inputs".
format(permitted_4_channel_modes))
if mode is None and npimg.dtype == np.uint8:
mode = 'RGBA'
if self.mode is None and npimg.dtype == np.uint8:
self.mode = 'RGBA'
else:
permitted_3_channel_modes = ['RGB', 'YCbCr', 'HSV']
if mode is not None and mode not in permitted_3_channel_modes:
if self.mode is not None and self.mode not in permitted_3_channel_modes:
raise ValueError("Only modes {} are supported for 3D inputs".
format(permitted_3_channel_modes))
if mode is None and npimg.dtype == np.uint8:
mode = 'RGB'
if self.mode is None and npimg.dtype == np.uint8:
self.mode = 'RGB'
if mode is None:
if self.mode is None:
raise TypeError('Input type {} is not supported'.format(
npimg.dtype))
return Image.fromarray(npimg, mode=mode)
return Image.fromarray(npimg, mode=self.mode)
```
......@@ -184,14 +184,15 @@ class ONNXGraph(Graph):
self.value_infos = {}
self.graph = onnx_model.graph
self.get_place_holder_nodes()
print("shape inferencing ...")
self.graph = SymbolicShapeInference.infer_shapes(
onnx_model, fixed_input_shape=self.fixed_input_shape)
if self.graph is None:
print("Shape inferencing ...")
try:
self.graph = SymbolicShapeInference.infer_shapes(
onnx_model, fixed_input_shape=self.fixed_input_shape)
except:
print('[WARNING] Shape inference by ONNX offical interface.')
onnx_model = shape_inference.infer_shapes(onnx_model)
self.graph = onnx_model.graph
print("shape inferenced.")
print("Shape inferenced.")
self.build()
self.collect_value_infos()
self.allocate_shapes()
......
......@@ -265,7 +265,7 @@ class SymbolicShapeInference:
if pending_nodes and self.verbose_ > 0:
print('SymbolicShapeInference: orphaned nodes discarded: ')
print(
*[n.op_type + ': ' + n.output[0] for n in pending_nodes],
* [n.op_type + ': ' + n.output[0] for n in pending_nodes],
sep='\n')
if input_shapes is not None:
......@@ -1588,7 +1588,9 @@ class SymbolicShapeInference:
assert version.parse(onnx.__version__) >= version.parse("1.5.0")
onnx_opset = get_opset(in_mp)
if not onnx_opset or onnx_opset < 7:
print('[WARNING] Symbolic shape inference only support models of onnx opset 7 and above.')
print(
'[WARNING] Symbolic shape inference only support models of onnx opset 7 and above.'
)
return
symbolic_shape_inference = SymbolicShapeInference(
int_max, auto_merge, guess_output_rank, verbose)
......@@ -1608,4 +1610,4 @@ class SymbolicShapeInference:
print('[WARNING] Incomplete symbolic shape inference')
symbolic_shape_inference.out_mp_ = shape_inference.infer_shapes(
symbolic_shape_inference.out_mp_)
return symbolic_shape_inference.out_mp_.graph
\ No newline at end of file
return symbolic_shape_inference.out_mp_.graph
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册