未验证 提交 84581b17 编写于 作者: C channings 提交者: GitHub

Merge pull request #5 from PaddlePaddle/develop

update
......@@ -5,7 +5,7 @@ X2Paddle支持将其余深度学习框架训练得到的模型,转换至Paddle
X2Paddle is a toolkit for converting trained model to PaddlePaddle from other deep learning frameworks.
## 转换模型库
X2Paddle在多个主流的CV模型上,测试过TensorFlow/Caffe/ONNX模型的转换,可以在[X2Paddle-Model-Zoo](x2paddle_model_zoo.md)查看我们的模型测试列表。如果你在新的模型上进行了测试转换,也欢迎继续补充该列表;如若无法转换,可通过ISSUE反馈给我们,我们会尽快跟进。
X2Paddle在多个主流的CV模型上,测试过TensorFlow/Caffe/ONNX模型的转换,可以在[X2Paddle-Model-Zoo](x2paddle_model_zoo.md)查看我们的模型测试列表,可以在[OP-LIST](op_list.md)中查看目前X2Paddle支持的OP列表。如果你在新的模型上进行了测试转换,也欢迎继续补充该列表;如若无法转换,可通过ISSUE反馈给我们,我们会尽快跟进。
## 环境依赖
......@@ -19,18 +19,6 @@ onnx : onnx == 1.5.0 onnxruntime == 0.4.0
## 安装
### 安装方式一(推荐)
使用最新的代码版本,可使用如下方式进行安装
```
pip install git+https://github.com/PaddlePaddle/X2Paddle.git@develop
```
### 安装方式二
我们会定期更新pip源上的x2paddle版本
```
pip install x2paddle
```
### 安装方式三
```
git clone https://github.com/PaddlePaddle/X2Paddle.git
cd X2Paddle
......@@ -38,6 +26,11 @@ git checkout develop
python setup.py install
```
### 安装方式二
我们会定期更新pip源上的x2paddle版本
```
pip install x2paddle --index https://pypi.Python.org/simple/
```
## 使用方法
### TensorFlow
```
......@@ -45,7 +38,7 @@ x2paddle --framework=tensorflow --model=tf_model.pb --save_dir=pd_model
```
### Caffe
```
x2paddle --framework=caffe --prototxt=deploy.proto --weight=deploy.caffemodel --save_dir=pd_model
x2paddle --framework=caffe --prototxt=deploy.prototxt --weight=deploy.caffemodel --save_dir=pd_model
```
### ONNX
```
......
# X2Paddle支持OP列表
> 目前X2Paddle支持40+的TensorFlow OP,30+的Caffe Layer,覆盖了大部分CV分类模型常用的操作。我们在如下列表中给出了目前X2Paddle支持的全部OP。
**注:** 目前,部分OP暂未支持,如您在转换过程中出现OP不支持的情况,可自行添加或反馈给我们。欢迎通过[ISSUE反馈](https://github.com/PaddlePaddle/X2Paddle/issues/new)的方式告知我们(模型名,代码实现或模型获取方式),我们会及时跟进:)
## TensorFlow
| 序号 | OP | 序号 | OP | 序号 | OP | 序号 | OP |
|------|------|------|------|------|------|------|------|
| 1 | Relu | 2 | Relu6 | 3 | Shape | 4 | Abs |
| 5 | Sigmoid | 6 | Exp | 7 | Rsqrt | 8 | swish_f32 |
| 9 | Tanh | 10 | LeakyRelu | 11 | Add | 12 | RealDiv |
| 13 | Sub | 14 | Maximum | 15 | Mul | 16 | FloorDiv |
| 17 | Placeholder | 18 | Const | 19 | Transpose | 20 | FusedBatchNorm |
| 21 | Conv2D | 22 | BiasAdd | 23 | MaxPool | 24 | DepthwiseConv2dNative |
| 25 | Reshape | 26 | AvgPool | 27 | SplitV | 28 | SquaredDifference |
| 29 | Tile | 30 | Pack | 31 | Pad | 32 | ResizeBilinear |
| 33 | Mean | 34 | MatMul | 35 | ArgMax | 36 | StridedSlice |
| 37 | Slice | 38 | Sum | 39 | Max | 40 | Conv2DBackpropInput |
| 41 | Cast | 42 | Split | 43 | Squeeze | 44 | ResizeNearestNeighbor |
| 45 | Softmax | 46 | Range | 47 | ConcatV2 |
## Caffe
| 序号 | OP | 序号 | OP | 序号 | OP | 序号 | OP |
|------|------|------|------|------|------|------|------|
| 1 | Input | 2 | Convolution | 3 | Deconvolution | 4 | Pooling |
| 5 | LRN | 6 | InnerProduct | 7 | Softmax | 8 | Slice |
| 9 | Concat | 10 | PReLU | 11 | Accuracy | 12 | Eltwise |
| 13 | BatchNorm | 14 | Scale | 15 | Reshape | 16 | ArgMax |
| 17 | Crop | 18 | Flatten | 19 | Power | 20 | Reduction |
| 21 | Axpy | 22 | ROIPolling | 23 | Permute | 24 | DetectionOutput |
| 25 | Normalize | 26 | Select | 27 | ShuffleChannel | 28 | ConvolutionDepthwise |
| 29 | ReLU | 30 | AbsVal | 31 | Sigmoid | 32 | TanH |
## ONNX
| 序号 | OP | 序号 | OP | 序号 | OP | 序号 | OP |
|------|------|------|------|------|------|------|------|
| 1 | Relu | 2 | LeakyRelu | 3 | Elu | 4 | ThresholdedRelu |
| 5 | Prelu | 6 | Tanh | 7 | Shrink | 8 | Sigmoid |
| 9 | Pow | 10 | Softplus | 11 | Softsign | 12 | HardSigmoid |
| 13 | Exp | 14 | Add | 15 | Div | 16 | Sub |
| 17 | Mul | 18 | Shape | 19 | Clip | 20 | AveragePool |
| 21 | Sqrt | 22 | ReduceSum | 23 | ReduceMin | 24 | ReduceMean |
| 25 | Constant | 26 | Pad | 27 | Unsqueeze | 28 | Resize |
| 29 | Upsample | 30 | Expand | 31 | Gather | 32 | Slice |
| 33 | Cast | 34 | Split | 35 | Reshape | 36 | ConstantOfShape |
| 37 | Ceil | 38 | Concat | 39 | Flatten | 40 | ConvTranspose |
| 41 | MatMul | 42 | Sum | 43 | Transpose | 44 | BatchNormalization |
| 45 | Squeeze | 46 | Equal | 47 | Identity | 48 | GlobalAveragePool |
| 49 | MaxPool | 50 | Conv | 51 | Gemm |
......@@ -104,10 +104,14 @@ def tf2paddle(model_path,
# neccesary optimization
optimizer.delete_redundance_code()
# optimizer below is experimental
optimizer.optimize_elementwise_op()
optimizer.merge_activation()
optimizer.merge_bias()
optimizer.merge_batch_norm()
optimizer.merge_prelu()
optimizer.optimize_sub_graph()
# optimizer.merge_batch_norm()
# optimizer.merge_prelu()
else:
mapper = TFOpMapperNHWC(model)
optimizer = TFOptimizer(mapper)
......@@ -125,9 +129,12 @@ def caffe2paddle(proto, weight, save_dir, caffe_proto):
from x2paddle.op_mapper.caffe_op_mapper import CaffeOpMapper
from x2paddle.optimizer.caffe_optimizer import CaffeOptimizer
import google.protobuf as gpb
ver_str = gpb.__version__.replace('.', '')
ver_int = int(ver_str[0:2])
assert ver_int >= 36, 'The version of protobuf must be larger than 3.6.0!'
ver_part = gpb.__version__.split('.')
version_satisfy = False
if (int(ver_part[0]) == 3 and int(ver_part[1]) >= 6) \
or (int(ver_part[0]) > 3):
version_satisfy = True
assert version_satisfy, 'google.protobuf >= 3.6.0 is required'
print("Now translating model from caffe to paddle.")
model = CaffeDecoder(proto, weight, caffe_proto)
mapper = CaffeOpMapper(model)
......
......@@ -60,6 +60,15 @@ class TFGraphNode(GraphNode):
raise Exception("Dtype[{}] not in dtype_map".format(dtype))
return self.dtype_map[dtype]
@property
def raw_dtype(self):
keys = ['dtype', 'Tidx', 'T', 'DstT']
for k in keys:
dtype = self.layer.attr[k].type
if dtype > 0:
break
return dtype
@property
def value(self):
assert self.layer_type == "Const", "Only Const node has value."
......@@ -120,6 +129,7 @@ class TFGraph(Graph):
# tensorflow graph optimize
self._remove_isolated_node()
self._remove_identity_node()
self._remove_cast_node()
def get_node(self, node_name, copy=False):
items = node_name.strip().split(':')
......@@ -190,6 +200,27 @@ class TFGraph(Graph):
idx = self.output_nodes.index(node_name)
self.output_nodes[idx] = input_node.layer_name
def _remove_cast_node(self):
cast_node = list()
for node_name, node in self.node_map.items():
if node.layer_type == "Cast":
input = self.get_node(node.inputs[0])
if input.layer_type != "Placeholder" or len(input.outputs) != 1:
continue
cast_node.append(node_name)
for node_name in cast_node:
node = self.get_node(node_name)
input_node = self.get_node(node.inputs[0])
input_node.layer.attr["dtype"].type = node.raw_dtype
self.remove_node(node_name)
self.identity_map[node_name] = input_node.layer_name
if node_name in self.output_nodes:
idx = self.output_nodes.index(node_name)
self.output_nodes[idx] = input_node.layer_name
def data_format_propagation(self, node):
current_node = self.node_map[node.layer_name]
current_node = node.tf_data_format
......@@ -281,7 +312,11 @@ class TFDecoder(object):
right_shape_been_input = False
while not right_shape_been_input:
shape = input("Shape of Input(e.g. None,224,224,3): ")
try:
shape = raw_input(
"Shape of Input(e.g. None,224,224,3): ")
except:
shape = input("Shape of Input(e.g. None,224,224,3): ")
if shape.count("None") > 1:
print("Only 1 dimension can be None, type again:)")
else:
......
......@@ -31,7 +31,6 @@ def register(kind, shape, layer, weights):
k) is str, 'invalid param "kind" for register, not a list of str'
assert k not in g_custom_layers, 'this type[%s] has already been registered' % (
k)
print('register layer[%s]' % (k))
g_custom_layers[k] = {
'shape': shape,
'layer': layer,
......
......@@ -8,13 +8,7 @@ def shufflechannel_shape(input_shape):
def shufflechannel_layer(inputs, group=None, input_shape=None, name=None):
input = inputs[0]
c_fm = fluid.layers.split(input, num_or_sections=input_shape[0][1], dim=1)
size = int(input_shape[0][1] / group)
new_c_fm = []
for i in range(size):
for j in range(group):
new_c_fm.append(c_fm[j * size + i])
out = fluid.layers.concat(new_c_fm, axis=1)
out = fluid.layers.shuffle_channel(x=input, group=group)
return out
......
......@@ -124,9 +124,6 @@ class CaffeOpMapper(OpMapper):
data[idx] = np.squeeze(d, axis=sq_axis)
shape_new = data[idx].shape
if len(shape_old) != shape_new:
print('squeeze idx:%d, with kind:%s,name:%s' % \
(idx, node.layer_type, node.layer.name))
return data
def get_kernel_parameters(self, kind, params):
......@@ -366,7 +363,7 @@ class CaffeOpMapper(OpMapper):
input = self.graph.get_bottom_node(node, idx=0, copy=True)
attr = {
'n': params.local_size,
'k': 1.0,
'k': params.k,
'alpha': alpha,
'beta': params.beta,
'name': string(node.layer_name)
......@@ -453,35 +450,19 @@ class CaffeOpMapper(OpMapper):
slice_dim = params.slice_dim
if slice_dim != 1 and axis == 1:
axis = slice_dim
points = list(params.slice_point)
if len(points) == 0:
dims = node.input_shape[0][axis]
assert dims % top_len == 0, "the parameter of Slice is wrong"
part = dims / top_len
t = part
while t < dims:
points.append(int(t))
t += part
maxint32 = 2147483647
points = [0] + points
points.append(maxint32)
i = 0
node.fluid_code.add_note('{} = []'.format(node.layer_name))
for i in range(len(points)):
attr = {
'axes': [axis],
'starts': [points[i]],
'ends': [points[i + 1]]
}
node.fluid_code.add_layer("slice",
inputs=input,
output=node.layer_name + '_' + str(i),
param_attr=attr)
node.fluid_code.add_note('{}.append({})'.format(
node.layer_name, node.layer_name + '_' + str(i)))
if i == len(points) - 2:
break
output_shape = node.output_shape
sections_list = []
for s in output_shape:
sections_list.append(s[axis])
attr = {
'num_or_sections': sections_list,
'dim': axis,
'name': string(node.layer_name)
}
node.fluid_code.add_layer("split",
inputs=input,
output=node.layer_name,
param_attr=attr)
def Concat(self, node):
assert len(
......@@ -652,7 +633,8 @@ class CaffeOpMapper(OpMapper):
]).astype('float32')
scale = 0
else:
node.data = [np.squeeze(i) for i in node.data]
node.data = [np.squeeze(i).astype('float32') for i in node.data]
mean, variance, scale = node.data
# Prescale the stats
scaling_factor = 1.0 / scale if scale != 0 else 0
......@@ -687,8 +669,10 @@ class CaffeOpMapper(OpMapper):
input_c,
]).astype('float32')
else:
self.weights[node.layer_name + '_scale'] = np.squeeze(node.data[0])
self.weights[node.layer_name + '_offset'] = np.squeeze(node.data[1])
self.weights[node.layer_name + '_scale'] = np.squeeze(
node.data[0]).astype('float32')
self.weights[node.layer_name + '_offset'] = np.squeeze(
node.data[1]).astype('float32')
params = node.layer.scale_param
axis = params.axis
num_axes = params.num_axes
......@@ -956,7 +940,6 @@ class CaffeOpMapper(OpMapper):
input = self.graph.get_bottom_node(node, idx=i, copy=True)
if i == 1 and op == 'DetectionOutput':
input = self.graph.get_bottom_node(node, idx=i, copy=True)
print(input.layer_type)
while input is not None and input.layer_type != 'Softmax':
input = self.graph.get_bottom_node(input, idx=0, copy=True)
assert input is not None, 'This kind of DetectionOutput is not supported!'
......
......@@ -44,7 +44,6 @@ def register(kind, shape, layer, child_func, weights):
k) is str, 'invalid param "kind" for register, not a list of str'
assert k not in g_custom_layers, 'this type[%s] has already been registered' % (
k)
print('register layer[%s]' % (k))
g_custom_layers[k] = {
'shape': shape,
'layer': layer,
......
......@@ -168,9 +168,34 @@ class TFOpMapper(OpMapper):
x_input = y
y_input = x
x_shape = y.out_shapes[0]
if len(x_shape) == 0:
x_shape = [1]
y_shape = x.out_shapes[0]
if len(y_shape) == 0:
y_shape = [1]
else:
raise Exception("Unexpected situation happend")
if len(x_shape) == 1 and len(y_shape) == 4 and x_shape[
0] == y_shape[-1] and y_shape.count(-1) < 1:
shape = [1, x_shape[0], 1, 1]
attr = {"shape": shape}
node.fluid_code.add_layer("reshape",
inputs=x_input,
output="reshape_x",
param_attr=attr)
if y_shape[0] != 1:
attr = {"expand_times": [y_shape[0], 1, 1, 1]}
node.fluid_code.add_layer("expand",
inputs="reshape_x",
output="reshape_x",
param_attr=attr)
inputs = {"x": "reshape_x", "y": y_input}
node.fluid_code.add_layer(op_type,
inputs=inputs,
output=node,
param_attr=None)
return
else:
raise Exception("Unexpected situation happend")
if len(x_shape) == 4 and len(y_shape) == 1:
if x_input.tf_data_format == "NHWC":
......@@ -985,7 +1010,7 @@ class TFOpMapper(OpMapper):
attr = {
"bias_attr": False,
"param_attr": string(kernel.layer_name),
"num_filters": k_size[3],
"num_filters": k_size[2],
"filter_size": k_size[0:2],
"stride": strides[2:4],
"dilation": dilations[2:4],
......@@ -1091,40 +1116,6 @@ class TFOpMapper(OpMapper):
output=node,
param_attr=attr)
def ResizeNearestNeighbor(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
resize_shape = self.graph.get_node(node.layer.input[1], copy=True)
self.add_omit_nodes(resize_shape.layer_name, node.layer_name)
if resize_shape.layer_type == "Const":
resize_shape = resize_shape.value.tolist()
else:
resize_shape = self.decoder.infer_shape_tensor(resize_shape)
align_corners = node.get_attr("align_corners")
attr = {"align_corners": align_corners, "out_shape": resize_shape}
node.fluid_code.add_layer("resize_nearest",
inputs=input,
output=node,
param_attr=attr)
def ResizeBilinear(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
resize_shape = self.graph.get_node(node.layer.input[1], copy=True)
self.add_omit_nodes(resize_shape.layer_name, node.layer_name)
if resize_shape.layer_type == "Const":
resize_shape = resize_shape.value.tolist()
else:
resize_shape = self.decoder.infer_shape_tensor(resize_shape)
align_corners = node.get_attr("align_corners")
attr = {
"align_corners": align_corners,
"out_shape": resize_shape,
"align_mode": 1
}
node.fluid_code.add_layer("resize_bilinear",
inputs=input,
output=node,
param_attr=attr)
def ResizeNearestNeighbor(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
resize_shape = self.graph.get_node(node.layer.input[1], copy=True)
......@@ -1170,37 +1161,6 @@ class TFOpMapper(OpMapper):
output=node,
param_attr=None)
def RandomUniform(self, node):
shape = self.graph.get_node(node.layer.input[0], copy=True)
self.add_omit_nodes(shape.layer_name, node.layer_name)
if shape.layer_type == "Const":
shape = shape.value.tolist()
else:
shape = self.decoder.infer_shape_tensor(shape)
if node.tf_data_format == "NHWC" and len(shape) == 4:
shape = [shape[i] for i in [0, 3, 1, 2]]
attr = {"shape": shape, "min": 0.0, "max": 0.9999}
if shape[0] < 0:
input = self.batch_node
node.fluid_code.add_layer("uniform_random_batch_size_like",
inputs=input,
output=node,
param_attr=attr)
else:
node.fluid_code.add_layer("uniform_random",
inputs=None,
output=node,
param_attr=attr)
def GreaterEqual(self, node):
x = self.graph.get_node(node.layer.input[0], copy=True)
y = self.graph.get_node(node.layer.input[1], copy=True)
inputs = {"x": x, "y": y}
node.fluid_code.add_layer("greater_equal",
inputs=inputs,
output=node,
param_attr=None)
def RandomUniform(self, node):
shape = self.graph.get_node(node.layer.input[0], copy=True)
self.add_omit_nodes(shape.layer_name, node.layer_name)
......
......@@ -41,6 +41,7 @@ class TFOpMapperNHWC(OpMapper):
'Exp': ['exp'],
'Rsqrt': ['rsqrt'],
'swish_f32': ['swish'],
'Tanh': ['tanh'],
'LeakyRelu': ['leaky_relu', {
'alpha': 'alpha'
}]
......@@ -120,10 +121,29 @@ class TFOpMapperNHWC(OpMapper):
pd_param_name = list(param.values())[0]
tf_param = node.get_attr(tf_param_name)
attr[pd_param_name] = tf_param
node.fluid_code.add_layer(op_info[0],
inputs=input,
output=node,
param_attr=attr)
if len(input.out_shapes[0]) == 4 and op_info[0] != 'shape':
attr1 = {"perm": [0, 3, 1, 2]}
node.fluid_code.add_layer('transpose',
inputs=input,
output=node,
param_attr=attr1)
input = node
node.fluid_code.add_layer(op_info[0],
inputs=input,
output=node,
param_attr=attr)
input = node
attr2 = {"perm": [0, 2, 3, 1]}
node.fluid_code.add_layer('transpose',
inputs=input,
output=node,
param_attr=attr2)
else:
node.fluid_code.add_layer(op_info[0],
inputs=input,
output=node,
param_attr=attr)
def elementwise_map(self, node):
assert node.layer_type in self.elementwise_ops
......@@ -148,7 +168,11 @@ class TFOpMapperNHWC(OpMapper):
x_input = y
y_input = x
x_shape = y.out_shapes[0]
if len(x_shape) == 0:
x_shape = [1]
y_shape = x.out_shapes[0]
if len(y_shape) == 0:
y_shape = [1]
else:
raise Exception("Unexpected situation happend")
......@@ -192,11 +216,30 @@ class TFOpMapperNHWC(OpMapper):
output="y_tmp",
param_attr=attr)
y_input = "y_tmp"
inputs = {"x": x_input, "y": y_input}
node.fluid_code.add_layer(op_type,
inputs=inputs,
output=node,
param_attr=None)
if len(x_shape) == 4 and len(y_shape) == 4:
node.fluid_code.add_layer("transpose",
inputs=x_input,
output=x_input,
param_attr={'perm': [0, 3, 1, 2]})
node.fluid_code.add_layer("transpose",
inputs=y_input,
output=y_input,
param_attr={'perm': [0, 3, 1, 2]})
inputs = {"x": x_input, "y": y_input}
node.fluid_code.add_layer(op_type,
inputs=inputs,
output=node,
param_attr=None)
node.fluid_code.add_layer("transpose",
inputs=node,
output=node,
param_attr={'perm': [0, 2, 3, 1]})
else:
inputs = {"x": x_input, "y": y_input}
node.fluid_code.add_layer(op_type,
inputs=inputs,
output=node,
param_attr=None)
def Placeholder(self, node):
shape = node.out_shapes[0]
......@@ -913,6 +956,12 @@ class TFOpMapperNHWC(OpMapper):
self.add_omit_nodes(kernel.layer_name, node.layer_name)
self.add_omit_nodes(out_shape.layer_name, node.layer_name)
if out_shape.layer_type == "Const":
out_shape = out_shape.value.tolist()
else:
out_shape = self.decoder.infer_shape_tensor(out_shape,
node.out_shapes[0])
in_shape = input.out_shapes[0]
if in_shape.count(-1) > 2:
in_shape = self.decoder.infer_tensor(input).shape
......@@ -920,7 +969,7 @@ class TFOpMapperNHWC(OpMapper):
if k_size.count(-1) > 2:
k_size = self.decoder.infer_tensor(kernel).shape
pad_mode = node.get_attr("padding")
pad_mode = node.get_attr("padding").decode()
strides = node.get_attr("strides")
dilations = node.get_attr("dilations")
data_format = node.get_attr("data_format").decode()
......@@ -958,7 +1007,7 @@ class TFOpMapperNHWC(OpMapper):
attr = {
"bias_attr": False,
"param_attr": string(kernel.layer_name),
"num_filters": k_size[3],
"num_filters": k_size[2],
"filter_size": k_size[0:2],
"stride": strides[2:4],
"dilation": dilations[2:4],
......@@ -969,6 +1018,22 @@ class TFOpMapperNHWC(OpMapper):
output=node,
param_attr=attr)
if pad_mode == "SAME":
if node.tf_data_format == "NHWC":
out_shape = [out_shape[i] for i in [0, 3, 1, 2]]
for i in range(4):
if out_shape[i] < 0:
out_shape[i] = 999999
attr = {
"axes": [0, 1, 2, 3],
"starts": [0, 0, 0, 0],
"ends": out_shape
}
node.fluid_code.add_layer("slice",
inputs=node,
output=node,
param_attr=attr)
if not channel_first:
attr = {"perm": [0, 2, 3, 1]}
node.fluid_code.add_layer("transpose",
......@@ -1127,3 +1192,17 @@ class TFOpMapperNHWC(OpMapper):
inputs=None,
output=node,
param_attr=attr)
def SquaredDifference(self, node):
x = self.graph.get_node(node.layer.input[0], copy=True)
y = self.graph.get_node(node.layer.input[1], copy=True)
inputs = {"x": x, "y": y}
node.fluid_code.add_layer("elementwise_sub",
inputs=inputs,
output=node,
param_attr=None)
inputs = {"x": node, "y": node}
node.fluid_code.add_layer("elementwise_mul",
inputs=inputs,
output=node,
param_attr=None)
此差异已折叠。
......@@ -65,3 +65,4 @@
| mNASNet | [pytorch(personal practice)](https://github.com/rwightman/gen-efficientnet-pytorch) |9|
| EfficientNet | [pytorch(personal practice)](https://github.com/rwightman/gen-efficientnet-pytorch) |9|
| SqueezeNet | [onnx official](https://s3.amazonaws.com/download.onnx/models/opset_9/squeezenet.tar.gz) |9|
|Ultra-Light-Fast-Generic-Face-Detector-1MB| [onnx_model](https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB/tree/master/models/onnx)| |
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册