app.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. import base64
  2. import logging
  3. import io
  4. import os
  5. from flask import Flask, render_template, request
  6. from load import init_model
  7. from PIL import Image
  8. from util import decode_prob
  9. logger = logging.getLogger("dog_breed_classifier")
  10. logger.setLevel(logging.DEBUG)
  11. app = Flask(__name__)
  12. # Initialize
  13. MODEL_DIR = os.path.abspath("./models")
  14. RESNET_CONFIG = {'arch':
  15. os.path.join(MODEL_DIR,
  16. 'model.Resnet50.json'),
  17. 'weights':
  18. os.path.join(MODEL_DIR,
  19. 'weights.Resnet50.hdf5')}
  20. INCEPTION_CONFIG = {'arch':
  21. os.path.join(MODEL_DIR,
  22. 'model.inceptionv3.json'),
  23. 'weights':
  24. os.path.join(MODEL_DIR,
  25. 'weights.inceptionv3.h5')}
  26. MODELS = {'resnet': RESNET_CONFIG,
  27. 'inception': INCEPTION_CONFIG}
  28. @app.route('/index')
  29. @app.route('/')
  30. def index():
  31. return render_template('select_files.html')
  32. @app.route('/settings', methods=['GET', 'POST'])
  33. def settings():
  34. """Select Model Architecture and Initialize
  35. """
  36. global model, graph, preprocess
  37. # grab model selected
  38. model_name = request.form['model']
  39. config = MODELS[model_name]
  40. # init the model with pre-trained architecture and weights
  41. model, graph = init_model(config['arch'], config['weights'])
  42. # use the proper preprocessing method
  43. if model_name == 'inception':
  44. from util import preprocess_inception
  45. preprocess = preprocess_inception
  46. else:
  47. from util import preprocess_resnet
  48. preprocess = preprocess_resnet
  49. return render_template('select_files.html', model_name=model_name)
  50. @app.route('/predict', methods=['GET', 'POST'])
  51. def predict():
  52. """File selection and display results
  53. """
  54. if request.method == 'POST' and 'file[]' in request.files:
  55. global model, graph, preprocess
  56. # grab model selected
  57. model_name = request.form['model']
  58. config = MODELS[model_name]
  59. # init the model with pre-trained architecture and weights
  60. model, graph = init_model(config['arch'], config['weights'])
  61. # use the proper preprocessing method
  62. if model_name == 'inception':
  63. from util import preprocess_inception
  64. preprocess = preprocess_inception
  65. else:
  66. from util import preprocess_resnet
  67. preprocess = preprocess_resnet
  68. inputs = []
  69. files = request.files.getlist('file[]')
  70. for file_obj in files:
  71. # Check if no files uploaded
  72. if file_obj.filename == '':
  73. if len(files) == 1:
  74. return render_template('select_files.html')
  75. continue
  76. entry = {}
  77. entry.update({'filename': file_obj.filename})
  78. try:
  79. img_bytes = io.BytesIO(file_obj.stream.getvalue())
  80. entry.update({'data':
  81. Image.open(
  82. img_bytes
  83. )})
  84. except AttributeError:
  85. img_bytes = io.BytesIO(file_obj.stream.read())
  86. entry.update({'data':
  87. Image.open(
  88. img_bytes
  89. )})
  90. # keep image in base64 for later use
  91. img_b64 = base64.b64encode(img_bytes.getvalue()).decode()
  92. entry.update({'img': img_b64})
  93. inputs.append(entry)
  94. outputs = []
  95. with graph.as_default():
  96. for input_ in inputs:
  97. # convert to 4D tensor to feed into our model
  98. x = preprocess(input_['data'])
  99. # perform prediction
  100. out = model.predict(x)
  101. outputs.append(out)
  102. # decode output prob
  103. outputs = decode_prob(outputs)
  104. results = []
  105. for input_, probs in zip(inputs, outputs):
  106. results.append({'filename': input_['filename'],
  107. 'image': input_['img'],
  108. 'predict_probs': probs})
  109. return render_template('results.html', results=results)
  110. # if no files uploaded
  111. return render_template('select_files.html')
  112. if __name__ == '__main__':
  113. #app.run(debug=True)
  114. app.run(host="0.0.0.0")