A stratified train-test split in pandas creates training and test sets that preserve the class distribution (or distribution of another categorical variable) from the original DataFrame.
When your target variable is imbalanced—think churn prediction with only 6 % churners or a fraud model with 0.3 % positives— a random train–test split can leave you with too few minority-class examples for model training or evaluation. A stratified split solves this by sampling each class proportionally, ensuring both training and test sets mirror the overall distribution.
A stratified split is a sampling technique that divides a dataset into two or more subsets, but does so stratum-by-stratum. A stratum is usually the target label, but it can also be any categorical column (e.g., geography, customer segment). By preserving class proportions, stratification guarantees that rare classes are represented in all subsets.
train_test_split()
?The vanilla train_test_split()
from sklearn.model_selection
uses uniform random sampling. On highly imbalanced data, this can lead to a test set that contains zero examples of the minority class, making metrics such as recall or AUC meaningless. Stratification maintains statistical power without needing to oversample or otherwise manipulate the data.
Although pandas itself does not ship a dedicated stratified splitter, we can marry pandas with scikit-learn’s StratifiedShuffleSplit
or train_test_split()
with the stratify
argument.
import pandas as pd
from sklearn.model_selection import train_test_split
df = pd.read_csv("transactions.csv")
y
, but any categorical column works:y = df["is_fraud"] # Minority class ~0.3%
train_df, test_df = train_test_split(
df, # full DataFrame
test_size=0.2, # 20% test
stratify=y, # preserve class distribution
random_state=42 # reproducibility
)
print(df["is_fraud"].value_counts(normalize=True))
print(train_df["is_fraud"].value_counts(normalize=True))
print(test_df["is_fraud"].value_counts(normalize=True))
Scikit-learn’s splitter computes indices for each class separately and then concatenates them. Because DataFrames are index-aligned, the resulting train_df
and test_df
maintain the original row order inside each stratum unless shuffle=False
is specified.
random_state
Always specify random_state
for reproducible experiments. Without it, every run will yield different splits.
If a class has fewer than n_splits
samples, stratification will error out. Consider combining tiny classes or using cross-validation with grouping constraints.
StratifiedKFold
for CVFor k-fold cross-validation, use StratifiedKFold
to achieve the same benefits across folds.
For purely temporal data, use forward-chaining splits instead. Stratification can leak future information into the past.
import pandas as pd
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report
# Load sample data
X, y = load_breast_cancer(return_X_y=True, as_frame=True)
df = X.copy()
df["target"] = y
# Stratified split
train_df, test_df = train_test_split(
df,
test_size=0.3,
stratify=df["target"],
random_state=0,
)
# Train model
clf = RandomForestClassifier(random_state=0)
clf.fit(train_df.drop("target", axis=1), train_df["target"])
# Evaluate
preds = clf.predict(test_df.drop("target", axis=1))
print(classification_report(test_df["target"], preds))
stratify
If you mistakenly pass a continuous column, scikit-learn will complain or coerce it to unique strata, defeating the purpose. Always use a categorical or discrete label.
After splitting, indices are preserved; downstream concatenation might duplicate indices. Use .reset_index(drop=True)
if index uniqueness matters.
Some practitioners oversample/undersample before the split, causing train–test leakage. Always split first, then apply resampling only on the training data.
Galaxy is primarily a SQL editor. A stratified split happens in Python during model development, not in SQL, so Galaxy isn’t directly involved. However, once you’ve trained a model, you might store results in a database and analyze them via Galaxy’s lightning-fast SQL editor.
train_test_split(..., stratify=target)
for quick one-off splits.In analytics workflows, improper sampling can skew evaluation metrics, leading to false confidence in model performance. A stratified split guarantees that both training and test sets are representative of the underlying distribution, which is critical for valid A/B tests, machine-learning pipelines, and regulatory reporting.
No. Pandas focuses on data manipulation. Use scikit-learn’s `train_test_split` with the `stratify` argument or `StratifiedShuffleSplit`.
If a class has too few rows, merge it with a similar class or use cross-validation techniques like `StratifiedKFold` with careful parameter choices. You can also employ resampling after splitting.
It’s most common in classification, but any categorical subgroup you need represented (e.g., business segment) can be used.
Yes. Create a combined categorical column, e.g., `df["strata"] = df["gender"] + "_" + df["region"]`, then pass it to `stratify`.