From cf66d9530cca1b8582a1a170bb569317c817be2a Mon Sep 17 00:00:00 2001 From: ShenYuhan Date: Wed, 23 Dec 2020 15:02:15 +0800 Subject: [PATCH] support float32 and double64 of ndarray --- visualdl/component/base_component.py | 37 +++++++++++++++++++--------- 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/visualdl/component/base_component.py b/visualdl/component/base_component.py index 7cbfec22..594e43fd 100644 --- a/visualdl/component/base_component.py +++ b/visualdl/component/base_component.py @@ -60,7 +60,7 @@ def imgarray2bytes(np_array): """Convert image ndarray to bytes. Args: - np_array (numpy.ndarray): Array to converte. + np_array (np.ndarray): Array to converte. Returns: Binary bytes of np_array. @@ -106,7 +106,7 @@ def convert_to_HWC(tensor, input_format): """Convert `NCHW`, `HWC`, `HW` to `HWC` Args: - tensor (numpy.ndarray): Value of image + tensor (np.ndarray): Value of image input_format (string): Format of image Return: @@ -138,12 +138,26 @@ def convert_to_HWC(tensor, input_format): return tensor +def denormalization(image_array): + """Renormalise ndarray matrix. + + Args: + image_array(np.ndarray): Value of image + + Return: + Matrix after renormalising. + """ + if image_array.max() <= 1 and image_array.min() >= 0: + image_array *= 255 + return image_array.astype(np.uint8) + + def image(tag, image_array, step, walltime=None, dataformats="HWC"): """Package data to one image. Args: tag (string): Data identifier - image_array (numpy.ndarray): Value of iamge + image_array (np.ndarray): Value of image step (int): Step of image walltime (int): Wall time of image dataformats (string): Format of image @@ -151,6 +165,7 @@ def image(tag, image_array, step, walltime=None, dataformats="HWC"): Return: Package with format of record_pb2.Record """ + image_array = denormalization(image_array) image_array = convert_to_HWC(image_array, dataformats) image_bytes = imgarray2bytes(image_array) image = Record.Image(encoded_image_string=image_bytes) @@ -165,7 +180,7 @@ def embedding(tag, labels, hot_vectors, step, labels_meta=None, walltime=None): Args: tag (string): Data identifier labels (list): A list of labels. - hot_vectors (numpy.array or list): A matrix which each row is + hot_vectors (np.array or list): A matrix which each row is feature of labels. step (int): Step of embeddings. walltime (int): Wall time of embeddings. @@ -199,7 +214,7 @@ def audio(tag, audio_array, sample_rate, step, walltime): Args: tag (string): Data identifier - audio_array (numpy.ndarray or list): audio represented by a numpy.array + audio_array (np.ndarray or list): audio represented by a np.array sample_rate (int): Sample rate of audio step (int): Step of audio walltime (int): Wall time of audio @@ -246,8 +261,8 @@ def histogram(tag, hist, bin_edges, step, walltime): Args: tag (string): Data identifier - hist (numpy.ndarray or list): The values of the histogram - bin_edges (numpy.ndarray or list): The bin edges + hist (np.ndarray or list): The values of the histogram + bin_edges (np.ndarray or list): The bin edges step (int): Step of histogram walltime (int): Wall time of histogram @@ -265,8 +280,8 @@ def compute_curve(labels, predictions, num_thresholds=None, weights=None): """ Compute precision-recall curve data by labels and predictions. Args: - labels (numpy.ndarray or list): Binary labels for each element. - predictions (numpy.ndarray or list): The probability that an element be + labels (np.ndarray or list): Binary labels for each element. + predictions (np.ndarray or list): The probability that an element be classified as true. num_thresholds (int): Number of thresholds used to draw the curve. weights (float): Multiple of data to display on the curve. @@ -318,8 +333,8 @@ def pr_curve(tag, labels, predictions, step, walltime, num_thresholds=127, Args: tag (string): Data identifier - labels (numpy.ndarray or list): Binary labels for each element. - predictions (numpy.ndarray or list): The probability that an element be + labels (np.ndarray or list): Binary labels for each element. + predictions (np.ndarray or list): The probability that an element be classified as true. step (int): Step of pr_curve walltime (int): Wall time of pr_curve -- GitLab