Skip to content

[Bug]: Boolean mask assignment failure #177

@VictoriaAdjeiQC

Description

@VictoriaAdjeiQC

MRE:

y = ndx.full((2, 1), np.nan)
mask = ndx.asarray([True, False])
x = ndx.asarray([[0]])

y[mask] = x

fails with:

Traceback (most recent call last):
    y[mask] = x
    ~^^^^^^
  File "ndonnx/_array.py", line 262, in __setitem__
    self._tyarray[key._tyarray] = updates
    ~~~~~~~~~~~~~^^^^^^^^^^^^^^
  File "ndonnx/_typed_array/onnx.py", line 485, in __setitem__
    return self._setitem_boolmask(key, value)
           ~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
  File "ndonnx/_typed_array/onnx.py", line 522, in _setitem_boolmask
    const(1, int64).broadcast_to(key.dynamic_shape).cumulative_sum() - 1,
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^
  File "ndonnx/_typed_array/onnx.py", line 1083, in cumulative_sum
    raise ValueError(
        "'axis' parameter must be provided for arrays with more than one dimension"
    )
ValueError: 'axis' parameter must be provided for arrays with more than one dimension

Specifying axis=0 to the cumulative_sum() function sidesteps the error, but the shape of the array returned (2,) is different to with numpy (2, 1). I think this is because we're not assigning arr to the shape returned by arr.reshape in the same function (_setitem_boolmask).

My proposed fix would be to modify the above function to be as follows:

    def _setitem_boolmask(self, key: TyArrayBool, value: Self) -> None:
        if self.ndim < key.ndim:
            raise IndexError("provided boolean mask has higher rank than 'self'")
        if self.ndim > key.ndim:
            diff = self.ndim - key.ndim
            # expand to rank of self
            idx = [...] + diff * [None]
            key = key[tuple(idx)]

        if value.ndim == 0:
            # We can be sure that we don't run into broadcasting
            # issues with a zero-sized value if the value is a
            # scalar. This allows us to use `where`.
            self._var = op.where(key._var, value._var, self._var)
            return

        int_keys = safe_cast(
            TyArrayInt64,
            const(1, int64).broadcast_to(key.dynamic_shape).cumulative_sum(axis=0) - 1,
        )[key]
        arr = self.copy().reshape((-1,))
        arr.put(int_keys, value.reshape((-1,)))
        arr = arr.reshape(self.dynamic_shape)
        self._var = arr._var

though I defer to the experts.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions