Data Leakage: The Achilles’ Heel of Machine Learning Models
In the ever-evolving landscape of artificial intelligence and machine learning, data is often hailed as the lifeblood of algorithms, fueling their ability to make accurate predictions and decisions. However, amidst the excitement of developing sophisticated models, there lurks a pervasive threat that has the potential to undermine the integrity and performance of these systems: data leakage.
Data leakage occurs when information from outside the training dataset inadvertently finds its way into the model, thereby skewing its performance and leading to erroneous conclusions. This phenomenon is particularly insidious because it often goes unnoticed until the model is deployed in real-world scenarios, where its flawed predictions can have significant consequences.
Data leakage refers to the phenomenon when a form of the label “leaks” into the set of features used for making predictions, and this same information is not available during inference.
Data leakage is challenging because often the leakage is nonobvious. It’s dangerous because it can cause your models to fail in an unexpected and spectacular way, even after extensive evaluation and testing. Let’s go over another example to demonstrate what data leakage is.
Suppose you want to build an ML model to predict whether a CT scan of a lung shows signs of cancer. You obtained the data from hospital A, removed the doctors’ diagnosis from the data, and trained your model. It did really well on the test data from hospital A, but poorly on the data from hospital B.
After extensive investigation, you learned that at hospital A, when doctors think that a patient has lung cancer, they send that patient to a more advanced scan machine, which outputs slightly different CT scan images. Your model learned to rely on the information on the scan machine used to make predictions on whether a scan image shows signs of lung cancer. Hospital B sends the patients to different CT scan machines at random, so your model has no information to rely on. We say that labels are leaked into the features during training.
Data leakage can happen not only with newcomers to the field, but has also happened to several experienced researchers whose work I admire, and in one of my own projects. Despite its prevalence, data leakage is rarely covered in ML curricula.
Common Causes for Data Leakage:
- Splitting time-correlated data randomly instead of by time
When we study machine learning, we’re usually told to divide our data into three parts: training, validation, and testing. This is also how researchers often divide their data in their studies. But sometimes, this method can lead to a problem called “data leakage.”
Let’s say we’re trying to predict something based on data that changes over time, like stock prices. Usually, if the prices of similar stocks go up or down today, they’ll do the same tomorrow. So, if we mix data from different days together when training our model, we might accidentally include information from the future. This can make our model seem better than it actually is because it’s using information it shouldn’t have had.
But it’s not just about obvious things like stock prices. Even in cases where the connection isn’t so clear, like predicting if someone will click on a song recommendation, there can still be a problem. For example, if a famous artist passes away, more people might listen to their songs that day. If our training data includes information from that day, our model might learn to rely on this trend, even though it’s not what we want it to focus on.
To avoid this problem, we should split our data based on time whenever we can. For instance, if we have data from five weeks, we should use the first four weeks for training and the fifth week for testing. This way, we make sure our model isn’t getting any hints about the future when it’s learning.
2. Scaling before splitting
Scaling your features is an important step in preparing your data for machine learning. It helps to ensure that all features have a similar scale, which can improve the performance of certain algorithms.
However, there’s a common mistake that can lead to data leakage when scaling. This happens when you calculate statistics like the mean and variance using the entire dataset before splitting it into training and testing sets. When you do this, you’re inadvertently including information from the testing set in the statistics used to scale the training set. This can cause your model to perform better during testing than it would in real-world situations because it’s using information it shouldn’t have had access to.
To avoid this problem, always split your data into training and testing sets before scaling. Then, calculate the mean and variance using only the training set and apply those statistics to scale all the sets. Some experts even recommend splitting your data before doing any exploratory data analysis or processing to prevent accidentally gaining insights into the testing set. This ensures that your model learns only from the training data and generalizes well to unseen data.
3. Filling in missing data with statistics from the test split
When dealing with missing values in a feature, a common approach is to replace them with the mean or median of the available values. However, if you calculate the mean or median using all the data instead of just the training set, it can cause data leakage. To prevent this, only use statistics from the training set to fill in missing values for all sets.
4. Poor handling of data duplication before splitting
If your data has duplicates or very similar samples, leaving them in before splitting can cause the same data to end up in both the training and testing sets. This is a common issue, even in widely used datasets like CIFAR-10 and CIFAR-100. For example, it was only discovered in 2019 that some images from the test sets of these datasets were duplicates of images in the training set.
Data duplication can happen during collection or when merging different datasets. For instance, in COVID-19 research, one dataset combined several others without realizing they overlapped. It can also occur during data processing, such as when oversampling leads to duplicates.
To avoid this, always check for duplicates before splitting your data, and again after splitting to double-check. If you’re oversampling your data, do it after splitting to prevent duplicates from affecting your model.
5. Group leakage
A group of examples have strongly correlated labels but are divided into different splits. For example, a patient might have two lung CT scans that are a week apart, which likely have the same labels on whether they contain signs of lung cancer, but one of them is in the train split and the second is in the test split. This type of leakage is common for objective detection tasks that contain photos of the same object taken milliseconds apart — some of them landed in the train split while others landed in the test split. It’s hard avoiding this type of data leakage without understanding how your data was generated.
6. Leakage from data generation process
The example mentioned earlier, where information about lung cancer diagnosis leaks through the CT scan machine, illustrates this type of data leakage. Detecting such leakage requires a deep understanding of how data is gathered. For instance, it’s challenging to realize that a model performs poorly in one hospital due to differences in scan machines unless you’re aware of these variations.
Preventing this type of leakage completely is difficult, but you can reduce the risk by understanding how your data is collected and processed. Normalize your data so that information from different sources has similar characteristics. For example, if CT scan machines produce images with different resolutions, standardizing the resolution of all images can help obscure their origin. Additionally, involve subject matter experts in your machine learning process to gain insights into data collection nuances and improve model design.
Detecting Data Leakage
Data leakage can occur at various stages of a machine learning project, including data generation, collection, sampling, splitting, processing, and feature engineering. It’s crucial to monitor for data leakage throughout the entire project lifecycle.
One way to detect potential leakage is to measure the predictive power of each feature or group of features concerning the target variable. If a feature shows unusually high correlation with the target, investigate how it’s generated and whether the correlation makes sense. Sometimes, two features independently may not contain leakage, but together, they might. For instance, in predicting employee tenure, the start and end dates alone may not provide much information, but together they could reveal tenure patterns.
Ablation studies can help assess the importance of features to your model. If removing a feature significantly worsens the model’s performance, delve into why that feature is crucial. While conducting ablation studies on every feature combination may be impractical for large datasets, focusing on suspected crucial features can still yield insights. Subject matter expertise can guide this process. Ablation studies can be performed offline during downtime.
When adding new features, be cautious if their inclusion drastically boosts model performance. It could mean the feature is genuinely valuable, or it might contain leaked information about the labels.
Finally, be vigilant when analyzing the test split. Any use of the test split beyond reporting final model performance, such as generating new feature ideas or tuning hyperparameters, risks leaking future information into the training process.
Summary
- Split data by time into train/valid/test splits instead of doing it randomly.
- If you oversample your data, do it after splitting.
- Scale and normalize your data after splitting to avoid data leakage.
- Use statistics from only the train split, instead of the entire data, to scale your features and handle missing values.
- Understand how your data is generated, collected, and processed. Involve domain experts if possible.
- Keep track of your data’s lineage.
- Understand feature importance to your model.
- Use features that generalize well.
- Remove no longer useful features from your models.