util.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. import numpy as np
  2. import os
  3. from keras.preprocessing import image
  4. from operator import itemgetter
  5. IMAGE_DIM = (224, 224)
  6. LABELS_PATH = os.path.abspath('labels.txt')
  7. global labels
  8. with open(LABELS_PATH) as f:
  9. labels = f.readlines()
  10. labels = np.array([label.strip() for label in labels])
  11. def img_to_tensor(img):
  12. """Loads a PIL image and outputs a 4d tensor
  13. :param img, PIL Image object
  14. :return 4D numpy array/tensor
  15. """
  16. if img.size != IMAGE_DIM:
  17. img = img.resize(IMAGE_DIM)
  18. # convert to 3d array
  19. x = image.img_to_array(img)
  20. # conver to 4d
  21. return np.expand_dims(x, axis=0)
  22. def preprocess_inception(img):
  23. """Prepare image for Inception model
  24. :param img, PIL Image object
  25. :return preprocssed input for inception model
  26. """
  27. from keras.applications.inception_v3 import preprocess_input
  28. img = preprocess_input(img_to_tensor(img))
  29. return img
  30. def extract_bottleneck_features_resnet(tensor):
  31. from keras.applications.resnet50 import ResNet50
  32. return ResNet50(weights='imagenet', include_top=False).predict(tensor)
  33. def preprocess_resnet(img):
  34. """Prepare image for Resnet50 model
  35. """
  36. from keras.applications.resnet50 import preprocess_input
  37. img = preprocess_input(img_to_tensor(img))
  38. return extract_bottleneck_features_resnet(img)
  39. def decode_prob(output_arr, top_probs=5):
  40. """Label class probabilities with class names
  41. :param output_arr, list: class probabilities
  42. :param top_probs, int: number of class probabilities to return out of 133
  43. :return list[dict]:
  44. """
  45. results = []
  46. for row in output_arr:
  47. entries = []
  48. for name, prob in zip(labels, row[0]):
  49. entries.append({'name': name,
  50. 'prob': prob})
  51. entries = sorted(entries,
  52. key=itemgetter('prob'),
  53. reverse=True)[:top_probs]
  54. for entry in entries:
  55. entry['prob'] = '{:.5f}'.format(entry['prob'])
  56. results.append(entries)
  57. return results