123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239 |
- import math
- import os
- import hashlib
- from urllib.request import urlretrieve
- import zipfile
- import gzip
- import shutil
- import numpy as np
- from PIL import Image
- from tqdm import tqdm
- def _read32(bytestream):
- """
- Read 32-bit integer from bytesteam
- :param bytestream: A bytestream
- :return: 32-bit integer
- """
- dt = np.dtype(np.uint32).newbyteorder('>')
- return np.frombuffer(bytestream.read(4), dtype=dt)[0]
- def _unzip(save_path, _, database_name, data_path):
- """
- Unzip wrapper with the same interface as _ungzip
- :param save_path: The path of the gzip files
- :param database_name: Name of database
- :param data_path: Path to extract to
- :param _: HACK - Used to have to same interface as _ungzip
- """
- print('Extracting {}...'.format(database_name))
- with zipfile.ZipFile(save_path) as zf:
- zf.extractall(data_path)
- def _ungzip(save_path, extract_path, database_name, _):
- """
- Unzip a gzip file and extract it to extract_path
- :param save_path: The path of the gzip files
- :param extract_path: The location to extract the data to
- :param database_name: Name of database
- :param _: HACK - Used to have to same interface as _unzip
- """
- # Get data from save_path
- with open(save_path, 'rb') as f:
- with gzip.GzipFile(fileobj=f) as bytestream:
- magic = _read32(bytestream)
- if magic != 2051:
- raise ValueError('Invalid magic number {} in file: {}'.format(magic, f.name))
- num_images = _read32(bytestream)
- rows = _read32(bytestream)
- cols = _read32(bytestream)
- buf = bytestream.read(rows * cols * num_images)
- data = np.frombuffer(buf, dtype=np.uint8)
- data = data.reshape(num_images, rows, cols)
- # Save data to extract_path
- for image_i, image in enumerate(
- tqdm(data, unit='File', unit_scale=True, miniters=1, desc='Extracting {}'.format(database_name))):
- Image.fromarray(image, 'L').save(os.path.join(extract_path, 'image_{}.jpg'.format(image_i)))
- def get_image(image_path, width, height, mode):
- """
- Read image from image_path
- :param image_path: Path of image
- :param width: Width of image
- :param height: Height of image
- :param mode: Mode of image
- :return: Image data
- """
- image = Image.open(image_path)
- if image.size != (width, height): # HACK - Check if image is from the CELEBA dataset
- # Remove most pixels that aren't part of a face
- face_width = face_height = 108
- j = (image.size[0] - face_width) // 2
- i = (image.size[1] - face_height) // 2
- image = image.crop([j, i, j + face_width, i + face_height])
- image = image.resize([width, height], Image.BILINEAR)
- return np.array(image.convert(mode))
- def get_batch(image_files, width, height, mode):
- data_batch = np.array(
- [get_image(sample_file, width, height, mode) for sample_file in image_files]).astype(np.float32)
- # Make sure the images are in 4 dimensions
- if len(data_batch.shape) < 4:
- data_batch = data_batch.reshape(data_batch.shape + (1,))
- return data_batch
- def images_square_grid(images, mode):
- """
- Save images as a square grid
- :param images: Images to be used for the grid
- :param mode: The mode to use for images
- :return: Image of images in a square grid
- """
- # Get maximum size for square grid of images
- save_size = math.floor(np.sqrt(images.shape[0]))
- # Scale to 0-255
- images = (((images - images.min()) * 255) / (images.max() - images.min())).astype(np.uint8)
- # Put images in a square arrangement
- images_in_square = np.reshape(
- images[:save_size*save_size],
- (save_size, save_size, images.shape[1], images.shape[2], images.shape[3]))
- if mode == 'L':
- images_in_square = np.squeeze(images_in_square, 4)
- # Combine images to grid image
- new_im = Image.new(mode, (images.shape[1] * save_size, images.shape[2] * save_size))
- for col_i, col_images in enumerate(images_in_square):
- for image_i, image in enumerate(col_images):
- im = Image.fromarray(image, mode)
- new_im.paste(im, (col_i * images.shape[1], image_i * images.shape[2]))
- return new_im
- def download_extract(database_name, data_path):
- """
- Download and extract database
- :param database_name: Database name
- """
- if database_name == DATASET_CELEBA_NAME:
- url = 'https://s3-us-west-1.amazonaws.com/udacity-dlnfd/datasets/celeba.zip'
- hash_code = '00d2c5bc6d35e252742224ab0c1e8fcb'
- extract_path = os.path.join(data_path, 'img_align_celeba')
- save_path = os.path.join(data_path, 'celeba.zip')
- extract_fn = _unzip
- elif database_name == DATASET_MNIST_NAME:
- url = 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz'
- hash_code = 'f68b3c2dcbeaaa9fbdd348bbdeb94873'
- extract_path = os.path.join(data_path, 'mnist')
- save_path = os.path.join(data_path, 'train-images-idx3-ubyte.gz')
- extract_fn = _ungzip
- if os.path.exists(extract_path):
- print('Found {} Data'.format(database_name))
- return
- if not os.path.exists(data_path):
- os.makedirs(data_path)
- if not os.path.exists(save_path):
- with DLProgress(unit='B', unit_scale=True, miniters=1, desc='Downloading {}'.format(database_name)) as pbar:
- urlretrieve(
- url,
- save_path,
- pbar.hook)
- assert hashlib.md5(open(save_path, 'rb').read()).hexdigest() == hash_code, \
- '{} file is corrupted. Remove the file and try again.'.format(save_path)
- os.makedirs(extract_path)
- try:
- extract_fn(save_path, extract_path, database_name, data_path)
- except Exception as err:
- shutil.rmtree(extract_path) # Remove extraction folder if there is an error
- raise err
- # Remove compressed data
- os.remove(save_path)
- class Dataset(object):
- """
- Dataset
- """
- def __init__(self, dataset_name, data_files):
- """
- Initalize the class
- :param dataset_name: Database name
- :param data_files: List of files in the database
- """
- if dataset_name == DATASET_CELEBA_NAME:
- self.image_mode = 'RGB'
- image_channels = 3
- elif dataset_name == DATASET_MNIST_NAME:
- self.image_mode = 'L'
- image_channels = 1
- self.data_files = data_files
- self.shape = len(data_files), IMAGE_WIDTH, IMAGE_HEIGHT, image_channels
- def get_batches(self, batch_size):
- """
- Generate batches
- :param batch_size: Batch Size
- :return: Batches of data
- """
- current_index = 0
- while current_index + batch_size <= self.shape[0]:
- data_batch = get_batch(
- self.data_files[current_index:current_index + batch_size],
- *self.shape[1:3],
- self.image_mode)
- current_index += batch_size
- yield data_batch / IMAGE_MAX_VALUE - 0.5
- class DLProgress(tqdm):
- """
- Handle Progress Bar while Downloading
- """
- last_block = 0
- def hook(self, block_num=1, block_size=1, total_size=None):
- """
- A hook function that will be called once on establishment of the network connection and
- once after each block read thereafter.
- :param block_num: A count of blocks transferred so far
- :param block_size: Block size in bytes
- :param total_size: The total size of the file. This may be -1 on older FTP servers which do not return
- a file size in response to a retrieval request.
- """
- self.total = total_size
- self.update((block_num - self.last_block) * block_size)
- self.last_block = block_num