提交 69e0699a 编写于 作者: B BBuf

match oneflow0.5.0, still has bug

上级 bb38c529
......@@ -13,128 +13,65 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np
import oneflow as flow
import oneflow.typing as tp
import oneflow.core.operator.op_conf_pb2 as op_conf_util
import onnxruntime as ort
import onnx
import oneflow.nn as nn
from oneflow_onnx.oneflow2onnx.util import convert_to_onnx_and_check
def _conv2d_layer(
name,
input,
filters,
kernel_size=3,
strides=1,
padding="SAME",
data_format="NCHW",
dilation_rate=1,
activation=op_conf_util.kRelu,
use_bias=False,
weight_initializer=flow.random_uniform_initializer(),
bias_initializer=flow.random_uniform_initializer(),
):
weight_shape = (filters, input.shape[1], kernel_size, kernel_size)
weight = flow.get_variable(
name + "-weight",
shape=weight_shape,
dtype=input.dtype,
initializer=weight_initializer,
)
output = flow.nn.conv2d(
input, weight, strides, padding, data_format, dilation_rate, name=name
)
if use_bias:
bias = flow.get_variable(
name + "-bias",
shape=(filters,),
dtype=input.dtype,
initializer=bias_initializer,
class AlexNet(nn.Module):
def __init__(self, num_classes: int = 1000) -> None:
super(AlexNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(64, 192, kernel_size=5, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(192, 384, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(384, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
)
self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
self.classifier = nn.Sequential(
nn.Dropout(),
nn.Linear(256 * 6 * 6, 4096),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, num_classes),
)
output = flow.nn.bias_add(output, bias, data_format)
if activation is not None:
if activation == op_conf_util.kRelu:
output = flow.nn.relu(output)
else:
raise NotImplementedError
return output
def alexnet(images, labels, trainable=True):
transposed = flow.transpose(images, name="transpose", perm=[0, 3, 1, 2])
conv1 = _conv2d_layer(
"conv1", transposed, filters=64, kernel_size=11, strides=4, padding="VALID"
)
pool1 = flow.nn.avg_pool2d(conv1, 3, 2, "VALID", "NCHW", name="pool1")
conv2 = _conv2d_layer("conv2", pool1, filters=192, kernel_size=5)
pool2 = flow.nn.avg_pool2d(conv2, 3, 2, "VALID", "NCHW", name="pool2")
conv3 = _conv2d_layer("conv3", pool2, filters=384)
conv4 = _conv2d_layer("conv4", conv3, filters=384)
conv5 = _conv2d_layer("conv5", conv4, filters=256)
pool5 = flow.nn.avg_pool2d(conv5, 3, 2, "VALID", "NCHW", name="pool5")
def _get_initializer():
return flow.random_uniform_initializer()
if len(pool5.shape) > 2:
pool5 = flow.reshape(pool5, shape=(pool5.shape[0], -1))
fc1 = flow.layers.dense(
inputs=pool5,
units=4096,
activation=flow.math.relu,
use_bias=False,
kernel_initializer=_get_initializer(),
bias_initializer=False,
trainable=trainable,
name="fc1",
)
dropout1 = fc1
fc2 = flow.layers.dense(
inputs=dropout1,
units=4096,
activation=flow.math.relu,
use_bias=False,
kernel_initializer=_get_initializer(),
bias_initializer=False,
trainable=trainable,
name="fc2",
)
dropout2 = fc2
fc3 = flow.layers.dense(
inputs=dropout2,
units=1001,
activation=None,
use_bias=False,
kernel_initializer=_get_initializer(),
bias_initializer=False,
trainable=trainable,
name="fc3",
)
return fc3
def forward(self, x: flow.Tensor) -> flow.Tensor:
x = self.features(x)
x = self.avgpool(x)
x = flow.flatten(x, 1)
x = self.classifier(x)
return x
alexnet = AlexNet()
alexnet.eval()
# flow.save(alexnet.state_dict(), "/tmp/alexnet")
class AlexNetGraph(flow.nn.Graph):
def __init__(self):
super().__init__()
self.m = alexnet
def build(self, x):
out = self.m(x)
return out
def test_alexnet():
@flow.global_function()
def alexnet_eval_job(x: tp.Numpy.Placeholder((1, 227, 227, 3))):
return alexnet(x, None, False)
alexnet_graph = AlexNetGraph()
alexnet_graph._compile(flow.randn(1, 3, 224, 224))
convert_to_onnx_and_check(alexnet_eval_job, flow_weight_dir=None, onnx_model_path="/tmp")
convert_to_onnx_and_check(alexnet_graph, flow_weight_dir="/tmp/alexnet", onnx_model_path="/tmp")
test_alexnet()
......@@ -222,7 +222,7 @@ def TopologicalSort(g, continue_on_error):
pass
def Export(
job_func: Callable,
graph: Callable,
model_save_dir: Text,
onnx_filename: Text,
continue_on_error: bool = False,
......@@ -235,7 +235,7 @@ def Export(
r"""Export a oneflow model into ONNX format.
Args:
job_func: The job function
graph: oneflow.nn.Graph
model_save_dir: The directory containing oneflow model weights. Users are expected to call check_point.save(dir), wait for the model saving finishing, and pass the argument 'dir' as model_save_dir.
onnx_filename: a string for the output filename
continue_on_error: if an op can't be processed (aka there is no mapping), continue
......@@ -246,38 +246,34 @@ def Export(
"""
assert os.getenv("ENABLE_USER_OP") != "False"
assert os.path.isdir(model_save_dir)
job_set = oneflow.experimental.get_job_set()
job_name = job_func.__name__
for job in job_set.job:
# TODO(OYY) Modify the interface before modifying it
if job.job_conf.job_name == job_name:
onnx_graph = ProcessFlowGraph(
job,
model_save_dir,
continue_on_error=continue_on_error,
opset=opset,
extra_opset=extra_opset,
shape_override=shape_override,
)
onnx_graph = optimizer.OptimizeGraph(onnx_graph)
model_proto = onnx_graph.MakeModel(
job_name, onnx_filename, external_data=external_data
)
job = graph._full_graph_proto
onnx_graph = ProcessFlowGraph(
job,
model_save_dir,
continue_on_error=continue_on_error,
opset=opset,
extra_opset=extra_opset,
shape_override=shape_override,
)
onnx_graph = optimizer.OptimizeGraph(onnx_graph)
model_proto = onnx_graph.MakeModel(
"tmp", onnx_filename, external_data=external_data
)
if dynamic_batch_size == True:
model_proto.graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None'
if dynamic_batch_size == True:
model_proto.graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None'
with open(onnx_filename, "wb") as f:
try:
f.write(model_proto.SerializeToString())
except ValueError as e:
raise ValueError(
"Error occured when running model_proto.SerializeToString(). If the model is larger than 2GB, please specify external_data=True when calling flow.onnx.export. Original error message:\n{}".format(
e
)
)
return
raise ValueError('Cannot find job "{}" in jobset'.format(job_name))
with open(onnx_filename, "wb") as f:
try:
f.write(model_proto.SerializeToString())
except ValueError as e:
raise ValueError(
"Error occured when running model_proto.SerializeToString(). If the model is larger than 2GB, please specify external_data=True when calling flow.onnx.export. Original error message:\n{}".format(
e
)
)
return
def ProcessFlowGraph(
......
......@@ -124,7 +124,7 @@ class GraphBuilder(object):
res = tensor
if isinstance(tensor, list):
res = self.graph.MakeConst(
oneflow.util.unique_str("const_slice"), np.array(tensor, dtype)
oneflow._oneflow_internal.UniqueStr("const_slice"), np.array(tensor, dtype)
).output[0]
util.MakeSure(
......
......@@ -69,7 +69,7 @@ def _WrapConcatWithCast(ctx, node):
next_nodes = ctx.FindOutputConsumers(node.output_tensor_names[0])
# cast output back to dtype unless the next op is a cast
if next_nodes[0].op_type != "Cast":
op_name = oneflow.util.unique_str(node.name)
op_name = oneflow._oneflow_internal.UniqueStr(node.name)
output_cast = ctx.InsertNewNodeOnOutput("Cast", output_name, name=op_name)
output_cast.attrs["to"] = dtype
ctx.set_dtype(output_cast.output_tensor_names[0], dtype)
......@@ -87,7 +87,7 @@ class Reshape:
onnx_pb.TensorProto.INT64,
]
shape_node = ctx.MakeConst(
oneflow.util.unique_str("shape"), np.array(node.attrs.get("shape"), None)
oneflow._oneflow_internal.UniqueStr("shape"), np.array(node.attrs.get("shape"), None)
)
node.input_tensor_names = node.input_tensor_names + [shape_node.name]
if ctx.opset >= 8 or not need_casting:
......@@ -102,7 +102,7 @@ class Reshape:
# if the next node is already a cast we don't need to insert another one
next_nodes = ctx.FindOutputConsumers(node.output_tensor_names[0])
if len(next_nodes) != 1 or next_nodes[0].op_type != "Cast":
op_name = oneflow.util.unique_str(node.name)
op_name = oneflow._oneflow_internal.UniqueStr(node.name)
output_cast = ctx.InsertNewNodeOnOutput(
"Cast", node.output_tensor_names[0], name=op_name
)
......
......@@ -57,7 +57,7 @@ class ScalarBinaryOp:
)
np_dtype = util.Onnx2NumpyDtype(ctx.get_dtype(node.input_tensor_names[0]))
scalar_node = ctx.MakeConst(
oneflow.util.unique_str("scalar"), np.array([scalar_val]).astype(np_dtype)
oneflow._oneflow_internal.UniqueStr("scalar"), np.array([scalar_val]).astype(np_dtype)
)
node.input_tensor_names.append(scalar_node.output_tensor_names[0])
......@@ -189,7 +189,7 @@ def _MakeMinOrMaxOp(
if output_dtypes is not None:
origin_dtype = output_dtypes[0]
ctx.set_dtype(node.output_tensor_names[0], target_dtype)
cast_name = oneflow.util.unique_str(node.name)
cast_name = oneflow._oneflow_internal.UniqueStr(node.name)
cast_node = ctx.InsertNewNodeOnOutput(
"Cast", node.output_tensor_names[0], name=cast_name, to=origin_dtype
)
......@@ -272,7 +272,7 @@ class ClipOps:
np_dtype = util.ONNX_2_NUMPY_DTYPE[onnx_dtype]
if min_val is not None:
clip_min = ctx.MakeConst(
oneflow.util.unique_str("{}_min".format(node.name)),
oneflow._oneflow_internal.UniqueStr("{}_min".format(node.name)),
np.array(min_val, dtype=np_dtype),
)
node.input_tensor_names.append(clip_min.output_tensor_names[0])
......@@ -280,7 +280,7 @@ class ClipOps:
node.input_tensor_names.append("")
if max_val is not None:
clip_max = ctx.MakeConst(
oneflow.util.unique_str("{}_max".format(node.name)),
oneflow._oneflow_internal.UniqueStr("{}_max".format(node.name)),
np.array(max_val, dtype=np_dtype),
)
node.input_tensor_names.append(clip_max.output_tensor_names[0])
......@@ -338,7 +338,7 @@ class Rsqrt:
@classmethod
def Version_1(cls, ctx, node, **kwargs):
node.op_type = "Sqrt"
op_name = oneflow.util.unique_str(node.name)
op_name = oneflow._oneflow_internal.UniqueStr(node.name)
reciprocal = ctx.InsertNewNodeOnOutput(
"Reciprocal", node.output_tensor_names[0], name=op_name
)
......@@ -350,7 +350,7 @@ class SquaredDifference:
@classmethod
def Version_1(cls, ctx, node, **kwargs):
node.op_type = "Sub"
op_name = oneflow.util.unique_str(node.name)
op_name = oneflow._oneflow_internal.UniqueStr(node.name)
mul = ctx.InsertNewNodeOnOutput(
"Mul", node.output_tensor_names[0], name=op_name
)
......@@ -372,7 +372,7 @@ class Sign:
raise ValueError(
"dtype " + str(node_dtype) + " is not supported in onnx for now"
)
zero_name = oneflow.util.unique_str("{}_zero".format(node.name))
zero_name = oneflow._oneflow_internal.UniqueStr("{}_zero".format(node.name))
ctx.MakeConst(zero_name, np.array(0, dtype=np.float32))
if node_dtype not in [
onnx_pb.TensorProto.FLOAT16,
......@@ -716,7 +716,7 @@ def _AddCastToOutput(graph, node):
# oneflow logical ops produce int8 tensor while onnx logical ops produce bool tensor
output = node.output_tensor_names[0]
cast_node = graph.InsertNewNodeOnOutput(
"Cast", output, oneflow.util.unique_str("cast"), to=graph.get_dtype(output)
"Cast", output, oneflow._oneflow_internal.UniqueStr("cast"), to=graph.get_dtype(output)
)
graph.CopyShape(output, node.output_tensor_names[0])
graph.set_dtype(node.output_tensor_names[0], TensorProto.BOOL)
......@@ -765,7 +765,7 @@ class Equal:
node.op_type = "Equal"
output_name = node.output_tensor_names[0]
not_node = ctx.InsertNewNodeOnOutput(
"Not", output_name, name=oneflow.util.unique_str(node.name)
"Not", output_name, name=oneflow._oneflow_internal.UniqueStr(node.name)
)
ctx.CopyShape(output_name, not_node.output_tensor_names[0])
ctx.CopyDtype(output_name, not_node.output_tensor_names[0])
......@@ -779,7 +779,7 @@ class Equal:
node.op_type = "Equal"
output_name = node.output_tensor_names[0]
not_node = ctx.InsertNewNodeOnOutput(
"Not", output_name, name=oneflow.util.unique_str(node.name)
"Not", output_name, name=oneflow._oneflow_internal.UniqueStr(node.name)
)
ctx.CopyShape(output_name, not_node.output_tensor_names[0])
ctx.CopyDtype(output_name, not_node.output_tensor_names[0])
......@@ -809,7 +809,7 @@ class GreaterLessEqual:
GreaterLess.Version_7(ctx, node, **kwargs)
output_name = node.output_tensor_names[0]
new_node = ctx.InsertNewNodeOnOutput(
"Not", output_name, name=oneflow.util.unique_str(node.name)
"Not", output_name, name=oneflow._oneflow_internal.UniqueStr(node.name)
)
ctx.CopyShape(output_name, new_node.output_tensor_names[0])
ctx.set_dtype(new_node.output_tensor_names[0], ctx.get_dtype(output_name))
......
......@@ -104,7 +104,7 @@ def _ConvConvertInputs(
reshape.skip_conversion = True
else:
# new reshape takes new shape as input_tensor_names[1]
shape_name = oneflow.util.unique_str(node.name)
shape_name = oneflow._oneflow_internal.UniqueStr(node.name)
ctx.MakeConst(shape_name, np.array(new_kernel_shape, dtype=np.int64))
input_name = node.input_tensor_names[1]
reshape = ctx.MakeNode("Reshape", [input_name, shape_name])
......@@ -139,7 +139,7 @@ def _ConvConvertInputs(
for idx in output_indices:
output_name = node.output_tensor_names[idx]
output_shape = ctx.get_shape(node.output_tensor_names[idx])
op_name = oneflow.util.unique_str(node.name)
op_name = oneflow._oneflow_internal.UniqueStr(node.name)
transpose = ctx.InsertNewNodeOnOutput(
"Transpose", output_name, name=op_name
)
......@@ -263,8 +263,8 @@ class ConvOp:
cls.Version_1(ctx, node, **kwargs)
@flow_op(["avg_pool_2d"], onnx_op="AveragePool")
@flow_op(["max_pool_2d"], onnx_op="MaxPool")
@flow_op(["avgpool_2d"], onnx_op="AveragePool")
@flow_op(["maxpool_2d"], onnx_op="MaxPool")
class PoolOp:
@classmethod
def Version_1(cls, ctx, node, **kwargs):
......@@ -306,6 +306,7 @@ class PoolOp:
_ConvConvertInputs(ctx, node, with_kernel=False)
@flow_op(["pad"], onnx_op="Pad")
class Pad:
@classmethod
......@@ -328,7 +329,7 @@ class Pad:
padding_before = node.attrs["padding_before"]
padding_after = node.attrs["padding_after"]
paddings = np.array(padding_before + padding_after).astype(np.int64)
padding_node = ctx.MakeConst(oneflow.util.unique_str("const"), paddings)
padding_node = ctx.MakeConst(oneflow._oneflow_internal.UniqueStr("const"), paddings)
node.input_tensor_names.append(padding_node.output_tensor_names[0])
dtype = ctx.get_dtype(node.input_tensor_names[0])
const_val = (
......@@ -337,7 +338,7 @@ class Pad:
else node.attrs["floating_constant_value"]
)
const_val = np.array(const_val).astype(util.Onnx2NumpyDtype(dtype))
const_val_node = ctx.MakeConst(oneflow.util.unique_str("const"), const_val)
const_val_node = ctx.MakeConst(oneflow._oneflow_internal.UniqueStr("const"), const_val)
node.input_tensor_names.append(const_val_node.output_tensor_names[0])
......@@ -388,7 +389,7 @@ class BatchNorm:
),
dtype=val_type,
)
new_mean_node_name = oneflow.util.unique_str(node.name)
new_mean_node_name = oneflow._oneflow_internal.UniqueStr(node.name)
ctx.MakeConst(new_mean_node_name, new_mean_value)
node.input_tensor_names[3] = new_mean_node_name
......@@ -399,7 +400,7 @@ class BatchNorm:
),
dtype=val_type,
)
new_val_node_name = oneflow.util.unique_str(node.name)
new_val_node_name = oneflow._oneflow_internal.UniqueStr(node.name)
ctx.MakeConst(new_val_node_name, new_var_value)
node.input_tensor_names[4] = new_val_node_name
......
......@@ -163,7 +163,7 @@ class FakeQuantization:
dequant_node = ctx.InsertNewNodeOnOutput(
"DequantizeLinear",
node.output_tensor_names[0],
name=oneflow.util.unique_str(node.name),
name=oneflow._oneflow_internal.UniqueStr(node.name),
)
if opset < 13:
scale_shape = ctx.get_shape(node.input_tensor_names[1])
......
......@@ -80,7 +80,7 @@ class ArgMax:
if ctx.get_dtype(node.output_tensor_names[0]) == onnx_pb.TensorProto.INT32:
# current node will return int64 after conversion, which differs from previous dtype got from oneflow
ctx.set_dtype(node.output_tensor_names[0], onnx_pb.TensorProto.INT64)
op_name = oneflow.util.unique_str("Cast")
op_name = oneflow._oneflow_internal.UniqueStr("Cast")
cast_node = ctx.InsertNewNodeOnOutput(
"Cast",
node.output_tensor_names[0],
......
......@@ -111,7 +111,7 @@ class ConstFoldOptimizer(GraphOptimizerBase):
"length of node outputs and const vals should be same",
)
for old_input, val in zip(node.output_tensor_names, vals):
const_node = graph.MakeConst(oneflow.util.unique_str("const_fold_opt"), val)
const_node = graph.MakeConst(oneflow._oneflow_internal.UniqueStr("const_fold_opt"), val)
graph.set_dtype(
const_node.output_tensor_names[0], util.Numpy2OnnxDtype(val.dtype)
)
......
......@@ -106,7 +106,7 @@ class LoopOptimizer(GraphOptimizerBase):
new_perm = [0] + [
i + 1 for i in ori_perm
] # body output's rank is m > rank of loop's output is m+1
name = oneflow.util.unique_str("trans_moved_from_loop_body")
name = oneflow._oneflow_internal.UniqueStr("trans_moved_from_loop_body")
_ = parent_graph.InsertNewNodeOnOutput(
"Transpose", name_in_parent, name, perm=new_perm
)
......
......@@ -120,7 +120,7 @@ class TransposeOptimizer(GraphOptimizerBase):
input_shape[1],
]
return graph.MakeConst(
oneflow.util.unique_str("new_shape"), np.array(new_shape, dtype=np.int64)
oneflow._oneflow_internal.UniqueStr("new_shape"), np.array(new_shape, dtype=np.int64)
).output_tensor_names[0]
# reshape requires tha output shape can only contain one -1, if not some extra op needed.
......@@ -129,11 +129,11 @@ class TransposeOptimizer(GraphOptimizerBase):
).output_tensor_names[0]
if IsNchwTranspose(op):
indice = graph.MakeConst(
oneflow.util.unique_str("indice"), np.array(NHWC_TO_NCHW)
oneflow._oneflow_internal.UniqueStr("indice"), np.array(NHWC_TO_NCHW)
).output_tensor_names[0]
else:
indice = graph.MakeConst(
oneflow.util.unique_str("indice"), np.array(NCHW_TO_NHWC)
oneflow._oneflow_internal.UniqueStr("indice"), np.array(NCHW_TO_NHWC)
).output_tensor_names[0]
return graph.MakeNode("Gather", [input_shape, indice]).output_tensor_names[
......@@ -449,7 +449,7 @@ class TransposeOptimizer(GraphOptimizerBase):
# for example shape of n is [x, y], then output shape of reshape will be [1, 1, x, y]
if shape is None:
const_4 = self._g.MakeConst(
oneflow.util.unique_str("const_4"), np.array([4], np.int64)
oneflow._oneflow_internal.UniqueStr("const_4"), np.array([4], np.int64)
).output_tensor_names[0]
tensor_1 = onnx.helper.make_tensor(
"value", onnx.TensorProto.INT64, [1], [1]
......@@ -480,7 +480,7 @@ class TransposeOptimizer(GraphOptimizerBase):
if shape_4d is None:
return False
const = self._g.MakeConst(
oneflow.util.unique_str("reshape_shape"), np.array(shape_4d, np.int64)
oneflow._oneflow_internal.UniqueStr("reshape_shape"), np.array(shape_4d, np.int64)
).output_tensor_names[0]
reshape = self._g.MakeNode(
"Reshape", [input_id, const]
......@@ -542,7 +542,7 @@ class TransposeOptimizer(GraphOptimizerBase):
]
conv_node = self._g.MakeNode(t_p.op_type, conv_inputs, attr=t_p.attrs)
ops = self._g.get_nodes()
trans.input_tensor_names[0] = oneflow.util.unique_str(conv_node.name)
trans.input_tensor_names[0] = oneflow._oneflow_internal.UniqueStr(conv_node.name)
self._g.ReplaceAllInputs(
ops, node.output_tensor_names[0], trans.output_tensor_names[0]
)
......@@ -771,7 +771,7 @@ class TransposeOptimizer(GraphOptimizerBase):
node.input_nodes[3].set_tensor_value(new_axes)
else:
new_axes_const = self._g.MakeConst(
oneflow.util.unique_str(node.input_nodes[3].name), new_axes
oneflow._oneflow_internal.UniqueStr(node.input_nodes[3].name), new_axes
)
self._g.ReplaceAllInputs(
node,
......@@ -795,7 +795,7 @@ class TransposeOptimizer(GraphOptimizerBase):
self._g.RemoveNode(node.name)
shape_node = self._g.MakeNode("Shape", [trans.input_tensor_names[0]])
const_node = self._g.MakeConst(
oneflow.util.unique_str("Const"), np.array(trans.attrs["perm"])
oneflow._oneflow_internal.UniqueStr("Const"), np.array(trans.attrs["perm"])
)
gather_node = self._g.MakeNode(
"Gather",
......
......@@ -59,43 +59,25 @@ def run_onnx(
def export_onnx_model(
job_func,
graph,
external_data=False,
opset=None,
flow_weight_dir=None,
onnx_model_path="/tmp",
dynamic_batch_size=False,
):
if flow_weight_dir == None:
flow_weight_dir = tempfile.TemporaryDirectory()
flow.checkpoint.save(flow_weight_dir.name)
# TODO(daquexian): a more elegant way?
while not os.path.exists(os.path.join(flow_weight_dir.name, "snapshot_done")):
pass
onnx_model_dir = onnx_model_path
onnx_model_path = os.path.join(onnx_model_dir, "model.onnx")
Export(
job_func,
flow_weight_dir.name,
onnx_model_path,
opset=opset,
external_data=external_data,
dynamic_batch_size=dynamic_batch_size,
)
flow_weight_dir.cleanup()
else:
while not os.path.exists(os.path.join(flow_weight_dir, "snapshot_done")):
pass
onnx_model_dir = onnx_model_path
onnx_model_path = os.path.join(onnx_model_dir, "model.onnx")
Export(
job_func,
flow_weight_dir,
onnx_model_path,
opset=opset,
external_data=external_data,
dynamic_batch_size=dynamic_batch_size,
)
while not os.path.exists(os.path.join(flow_weight_dir, "snapshot_done")):
pass
onnx_model_dir = onnx_model_path
onnx_model_path = os.path.join(onnx_model_dir, "model.onnx")
Export(
graph,
flow_weight_dir,
onnx_model_path,
opset=opset,
external_data=external_data,
dynamic_batch_size=dynamic_batch_size,
)
def cleanup():
if os.path.exists(onnx_model_path):
......@@ -121,7 +103,7 @@ def compare_result(
def convert_to_onnx_and_check(
job_func,
graph,
print_outlier=False,
explicit_init=False,
external_data=False,
......@@ -131,21 +113,8 @@ def convert_to_onnx_and_check(
onnx_model_path="/tmp",
dynamic_batch_size=False,
):
if explicit_init:
# it is a trick to keep check_point.save() from hanging when there is no variable
@flow.global_function()
def add_var():
return flow.get_variable(
name="trick",
shape=(1,),
dtype=flow.float,
initializer=flow.random_uniform_initializer(),
)
flow.train.CheckPoint().init()
onnx_model_path, cleanup = export_onnx_model(
job_func, external_data, opset, flow_weight_dir, onnx_model_path, dynamic_batch_size
graph, external_data, opset, flow_weight_dir, onnx_model_path, dynamic_batch_size
)
......@@ -153,7 +122,7 @@ def convert_to_onnx_and_check(
ipt_dict, onnx_res = run_onnx(
onnx_model_path, ["CPUExecutionProvider"], ort_optimize=ort_optimize
)
oneflow_res = job_func(*ipt_dict.values())
oneflow_res = graph(*ipt_dict.values())
if not isinstance(oneflow_res, np.ndarray):
oneflow_res = oneflow_res.get().numpy()
......
......@@ -427,7 +427,7 @@ class Graph(object):
# add identity node after each output, in case it is renamed during conversion.
for o in self.outputs:
n = self.get_node_by_output_in_current_graph(o)
new_output_name = oneflow.util.unique_str(n.name + "_raw_output")
new_output_name = oneflow._oneflow_internal.UniqueStr(n.name + "_raw_output")
n_shapes = n.output_shapes
n_dtypes = n.output_dtypes
body_graphs = n.graph.contained_graphs.pop(n.name, None)
......@@ -530,7 +530,7 @@ class Graph(object):
dtypes = []
if name is None:
name = oneflow.util.unique_str(op_type)
name = oneflow._oneflow_internal.UniqueStr(op_type)
if op_name_scope:
name = "_".join([op_name_scope, name])
......@@ -859,7 +859,7 @@ class Graph(object):
tensor_name = node.output_tensor_names[0]
# TODO(daquexian): node.output_tensor_names[0] is "node_name/output_name", so this pathjoin doesn't work
# on windows (where path separator is "\")
path = pathjoin(self._model_save_dir, node.output_tensor_names[0])
path = pathjoin(self._model_save_dir, node.output_tensor_names[0][2:])
tensor_value = np.fromfile(
path, dtype=util.Onnx2NumpyDtype(self.get_dtype(tensor_name))
).reshape(self.get_shape(tensor_name))
......@@ -1112,7 +1112,7 @@ class Graph(object):
space,
node.op_type,
node.name,
self.get_shape(oneflow.util.unique_str(node.name)),
self.get_shape(oneflow._oneflow_internal.UniqueStr(node.name)),
)
)
space += " "
......@@ -1167,8 +1167,8 @@ class Graph(object):
node that was inserted
"""
if name is None:
name = oneflow.util.unique_str(node.name)
new_output = oneflow.util.unique_str(name)
name = oneflow._oneflow_internal.UniqueStr(node.name)
new_output = oneflow._oneflow_internal.UniqueStr(name)
if not isinstance(input_name, list):
input_name = [input_name]
......@@ -1208,7 +1208,7 @@ class Graph(object):
type(op_type),
)
new_output = oneflow.util.unique_str(name)
new_output = oneflow._oneflow_internal.UniqueStr(name)
new_node = self.MakeNode(
op_type,
[output_name],
......
......@@ -136,7 +136,7 @@ def MakeOnnxShape(shape):
"""shape with -1 is not valid in onnx ... make it a name."""
if shape:
# don't do this if input is a scalar
return [oneflow.util.unique_str("unk") if i == -1 else i for i in shape]
return [oneflow._oneflow_internal.UniqueStr("unk") if i == -1 else i for i in shape]
return shape
......@@ -217,7 +217,7 @@ def TensorProtoFromNumpy(
arr: np.ndarray, name=None, external_data=False, export_path=None
):
if name is None:
name = oneflow.util.unique_str("tensor_")
name = oneflow._oneflow_internal.UniqueStr("tensor_")
tp = numpy_helper.from_array(arr, name)
# value with size < 1024 bytes will remain in .onnx file
# (like what pytorch does)
......
......@@ -24,7 +24,7 @@ long_description += "Email: zhangxiaoyu@oneflow.org"
setuptools.setup(
name="oneflow_onnx",
version="0.3.4",
version="0.5.0.rc",
author="zhangxiaoyu",
author_email="zhangxiaoyu@oneflow.org",
description="a toolkit for converting trained model of OneFlow to ONNX and ONNX to OneFlow.",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册