Undersampling minority samplesΒΆ

This example creates an imbalanced classification dataset, and under-samples it to balance the class ratios. Under-sampling is typically only possible if your dataset is large enough to facilitate losing some samples from the majority class. If not, you may later suffer from a high variance problem.


Out:

Num zero class (pre-balance): 513
Num one class (pre-balance): 4487

Num zero class (post-balance): 513
Num one class (post-balance): 2565
Num samples (post-balance): 3078
             0  ...        19
3475  1.291530  ...  0.193104
497  -1.520067  ... -3.660372
2767  2.465927  ... -2.666850
3226 -0.161823  ... -0.314925
4905  0.103988  ...  0.427975

[5 rows x 20 columns]

print(__doc__)

# Author: Taylor Smith <taylor.smith@alkaline-ml.com>

from sklearn.datasets import make_classification
from skoot.balance import under_sample_balance
import pandas as pd

# #############################################################################
# Create an imbalanced dataset
X, y = make_classification(n_samples=5000, n_classes=2, weights=[0.1, 0.9],
                           random_state=42)

# get counts:
zero_mask = y == 0
print("Num zero class (pre-balance): %i" % zero_mask.sum())
print("Num one class (pre-balance): %i\n" % (~zero_mask).sum())

# #############################################################################
# Balance the dataset
X_balance, y_balance = under_sample_balance(X, y, balance_ratio=0.2,
                                            random_state=42)

# get the new counts
new_mask = y_balance == 0
print("Num zero class (post-balance): %i" % new_mask.sum())
print("Num one class (post-balance): %i" % (~new_mask).sum())
print("Num samples (post-balance): %i" % X_balance.shape[0])

# #############################################################################
# This also works for pandas DataFrames

X_balance_df, _ = under_sample_balance(pd.DataFrame.from_records(X),
                                       y, balance_ratio=0.2,
                                       random_state=42)

print(X_balance_df.head())

Total running time of the script: ( 0 minutes 0.055 seconds)

Gallery generated by Sphinx-Gallery