#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Fri Mar 22 11:19:45 2019 @author: Macrobull Not all ops in this file are supported by both PyTorch and ONNX This only demostrates the conversion/validation workflow from PyTorch to ONNX to Paddle fluid """ from __future__ import print_function import torch import torch.nn as nn import torch.nn.functional as F from onnx2fluid.torch_export_helper import export_onnx_with_validation prefix = 'sample_' idx = 0 ######## example: RNN cell ######## class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.gru = nn.GRUCell(6, 5) self.lstm = nn.LSTMCell(5, 4) def forward(self, x, h1, h2, c2): h = self.gru(x, h1) h, c = self.lstm(h, (h2, c2)) return h, c model = Model() model.eval() xb = torch.rand((7, 6)) h1 = torch.zeros((7, 5)) h2 = torch.zeros((7, 4)) c2 = torch.zeros((7, 4)) yp = model(xb, h1, h2, c2) idx += 1 print('index: ', idx) export_onnx_with_validation(model, [xb, h1, h2, c2], prefix + str(idx), ['x', 'h1', 'h2', 'c2'], ['h', 'c'], verbose=True, training=False) ######## example: RNN ######## class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.gru = nn.GRU(6, 5, 3) self.lstm = nn.LSTM(5, 4, 2) def forward(self, x, h1, h2, c2): y, h1 = self.gru(x, h1) y, (h2, c2) = self.lstm(y, (h2, c2)) return y model = Model() model.eval() xb = torch.rand((8, 1, 6)) h1 = torch.zeros((3, 1, 5)) h2 = torch.zeros((2, 1, 4)) c2 = torch.zeros((2, 1, 4)) yp = model(xb, h1, h2, c2) idx += 1 print('index: ', idx) export_onnx_with_validation(model, [xb, h1, h2, c2], prefix + str(idx), ['x', 'h1', 'h2', 'c2'], ['y'], verbose=True, training=False) ######## example: random ######## """ symbolic registration: def rand(g, *shapes): shapes_list = list(shapes) shape = _maybe_get_const(shapes_list[0], "is") return g.op('RandomUniform', shape_i=shape) """ class Model(nn.Module): def __init__(self): super(Model, self).__init__() def forward(self, x): y = torch.rand((2, 3)) # + torch.rand_like(x) y = y + torch.randn((2, 3)) # + torch.randn_like(x) y = y + x return y model = Model() model.eval() xb = torch.rand((2, 3)) yp = model(xb) idx += 1 print('index: ', idx) export_onnx_with_validation(model, [xb], prefix + str(idx), ['x'], ['y'], verbose=True, training=False) ######## example: fc ######## class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.fc = nn.Linear(3, 8) def forward(self, x): y = x y = self.fc(y) return y model = Model() model.eval() xb = torch.rand((2, 3)) yp = model(xb) idx += 1 print('index: ', idx) export_onnx_with_validation(model, [xb], prefix + str(idx), ['x'], ['y'], verbose=True, training=False) ######## example: compare ######## class Model(nn.Module): def __init__(self): super(Model, self).__init__() def forward(self, x0, x1): x0 = x0.clamp(-1, 1) a = torch.max(x0, x1) == x1 b = x0 < x1 c = x0 > x1 return a, b, c model = Model() model.eval() xb0 = torch.rand((2, 3)) xb1 = torch.rand((2, 3)) ya, yb, yc = model(xb0, xb1) idx += 1 print('index: ', idx) export_onnx_with_validation(model, [xb0, xb1], prefix + str(idx), ['x0', 'x1'], ['ya', 'yb', 'yc'], verbose=True, training=False) ######## example: affine_grid ######## """ symbolic registration: @parse_args('v', 'is') def affine_grid_generator(g, theta, size): return g.op('AffineGrid', theta, size_i=size) """ class Model(nn.Module): def __init__(self): super(Model, self).__init__() def forward(self, theta): grid = F.affine_grid(theta, (2, 2, 8, 8)) return grid model = Model() model.eval() theta = torch.rand((2, 2, 3)) grid = model(theta) idx += 1 print('index: ', idx) export_onnx_with_validation(model, (theta, ), prefix + str(idx), ['theta'], ['grid'], verbose=True, training=False) ######## example: conv2d_transpose ######## class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.conv = nn.ConvTranspose2d(3, 8, 3) self.dropout = nn.Dropout2d() def forward(self, x): y = x y = self.conv(y) y = self.dropout(y) return y model = Model() model.eval() xb = torch.rand((2, 3, 4, 5)) yp = model(xb) idx += 1 print('index: ', idx) export_onnx_with_validation(model, [xb], prefix + str(idx), ['x'], ['y'], verbose=True, training=False) ######## example: conv2d ######## class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.conv = nn.Conv2d(3, 8, 3) self.batch_norm = nn.BatchNorm2d(8) self.pool = nn.AdaptiveAvgPool2d(1) def forward(self, x): y = x y = self.conv(y) y = self.batch_norm(y) y = self.pool(y) return y model = Model() model.eval() xb = torch.rand((2, 3, 4, 5)) yp = model(xb) idx += 1 print('index: ', idx) export_onnx_with_validation(model, [xb], prefix + str(idx), ['x'], ['y'], verbose=True, training=False) ######### example: conv1d ######## # #class Model(nn.Module): # def __init__(self): # super(Model, self).__init__() # self.batch_norm = nn.BatchNorm2d(3) # # def forward(self, x): # y = x # y = self.batch_norm(y) # return y # # #model = Model() #model.eval() #xb = torch.rand((2, 3, 4, 5)) #yp = model(xb) #idx += 1 #print('index: ', idx) #export_onnx_with_validation( # model, [xb], prefix + str(idx), # ['x'], ['y'], # verbose=True, training=False) ######## example: empty ######## class Model(nn.Module): def __init__(self): super(Model, self).__init__() def forward(self, x): return x model = Model() model.eval() xb = torch.rand((2, 3)) yp = model(xb) idx += 1 print('index: ', idx) export_onnx_with_validation(model, [xb], prefix + str(idx), ['y'], ['y'], verbose=True, training=False)