array_api_extra.at

class array_api_extra.at(x, idx=<object object>, /)

Update operations for read-only arrays.

This implements jax.numpy.ndarray.at for all writeable backends (those that support __setitem__) and routes to the .at[] method for JAX arrays.

Parameters:
  • x (array) – Input array.

  • idx (index, optional) –

    Only array API standard compliant indices are supported.

    You may use two alternate syntaxes:

    >>> import array_api_extra as xpx
    >>> xpx.at(x, idx).set(value)  # or add(value), etc.
    >>> xpx.at(x)[idx].set(value)
    

  • copy (bool, optional) –

    None (default)

    The array parameter may be modified in place if it is possible and beneficial for performance. You should not reuse it after calling this function.

    True

    Ensure that the inputs are not modified.

    False

    Ensure that the update operation writes back to the input. Raise ValueError if a copy cannot be avoided.

  • xp (array_namespace, optional) – The standard-compatible namespace for x. Default: infer.

Return type:

Updated input array.

Warning

(a) When you omit the copy parameter, you should always immediately overwrite the parameter array:

>>> import array_api_extra as xpx
>>> x = xpx.at(x, 0).set(2)

The anti-pattern below must be avoided, as it will result in different behaviour on read-only versus writeable arrays:

>>> x = xp.asarray([0, 0, 0])
>>> y = xpx.at(x, 0).set(2)
>>> z = xpx.at(x, 1).set(3)

In the above example, x == [0, 0, 0], y == [2, 0, 0] and z == [0, 3, 0] when x is read-only, whereas x == y == z == [2, 3, 0] when x is writeable!

(b) The array API standard does not support integer array indices. The behaviour of update methods when the index is an array of integers is undefined and will vary between backends; this is particularly true when the index contains multiple occurrences of the same index, e.g.:

>>> import numpy as np
>>> import jax.numpy as jnp
>>> import array_api_extra as xpx
>>> xpx.at(np.asarray([123]), np.asarray([0, 0])).add(1)
array([124])
>>> xpx.at(jnp.asarray([123]), jnp.asarray([0, 0])).add(1)
Array([125], dtype=int32)

See also

jax.numpy.ndarray.at

Equivalent array method in JAX.

Notes

sparse, as well as read-only arrays from libraries not explicitly covered by array-api-compat, are not supported by update methods.

Examples

Given either of these equivalent expressions:

>>> import array_api_extra as xpx
>>> x = xpx.at(x)[1].add(2)
>>> x = xpx.at(x, 1).add(2)

If x is a JAX array, they are the same as:

>>> x = x.at[1].add(2)

If x is a read-only numpy array, they are the same as:

>>> x = x.copy()
>>> x[1] += 2

For other known backends, they are the same as:

>>> x[1] += 2
__init__(x, idx=<object object>, /)
Parameters:

Methods

__init__(x[, idx])

add(y, /[, copy, xp])

Apply x[idx] += y and return the updated array.

divide(y, /[, copy, xp])

Apply x[idx] /= y and return the updated array.

max(y, /[, copy, xp])

Apply x[idx] = maximum(x[idx], y) and return the updated array.

min(y, /[, copy, xp])

Apply x[idx] = minimum(x[idx], y) and return the updated array.

multiply(y, /[, copy, xp])

Apply x[idx] *= y and return the updated array.

power(y, /[, copy, xp])

Apply x[idx] **= y and return the updated array.

set(y, /[, copy, xp])

Apply x[idx] = y and return the update array.

subtract(y, /[, copy, xp])

Apply x[idx] -= y and return the updated array.