9.2 K-means Clustering
Another popular machine learning technique is k-means clustering. It seeks to group your data into a fixed number of clusters based on a measure of distance. An important pre-processing step to clustering is the center and scale your data.
cats <- cats %>%
mutate(scale_weight = as.numeric(scale(weight)),
scale_wander = as.numeric(scale(wander_dist)),
scale_age = as.numeric(scale(age)))
We use the kmeans
function to perform the clustering, and need to pass it a data frame, the number of centers we want, and another argument nstart
, which helps get around some numerical issues if the algorithm gets stuck in a local extrema.
cats_cluster <-
kmeans(x = cats %>% dplyr::select(scale_weight, scale_wander, scale_age),
centers = 3,
nstart = 20)
str(cats_cluster)
cats_cluster$centers
We can use ggplot
to visualize the clusters.
cats$cluster <- factor(cats_cluster$cluster)
cluster_centers <- as.data.frame(cats_cluster$centers)
ggplot(data = cats) +
geom_point(aes(x = scale_age, y = scale_weight, color = cluster), size = 4) +
geom_point(data = cluster_centers, aes(x = scale_age, y = scale_weight), color = 'black', size = 8) +
theme_bw(base_size = 18)
ggplot(data = cats) +
geom_point(aes(x = scale_age, y = scale_wander, color = cluster), size = 4) +
geom_point(data = cluster_centers, aes(x = scale_age, y = scale_wander), color = 'black', size = 8) +
theme_bw(base_size = 18)
ggplot(data = cats) +
geom_point(aes(x = scale_weight, y = scale_wander, color = cluster), size = 4) +
geom_point(data = cluster_centers, aes(x = scale_weight, y = scale_wander), color = 'black', size = 8) +
theme_bw(base_size = 18)