my_model_selectors.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. import math
  2. import statistics
  3. import warnings
  4. import numpy as np
  5. from hmmlearn.hmm import GaussianHMM
  6. from sklearn.model_selection import KFold
  7. from asl_utils import combine_sequences
  8. class ModelSelector(object):
  9. '''Base class for model selection (strategy design pattern).
  10. '''
  11. def __init__(self, all_word_sequences: dict, all_word_Xlengths: dict,
  12. this_word: str, n_constant=3,
  13. min_n_components=2, max_n_components=10,
  14. random_state=14, verbose=False):
  15. self.words = all_word_sequences
  16. self.hwords = all_word_Xlengths
  17. self.sequences = all_word_sequences[this_word]
  18. self.X, self.lengths = all_word_Xlengths[this_word]
  19. self.this_word = this_word
  20. self.n_constant = n_constant
  21. self.min_n_components = min_n_components
  22. self.max_n_components = max_n_components
  23. self.random_state = random_state
  24. self.verbose = verbose
  25. self.n_components = range(self.min_n_components, \
  26. self.max_n_components + 1)
  27. def select(self):
  28. raise NotImplementedError
  29. def base_model(self, num_states):
  30. # with warnings.catch_warnings():
  31. warnings.filterwarnings("ignore", category=DeprecationWarning)
  32. # warnings.filterwarnings("ignore", category=RuntimeWarning)
  33. try:
  34. hmm_model = GaussianHMM(n_components=num_states,
  35. covariance_type="diag", n_iter=1000,
  36. random_state=self.random_state,
  37. verbose=False).fit(self.X, self.lengths)
  38. if self.verbose:
  39. print("Model created for {} with {} states"\
  40. .format(self.this_word, num_states))
  41. return hmm_model
  42. except:
  43. if self.verbose:
  44. print("failure on {} with {} states".\
  45. format(self.this_word, num_states))
  46. return None
  47. class SelectorConstant(ModelSelector):
  48. """Select the model with value self.n_constant.
  49. """
  50. def select(self):
  51. """Select based on n_constant value.
  52. :return: GaussianHMM object
  53. """
  54. best_num_components = self.n_constant
  55. return self.base_model(best_num_components)
  56. class SelectorBIC(ModelSelector):
  57. """Select the model with the lowest Baysian Information Criterion (BIC)
  58. score -- http://www2.imm.dtu.dk/courses/02433/doc/ch6_slides.pdf.
  59. Bayesian information criteria: BIC = -2 * logL + p * logN
  60. """
  61. def select(self):
  62. """Select the best model for self.this_word based on
  63. BIC score for n between self.min_n_components and self.max_n_components
  64. :return: GaussianHMM object
  65. """
  66. warnings.filterwarnings("ignore", category=DeprecationWarning)
  67. bic_scores = []
  68. try:
  69. for n in self.n_components:
  70. # BIC = −2 log L + p log N
  71. # L = is the likelihood of the fitted model
  72. # p = is the number of parameters
  73. # N = is the number of data points
  74. model = self.base_model(n)
  75. log_l = model.score(self.X, self.lengths)
  76. p = n ** 2 + 2 * n * model.n_features - 1
  77. bic_score = -2 * log_l + p * math.log(n)
  78. bic_scores.append(bic_score)
  79. except Exception as e:
  80. pass
  81. states = self.n_components[np.argmax(bic_scores)] \
  82. if bic_scores else self.n_constant
  83. return self.base_model(states)
  84. class SelectorDIC(ModelSelector):
  85. """Select best model based on Discriminative Information Criterion
  86. Biem, Alain. "A model selection criterion for classification:
  87. Application to hmm topology optimization."
  88. Document Analysis and Recognition, 2003. Proceedings.
  89. Seventh International Conference on. IEEE, 2003.
  90. http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.58.6208
  91. DIC = log(P(X(i)) - 1/(M-1)SUM(log(P(X(all but i))
  92. """
  93. def select(self):
  94. warnings.filterwarnings("ignore", category=DeprecationWarning)
  95. dic_scores = []
  96. logs_l = []
  97. try:
  98. for n_component in self.n_components:
  99. model = self.base_model(n_component)
  100. logs_l.append(model.score(self.X, self.lengths))
  101. sum_logs_l = sum(logs_l)
  102. m = len(self.n_components)
  103. for log_l in logs_l:
  104. # DIC = log(P(X(i)) - 1/(M-1)SUM(log(P(X(all but i))
  105. other_words_likelihood = (sum_logs_l - log_l) / (m - 1)
  106. dic_scores.append(log_l - other_words_likelihood)
  107. except Exception as e:
  108. pass
  109. states = self.n_components[np.argmax(dic_scores)] \
  110. if dic_scores else self.n_constant
  111. return self.base_model(states)
  112. class SelectorCV(ModelSelector):
  113. """Slect best model based on average log Likelihood of cross-validation
  114. folds.
  115. """
  116. def select(self):
  117. warnings.filterwarnings("ignore", category=DeprecationWarning)
  118. mean_scores = []
  119. # Save reference to 'KFold' in variable as shown in notebook
  120. split_method = KFold()
  121. try:
  122. for n_component in self.n_components:
  123. model = self.base_model(n_component)
  124. # Fold and calculate model mean scores
  125. fold_scores = []
  126. for _, test_idx in split_method.split(self.sequences):
  127. # Get test sequences
  128. test_X, test_length = combine_sequences(test_idx, \
  129. self.sequences)
  130. # Record each model score
  131. fold_scores.append(model.score(test_X, test_length))
  132. # Compute mean of all fold scores
  133. mean_scores.append(np.mean(fold_scores))
  134. except Exception as e:
  135. pass
  136. states = self.n_components[np.argmax(mean_scores)] \
  137. if mean_scores else self.n_constant
  138. return self.base_model(states)