提交 8d825246 编写于 作者: M Megvii Engine Team

fix(sdk/load_and_run): remove docs of dump_with_testcase_with_mge.py

GitOrigin-RevId: 4a1138cb55238a44cc988d4e1976bb488695ee55
上级 38b7cfde
...@@ -30,7 +30,7 @@ def main(): ...@@ -30,7 +30,7 @@ def main():
parser.add_argument("--input", help="mace model file") parser.add_argument("--input", help="mace model file")
parser.add_argument("--param", help="mace param file") parser.add_argument("--param", help="mace param file")
parser.add_argument( parser.add_argument(
"--output", help="converted model that can be fed to dump_with_testcase_mge.py" "--output", help="converted mge model"
) )
parser.add_argument("--config", help="config file with yaml format") parser.add_argument("--config", help="config file with yaml format")
args = parser.parse_args() args = parser.parse_args()
......
...@@ -75,8 +75,8 @@ def main(): ...@@ -75,8 +75,8 @@ def main():
for step, minibatch in enumerate(train_dataset): for step, minibatch in enumerate(train_dataset):
if step > 1000: if step > 1000:
break break
data = minibatch["data"] data = mge.tensor(minibatch["data"])
label = minibatch["label"] label = mge.tensor(minibatch["label"])
net.train() net.train()
_, loss = train_fun(data, label) _, loss = train_fun(data, label)
train_loss.append((step, loss.numpy())) train_loss.append((step, loss.numpy()))
...@@ -128,6 +128,11 @@ def main(): ...@@ -128,6 +128,11 @@ def main():
print("Dump model as {}".format(model_name)) print("Dump model as {}".format(model_name))
pred_fun.dump(model_name, arg_names=["data"]) pred_fun.dump(model_name, arg_names=["data"])
model_with_testcase_name = "xornet_with_testcase.mge"
print("Dump model with testcase as {}".format(model_with_testcase_name))
pred_fun.dump(model_with_testcase_name, arg_names=["data"], input_data=["#rand(0.1, 0.8, 4, 2)"])
if __name__ == "__main__": if __name__ == "__main__":
main() main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册