diff --git a/x2paddle/op_mapper/onnx_op_mapper.py b/x2paddle/op_mapper/onnx_op_mapper.py index 4bdfc660e840e7cb963a3dcf7b77139eda5fd5af..3b8ff04c6b31a27639e65ac717f9a8b60479ced1 100644 --- a/x2paddle/op_mapper/onnx_op_mapper.py +++ b/x2paddle/op_mapper/onnx_op_mapper.py @@ -121,7 +121,13 @@ class ONNXOpMapper(OpMapper): for data_node in data_nodes: value_info = value_infos[data_node] - ipt = np.random.random(value_info['shape']).astype( + shape = value_info['shape'] + for i, dim_shape in enumerate(shape): + if dim_shape==0 and i==0: + shape[i]=1 + if dim_shape==0 and i!=0: + assert 'shape of input is not assigned' + ipt = np.random.random(shape).astype( value_info['dtype']) np.save(os.path.join(self.tmp_data_dir, data_node), ipt)