diff --git a/labml_nn/gan/cycle_gan.py b/labml_nn/gan/cycle_gan.py index dff94354eb9495d86598c7764a9c0b84ddd0de5f..ac6ac51a11bd4d5c5c4e6931b1541790a9e3a382 100644 --- a/labml_nn/gan/cycle_gan.py +++ b/labml_nn/gan/cycle_gan.py @@ -652,7 +652,7 @@ def sample(): conf_dict = experiment.load_configs(trained_run_uuid) # Calculate configurations. We specify the generators `'generator_xy', 'generator_yx'` # so that it only loads those and their dependencies. - # Configs like `device`, `img_height` and `img_width` will be calculated since these are required by + # Configs like `device` and `img_channels` will be calculated since these are required by # `generator_xy` and `generator_yx`. # # If you want other parameters like `dataset_name` you should specify them here. @@ -684,7 +684,7 @@ def sample(): # Load dataset dataset = ImageDataset(images_path, transforms_, True, 'train') # Get an images from dataset - x_image = dataset[0]['x'] + x_image = dataset[10]['x'] # Display the image plot_image(x_image) @@ -698,10 +698,6 @@ def sample(): data = x_image.unsqueeze(0).to(conf.device) generated_y = conf.generator_xy(data) - # Normalize the image - mm_range = generated_y.min(), generated_y.max() - generated_y = (generated_y - mm_range[0]) / (mm_range[1] - mm_range[0] + 1e-5) - # Display the generated image. plot_image(generated_y[0].cpu())