helper.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. import math
  2. import os
  3. import hashlib
  4. from urllib.request import urlretrieve
  5. import zipfile
  6. import gzip
  7. import shutil
  8. import numpy as np
  9. from PIL import Image
  10. from tqdm import tqdm
  11. def _read32(bytestream):
  12. """
  13. Read 32-bit integer from bytesteam
  14. :param bytestream: A bytestream
  15. :return: 32-bit integer
  16. """
  17. dt = np.dtype(np.uint32).newbyteorder('>')
  18. return np.frombuffer(bytestream.read(4), dtype=dt)[0]
  19. def _unzip(save_path, _, database_name, data_path):
  20. """
  21. Unzip wrapper with the same interface as _ungzip
  22. :param save_path: The path of the gzip files
  23. :param database_name: Name of database
  24. :param data_path: Path to extract to
  25. :param _: HACK - Used to have to same interface as _ungzip
  26. """
  27. print('Extracting {}...'.format(database_name))
  28. with zipfile.ZipFile(save_path) as zf:
  29. zf.extractall(data_path)
  30. def _ungzip(save_path, extract_path, database_name, _):
  31. """
  32. Unzip a gzip file and extract it to extract_path
  33. :param save_path: The path of the gzip files
  34. :param extract_path: The location to extract the data to
  35. :param database_name: Name of database
  36. :param _: HACK - Used to have to same interface as _unzip
  37. """
  38. # Get data from save_path
  39. with open(save_path, 'rb') as f:
  40. with gzip.GzipFile(fileobj=f) as bytestream:
  41. magic = _read32(bytestream)
  42. if magic != 2051:
  43. raise ValueError('Invalid magic number {} in file: {}'.format(magic, f.name))
  44. num_images = _read32(bytestream)
  45. rows = _read32(bytestream)
  46. cols = _read32(bytestream)
  47. buf = bytestream.read(rows * cols * num_images)
  48. data = np.frombuffer(buf, dtype=np.uint8)
  49. data = data.reshape(num_images, rows, cols)
  50. # Save data to extract_path
  51. for image_i, image in enumerate(
  52. tqdm(data, unit='File', unit_scale=True, miniters=1, desc='Extracting {}'.format(database_name))):
  53. Image.fromarray(image, 'L').save(os.path.join(extract_path, 'image_{}.jpg'.format(image_i)))
  54. def get_image(image_path, width, height, mode):
  55. """
  56. Read image from image_path
  57. :param image_path: Path of image
  58. :param width: Width of image
  59. :param height: Height of image
  60. :param mode: Mode of image
  61. :return: Image data
  62. """
  63. image = Image.open(image_path)
  64. if image.size != (width, height): # HACK - Check if image is from the CELEBA dataset
  65. # Remove most pixels that aren't part of a face
  66. face_width = face_height = 108
  67. j = (image.size[0] - face_width) // 2
  68. i = (image.size[1] - face_height) // 2
  69. image = image.crop([j, i, j + face_width, i + face_height])
  70. image = image.resize([width, height], Image.BILINEAR)
  71. return np.array(image.convert(mode))
  72. def get_batch(image_files, width, height, mode):
  73. data_batch = np.array(
  74. [get_image(sample_file, width, height, mode) for sample_file in image_files]).astype(np.float32)
  75. # Make sure the images are in 4 dimensions
  76. if len(data_batch.shape) < 4:
  77. data_batch = data_batch.reshape(data_batch.shape + (1,))
  78. return data_batch
  79. def images_square_grid(images, mode):
  80. """
  81. Save images as a square grid
  82. :param images: Images to be used for the grid
  83. :param mode: The mode to use for images
  84. :return: Image of images in a square grid
  85. """
  86. # Get maximum size for square grid of images
  87. save_size = math.floor(np.sqrt(images.shape[0]))
  88. # Scale to 0-255
  89. images = (((images - images.min()) * 255) / (images.max() - images.min())).astype(np.uint8)
  90. # Put images in a square arrangement
  91. images_in_square = np.reshape(
  92. images[:save_size*save_size],
  93. (save_size, save_size, images.shape[1], images.shape[2], images.shape[3]))
  94. if mode == 'L':
  95. images_in_square = np.squeeze(images_in_square, 4)
  96. # Combine images to grid image
  97. new_im = Image.new(mode, (images.shape[1] * save_size, images.shape[2] * save_size))
  98. for col_i, col_images in enumerate(images_in_square):
  99. for image_i, image in enumerate(col_images):
  100. im = Image.fromarray(image, mode)
  101. new_im.paste(im, (col_i * images.shape[1], image_i * images.shape[2]))
  102. return new_im
  103. def download_extract(database_name, data_path):
  104. """
  105. Download and extract database
  106. :param database_name: Database name
  107. """
  108. DATASET_CELEBA_NAME = 'celeba'
  109. DATASET_MNIST_NAME = 'mnist'
  110. if database_name == DATASET_CELEBA_NAME:
  111. url = 'https://s3-us-west-1.amazonaws.com/udacity-dlnfd/datasets/celeba.zip'
  112. hash_code = '00d2c5bc6d35e252742224ab0c1e8fcb'
  113. extract_path = os.path.join(data_path, 'img_align_celeba')
  114. save_path = os.path.join(data_path, 'celeba.zip')
  115. extract_fn = _unzip
  116. elif database_name == DATASET_MNIST_NAME:
  117. url = 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz'
  118. hash_code = 'f68b3c2dcbeaaa9fbdd348bbdeb94873'
  119. extract_path = os.path.join(data_path, 'mnist')
  120. save_path = os.path.join(data_path, 'train-images-idx3-ubyte.gz')
  121. extract_fn = _ungzip
  122. if os.path.exists(extract_path):
  123. print('Found {} Data'.format(database_name))
  124. return
  125. if not os.path.exists(data_path):
  126. os.makedirs(data_path)
  127. if not os.path.exists(save_path):
  128. with DLProgress(unit='B', unit_scale=True, miniters=1, desc='Downloading {}'.format(database_name)) as pbar:
  129. urlretrieve(
  130. url,
  131. save_path,
  132. pbar.hook)
  133. assert hashlib.md5(open(save_path, 'rb').read()).hexdigest() == hash_code, \
  134. '{} file is corrupted. Remove the file and try again.'.format(save_path)
  135. os.makedirs(extract_path)
  136. try:
  137. extract_fn(save_path, extract_path, database_name, data_path)
  138. except Exception as err:
  139. shutil.rmtree(extract_path) # Remove extraction folder if there is an error
  140. raise err
  141. # Remove compressed data
  142. os.remove(save_path)
  143. class Dataset(object):
  144. """
  145. Dataset
  146. """
  147. def __init__(self, dataset_name, data_files):
  148. """
  149. Initalize the class
  150. :param dataset_name: Database name
  151. :param data_files: List of files in the database
  152. """
  153. DATASET_CELEBA_NAME = 'celeba'
  154. DATASET_MNIST_NAME = 'mnist'
  155. IMAGE_WIDTH = 28
  156. IMAGE_HEIGHT = 28
  157. if dataset_name == DATASET_CELEBA_NAME:
  158. self.image_mode = 'RGB'
  159. image_channels = 3
  160. elif dataset_name == DATASET_MNIST_NAME:
  161. self.image_mode = 'L'
  162. image_channels = 1
  163. self.data_files = data_files
  164. self.shape = len(data_files), IMAGE_WIDTH, IMAGE_HEIGHT, image_channels
  165. def get_batches(self, batch_size):
  166. """
  167. Generate batches
  168. :param batch_size: Batch Size
  169. :return: Batches of data
  170. """
  171. IMAGE_MAX_VALUE = 255
  172. current_index = 0
  173. while current_index + batch_size <= self.shape[0]:
  174. data_batch = get_batch(
  175. self.data_files[current_index:current_index + batch_size],
  176. *self.shape[1:3],
  177. self.image_mode)
  178. current_index += batch_size
  179. yield data_batch / IMAGE_MAX_VALUE - 0.5
  180. class DLProgress(tqdm):
  181. """
  182. Handle Progress Bar while Downloading
  183. """
  184. last_block = 0
  185. def hook(self, block_num=1, block_size=1, total_size=None):
  186. """
  187. A hook function that will be called once on establishment of the network connection and
  188. once after each block read thereafter.
  189. :param block_num: A count of blocks transferred so far
  190. :param block_size: Block size in bytes
  191. :param total_size: The total size of the file. This may be -1 on older FTP servers which do not return
  192. a file size in response to a retrieval request.
  193. """
  194. self.total = total_size
  195. self.update((block_num - self.last_block) * block_size)
  196. self.last_block = block_num