problem_unittests.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. from copy import deepcopy
  2. from unittest import mock
  3. import tensorflow as tf
  4. def test_safe(func):
  5. """
  6. Isolate tests
  7. """
  8. def func_wrapper(*args):
  9. with tf.Graph().as_default():
  10. result = func(*args)
  11. print('Tests Passed')
  12. return result
  13. return func_wrapper
  14. def _assert_tensor_shape(tensor, shape, display_name):
  15. assert tf.assert_rank(tensor, len(shape), message='{} has wrong rank'.format(display_name))
  16. tensor_shape = tensor.get_shape().as_list() if len(shape) else []
  17. wrong_dimension = [ten_dim for ten_dim, cor_dim in zip(tensor_shape, shape)
  18. if cor_dim is not None and ten_dim != cor_dim]
  19. assert not wrong_dimension, \
  20. '{} has wrong shape. Found {}'.format(display_name, tensor_shape)
  21. def _check_input(tensor, shape, display_name, tf_name=None):
  22. assert tensor.op.type == 'Placeholder', \
  23. '{} is not a Placeholder.'.format(display_name)
  24. _assert_tensor_shape(tensor, shape, 'Real Input')
  25. if tf_name:
  26. assert tensor.name == tf_name, \
  27. '{} has bad name. Found name {}'.format(display_name, tensor.name)
  28. class TmpMock():
  29. """
  30. Mock a attribute. Restore attribute when exiting scope.
  31. """
  32. def __init__(self, module, attrib_name):
  33. self.original_attrib = deepcopy(getattr(module, attrib_name))
  34. setattr(module, attrib_name, mock.MagicMock())
  35. self.module = module
  36. self.attrib_name = attrib_name
  37. def __enter__(self):
  38. return getattr(self.module, self.attrib_name)
  39. def __exit__(self, type, value, traceback):
  40. setattr(self.module, self.attrib_name, self.original_attrib)
  41. @test_safe
  42. def test_model_inputs(model_inputs):
  43. image_width = 28
  44. image_height = 28
  45. image_channels = 3
  46. z_dim = 100
  47. input_real, input_z, learn_rate = model_inputs(image_width, image_height, image_channels, z_dim)
  48. _check_input(input_real, [None, image_width, image_height, image_channels], 'Real Input')
  49. _check_input(input_z, [None, z_dim], 'Z Input')
  50. _check_input(learn_rate, [], 'Learning Rate')
  51. @test_safe
  52. def test_discriminator(discriminator, tf_module):
  53. with TmpMock(tf_module, 'variable_scope') as mock_variable_scope:
  54. image = tf.placeholder(tf.float32, [None, 28, 28, 3])
  55. output, logits = discriminator(image)
  56. _assert_tensor_shape(output, [None, 1], 'Discriminator Training(reuse=false) output')
  57. _assert_tensor_shape(logits, [None, 1], 'Discriminator Training(reuse=false) Logits')
  58. assert mock_variable_scope.called,\
  59. 'tf.variable_scope not called in Discriminator Training(reuse=false)'
  60. assert mock_variable_scope.call_args == mock.call('discriminator', reuse=False), \
  61. 'tf.variable_scope called with wrong arguments in Discriminator Training(reuse=false)'
  62. mock_variable_scope.reset_mock()
  63. output_reuse, logits_reuse = discriminator(image, True)
  64. _assert_tensor_shape(output_reuse, [None, 1], 'Discriminator Inference(reuse=True) output')
  65. _assert_tensor_shape(logits_reuse, [None, 1], 'Discriminator Inference(reuse=True) Logits')
  66. assert mock_variable_scope.called, \
  67. 'tf.variable_scope not called in Discriminator Inference(reuse=True)'
  68. assert mock_variable_scope.call_args == mock.call('discriminator', reuse=True), \
  69. 'tf.variable_scope called with wrong arguments in Discriminator Inference(reuse=True)'
  70. @test_safe
  71. def test_generator(generator, tf_module):
  72. with TmpMock(tf_module, 'variable_scope') as mock_variable_scope:
  73. z = tf.placeholder(tf.float32, [None, 100])
  74. out_channel_dim = 5
  75. output = generator(z, out_channel_dim)
  76. _assert_tensor_shape(output, [None, 28, 28, out_channel_dim], 'Generator output (is_train=True)')
  77. assert mock_variable_scope.called, \
  78. 'tf.variable_scope not called in Generator Training(reuse=false)'
  79. assert mock_variable_scope.call_args == mock.call('generator', reuse=False), \
  80. 'tf.variable_scope called with wrong arguments in Generator Training(reuse=false)'
  81. mock_variable_scope.reset_mock()
  82. output = generator(z, out_channel_dim, False)
  83. _assert_tensor_shape(output, [None, 28, 28, out_channel_dim], 'Generator output (is_train=False)')
  84. assert mock_variable_scope.called, \
  85. 'tf.variable_scope not called in Generator Inference(reuse=True)'
  86. assert mock_variable_scope.call_args == mock.call('generator', reuse=True), \
  87. 'tf.variable_scope called with wrong arguments in Generator Inference(reuse=True)'
  88. @test_safe
  89. def test_model_loss(model_loss):
  90. out_channel_dim = 4
  91. input_real = tf.placeholder(tf.float32, [None, 28, 28, out_channel_dim])
  92. input_z = tf.placeholder(tf.float32, [None, 100])
  93. d_loss, g_loss = model_loss(input_real, input_z, out_channel_dim)
  94. _assert_tensor_shape(d_loss, [], 'Discriminator Loss')
  95. _assert_tensor_shape(g_loss, [], 'Generator Loss')
  96. @test_safe
  97. def test_model_opt(model_opt, tf_module):
  98. with TmpMock(tf_module, 'trainable_variables') as mock_trainable_variables:
  99. with tf.variable_scope('discriminator'):
  100. discriminator_logits = tf.Variable(tf.zeros([3, 3]))
  101. with tf.variable_scope('generator'):
  102. generator_logits = tf.Variable(tf.zeros([3, 3]))
  103. mock_trainable_variables.return_value = [discriminator_logits, generator_logits]
  104. d_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
  105. logits=discriminator_logits,
  106. labels=[[0.0, 0.0, 1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0]]))
  107. g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
  108. logits=generator_logits,
  109. labels=[[0.0, 0.0, 1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0]]))
  110. learning_rate = 0.001
  111. beta1 = 0.9
  112. d_train_opt, g_train_opt = model_opt(d_loss, g_loss, learning_rate, beta1)
  113. assert mock_trainable_variables.called,\
  114. 'tf.mock_trainable_variables not called'