PythonPandas.com

numpy.argmax() in Python



numpy.argmax() returns the index (or indices) of the maximum element along the specified axis.

Syntax – numpy.argmax(a, axis=None, out=None)

a : array_like — input array.
axis : int or None — axis along which to find the indices of the maximum. None flattens.
out : ndarray, optional — alternative output array to store result (must have appropriate shape and dtype).

Return: index (or indices) of maximum values. If axis is specified, returns an array of indices with the dimension along axis removed.

Example 1 — 1D array (ties: first occurrence)

import numpy as np

arr = np.array([1, 3, 3, 2])
idx = np.argmax(arr)
print("Array:", arr)
print("argmax:", idx)

Output:

 Array: [1 3 3 2] argmax: 1 

Explanation: The maximum value is 3 and appears at indices 1 and 2, but argmax returns the first occurrence 1.

Example 2 — 2D array, axis=None (flattened) and axis-specific

import numpy as np

a = np.arange(12).reshape(3, 4)
print("Array:\n", a)

# Flattened global maximum index (index in flattened array)
flat_idx = np.argmax(a)            # same as np.argmax(a, axis=None)
print("\nargmax (flattened):", flat_idx)

# Column-wise max indices (axis=0)
col_idx = np.argmax(a, axis=0)
print("argmax axis=0 (per column):", col_idx)

# Row-wise max indices (axis=1)
row_idx = np.argmax(a, axis=1)
print("argmax axis=1 (per row):", row_idx)

Output:

 Array: [[ 0 1 2 3] [ 4 5 6 7] [ 8 9 10 11]] argmax (flattened): 11 argmax axis=0 (per column): [2 2 2 2] argmax axis=1 (per row): [3 3 3] 

Notes:

Flattened index 11 corresponds to element 11 at position (2, 3).

axis=0 gives the row index of the max for each column.

axis=1 gives the column index of the max for each row.

Example 3 — explicit 2D matrix (same as your earlier sample)

import numpy as np

a = np.array([
    [ 0,  3,  8, 13],
    [12, 11,  2, 11],
    [ 5, 13,  8,  3],
    [12, 15,  3,  4]
])
print("Array:\n", a)

print("\nGlobal max value:", a.max())
print("argmax (flattened):", np.argmax(a))
print("argmax axis=0 (per column):", np.argmax(a, axis=0))
print("argmax axis=1 (per row):", np.argmax(a, axis=1))

Output:

 Array: [[ 0 3 8 13] [12 11 2 11] [ 5 13 8 3] [12 15 3 4]] Global max value: 15 argmax (flattened): 15 argmax axis=0 (per column): [1 3 0 0] argmax axis=1 (per row): [3 0 1 1] 

np.argmax(a) printed 15 above because the global max value printed, to get the flattened index, call np.argmax(a) (it returns the index number). if you print np.argmax(a) you get the flattened index 15 only when the array contains values 0..15
Here for clarity we printed a.max() and np.argmax(a) results separately. (In practice np.argmax(a) returns the flattened index i such that a.flat[i] is maximal.)

Converting flattened index to coordinates

Use np.unravel_index() to convert a flattened index into multi-dimensional coordinates.

import numpy as np

a = np.arange(12).reshape(3, 4)
flat_idx = np.argmax(a)  # flattened index
coords = np.unravel_index(flat_idx, a.shape)
print("flattened index:", flat_idx)
print("coords:", coords)
print("max element:", a[coords])

Output:

 flattened index: 11 coords: (2, 3) max element: 11 

Using the out parameter

You can supply an output array to receive the results (useful in advanced pipelines to avoid allocations).

import numpy as np

a = np.arange(12).reshape(3, 4)
out = np.empty(4, dtype=np.intp)   # for axis=0, result length = 4
np.argmax(a, axis=0, out=out)
print("out array after argmax(axis=0):", out)

Output:

out array after argmax(axis=0): [2 2 2 2]

Note: out must have the correct shape and dtype; otherwise NumPy raises an error.

Performance notes and tips

  • argmax is implemented in C and is efficient for large arrays; it runs in linear time O(n) in the number of elements along the scanned dimension.
  • For finding top-k indices (not just the single max), consider np.argpartition which runs faster than sorting for partial order.
  • When working with floating point arrays, NaN values propagate: comparisons with NaN are false, so the position of NaN may affect results unexpectedly. Handle NaN (e.g., with np.nanargmax) when needed.


Summary

– np.argmax() returns indices of maximum elements; axis=None flattens the array.

– Ties resolved by returning first occurrence.

– Use np.unravel_index() to map flattened indices to multi-dimensional coordinates.

– For partial or multiple top values, prefer np.argpartition().

Reference –

Official NumPy docs — numpy.argmax: https://docs.scipy.org/doc/numpy/reference/generated/numpy.argmax.html

Related Post