babs_visualizations.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. import numpy as np
  2. import pandas as pd
  3. import matplotlib.pyplot as plt
  4. import seaborn as sns
  5. def filter_data(data, condition):
  6. """
  7. Remove elements that do not match the condition provided.
  8. Takes a data list as input and returns a filtered list.
  9. Conditions should be a list of strings of the following format:
  10. '<field> <op> <value>'
  11. where the following operations are valid: >, <, >=, <=, ==, !=
  12. Example: ["duration < 15", "start_city == 'San Francisco'"]
  13. """
  14. # Only want to split on first two spaces separating field from operator and
  15. # operator from value: spaces within value should be retained.
  16. field, op, value = condition.split(" ", 2)
  17. # check if field is valid
  18. if field not in data.columns.values :
  19. raise Exception("'{}' is not a feature of the dataframe. Did you spell something wrong?".format(field))
  20. # convert value into number or strip excess quotes if string
  21. try:
  22. value = float(value)
  23. except:
  24. value = value.strip("\'\"")
  25. # get booleans for filtering
  26. if op == ">":
  27. matches = data[field] > value
  28. elif op == "<":
  29. matches = data[field] < value
  30. elif op == ">=":
  31. matches = data[field] >= value
  32. elif op == "<=":
  33. matches = data[field] <= value
  34. elif op == "==":
  35. matches = data[field] == value
  36. elif op == "!=":
  37. matches = data[field] != value
  38. else: # catch invalid operation codes
  39. raise Exception("Invalid comparison operator. Only >, <, >=, <=, ==, != allowed.")
  40. # filter data and outcomes
  41. data = data[matches].reset_index(drop = True)
  42. return data
  43. def usage_stats(data, filters = [], verbose = True):
  44. """
  45. Report number of trips and average trip duration for data points that meet
  46. specified filtering criteria.
  47. """
  48. n_data_all = data.shape[0]
  49. # Apply filters to data
  50. for condition in filters:
  51. data = filter_data(data, condition)
  52. # Compute number of data points that met the filter criteria.
  53. n_data = data.shape[0]
  54. # Compute statistics for trip durations.
  55. duration_mean = data['duration'].mean()
  56. duration_qtiles = data['duration'].quantile([.25, .5, .75]).as_matrix()
  57. # Report computed statistics if verbosity is set to True (default).
  58. if verbose:
  59. if filters:
  60. print('There are {:d} data points ({:.2f}%) matching the filter criteria.'.format(n_data, 100. * n_data / n_data_all))
  61. else:
  62. print('There are {:d} data points in the dataset.'.format(n_data))
  63. print('The average duration of trips is {:.2f} minutes.'.format(duration_mean))
  64. print('The median trip duration is {:.2f} minutes.'.format(duration_qtiles[1]))
  65. print('25% of trips are shorter than {:.2f} minutes.'.format(duration_qtiles[0]))
  66. print('25% of trips are longer than {:.2f} minutes.'.format(duration_qtiles[2]))
  67. # Return three-number summary
  68. return duration_qtiles
  69. def usage_plot(data, key = '', filters = [], **kwargs):
  70. """
  71. Plot number of trips, given a feature of interest and any number of filters
  72. (including no filters). Function takes a number of optional arguments for
  73. plotting data on continuously-valued variables:
  74. - n_bins: number of bars (default = 10)
  75. - bin_width: width of each bar (default divides the range of the data by
  76. number of bins). "n_bins" and "bin_width" cannot be used simultaneously.
  77. - boundary: specifies where one of the bar edges will be placed; other
  78. bar edges will be placed around that value (may result in an additional
  79. bar being plotted). Can be used with "n_bins" and "bin_width".
  80. """
  81. # Check that the key exists
  82. if not key:
  83. raise Exception("No key has been provided. Make sure you provide a variable on which to plot the data.")
  84. if key not in data.columns.values :
  85. raise Exception("'{}' is not a feature of the dataframe. Did you spell something wrong?".format(key))
  86. # Apply filters to data
  87. for condition in filters:
  88. data = filter_data(data, condition)
  89. # Create plotting figure
  90. plt.figure(figsize=(8,6))
  91. if isinstance(data[key][0] , str): # Categorical features
  92. # For strings, collect unique strings and then count number of
  93. # outcomes for survival and non-survival.
  94. # Summarize dataframe to get counts in each group
  95. data['count'] = 1
  96. data = data.groupby(key, as_index = False).count()
  97. levels = data[key].unique()
  98. n_levels = len(levels)
  99. bar_width = 0.8
  100. for i in range(n_levels):
  101. trips_bar = plt.bar(i - bar_width/2, data.loc[i]['count'], width = bar_width)
  102. # add labels to ticks for each group of bars.
  103. plt.xticks(range(n_levels), levels)
  104. else: # Numeric features
  105. # For numbers, divide the range of data into bins and count
  106. # number of outcomes for survival and non-survival in each bin.
  107. # Set up bin boundaries for plotting
  108. if kwargs and 'n_bins' in kwargs and 'bin_width' in kwargs:
  109. raise Exception("Arguments 'n_bins' and 'bin_width' cannot be used simultaneously.")
  110. min_value = data[key].min()
  111. max_value = data[key].max()
  112. value_range = max_value - min_value
  113. n_bins = 10
  114. bin_width = float(value_range) / n_bins
  115. if kwargs and 'n_bins' in kwargs:
  116. n_bins = int(kwargs['n_bins'])
  117. bin_width = float(value_range) / n_bins
  118. elif kwargs and 'bin_width' in kwargs:
  119. bin_width = kwargs['bin_width']
  120. n_bins = int(np.ceil(float(value_range) / bin_width))
  121. if kwargs and 'boundary' in kwargs:
  122. bound_factor = np.floor(( min_value - kwargs['boundary'] ) / bin_width)
  123. min_value = kwargs['boundary'] + bound_factor * bin_width
  124. if min_value + n_bins * bin_width <= max_value:
  125. n_bins += 1
  126. bins = [i*bin_width + min_value for i in range(n_bins+1)]
  127. # plot the data
  128. plt.hist(data[key], bins = bins)
  129. # Common attributes for plot formatting
  130. key_name = ' '.join([x.capitalize() for x in key.split('_')])
  131. plt.xlabel(key_name)
  132. plt.ylabel("Number of Trips")
  133. plt.title("Number of Trips by {:s}".format(key_name))
  134. plt.show()