diff --git a/model.py b/model.py index e6faeb762cccd6e3fc56b99e265d86fc77691690..6fecbf1d29fa3c37ad3073fae0fcdcd819b52937 100644 --- a/model.py +++ b/model.py @@ -42,6 +42,14 @@ __all__ = ['Model', 'Loss', 'CrossEntropy', 'Input', 'set_device'] def set_device(device): + """ + Args: + device (str): specify device type, 'cpu' or 'gpu'. + + Returns: + fluid.CUDAPlace or fluid.CPUPlace: Created GPU or CPU place. + """ + assert isinstance(device, six.string_types) and device.lower() in ['cpu', 'gpu'], \ "Expected device in ['cpu', 'gpu'], but got {}".format(device) @@ -1082,7 +1090,11 @@ class Model(fluid.dygraph.Layer): return eval_result - def predict(self, test_data, batch_size=1, num_workers=0, stack_outputs=True): + def predict(self, + test_data, + batch_size=1, + num_workers=0, + stack_outputs=True): """ FIXME: add more comments and usage Args: