未验证 提交 98a6cacb 编写于 作者: J Jason 提交者: GitHub

Merge pull request #84 from Channingss/develop

modify readme.md
...@@ -12,7 +12,7 @@ paddlepaddle >= 1.5.0 ...@@ -12,7 +12,7 @@ paddlepaddle >= 1.5.0
**以下依赖只需对应安装自己需要的即可** **以下依赖只需对应安装自己需要的即可**
转换tensorflow模型 : tensorflow == 1.14.0 转换tensorflow模型 : tensorflow == 1.14.0
转换caffe模型 : caffe == 1.0.0 转换caffe模型 : caffe == 1.0.0
转换onnx模型 : onnx == 1.5.0 pytorch == 1.1.0
## 安装 ## 安装
``` ```
pip install x2paddle pip install x2paddle
...@@ -32,8 +32,9 @@ x2paddle --framework=tensorflow --model=tf_model.pb --save_dir=pd_model ...@@ -32,8 +32,9 @@ x2paddle --framework=tensorflow --model=tf_model.pb --save_dir=pd_model
x2paddle --framework=caffe --prototxt=deploy.proto --weight=deploy.caffemodel --save_dir=pd_model x2paddle --framework=caffe --prototxt=deploy.proto --weight=deploy.caffemodel --save_dir=pd_model
``` ```
### ONNX ### ONNX
即将release,目前仍可使用[onnx2fluid](https://github.com/PaddlePaddle/X2Paddle/tree/release-0.3/onnx2fluid) ```
x2paddle --framework=onnx --model=onnx_model.onnx --save_dir=pd_model
```
### 参数选项 ### 参数选项
| 参数 | | | 参数 | |
|----------|--------------| |----------|--------------|
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
from collections import OrderedDict as _dict from collections import OrderedDict as _dict
import numpy as _np
default_op_mapping_field_values = _dict() default_op_mapping_field_values = _dict()
default_op_mapping_field_values['FLUID_OP'] = '' default_op_mapping_field_values['FLUID_OP'] = ''
...@@ -30,6 +31,21 @@ default_op_mapping = { ...@@ -30,6 +31,21 @@ default_op_mapping = {
'Mul': ['elementwise_mul', ['X', 'Y'], ['Out'], 'Mul': ['elementwise_mul', ['X', 'Y'], ['Out'],
dict(), dict(),
dict(axis=-1)], dict(axis=-1)],
'Clip': [
'clip', ['X'], ['Out'],
dict(),
dict(
min=(_np.asarray([255, 255, 127, 255],
dtype=_np.uint8).view(_np.float32)),
max=(_np.asarray([255, 255, 127, 127],
dtype=_np.uint8).view(_np.float32)),
)
],
'ReduceMean': [
'reduce_mean', ['X'], ['Out'],
dict(axes='dim', keepdims='keep_dim'),
dict(keep_dim=1)
]
} }
default_ioa_constraint = { default_ioa_constraint = {
......
...@@ -62,7 +62,7 @@ class ONNXOpMapper(OpMapper): ...@@ -62,7 +62,7 @@ class ONNXOpMapper(OpMapper):
func = getattr(self, op) func = getattr(self, op)
func(node) func(node)
elif op in default_op_mapping: elif op in default_op_mapping:
self._default(node) self.directly_map(node)
def op_checker(self): def op_checker(self):
unsupported_ops = set() unsupported_ops = set()
...@@ -80,7 +80,7 @@ class ONNXOpMapper(OpMapper): ...@@ -80,7 +80,7 @@ class ONNXOpMapper(OpMapper):
print(op) print(op)
return False return False
def _default(self, node, *args, name='', **kwargs): def directly_map(self, node, *args, name='', **kwargs):
inputs = node.layer.input inputs = node.layer.input
outputs = node.layer.output outputs = node.layer.output
op_type = node.layer_type op_type = node.layer_type
...@@ -544,7 +544,7 @@ class ONNXOpMapper(OpMapper): ...@@ -544,7 +544,7 @@ class ONNXOpMapper(OpMapper):
"momentum": momentum, "momentum": momentum,
"epsilon": epsilon, "epsilon": epsilon,
"data_layout": string('NCHW'), "data_layout": string('NCHW'),
"is_test": 'True', "is_test": True,
"param_attr": string(val_scale.layer_name), "param_attr": string(val_scale.layer_name),
"bias_attr": string(val_b.layer_name), "bias_attr": string(val_b.layer_name),
"moving_mean_name": string(val_mean.layer_name), "moving_mean_name": string(val_mean.layer_name),
......
...@@ -47,8 +47,12 @@ ...@@ -47,8 +47,12 @@
``` ```
import torch import torch
import torchvision import torchvision
dummy_input = torch.randn(1, 3, 224, 224) #根据不同模型调整shape
#根据不同模型调整输入的shape
dummy_input = torch.randn(1, 3, 224, 224)
resnet18 = torchvision.models.resnet18(pretrained=True) resnet18 = torchvision.models.resnet18(pretrained=True)
torch.onnx.export(resnet18, dummy_input, "resnet18.onnx",verbose=True)#"resnet18.onnx"为onnx model的存储路径
#"resnet18.onnx"为onnx model的存储路径
torch.onnx.export(resnet18, dummy_input, "resnet18.onnx",verbose=True)
``` ```
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册