One-hot encodingΒΆ

Demonstrates how to use the DummyEncoder. For a more comprehensive explanation, take a look at the demo on alkaline-ml.com.


Out:

Test transformation:
       age  ...  native-country_Vietnam
14160   27  ...                     0.0
27048   45  ...                     0.0
28868   29  ...                     0.0
5667    30  ...                     0.0
7827    29  ...                     0.0

[5 rows x 99 columns]

Applied on a row with a new native-country:
       age  ...  native-country_Vietnam
14160   27  ...                     0.0

[1 rows x 99 columns]

print(__doc__)

# Author: Taylor Smith <taylor.smith@alkaline-ml.com>

from skoot.datasets import load_adult_df
from skoot.preprocessing import DummyEncoder
from skoot.utils.dataframe import get_categorical_columns
from sklearn.model_selection import train_test_split
import pandas as pd

# #############################################################################
# load & split the data
adult = load_adult_df(tgt_name="target")
y = adult.pop("target")

# we don't want this column
_ = adult.pop("education-num")

X_train, X_test, y_train, y_test = train_test_split(adult, y, random_state=42,
                                                    test_size=0.2)

# #############################################################################
# Fit a dummy encoder
obj_cols = get_categorical_columns(X_train).columns
encoder = DummyEncoder(cols=obj_cols, handle_unknown='ignore', n_jobs=4)
encoder.fit(X_train, y_train)

# #############################################################################
# Apply to the test set
print("Test transformation:")
print(encoder.transform(X_test).head())

# #############################################################################
# Show we can work with levels we've never seen before
test_row = X_test.iloc[0]
test_row.set_value("native-country", "Atlantis")
trans = encoder.transform(pd.DataFrame([test_row]))
print("\nApplied on a row with a new native-country:")
print(trans)

nc_mask = trans.columns.str.contains("native-country")
assert trans[trans.columns[nc_mask]].sum().sum() == 0

Total running time of the script: ( 0 minutes 1.526 seconds)

Gallery generated by Sphinx-Gallery