app.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import base64
  2. import logging
  3. import io
  4. import os
  5. from flask import Flask, render_template, request
  6. from load import init_model
  7. from PIL import Image
  8. from util import decode_prob
  9. logger = logging.getLogger("dog_breed_classifier")
  10. logger.setLevel(logging.DEBUG)
  11. app = Flask(__name__)
  12. # Initialize
  13. MODEL_DIR = os.path.abspath("./models")
  14. RESNET_CONFIG = {'arch':
  15. os.path.join(MODEL_DIR,
  16. 'model.Resnet50.json'),
  17. 'weights':
  18. os.path.join(MODEL_DIR,
  19. 'weights.Resnet50.hdf5')}
  20. INCEPTION_CONFIG = {'arch':
  21. os.path.join(MODEL_DIR,
  22. 'model.inceptionv3.json'),
  23. 'weights':
  24. os.path.join(MODEL_DIR,
  25. 'weights.inceptionv3.h5')}
  26. MODELS = {'resnet': RESNET_CONFIG,
  27. 'inception': INCEPTION_CONFIG}
  28. @app.route('/index')
  29. @app.route('/')
  30. def index():
  31. return render_template('settings.html')
  32. @app.route('/settings', methods=['GET', 'POST'])
  33. def settings():
  34. """Select Model Architecture and Initialize
  35. """
  36. global model, graph, preprocess
  37. # grab model selected
  38. model_name = request.form['model']
  39. config = MODELS[model_name]
  40. # init the model with pre-trained architecture and weights
  41. model, graph = init_model(config['arch'], config['weights'])
  42. # use the proper preprocessing method
  43. if model_name == 'inception':
  44. from util import preprocess_inception
  45. preprocess = preprocess_inception
  46. else:
  47. from util import preprocess_resnet
  48. preprocess = preprocess_resnet
  49. return render_template('select_files.html', model_name=model_name)
  50. @app.route('/predict', methods=['GET', 'POST'])
  51. def predict():
  52. """File selection and display results
  53. """
  54. if request.method == 'POST' and 'file[]' in request.files:
  55. inputs = []
  56. files = request.files.getlist('file[]')
  57. for file_obj in files:
  58. # Check if no files uploaded
  59. if file_obj.filename == '':
  60. if len(files) == 1:
  61. return render_template('select_files.html')
  62. continue
  63. entry = {}
  64. entry.update({'filename': file_obj.filename})
  65. try:
  66. img_bytes = io.BytesIO(file_obj.stream.getvalue())
  67. entry.update({'data':
  68. Image.open(
  69. img_bytes
  70. )})
  71. except AttributeError:
  72. img_bytes = io.BytesIO(file_obj.stream.read())
  73. entry.update({'data':
  74. Image.open(
  75. img_bytes
  76. )})
  77. # keep image in base64 for later use
  78. img_b64 = base64.b64encode(img_bytes.getvalue()).decode()
  79. entry.update({'img': img_b64})
  80. inputs.append(entry)
  81. outputs = []
  82. with graph.as_default():
  83. for input_ in inputs:
  84. # convert to 4D tensor to feed into our model
  85. x = preprocess(input_['data'])
  86. # perform prediction
  87. out = model.predict(x)
  88. outputs.append(out)
  89. # decode output prob
  90. outputs = decode_prob(outputs)
  91. results = []
  92. for input_, probs in zip(inputs, outputs):
  93. results.append({'filename': input_['filename'],
  94. 'image': input_['img'],
  95. 'predict_probs': probs})
  96. return render_template('results.html', results=results)
  97. # if no files uploaded
  98. return render_template('select_files.html')
  99. if __name__ == '__main__':
  100. app.run(debug=True)