提交 a09a8126 编写于 作者: F Flowingsun007

modify functions with new api >> global_function

上级 8b9075ab
......@@ -7,6 +7,7 @@ import numpy as np
from PIL import Image
import config as configs
parser = configs.get_parser()
args = parser.parse_args()
configs.print_args(args)
......@@ -38,9 +39,9 @@ def load_image(image_path='test_img/ILSVRC2012_val_00020287.JPEG'):
return np.ascontiguousarray(im, 'float32')
@flow.global_function(flow.function_config())
def InferenceNet(images:tp.Numpy.Placeholder((1, 3, 224, 224), dtype=flow.float)):
logits = model_dict[args.model](images,training=False)
@flow.global_function("predict", flow.function_config())
def InferenceNet(images: tp.Numpy.Placeholder((1, 3, 224, 224), dtype=flow.float)) -> tp.Numpy:
logits = model_dict[args.model](images, training=False)
predictions = flow.nn.softmax(logits)
return predictions
......@@ -52,9 +53,9 @@ def main():
check_point.load(args.model_load_dir)
image = load_image(args.image_path)
predictions = InferenceNet(image).get()
clsidx = predictions.numpy().argmax()
print(predictions.numpy().max(), clsidx_2_labels[clsidx])
predictions = InferenceNet(image)
clsidx = predictions.argmax()
print(predictions.max(), clsidx_2_labels[clsidx])
if __name__ == "__main__":
......
......@@ -43,7 +43,7 @@ def label_smoothing(labels, classes, eta, dtype):
on_value=1 - eta + eta / classes, off_value=eta/classes)
@flow.global_function(get_train_config(args))
@flow.global_function("train", get_train_config(args))
def TrainNet():
if args.train_data_dir:
assert os.path.exists(args.train_data_dir)
......@@ -68,7 +68,7 @@ def TrainNet():
return outputs
@flow.global_function(get_val_config(args))
@flow.global_function("predict", get_val_config(args))
def InferenceNet():
if args.val_data_dir:
assert os.path.exists(args.val_data_dir)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册