problem_unittests.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. import numpy as np
  2. import tensorflow as tf
  3. from tensorflow.contrib import rnn
  4. def _print_success_message():
  5. print('Tests Passed')
  6. def test_create_lookup_tables(create_lookup_tables):
  7. with tf.Graph().as_default():
  8. test_text = '''
  9. Moe_Szyslak Moe's Tavern Where the elite meet to drink
  10. Bart_Simpson Eh yeah hello is Mike there Last name Rotch
  11. Moe_Szyslak Hold on I'll check Mike Rotch Mike Rotch Hey has anybody seen Mike Rotch lately
  12. Moe_Szyslak Listen you little puke One of these days I'm gonna catch you and I'm gonna carve my name on your back with an ice pick
  13. Moe_Szyslak Whats the matter Homer You're not your normal effervescent self
  14. Homer_Simpson I got my problems Moe Give me another one
  15. Moe_Szyslak Homer hey you should not drink to forget your problems
  16. Barney_Gumble Yeah you should only drink to enhance your social skills'''
  17. test_text = test_text.lower()
  18. test_text = test_text.split()
  19. vocab_to_int, int_to_vocab = create_lookup_tables(test_text)
  20. # Check types
  21. assert isinstance(vocab_to_int, dict),\
  22. 'vocab_to_int is not a dictionary.'
  23. assert isinstance(int_to_vocab, dict),\
  24. 'int_to_vocab is not a dictionary.'
  25. # Compare lengths of dicts
  26. assert len(vocab_to_int) == len(int_to_vocab),\
  27. 'Length of vocab_to_int and int_to_vocab don\'t match. ' \
  28. 'vocab_to_int is length {}. int_to_vocab is length {}'.format(len(vocab_to_int), len(int_to_vocab))
  29. # Make sure the dicts have the same words
  30. vocab_to_int_word_set = set(vocab_to_int.keys())
  31. int_to_vocab_word_set = set(int_to_vocab.values())
  32. assert not (vocab_to_int_word_set - int_to_vocab_word_set),\
  33. 'vocab_to_int and int_to_vocab don\'t have the same words.' \
  34. '{} found in vocab_to_int, but not in int_to_vocab'.format(vocab_to_int_word_set - int_to_vocab_word_set)
  35. assert not (int_to_vocab_word_set - vocab_to_int_word_set),\
  36. 'vocab_to_int and int_to_vocab don\'t have the same words.' \
  37. '{} found in int_to_vocab, but not in vocab_to_int'.format(int_to_vocab_word_set - vocab_to_int_word_set)
  38. # Make sure the dicts have the same word ids
  39. vocab_to_int_word_id_set = set(vocab_to_int.values())
  40. int_to_vocab_word_id_set = set(int_to_vocab.keys())
  41. assert not (vocab_to_int_word_id_set - int_to_vocab_word_id_set),\
  42. 'vocab_to_int and int_to_vocab don\'t contain the same word ids.' \
  43. '{} found in vocab_to_int, but not in int_to_vocab'.format(vocab_to_int_word_id_set - int_to_vocab_word_id_set)
  44. assert not (int_to_vocab_word_id_set - vocab_to_int_word_id_set),\
  45. 'vocab_to_int and int_to_vocab don\'t contain the same word ids.' \
  46. '{} found in int_to_vocab, but not in vocab_to_int'.format(int_to_vocab_word_id_set - vocab_to_int_word_id_set)
  47. # Make sure the dicts make the same lookup
  48. missmatches = [(word, id, id, int_to_vocab[id]) for word, id in vocab_to_int.items() if int_to_vocab[id] != word]
  49. assert not missmatches,\
  50. 'Found {} missmatche(s). First missmatch: vocab_to_int[{}] = {} and int_to_vocab[{}] = {}'.format(
  51. len(missmatches),
  52. *missmatches[0])
  53. assert len(vocab_to_int) > len(set(test_text))/2,\
  54. 'The length of vocab seems too small. Found a length of {}'.format(len(vocab_to_int))
  55. _print_success_message()
  56. def test_get_batches(get_batches):
  57. with tf.Graph().as_default():
  58. test_batch_size = 128
  59. test_seq_length = 5
  60. test_int_text = list(range(1000*test_seq_length))
  61. batches = get_batches(test_int_text, test_batch_size, test_seq_length)
  62. # Check type
  63. assert isinstance(batches, np.ndarray),\
  64. 'Batches is not a Numpy array'
  65. # Check shape
  66. assert batches.shape == (7, 2, 128, 5),\
  67. 'Batches returned wrong shape. Found {}'.format(batches.shape)
  68. for x in range(batches.shape[2]):
  69. assert np.array_equal(batches[0,0,x], np.array(range(x * 35, x * 35 + batches.shape[3]))),\
  70. 'Batches returned wrong contents. For example, input sequence {} in the first batch was {}'.format(x, batches[0,0,x])
  71. assert np.array_equal(batches[0,1,x], np.array(range(x * 35 + 1, x * 35 + 1 + batches.shape[3]))),\
  72. 'Batches returned wrong contents. For example, target sequence {} in the first batch was {}'.format(x, batches[0,1,x])
  73. last_seq_target = (test_batch_size-1) * 35 + 31
  74. last_seq = np.array(range(last_seq_target, last_seq_target+ batches.shape[3]))
  75. last_seq[-1] = batches[0,0,0,0]
  76. assert np.array_equal(batches[-1,1,-1], last_seq),\
  77. 'The last target of the last batch should be the first input of the first batch. Found {} but expected {}'.format(batches[-1,1,-1], last_seq)
  78. _print_success_message()
  79. def test_tokenize(token_lookup):
  80. with tf.Graph().as_default():
  81. symbols = set(['.', ',', '"', ';', '!', '?', '(', ')', '--', '\n'])
  82. token_dict = token_lookup()
  83. # Check type
  84. assert isinstance(token_dict, dict), \
  85. 'Returned type is {}.'.format(type(token_dict))
  86. # Check symbols
  87. missing_symbols = symbols - set(token_dict.keys())
  88. unknown_symbols = set(token_dict.keys()) - symbols
  89. assert not missing_symbols, \
  90. 'Missing symbols: {}'.format(missing_symbols)
  91. assert not unknown_symbols, \
  92. 'Unknown symbols: {}'.format(unknown_symbols)
  93. # Check values type
  94. bad_value_type = [type(val) for val in token_dict.values() if not isinstance(val, str)]
  95. assert not bad_value_type,\
  96. 'Found token as {} type.'.format(bad_value_type[0])
  97. # Check for spaces
  98. key_has_spaces = [k for k in token_dict.keys() if ' ' in k]
  99. val_has_spaces = [val for val in token_dict.values() if ' ' in val]
  100. assert not key_has_spaces,\
  101. 'The key "{}" includes spaces. Remove spaces from keys and values'.format(key_has_spaces[0])
  102. assert not val_has_spaces,\
  103. 'The value "{}" includes spaces. Remove spaces from keys and values'.format(val_has_spaces[0])
  104. # Check for symbols in values
  105. symbol_val = ()
  106. for symbol in symbols:
  107. for val in token_dict.values():
  108. if symbol in val:
  109. symbol_val = (symbol, val)
  110. assert not symbol_val,\
  111. 'Don\'t use a symbol that will be replaced in your tokens. Found the symbol {} in value {}'.format(*symbol_val)
  112. _print_success_message()
  113. def test_get_inputs(get_inputs):
  114. with tf.Graph().as_default():
  115. input_data, targets, lr = get_inputs()
  116. # Check type
  117. assert input_data.op.type == 'Placeholder',\
  118. 'Input not a Placeholder.'
  119. assert targets.op.type == 'Placeholder',\
  120. 'Targets not a Placeholder.'
  121. assert lr.op.type == 'Placeholder',\
  122. 'Learning Rate not a Placeholder.'
  123. # Check name
  124. assert input_data.name == 'input:0',\
  125. 'Input has bad name. Found name {}'.format(input_data.name)
  126. # Check rank
  127. input_rank = 0 if input_data.get_shape() == None else len(input_data.get_shape())
  128. targets_rank = 0 if targets.get_shape() == None else len(targets.get_shape())
  129. lr_rank = 0 if lr.get_shape() == None else len(lr.get_shape())
  130. assert input_rank == 2,\
  131. 'Input has wrong rank. Rank {} found.'.format(input_rank)
  132. assert targets_rank == 2,\
  133. 'Targets has wrong rank. Rank {} found.'.format(targets_rank)
  134. assert lr_rank == 0,\
  135. 'Learning Rate has wrong rank. Rank {} found'.format(lr_rank)
  136. _print_success_message()
  137. def test_get_init_cell(get_init_cell):
  138. with tf.Graph().as_default():
  139. test_batch_size_ph = tf.placeholder(tf.int32, [])
  140. test_rnn_size = 256
  141. cell, init_state = get_init_cell(test_batch_size_ph, test_rnn_size)
  142. # Check type
  143. assert isinstance(cell, tf.contrib.rnn.MultiRNNCell),\
  144. 'Cell is wrong type. Found {} type'.format(type(cell))
  145. # Check for name attribute
  146. assert hasattr(init_state, 'name'),\
  147. 'Initial state doesn\'t have the "name" attribute. Try using `tf.identity` to set the name.'
  148. # Check name
  149. assert init_state.name == 'initial_state:0',\
  150. 'Initial state doesn\'t have the correct name. Found the name {}'.format(init_state.name)
  151. _print_success_message()
  152. def test_get_embed(get_embed):
  153. with tf.Graph().as_default():
  154. embed_shape = [50, 5, 256]
  155. test_input_data = tf.placeholder(tf.int32, embed_shape[:2])
  156. test_vocab_size = 27
  157. test_embed_dim = embed_shape[2]
  158. embed = get_embed(test_input_data, test_vocab_size, test_embed_dim)
  159. # Check shape
  160. assert embed.shape == embed_shape,\
  161. 'Wrong shape. Found shape {}'.format(embed.shape)
  162. _print_success_message()
  163. def test_build_rnn(build_rnn):
  164. with tf.Graph().as_default():
  165. test_rnn_size = 256
  166. test_rnn_layer_size = 2
  167. test_cell = rnn.MultiRNNCell([rnn.BasicLSTMCell(test_rnn_size) for _ in range(test_rnn_layer_size)])
  168. test_inputs = tf.placeholder(tf.float32, [None, None, test_rnn_size])
  169. outputs, final_state = build_rnn(test_cell, test_inputs)
  170. # Check name
  171. assert hasattr(final_state, 'name'),\
  172. 'Final state doesn\'t have the "name" attribute. Try using `tf.identity` to set the name.'
  173. assert final_state.name == 'final_state:0',\
  174. 'Final state doesn\'t have the correct name. Found the name {}'.format(final_state.name)
  175. # Check shape
  176. assert outputs.get_shape().as_list() == [None, None, test_rnn_size],\
  177. 'Outputs has wrong shape. Found shape {}'.format(outputs.get_shape())
  178. assert final_state.get_shape().as_list() == [test_rnn_layer_size, 2, None, test_rnn_size],\
  179. 'Final state wrong shape. Found shape {}'.format(final_state.get_shape())
  180. _print_success_message()
  181. def test_build_nn(build_nn):
  182. with tf.Graph().as_default():
  183. test_input_data_shape = [128, 5]
  184. test_input_data = tf.placeholder(tf.int32, test_input_data_shape)
  185. test_rnn_size = 256
  186. test_embed_dim = 300
  187. test_rnn_layer_size = 2
  188. test_vocab_size = 27
  189. test_cell = rnn.MultiRNNCell([rnn.BasicLSTMCell(test_rnn_size) for _ in range(test_rnn_layer_size)])
  190. logits, final_state = build_nn(test_cell, test_rnn_size, test_input_data, test_vocab_size, test_embed_dim)
  191. # Check name
  192. assert hasattr(final_state, 'name'), \
  193. 'Final state doesn\'t have the "name" attribute. Are you using build_rnn?'
  194. assert final_state.name == 'final_state:0', \
  195. 'Final state doesn\'t have the correct name. Found the name {}. Are you using build_rnn?'.format(final_state.name)
  196. # Check Shape
  197. assert logits.get_shape().as_list() == test_input_data_shape + [test_vocab_size], \
  198. 'Outputs has wrong shape. Found shape {}'.format(logits.get_shape())
  199. assert final_state.get_shape().as_list() == [test_rnn_layer_size, 2, 128, test_rnn_size], \
  200. 'Final state wrong shape. Found shape {}'.format(final_state.get_shape())
  201. _print_success_message()
  202. def test_get_tensors(get_tensors):
  203. test_graph = tf.Graph()
  204. with test_graph.as_default():
  205. test_input = tf.placeholder(tf.int32, name='input')
  206. test_initial_state = tf.placeholder(tf.int32, name='initial_state')
  207. test_final_state = tf.placeholder(tf.int32, name='final_state')
  208. test_probs = tf.placeholder(tf.float32, name='probs')
  209. input_text, initial_state, final_state, probs = get_tensors(test_graph)
  210. # Check correct tensor
  211. assert input_text == test_input,\
  212. 'Test input is wrong tensor'
  213. assert initial_state == test_initial_state, \
  214. 'Initial state is wrong tensor'
  215. assert final_state == test_final_state, \
  216. 'Final state is wrong tensor'
  217. assert probs == test_probs, \
  218. 'Probabilities is wrong tensor'
  219. _print_success_message()
  220. def test_pick_word(pick_word):
  221. with tf.Graph().as_default():
  222. test_probabilities = np.array([0.1, 0.8, 0.05, 0.05])
  223. test_int_to_vocab = {word_i: word for word_i, word in enumerate(['this', 'is', 'a', 'test'])}
  224. pred_word = pick_word(test_probabilities, test_int_to_vocab)
  225. # Check type
  226. assert isinstance(pred_word, str),\
  227. 'Predicted word is wrong type. Found {} type.'.format(type(pred_word))
  228. # Check word is from vocab
  229. assert pred_word in test_int_to_vocab.values(),\
  230. 'Predicted word not found in int_to_vocab.'
  231. _print_success_message()