未验证 提交 184b7388 编写于 作者: S SunAhong1993 提交者: GitHub

Merge pull request #324 from driftcloudy/develop

caffe2paddle, support relu6/upsample,compatible with LeakReLU,repair axpy and dropuout
目前,代码中已经提供了8个非官方op(不在[官网](http://caffe.berkeleyvision.org/tutorial/layers)上的op)的转换,这些op对应的Caffe实现源码如下:
目前,代码中已经提供了10个非官方op(不在[官网](http://caffe.berkeleyvision.org/tutorial/layers)上的op)的转换,这些op对应的Caffe实现源码如下:
| op | 该版本实现源码 |
|-------|--------|
......@@ -10,3 +10,5 @@
| Normalize | [code](https://github.com/weiliu89/caffe/blob/ssd/src/caffe/layers/normalize_layer.cpp) |
| ROIPooling | [code](https://github.com/rbgirshick/caffe-fast-rcnn/blob/0dcd397b29507b8314e252e850518c5695efbb83/src/caffe/layers/roi_pooling_layer.cpp) |
| Axpy | [code](https://github.com/hujie-frank/SENet/blob/master/src/caffe/layers/axpy_layer.cpp) |
| ReLU6 | [code](https://github.com/chuanqi305/ssd/blob/ssd/src/caffe/layers/relu6_layer.cpp) |
| Upsample | [code](https://github.com/eric612/MobileNet-YOLO/blob/master/src/caffe/layers/upsample_layer.cpp) |
......@@ -34,6 +34,7 @@
| 21 | Axpy | 22 | ROIPolling | 23 | Permute | 24 | DetectionOutput |
| 25 | Normalize | 26 | Select | 27 | ShuffleChannel | 28 | ConvolutionDepthwise |
| 29 | ReLU | 30 | AbsVal | 31 | Sigmoid | 32 | TanH |
| 33 | ReLU6 | 34 | Upsample |
## ONNX
......
......@@ -88,6 +88,19 @@ class CaffeGraph(Graph):
# filter them out here.
if (not exclude) and (phase == 'test'):
exclude = (type_str == 'Dropout')
if layer.type == 'Dropout':
drop_layer_top = layer.top[0]
drop_layer_bottom = layer.bottom[0]
if drop_layer_top != drop_layer_bottom:
for next_layer in layers:
for next_layer_bottom_idx, next_layer_bottom in enumerate(
next_layer.bottom):
if drop_layer_top == next_layer_bottom:
next_layer.bottom.remove(drop_layer_top)
next_layer.bottom.insert(
next_layer_bottom_idx,
drop_layer_bottom)
if not exclude:
filtered_layers.append(layer)
# Guard against dupes.
......
......@@ -10,6 +10,8 @@ from . import select
from . import shufflechannel
from . import convolutiondepthwise
from . import axpy
from . import upsample
from . import relu6
#custom layer import ends
custom_layers = get_registered_layers()
......
......@@ -2,7 +2,7 @@ from .register import register
from x2paddle.core.util import *
def axpy_shape(input_shape):
def axpy_shape(input_shapes):
assert len(input_shapes) == 3, "not valid input shape for axpy layer"
assert len(input_shapes[0]) == len(input_shapes[1]), 'should have same dims'
output_shape = input_shapes[1]
......@@ -18,7 +18,7 @@ def axpy_layer(inputs, input_shape=None, name=None):
y = inputs[2]
out = fluid.layers.elementwise_mul(x, alpha, axis=0)
out = fluid.layers.elementwise_add(out, y, name=name)
print(out)
return out
def axpy_weights(name, data=None):
......
from .register import register
from x2paddle.core.util import *
def relu6_shape(input_shape):
return input_shape
def relu6_layer(inputs, input_shape=None, name=None):
input = inputs[0]
out = fluid.layers.relu6(x=input)
return out
def relu6_weights(name, data=None):
weights_name = []
return weights_name
register(
kind='ReLU6', shape=relu6_shape, layer=relu6_layer, weights=relu6_weights)
# -*- coding: utf-8 -*-
################################################################################
#
# Copyright (c) 2020 Baidu.com, Inc. All Rights Reserved
#
################################################################################
"""
Author: Drift
Email: wutuobang@baidu.com
Date: 2020/04/22 18:45
"""
from .register import register
from x2paddle.core.util import *
def upsample_shape(input_shapes, scale):
"""
:param input_shapes:
:param scale:
:return:
"""
assert len(input_shapes) == 1, "not valid input shape for upsample layer"
assert type(scale) is int
input_shape = input_shapes[0]
new_h = scale * input_shape[2]
new_w = scale * input_shape[3]
output_shape = [input_shape[0], input_shape[1], new_h, new_w]
return [output_shape]
def upsample_layer(inputs, scale, input_shape=None, name=None):
"""
:param inputs:
:param scale:
:param input_shape:
:param name:
:return:
"""
x = inputs[0]
out = fluid.layers.resize_nearest(
x, align_corners=False, scale=scale, name=name)
return out
def upsample_weights(name, data=None):
"""
:param name:
:param data:
:return:
"""
weights_name = []
return weights_name
register(
kind='Upsample',
shape=upsample_shape,
layer=upsample_layer,
weights=upsample_weights)
......@@ -23,7 +23,6 @@ from x2paddle.op_mapper.caffe_custom_layer import *
class CaffeOpMapper(OpMapper):
directly_map_ops = {
'ReLU': 'relu',
'AbsVal': 'abs',
'Sigmoid': 'sigmoid',
'TanH': 'tanh',
......@@ -435,6 +434,26 @@ class CaffeOpMapper(OpMapper):
node.fluid_code.add_layer(
"concat", inputs=inputs, output=node, param_attr=attr)
def ReLU(self, node):
"""
:param node:
:return:
"""
assert len(
node.inputs) == 1, 'The count of ReLU node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True)
params = node.layer.relu_param
if params.HasField('negative_slope') and params.negative_slope != 0:
negative_slope = float(params.negative_slope)
attr = {'alpha': negative_slope}
node.fluid_code.add_layer(
'leaky_relu', inputs=input, output=node, param_attr=attr)
else:
node.fluid_code.add_layer('relu', inputs=input, output=node)
def PReLU(self, node):
assert len(
node.inputs) == 1, 'The count of PReLU node\'s input is not 1.'
......
......@@ -33,6 +33,8 @@
| MobileNet_V1 | [code](https://github.com/shicai/MobileNet-Caffe) |
| MobileNet_V2 | [code](https://github.com/shicai/MobileNet-Caffe) |
| ShuffleNet_v2 | [code](https://github.com/miaow1988/ShuffleNet_V2_pytorch_caffe/releases/tag/v0.1.0) |
| InceptionV3 | [code](https://github.com/soeaver/caffe-model/blob/master/cls/inception/) |
| InceptionV4 | [code](https://github.com/soeaver/caffe-model/blob/master/cls/inception/) |
| mNASNet | [code](https://github.com/LiJianfei06/MnasNet-caffe) |
| MTCNN | [code](https://github.com/kpzhang93/MTCNN_face_detection_alignment/tree/master/code/codes/MTCNNv1/model) |
| Mobilenet_SSD | [code](https://github.com/chuanqi305/MobileNet-SSD) |
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册