未验证 提交 7f182d54 编写于 作者: J Jason 提交者: GitHub

Merge pull request #510 from SunAhong1993/develop

添加LRN+修复文档
......@@ -9,7 +9,7 @@ X2Paddle在多个主流的CV模型上,测试过TensorFlow/Caffe/ONNX/PyTorch
## 环境依赖
python == 2.7 | python >= 3.5
python >= 3.5
paddlepaddle 2.0.0-rc1 或者 develop
**按需安装以下依赖**
......
......@@ -42,7 +42,7 @@
| 21 | Axpy | 22 | ROIPolling | 23 | Permute | 24 | DetectionOutput |
| 25 | Normalize | 26 | Select | 27 | ShuffleChannel | 28 | ConvolutionDepthwise |
| 29 | ReLU | 30 | AbsVal | 31 | Sigmoid | 32 | TanH |
| 33 | ReLU6 | 34 | Upsample | | | | |
| 33 | ReLU6 | 34 | Upsample | 35 | MemoryData | | |
## ONNX
......@@ -63,7 +63,7 @@
| 49 | MaxPool | 50 | Conv | 51 | Gemm | 52 | NonZero |
| 53 | Abs | 54 | Floor | 56 | ArgMax | 57 | Sign |
| 58 | Reciprocal | 59 | Size | 60 | OneHot | 61 | ReduceProd |
| 62 | LogSoftmax | 63 | LSTM | | | | |
| 62 | LogSoftmax | 63 | LSTM | 64 | LRN | | |
......@@ -99,7 +99,7 @@ Aten:
| 101 | aten::upsample\_bilinear2d | 102 | aten::values |103|aten::view|104|aten::warn|
| 105 | aten::where | 106 | aten::zeros |107|aten::zeros\_like|108|aten::bmm|
| 109 | aten::sub\_ | 110 | aten:erf |111|aten::lstm|112|aten::gather|
| 113 | aten::upsample\_nearest2d || |||||
| 113 | aten::upsample\_nearest2d | 114 |aten::split\_with\_sizes |||||
Prim:
| 序号 | OP | 序号 | OP | 序号 | OP | 序号 | OP |
......
# X2Paddle模型测试库
> 目前X2Paddle支持70+的TensorFlow OP,40+的Caffe Layer,覆盖了大部分CV分类模型常用的操作。我们在如下模型列表中测试了X2Paddle的转换。
> 目前X2Paddle支持80+的TensorFlow OP,30+的Caffe Layer,60+的ONNX OP,110+的PyTorch Aten,10+的PyTorch Prim,覆盖了大部分CV分类模型常用的操作。我们在如下模型列表中测试了X2Paddle的转换。
**注:** 受限于不同框架的差异,部分模型可能会存在目前无法转换的情况,如TensorFlow中包含控制流的模型,NLP模型等。对于CV常见的模型,如若您发现无法转换或转换失败,存在较大diff等问题,欢迎通过[ISSUE反馈](https://github.com/PaddlePaddle/X2Paddle/issues/new)的方式告知我们(模型名,代码实现或模型获取方式),我们会及时跟进:)
......
......@@ -229,6 +229,11 @@ def main():
assert args.paddle_type in ["dygraph", "static"], "--paddle_type must be 'dygraph' or 'static'"
try:
import platform
v0, v1, v2 = platform.python_version().split('.')
if not(int(v0) >= 3 and int(v1) >= 5):
print("[ERROR] python>=3.5 is required")
return
import paddle
v0, v1, v2 = paddle.__version__.split('.')
print("paddle.__version__ = {}".format(paddle.__version__))
......
......@@ -18,3 +18,4 @@ from .pad_two_input import PadWithTwoInput
from .pad_all_dim2 import PadAllDim2
from .pad_all_dim4 import PadAllDim4
from .pad_all_dim4_one_input import PadAllDim4WithOneInput
from .lrn import LocalResponseNorm
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
class LocalResponseNorm(object):
def __init__(self,
size,
alpha=1e-4,
beta=0.75,
k=1.):
self.size = size
self.alpha = alpha
self.beta = beta
self.k = k
def __call__(self, x):
sizes = x.shape
dim = len(sizes)
if dim < 3:
raise ValueError(
'Expected 3D or higher dimensionality input, but got {} dimensions'.
format(dim))
div = paddle.unsqueeze(paddle.multiply(x, x), axis=1)
pad4d_shape = [self.size // 2, (self.size - 1) // 2, 0, 0]
pool2d_shape = (1, self.size)
pad5d_shape = [self.size // 2, (self.size - 1) // 2, 0, 0, 0, 0]
pool3d_shape = (1, 1, self.size)
if dim == 3:
div = paddle.nn.functional.pad(div, pad=pad4d_shape)
div = paddle.nn.functional.avg_pool2d(
div, kernel_size=pool2d_shape, stride=1)
div = paddle.squeeze(div, axis=1)
else:
tmp = paddle.unsqueeze(x, axis=1)
reshape_shape = paddle.shape(tmp)
new_reshape_shape = paddle.cast(reshape_shape, "float32")
index = paddle.full(shape=[1], fill_value=-2, dtype="int32")
value = paddle.full(shape=[1], fill_value=-1, dtype="float32")
new_reshape_shape = paddle.scatter(new_reshape_shape, index, value)
new_reshape_shape = paddle.cast(new_reshape_shape, "int32")
div = paddle.reshape(div, shape=reshape_shape)
div = paddle.nn.functional.pad(div,
pad=pad5d_shape,
data_format='NCDHW')
div = paddle.nn.functional.avg_pool3d(
div, kernel_size=pool3d_shape, stride=1)
div = paddle.reshape(paddle.squeeze(div, axis=1), sizes)
div = paddle.scale(div, scale=self.alpha, bias=self.k)
div = paddle.pow(div, self.beta)
res = paddle.divide(x, div)
return res
\ No newline at end of file
......@@ -542,7 +542,6 @@ class OpSet9():
value=value,
mode=string(mode))
else:
print(pads_len)
raise Exception("The padding value is wrong!")
if not op_independent:
return node.name + '_paded'
......@@ -2013,3 +2012,25 @@ class OpSet9():
"k": val_k.name},
outputs=["{}_p{}".format(node.layer_name, 0), "{}_p{}".format(node.layer_name, 1)],
**layer_attrs)
@print_mapping_info
def LRN(self, node):
op_name = name_generator("lrn", self.nn_name2id)
output_name = node.name
layer_outputs = [op_name, output_name]
val_x = self.graph.get_input_node(node, idx=0, copy=True)
alpha = node.get_attr('alpha', 0.0001)
beta = node.get_attr('beta', 0.75)
bias = node.get_attr('bias', 1.0)
size = node.get_attr('size')
layer_attrs = {
'size': size,
'alpha': alpha,
'beta': beta,
'k': bias
}
self.paddle_graph.add_layer(
"custom_layer:LocalResponseNorm",
inputs={"x": val_x.name},
outputs=layer_outputs,
**layer_attrs)
......@@ -17,4 +17,5 @@ from .one_hot import one_hot
from .pad_two_input import pad_with_two_input
from .pad_all_dim2 import pad_all_dim2
from .pad_all_dim4 import pad_all_dim4
from .pad_all_dim4_one_input import pad_all_dim4_one_input
\ No newline at end of file
from .pad_all_dim4_one_input import pad_all_dim4_one_input
from .lrn import local_response_norm
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
def local_response_norm(x, size, alpha=1e-4, beta=0.75, k=1.):
sizes = x.shape
dim = len(sizes)
if dim < 3:
raise ValueError(
'Expected 3D or higher dimensionality input, but got {} dimensions'.
format(dim))
div = paddle.unsqueeze(paddle.multiply(x, x), axis=1)
pad4d_shape = [size // 2, (size - 1) // 2, 0, 0]
pool2d_shape = (1, size)
pad5d_shape = [size // 2, (size - 1) // 2, 0, 0, 0, 0]
pool3d_shape = (1, 1, size)
if dim == 3:
div = paddle.nn.functional.pad(div, pad=pad4d_shape)
div = paddle.nn.functional.avg_pool2d(
div, kernel_size=pool2d_shape, stride=1)
div = paddle.squeeze(div, axis=1)
else:
tmp = paddle.unsqueeze(x, axis=1)
reshape_shape = paddle.shape(tmp)
new_reshape_shape = paddle.cast(reshape_shape, "float32")
index = paddle.full(shape=[1], fill_value=-2, dtype="int32")
value = paddle.full(shape=[1], fill_value=-1, dtype="float32")
new_reshape_shape = paddle.scatter(new_reshape_shape, index, value)
new_reshape_shape = paddle.cast(new_reshape_shape, "int32")
div = paddle.reshape(div, shape=reshape_shape)
div = paddle.nn.functional.pad(div,
pad=pad5d_shape,
data_format='NCDHW')
div = paddle.nn.functional.avg_pool3d(
div, kernel_size=pool3d_shape, stride=1)
div = paddle.reshape(paddle.squeeze(div, axis=1), sizes)
div = paddle.scale(div, scale=alpha, bias=k)
div = paddle.pow(div, beta)
res = paddle.divide(x, div)
return res
\ No newline at end of file
......@@ -1780,4 +1780,24 @@ class OpSet9():
inputs={"x": val_x.name,
"k": val_k.name},
outputs=["{}_p{}".format(node.layer_name, 0), "{}_p{}".format(node.layer_name, 1)],
**layer_attrs)
\ No newline at end of file
**layer_attrs)
@print_mapping_info
def LRN(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_x = self.graph.get_input_node(node, idx=0, copy=True)
alpha = node.get_attr('alpha', 0.0001)
beta = node.get_attr('beta', 0.75)
bias = node.get_attr('bias', 1.0)
size = node.get_attr('size')
layer_attrs = {
'size': size,
'alpha': alpha,
'beta': beta,
'k': bias
}
self.paddle_graph.add_layer(
"custom_layer:local_response_norm",
inputs={"x": val_x.name},
outputs=[node.name],
**layer_attrs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册