提交 d16aeef6 编写于 作者: A Alexander Alekhin

Merge pull request #986 from zihaomu:bug_fix_22195_3_4

......@@ -90,6 +90,23 @@ def save_data_and_onnx_model(name, input_np, output_np, onnx_model):
with open(models_files, 'wb') as file:
file.write(model_def.SerializeToString())
def save_data_and_onnx_model_multy_inputs(name, input_list, output_np, onnx_model):
for index in range(len(input_list)):
print(name + " input "+str(index)+" has sizes", input_list[index].shape)
input_files = os.path.join("data", "input_" + name + "_" + str(index))
np.save(input_files, input_list[index])
print(name + " output has sizes", output_np.shape)
print()
output_files = os.path.join("data", "output_" + name)
np.save(output_files, np.ascontiguousarray(output_np.data))
models_files = os.path.join("models", name + ".onnx")
onnx_model_pb = onnx._serialize(onnx_model)
model_def = assertONNXExpected(onnx_model_pb)
with open(models_files, 'wb') as file:
file.write(model_def.SerializeToString())
def simplify(name, rename=False, **kwargs):
model, check = onnxsim.simplify(name, **kwargs)
......@@ -1725,4 +1742,23 @@ graph2 = onnx.helper.make_graph(nodes2,
outputs, initializer=[weight_tensor])
gemm_model2 = onnx.helper.make_model(graph2)
output_np = gemm_reference_implementation(input_np, weight_np)
save_data_and_onnx_model("gemm_transB_0", input_np, output_np, gemm_model2)
\ No newline at end of file
save_data_and_onnx_model("gemm_transB_0", input_np, output_np, gemm_model2)
# ########################## DivBroadcast ##########################
input_np = np.random.rand(1, 4).astype("float32")
input2_np = np.random.rand(1, 1).astype(np.float32)
inputs = [onnx.helper.make_tensor_value_info("input1", onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[input_np.dtype], shape=input_np.shape), \
onnx.helper.make_tensor_value_info("input2", onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[input2_np.dtype], shape=input2_np.shape)]
outputs = [onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape=(1, 4))]
nodes = [onnx.helper.make_node("Div", ["input1", "input2"], ["output"])]
graph = onnx.helper.make_graph(nodes,
"div_test",
inputs,
outputs)
onnx_model = onnx.helper.make_model(graph)
output_np = input_np/input2_np
save_data_and_onnx_model_multy_inputs("div_test_1x1", [input_np, input2_np], output_np, onnx_model)
\ No newline at end of file
:w

input1
input2output"Divdiv_testZ
input1


Z
input2


b
output


B
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册