提交 28f195d7 编写于 作者: S Smirnov Egor

add test for prelu negative slope access pattern

上级 4830352b
......@@ -8,6 +8,7 @@ import paddle # version 2.1.1
import numpy as np
import os.path
import onnx
import onnxsim
import google.protobuf.text_format
import io
......@@ -72,6 +73,14 @@ def save_onnx_data_and_model(input, output, name, operation, *args, **kwargs):
model = onnx.helper.make_model(graph, producer_name=name)
onnx.save(model, models_files)
def simplify(name, rename=False, **kwargs):
model, check = onnxsim.simplify(name, **kwargs)
assert check, "couldn't valide"
name = name[:-5]
if rename:
name += '_optimized'
onnx.save(model, name + '.onnx')
torch.manual_seed(0)
np.random.seed(0)
......@@ -127,6 +136,18 @@ input = Variable(torch.randn(1, 2, 10, 10))
relu = nn.ReLU(inplace=True)
save_data_and_model("ReLU", input, relu)
class PReLU_slope(nn.Module):
def __init__(self, *args, **kwargs):
super(PReLU_slope, self).__init__()
def forward(self, x):
return nn.PReLU()(x)
model = PReLU_slope()
input_ = Variable(torch.randn(1, 1, 5, 5, dtype=torch.float32))
save_data_and_model("PReLU_slope", input_, model, export_params=True)
simplify('models/PReLU_slope.onnx', False)
input = Variable(torch.randn(2, 3))
dropout = nn.Dropout()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册