Predicting spotted-wing drosophila trap counts using machine learning

In this study we analyze three years of data, which include environmental parameters from nearby weather stations, information about the sites where traps are located, and the total summer counts of the pest fly spotted-wing drosophila in each trap. We try to use machine learning models to determine if we can predict whether a site is "high risk" (fly totals above the median), or "low risk" (fly totals below the median).

This dataset is likely too small to draw conclusive results, but it is a good case study.

In [ ]:
!pip install silence_tensorflow
Collecting silence_tensorflow
  Downloading https://files.pythonhosted.org/packages/96/d7/076b21d0e79cfc8a085f623e6577b754c50a42cfbcce51d77d0d2206988c/silence_tensorflow-1.1.1.tar.gz
Building wheels for collected packages: silence-tensorflow
  Building wheel for silence-tensorflow (setup.py) ... done
  Created wheel for silence-tensorflow: filename=silence_tensorflow-1.1.1-cp36-none-any.whl size=3743 sha256=0f355d7877783758dbd821898fcd7cc0a9afb2280e734e6c8381f70cabfcb0a3
  Stored in directory: /root/.cache/pip/wheels/51/0b/35/cf3020764bee61daa81fa249df3a448e3806344a087fc12292
Successfully built silence-tensorflow
Installing collected packages: silence-tensorflow
Successfully installed silence-tensorflow-1.1.1
In [ ]:
import numpy as np # for linear algebra
from sklearn import preprocessing # for machine learning
import pandas as pd # for dataframes
import silence_tensorflow.auto  # to silence warning messages from tf
import tensorflow as tf # for ML models
import os # for file paths
from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive

Part 1: data preprocessing

In [ ]:
#Load dataframe
raw_csv_data = pd.read_csv("/content/drive/My Drive/SWD Count Data/SWD_traps_HR_Wasco_171819.csv")
raw_csv_data.head()
Out[ ]:
trap_ID town host management setting lure latitude longitude northing_UTM easting_UTM elevation_m year weather_station_uspest first_catch_spring_date first_catch_spring_day total_SWD_spring total_SWD_summer total_SWD_winter SWD_June Tmin_winter Tmax_winter Tmin_spring Tmax_spring Tmin_summer Tmax_summer days_below_5_winter days_below_0_winter DD_winter DD_spring DD_summer precipitation_winter precipitation_spring precipitation_summer
0 301 The Dalles cherry managed agricultural Trece+ACV 45.604670 -121.227535 5051652.993 638219.7744 277 2018 E9627 43245.0 145.0 3.0 5.0 NaN 3.0 1.153846 9.180769 8.129444 21.059667 13.241556 29.894333 6.0 27 36.4 545.700000 1577.200000 105.410 162.306 164.592
1 303 The Dalles cherry managed agricultural Trece+ACV 45.577823 -121.221128 5048681.416 638785.5733 178 2018 F2372 43227.0 127.0 3.0 11.0 NaN 0.0 -0.465385 8.550769 6.931889 21.557111 11.598444 30.004111 0.0 48 12.0 507.600000 1474.200000 NaN NaN NaN
2 305 The Dalles raspberry managed agricultural Trece+ACV 45.565197 -121.190791 5047331.647 641184.0583 252 2018 TD600 43227.0 127.0 3.0 2.0 NaN 0.0 0.347949 8.969744 6.410333 19.805667 11.789778 27.702667 6.0 34 31.9 446.900000 1335.200000 NaN NaN NaN
3 307 The Dalles cherry managed agricultural Trece+ACV 45.532002 -121.227432 5043579.941 638406.2157 461 2018 TD1000 43237.0 137.0 2.0 11.0 NaN 1.0 -0.429615 7.959872 6.238222 20.441667 11.720111 29.793556 7.0 43 25.7 472.200000 1435.000000 186.944 222.504 225.552
4 310 The Dalles cherry managed agricultural Trece+ACV 45.564390 -121.143524 5047326.273 644874.6538 241 2018 TD750 43227.0 127.0 12.0 2.0 NaN 6.0 1.131538 9.702949 7.064556 21.397667 12.946222 31.341444 6.0 29 37.7 538.800000 1605.100000 NaN NaN NaN
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
122 346 Mosier cherry managed agricultural Trece+ACV 45.689726 -121.372700 5060862.584 626707.6676 100 2019 F2359 NaN NaN 21.0 398.0 NaN 9.0 -2.311966 4.063390 7.975309 19.364198 13.680247 26.021605 19.0 51 0.0 387.222222 1287.444444 187.706 284.226 NaN
123 350 Dallesport cherry managed agricultural Trece+ACV 45.621344 -121.181155 5053586.462 641794.5018 62 2019 KDLS NaN NaN 273.0 1330.0 NaN 0.0 -2.568376 4.281339 8.938889 22.122840 15.121605 29.186420 17.0 50 0.0 542.611111 1653.888889 146.812 214.376 267.462
124 351 Dallesport cherry managed agricultural Trece+ACV 45.619840 -121.187080 5053408.904 641336.3806 59 2019 KDLS NaN NaN 82.0 902.0 NaN 1.0 -2.568376 4.281339 8.938889 22.122840 15.121605 29.186420 17.0 50 0.0 542.611111 1653.888889 146.812 214.376 267.462
125 353 Dallesport cherry managed agricultural Trece+ACV 45.630112 -121.135394 5054642.561 645339.2697 61 2019 F2362 NaN NaN 0.0 306.0 NaN 0.0 -3.287037 3.886752 8.277778 21.160494 14.885185 27.754321 19.0 58 0.0 483.111111 1517.500000 124.460 181.610 NaN
126 356 The Dalles cherry managed agricultural Trece+ACV 45.439726 -121.156874 5033452.551 410490.6463 445 2019 TD2000 NaN NaN 21.0 69.0 NaN 10.0 -5.537749 3.094017 4.556790 19.466049 8.303086 26.855556 34.0 73 0.0 309.277778 1018.944444 NaN NaN NaN

127 rows × 33 columns

In [ ]:
# Create a copy of the raw dataset to work on
df = raw_csv_data.copy()
# Inspect dataframe information
df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 127 entries, 0 to 126
Data columns (total 33 columns):
 #   Column                   Non-Null Count  Dtype  
---  ------                   --------------  -----  
 0   trap_ID                  127 non-null    int64  
 1   town                     127 non-null    object 
 2   host                     127 non-null    object 
 3   management               127 non-null    object 
 4   setting                  127 non-null    object 
 5   lure                     127 non-null    object 
 6   latitude                 127 non-null    float64
 7   longitude                127 non-null    float64
 8   northing_UTM             127 non-null    float64
 9   easting_UTM              127 non-null    float64
 10  elevation_m              127 non-null    int64  
 11  year                     127 non-null    int64  
 12  weather_station_uspest   127 non-null    object 
 13  first_catch_spring_date  82 non-null     float64
 14  first_catch_spring_day   82 non-null     float64
 15  total_SWD_spring         126 non-null    float64
 16  total_SWD_summer         126 non-null    float64
 17  total_SWD_winter         37 non-null     float64
 18  SWD_June                 126 non-null    float64
 19  Tmin_winter              126 non-null    float64
 20  Tmax_winter              127 non-null    float64
 21  Tmin_spring              127 non-null    float64
 22  Tmax_spring              127 non-null    float64
 23  Tmin_summer              127 non-null    float64
 24  Tmax_summer              126 non-null    float64
 25  days_below_5_winter      126 non-null    float64
 26  days_below_0_winter      127 non-null    int64  
 27  DD_winter                127 non-null    float64
 28  DD_spring                127 non-null    float64
 29  DD_summer                127 non-null    float64
 30  precipitation_winter     70 non-null     float64
 31  precipitation_spring     71 non-null     float64
 32  precipitation_summer     66 non-null     float64
dtypes: float64(23), int64(4), object(6)
memory usage: 32.9+ KB
In [ ]:
# Make 'year' categorical

df["year"] = df["year"].astype("category", copy = False)
In [ ]:
# Fill NAs with means from the same town and year, only for numerical variables

df.groupby(['town', 'year'])
for column in df:
    if df[column].dtype == np.float64:
        df[column].fillna(df.groupby(['town', 'year'])[column].transform('mean'), inplace=True)
        df[column].fillna(df[column].mean(), inplace=True)
In [ ]:
# Drop unnecessary columns. axis = 1 indicate that these are column names, not rows

df = df.drop(["trap_ID", "town", "northing_UTM", "easting_UTM", 'weather_station_uspest', 'first_catch_spring_date',
         'first_catch_spring_day', 'total_SWD_spring', 'total_SWD_winter', 'SWD_June'], axis = 1)
In [ ]:
# Reorder columns

column_names_reordered = ['year', 'host', 'management', 'setting', 'lure', 'latitude', 'longitude',
                          'elevation_m', 'Tmin_winter', 'Tmax_winter', 'Tmin_spring', 'Tmax_spring',
                          'Tmin_summer', 'Tmax_summer', 'days_below_5_winter', 'days_below_0_winter',
                          'DD_winter', 'DD_spring', 'DD_summer', 'precipitation_winter',
                          'precipitation_spring', 'precipitation_summer', 'total_SWD_summer']
df = df[column_names_reordered]

Here we end up with 22 independent variables (related to the setting and weather variables), and one outcome variable ("total_SWD_summer") which is the total number of flies collected in a trap in each site during the summer solstice.

In [ ]:
# Create a checkpoint
df1 = df.copy()
In [ ]:
#Create dummy values for categorical variables, in some cases narrow down to only two variables to balance dataset. 

# host = "cherry" or "other"
df1['host'] = df1['host'].map({'cherry':0, 'raspberry':1, 'blackberry':1, 'peach':1, 'blueberry':1}).astype("category", copy = False)
# management = "unmanaged" or "managed"
df1['management'] = df1['management'].map({'unmanaged': 0, 'managed': 1}).astype("category", copy = False)
# setting = "agricultural" or not
df1['setting'] = df1['setting'].map({'agricultural': 0, 'urban': 1, 'forest': 1}).astype("category", copy = False)
# lure = "Trece+ACV" or "CACV"
df1['lure'] = df1['lure'].map({'Trece+ACV': 0, 'CACV': 1}).astype("category", copy = False)
In [ ]:
# Because the dataset is so small, instead of using counts as a continuous outcome variable, 
# we will divide the trap sites into two categories, "low risk" and "high risk"

swd_median = df1['total_SWD_summer'].median()
# This is for creating two categories for targets instead of a continuous variable, 
# depending on whether a trap count is above or below the median
targets = np.where(df1['total_SWD_summer'] > swd_median, 1, 0) 
unscaled_inputs = df1.drop(['total_SWD_summer'], axis = 1) # Eliminate the output variable from the dataset
In [ ]:
# Standardize all the numeric (integers or float) variables, returns data in form of numpy array

scaled_inputs = unscaled_inputs.copy()
for column in scaled_inputs:
    if scaled_inputs[column].dtype == np.float64 or scaled_inputs[column].dtype == np.int64:
        scaled_inputs[column] = preprocessing.scale(scaled_inputs[column].astype(np.float64, copy = False))

The data is all preprocessed here and we can move on to building and training a machine learning model

Part 2: Machine learning model

In [ ]:
# Use a "config" dictionary to determine and modify the hyperparameters of the model

config = {
    # For building the model
    "input_size": len(scaled_inputs.columns), # count number of inputs
    "output_size": 1, # binary choice 1 and 0
    "hidden_layer_size": 64,
    "dense_layer_activation": "relu",
    "output_activation": "sigmoid",
    # For compiling the model
    "optimizer": "adam",
    "loss": "binary_crossentropy",
    "metrics": ["accuracy"],
    # For fitting the model
    "batch_size": 1,
    "epochs": 10,
    "validation_freq": 1,
    # For callbacks
    # This learning rate schedule adjusts a dynamic learning rate as the model trains
    "learning_rate_schedule": [
        {
            "proportion": 0.33,
            "learning_rate": 0.001
        },
        {
            "proportion": 0.33,
            "learning_rate": 0.0001
        },
        {
            "proportion": 0.34,
            "learning_rate": 0.00001
        }
    ],
    "tensorboard": True,
    # For storing logs
    "work_dir": "/content/drive/My Drive/models/swd_trap_counts"
}
In [ ]:
# Create callbacks for model.fit

callbacks = []

if "learning_rate" in config:
    # Define learning rate schedule
    def learning_rate_schedule(epoch_index: int) -> float:
        epoch_until = 0
        for rule in config["learning_rate"]:
            epoch_until += int(rule['proportion'] * config['epochs'])
            if epoch_index < epoch_until:
                return rule['learning_rate']

    learning_rate_callback = tf.keras.callbacks.LearningRateScheduler(
        learning_rate_schedule, verbose=1)
    callbacks.append(learning_rate_callback)
    
if "tensorboard" in config:
    # Writing logs to monitor training in Tensorboard
    checkpoint_tensorboard = tf.keras.callbacks.TensorBoard(
        log_dir=os.path.join(config["work_dir"], 'logs'),
    )
    callbacks.append(checkpoint_tensorboard)
In [ ]:
def get_model():
    # set global random seed for Tensorflow
    tf.random.set_seed(1234)
    # Build model
    model = tf.keras.Sequential([
        tf.keras.layers.Input(shape = config["input_size"]),
        tf.keras.layers.Dense(
            units = config["hidden_layer_size"],
            activation = config["dense_layer_activation"]),
        tf.keras.layers.Dense(
            units = config["output_size"], 
            activation = config["output_activation"])
    ])

    model.compile(
        optimizer= config["optimizer"],
        loss= config["loss"],
        metrics= config["metrics"]
    )
    return model
In [ ]:
# Create a tensorflow dataset 

dataset = tf.data.Dataset.from_tensor_slices((scaled_inputs.values, targets))
dataset = dataset.shuffle(len(scaled_inputs)).batch(1)
In [ ]:
# Because we have so little data, we perform k-fold crossvalidation training
# We split the dataset into "k" folds (parts). Then the model trains k-times, 
# selecting different samples of the dataset as training and test data in each fold. 

# 1. Add index. Each "row" now has an index 0, 1, 2, 3, ..., N
dataset = dataset.enumerate()

# 2. Define number of folds. In this case, the number of folds is the number of
#    data items, i.e., we take out one data item and use the rest for training.
num_folds = len(scaled_inputs)

# 3. Go through each fold
for fold_index in range(num_folds):
    # Split the data into training and testing according to the fold_index
    train_list_ds = dataset.filter(lambda i, data: i % num_folds != fold_index)
    test_list_ds = dataset.filter(lambda i, data: i % num_folds == fold_index)

    # Remove the extra "index" column added so that we can more easily split
    # data items into folds.
    train_list_ds = train_list_ds.map(lambda i, data: data)
    test_list_ds = test_list_ds.map(lambda i, data: data)

    # Get a "fresh" model instance
    model = get_model()

    # Train the model, using the fold's test data as validation data
    model.fit(
        train_list_ds,
        batch_size = config["batch_size"],
        epochs = config["epochs"],
        callbacks = callbacks,
        validation_data = test_list_ds,
        validation_freq = config["epochs"],
        verbose = 0
    )
Out[ ]:
{'tags': ['hide-output']}
In [ ]:
# Load the TensorBoard notebook extension

%load_ext tensorboard
In [ ]:
# Set up Tensorboard to plot accuracy and loss as the model trains

%tensorboard --logdir="/content/drive/My Drive/models/swd_trap_counts"

tensorboard.png

The graph shows the validation and loss of k-folds number of models, so it is a bit messy to read. Because there is only one validation sample, the score is 1 (when the model predicted the outcome correctly), or 0 (when the model predicted the outcome incorrectly). To get a better idea of the validation accuracy, one can add the validation accuracy score from all the folds and divide it by the total number of folds.

In [ ]:
from tensorflow.python.summary.summary_iterator import summary_iterator
from pathlib import Path

validation_accuracy = []

for log_file_path in Path("/content/drive/My Drive/models/swd_trap_counts/logs/validation/").glob("*"):
  for event in summary_iterator(str(log_file_path)):
    for value in event.summary.value:
        if value.tag == 'epoch_accuracy':
            validation_accuracy.append(value.simple_value)
In [ ]:
np.asarray(validation_accuracy).sum()/len(validation_accuracy)
Out[ ]:
0.5934959349593496

With this particular dataset, we see that the training accuracy is only 0.59, not much better than chance. This means that, when given the environmental variables, the model can only correctly predict whether a site will be "high risk" or "low risk" 59% of the times. Either the predictors we chose are poor predictors to determine which sites will have high or low counts of SWD, or the dataset is too small to make a robust model.