未验证 提交 23fa4be9 编写于 作者: J Jason 提交者: GitHub

Merge pull request #903 from wjj19950828/Add_onnx_tests

Fixed ToPILImage && rm SymbolicShapeInference
...@@ -18,7 +18,8 @@ from paddle.vision.transforms import functional as F ...@@ -18,7 +18,8 @@ from paddle.vision.transforms import functional as F
class ToPILImage(BaseTransform): class ToPILImage(BaseTransform):
def __init__(self, mode=None, keys=None): def __init__(self, mode=None, keys=None):
super(ToTensor, self).__init__(keys) super(ToPILImage, self).__init__(keys)
self.mode = mode
def _apply_image(self, pic): def _apply_image(self, pic):
""" """
...@@ -53,7 +54,7 @@ class ToPILImage(BaseTransform): ...@@ -53,7 +54,7 @@ class ToPILImage(BaseTransform):
npimg = pic npimg = pic
if isinstance(pic, paddle.Tensor) and "float" in str(pic.numpy( if isinstance(pic, paddle.Tensor) and "float" in str(pic.numpy(
).dtype) and mode != 'F': ).dtype) and self.mode != 'F':
pic = pic.mul(255).byte() pic = pic.mul(255).byte()
if isinstance(pic, paddle.Tensor): if isinstance(pic, paddle.Tensor):
npimg = np.transpose(pic.numpy(), (1, 2, 0)) npimg = np.transpose(pic.numpy(), (1, 2, 0))
...@@ -74,40 +75,40 @@ class ToPILImage(BaseTransform): ...@@ -74,40 +75,40 @@ class ToPILImage(BaseTransform):
expected_mode = 'I' expected_mode = 'I'
elif npimg.dtype == np.float32: elif npimg.dtype == np.float32:
expected_mode = 'F' expected_mode = 'F'
if mode is not None and mode != expected_mode: if self.mode is not None and self.mode != expected_mode:
raise ValueError( raise ValueError(
"Incorrect mode ({}) supplied for input type {}. Should be {}" "Incorrect mode ({}) supplied for input type {}. Should be {}"
.format(mode, np.dtype, expected_mode)) .format(self.mode, np.dtype, expected_mode))
mode = expected_mode self.mode = expected_mode
elif npimg.shape[2] == 2: elif npimg.shape[2] == 2:
permitted_2_channel_modes = ['LA'] permitted_2_channel_modes = ['LA']
if mode is not None and mode not in permitted_2_channel_modes: if self.mode is not None and self.mode not in permitted_2_channel_modes:
raise ValueError("Only modes {} are supported for 2D inputs". raise ValueError("Only modes {} are supported for 2D inputs".
format(permitted_2_channel_modes)) format(permitted_2_channel_modes))
if mode is None and npimg.dtype == np.uint8: if self.mode is None and npimg.dtype == np.uint8:
mode = 'LA' self.mode = 'LA'
elif npimg.shape[2] == 4: elif npimg.shape[2] == 4:
permitted_4_channel_modes = ['RGBA', 'CMYK', 'RGBX'] permitted_4_channel_modes = ['RGBA', 'CMYK', 'RGBX']
if mode is not None and mode not in permitted_4_channel_modes: if self.mode is not None and self.mode not in permitted_4_channel_modes:
raise ValueError("Only modes {} are supported for 4D inputs". raise ValueError("Only modes {} are supported for 4D inputs".
format(permitted_4_channel_modes)) format(permitted_4_channel_modes))
if mode is None and npimg.dtype == np.uint8: if self.mode is None and npimg.dtype == np.uint8:
mode = 'RGBA' self.mode = 'RGBA'
else: else:
permitted_3_channel_modes = ['RGB', 'YCbCr', 'HSV'] permitted_3_channel_modes = ['RGB', 'YCbCr', 'HSV']
if mode is not None and mode not in permitted_3_channel_modes: if self.mode is not None and self.mode not in permitted_3_channel_modes:
raise ValueError("Only modes {} are supported for 3D inputs". raise ValueError("Only modes {} are supported for 3D inputs".
format(permitted_3_channel_modes)) format(permitted_3_channel_modes))
if mode is None and npimg.dtype == np.uint8: if self.mode is None and npimg.dtype == np.uint8:
mode = 'RGB' self.mode = 'RGB'
if mode is None: if self.mode is None:
raise TypeError('Input type {} is not supported'.format( raise TypeError('Input type {} is not supported'.format(
npimg.dtype)) npimg.dtype))
return Image.fromarray(npimg, mode=mode) return Image.fromarray(npimg, mode=self.mode)
``` ```
...@@ -184,14 +184,15 @@ class ONNXGraph(Graph): ...@@ -184,14 +184,15 @@ class ONNXGraph(Graph):
self.value_infos = {} self.value_infos = {}
self.graph = onnx_model.graph self.graph = onnx_model.graph
self.get_place_holder_nodes() self.get_place_holder_nodes()
print("shape inferencing ...") print("Shape inferencing ...")
self.graph = SymbolicShapeInference.infer_shapes( try:
onnx_model, fixed_input_shape=self.fixed_input_shape) self.graph = SymbolicShapeInference.infer_shapes(
if self.graph is None: onnx_model, fixed_input_shape=self.fixed_input_shape)
except:
print('[WARNING] Shape inference by ONNX offical interface.') print('[WARNING] Shape inference by ONNX offical interface.')
onnx_model = shape_inference.infer_shapes(onnx_model) onnx_model = shape_inference.infer_shapes(onnx_model)
self.graph = onnx_model.graph self.graph = onnx_model.graph
print("shape inferenced.") print("Shape inferenced.")
self.build() self.build()
self.collect_value_infos() self.collect_value_infos()
self.allocate_shapes() self.allocate_shapes()
......
...@@ -265,7 +265,7 @@ class SymbolicShapeInference: ...@@ -265,7 +265,7 @@ class SymbolicShapeInference:
if pending_nodes and self.verbose_ > 0: if pending_nodes and self.verbose_ > 0:
print('SymbolicShapeInference: orphaned nodes discarded: ') print('SymbolicShapeInference: orphaned nodes discarded: ')
print( print(
*[n.op_type + ': ' + n.output[0] for n in pending_nodes], * [n.op_type + ': ' + n.output[0] for n in pending_nodes],
sep='\n') sep='\n')
if input_shapes is not None: if input_shapes is not None:
...@@ -1588,7 +1588,9 @@ class SymbolicShapeInference: ...@@ -1588,7 +1588,9 @@ class SymbolicShapeInference:
assert version.parse(onnx.__version__) >= version.parse("1.5.0") assert version.parse(onnx.__version__) >= version.parse("1.5.0")
onnx_opset = get_opset(in_mp) onnx_opset = get_opset(in_mp)
if not onnx_opset or onnx_opset < 7: if not onnx_opset or onnx_opset < 7:
print('[WARNING] Symbolic shape inference only support models of onnx opset 7 and above.') print(
'[WARNING] Symbolic shape inference only support models of onnx opset 7 and above.'
)
return return
symbolic_shape_inference = SymbolicShapeInference( symbolic_shape_inference = SymbolicShapeInference(
int_max, auto_merge, guess_output_rank, verbose) int_max, auto_merge, guess_output_rank, verbose)
...@@ -1608,4 +1610,4 @@ class SymbolicShapeInference: ...@@ -1608,4 +1610,4 @@ class SymbolicShapeInference:
print('[WARNING] Incomplete symbolic shape inference') print('[WARNING] Incomplete symbolic shape inference')
symbolic_shape_inference.out_mp_ = shape_inference.infer_shapes( symbolic_shape_inference.out_mp_ = shape_inference.infer_shapes(
symbolic_shape_inference.out_mp_) symbolic_shape_inference.out_mp_)
return symbolic_shape_inference.out_mp_.graph return symbolic_shape_inference.out_mp_.graph
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册