未验证 提交 cf66d953 编写于 作者: 走神的阿圆's avatar 走神的阿圆 提交者: GitHub

support float32 and double64 of ndarray

上级 e5150df6
......@@ -60,7 +60,7 @@ def imgarray2bytes(np_array):
"""Convert image ndarray to bytes.
Args:
np_array (numpy.ndarray): Array to converte.
np_array (np.ndarray): Array to converte.
Returns:
Binary bytes of np_array.
......@@ -106,7 +106,7 @@ def convert_to_HWC(tensor, input_format):
"""Convert `NCHW`, `HWC`, `HW` to `HWC`
Args:
tensor (numpy.ndarray): Value of image
tensor (np.ndarray): Value of image
input_format (string): Format of image
Return:
......@@ -138,12 +138,26 @@ def convert_to_HWC(tensor, input_format):
return tensor
def denormalization(image_array):
"""Renormalise ndarray matrix.
Args:
image_array(np.ndarray): Value of image
Return:
Matrix after renormalising.
"""
if image_array.max() <= 1 and image_array.min() >= 0:
image_array *= 255
return image_array.astype(np.uint8)
def image(tag, image_array, step, walltime=None, dataformats="HWC"):
"""Package data to one image.
Args:
tag (string): Data identifier
image_array (numpy.ndarray): Value of iamge
image_array (np.ndarray): Value of image
step (int): Step of image
walltime (int): Wall time of image
dataformats (string): Format of image
......@@ -151,6 +165,7 @@ def image(tag, image_array, step, walltime=None, dataformats="HWC"):
Return:
Package with format of record_pb2.Record
"""
image_array = denormalization(image_array)
image_array = convert_to_HWC(image_array, dataformats)
image_bytes = imgarray2bytes(image_array)
image = Record.Image(encoded_image_string=image_bytes)
......@@ -165,7 +180,7 @@ def embedding(tag, labels, hot_vectors, step, labels_meta=None, walltime=None):
Args:
tag (string): Data identifier
labels (list): A list of labels.
hot_vectors (numpy.array or list): A matrix which each row is
hot_vectors (np.array or list): A matrix which each row is
feature of labels.
step (int): Step of embeddings.
walltime (int): Wall time of embeddings.
......@@ -199,7 +214,7 @@ def audio(tag, audio_array, sample_rate, step, walltime):
Args:
tag (string): Data identifier
audio_array (numpy.ndarray or list): audio represented by a numpy.array
audio_array (np.ndarray or list): audio represented by a np.array
sample_rate (int): Sample rate of audio
step (int): Step of audio
walltime (int): Wall time of audio
......@@ -246,8 +261,8 @@ def histogram(tag, hist, bin_edges, step, walltime):
Args:
tag (string): Data identifier
hist (numpy.ndarray or list): The values of the histogram
bin_edges (numpy.ndarray or list): The bin edges
hist (np.ndarray or list): The values of the histogram
bin_edges (np.ndarray or list): The bin edges
step (int): Step of histogram
walltime (int): Wall time of histogram
......@@ -265,8 +280,8 @@ def compute_curve(labels, predictions, num_thresholds=None, weights=None):
""" Compute precision-recall curve data by labels and predictions.
Args:
labels (numpy.ndarray or list): Binary labels for each element.
predictions (numpy.ndarray or list): The probability that an element be
labels (np.ndarray or list): Binary labels for each element.
predictions (np.ndarray or list): The probability that an element be
classified as true.
num_thresholds (int): Number of thresholds used to draw the curve.
weights (float): Multiple of data to display on the curve.
......@@ -318,8 +333,8 @@ def pr_curve(tag, labels, predictions, step, walltime, num_thresholds=127,
Args:
tag (string): Data identifier
labels (numpy.ndarray or list): Binary labels for each element.
predictions (numpy.ndarray or list): The probability that an element be
labels (np.ndarray or list): Binary labels for each element.
predictions (np.ndarray or list): The probability that an element be
classified as true.
step (int): Step of pr_curve
walltime (int): Wall time of pr_curve
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册