RでK-means++を実装する

R に付属の関数 kmeans を使って,K-means++をなるべく高速に実装します.特に必要となるライブラリはありません.

コード

検証

単純な例で速度を比較してみます.まずデータを生成します.

> x <- rbind(matrix(rnorm(1e6, 0, 1), ncol=2), matrix(rnorm(1e6, 5, 1), ncol=2))

データ数1,000,000,次元数2,クラスタ数2のデータセットです.
速度は10回の平均で測ります,結果を入れるベクトルを用意しておきます.

> t <- numeric(10)

まずはベースライン.普通のK-meansです.

> for (i in 1:10) t[i] <- system.time(kmeans(x, 2))[1]; mean(t)
[1] 0.5734  # 単位は秒

次に,上に載せたK-means++です.devtoolsライブラリにあるsource_gist関数を使うと,gistのコードを直接読み込むことができます.

> install.packages("devtools")
> library(devtools)
> source_gist("ef54a3b17fff4629f106")
> for (i in 1:10) t[i] <- system.time(kmeansp2(x, 2))[1]; mean(t)
[1] 0.6614

クラスタ中心の初期値を選ぶ際に必要となる距離計算の分だけ,少し遅くなっていますが,いい感じです.
また,ライブラリLICORSにK-means++の実装があります.これを使ってみます.

> install.packages("LICORS")
> library(LICORS)
> for (i in 1:10) t[i] <- system.time(kmeanspp(x, 2, nstart = 1))[1]; mean(t)
[1] 2.7971

約4倍も遅いです.これは,距離計算がベクトル化されておらず,apply関数(for文)で実装されているためです.
最後に,このコードを試してみます.ライブラリpracma にあるdistmat関数を使って距離計算をしています.

> install.packages("pracma")
> library(pracma)
> for (i in 1:10) t[i] <- system.time(kmpp(x, 2))[1]; mean(t)
[1] 4.7541

8倍ほど遅くなってしまっています.これもfor文(apply関数)の使用が原因です.
このように,for文を避けて処理をベクトル化することで,高速に処理できます. R上でK-means++を使うときには,ぜひ上で紹介したコードを使ってみてください.

高速化のポイント

あるベクトル q から,データセット X 中の各ベクトルへの(ユークリッド)距離を全て求める計算を,どのように実装するかがポイントです(上のコードでは10行目).
例えば,以下のようにデータを生成します.

> x <- matrix(runif(3e6), ncol=3)  # データ数1,000,000,次元数3
> q <- runif(3)

まずは,for文を使って普通に距離を計算してみます.

> dist <- numeric(nrow(x))  # 結果を入れるベクトル
> system.time(for (i in 1:nrow(x)) dist[i] <- sqrt(sum((q - x[i,])^2)))
   user  system elapsed
  2.017   0.007   2.024

apply関数を使うと,もう少し簡潔に書けます.

> system.time(dist <- apply(x, 1, function(x) sqrt(sum((x - q)^2))))
   user  system elapsed
  2.392   0.027   2.418

しかし,速度はほぼ同じ,むしろ遅くなってしまっています.そこで,この処理を以下のようにベクトル化します.

> system.time(dist <- sqrt(colSums((t(x) - q)^2)))
   user  system elapsed
  0.023   0.001   0.023

なんと100倍くらい速くなりました.

参考文献

Arthur, D. and Vassilvitskii, S.: K-means++: The Advantages of Careful Seeding, In Proc. of the 18th Annual ACM-SIAM Symposium on Discrete Algorithms, 1027-1035 (2007).