python

numpy.where使い方

条件に合ったindexを返す

numpy.where()は、条件式(condition)がTrueのときはxFalseのときはyとするndarrayを返す関数。numpy.where(condition, x, y)の形になります。

array = np.arange(10)
np.where(array < 5, 3, 6)
=>array([3, 3, 3, 3, 3, 6, 6, 6, 6, 6])

x,yを省略した場合は、条件式に当てはまる要素のindexが返ります(ndarryのタプルで返る)。

array = np.arange(10)
np.where(array < 5)
=>(array([0, 1, 2, 3, 4]),)
result, = np.where(array < 5)
result
=>array([0, 1, 2, 3, 4])

np.where(array<5)の戻り値はタプルですが(今回は1つ)、最後のカンマがいらないので、アンパックすることで1つ目のarrayを取り出すことができます。

条件式を複数指定することもできます。その場合は条件式を()で囲み、&や|を使用して条件式を指定することで処理が可能となります。