-
Notifications
You must be signed in to change notification settings - Fork 3
Open
Description
MRE:
y = ndx.full((2, 1), np.nan)
mask = ndx.asarray([True, False])
x = ndx.asarray([[0]])
y[mask] = x
fails with:
Traceback (most recent call last):
y[mask] = x
~^^^^^^
File "ndonnx/_array.py", line 262, in __setitem__
self._tyarray[key._tyarray] = updates
~~~~~~~~~~~~~^^^^^^^^^^^^^^
File "ndonnx/_typed_array/onnx.py", line 485, in __setitem__
return self._setitem_boolmask(key, value)
~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
File "ndonnx/_typed_array/onnx.py", line 522, in _setitem_boolmask
const(1, int64).broadcast_to(key.dynamic_shape).cumulative_sum() - 1,
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^
File "ndonnx/_typed_array/onnx.py", line 1083, in cumulative_sum
raise ValueError(
"'axis' parameter must be provided for arrays with more than one dimension"
)
ValueError: 'axis' parameter must be provided for arrays with more than one dimension
Specifying axis=0 to the cumulative_sum() function sidesteps the error, but the shape of the array returned (2,) is different to with numpy (2, 1). I think this is because we're not assigning arr to the shape returned by arr.reshape in the same function (_setitem_boolmask).
My proposed fix would be to modify the above function to be as follows:
def _setitem_boolmask(self, key: TyArrayBool, value: Self) -> None:
if self.ndim < key.ndim:
raise IndexError("provided boolean mask has higher rank than 'self'")
if self.ndim > key.ndim:
diff = self.ndim - key.ndim
# expand to rank of self
idx = [...] + diff * [None]
key = key[tuple(idx)]
if value.ndim == 0:
# We can be sure that we don't run into broadcasting
# issues with a zero-sized value if the value is a
# scalar. This allows us to use `where`.
self._var = op.where(key._var, value._var, self._var)
return
int_keys = safe_cast(
TyArrayInt64,
const(1, int64).broadcast_to(key.dynamic_shape).cumulative_sum(axis=0) - 1,
)[key]
arr = self.copy().reshape((-1,))
arr.put(int_keys, value.reshape((-1,)))
arr = arr.reshape(self.dynamic_shape)
self._var = arr._var
though I defer to the experts.
Metadata
Metadata
Assignees
Labels
No labels