Question

Why do NumPy scalars multiply with custom sequences but not with lists?

I have a question to NumPy experts. Consider a NumPy scalar: c = np.arange(3.0).sum(). If I try to multiply it with a custom sequence like e.g.

class S:
    
    def __init__(self, lst):
        self.lst = lst
        
    def __len__(self):
        return len(self.lst)
    
    def __getitem__(self, s):
        return self.lst[s]

c = np.arange(3.0).sum()
s = S([1, 2, 3])

print(c * s)

it works and I get: array([3., 6., 9.]).

However, I can't do so with a list. For instance, if I inherit S from list and try it, this does not work anymore

class S(list):
    
    def __init__(self, lst):
        self.lst = lst
        
    def __len__(self):
        return len(self.lst)
    
    def __getitem__(self, s):
        return self.lst[s]

c = np.arange(3.0).sum()
s = S([1, 2, 3])

print(c * s)

and I get "can't multiply sequence by non-int of type 'numpy.float64'".

So how does NumPy distinguish between the two cases?

I am asking because I want to prevent such behavior for my "S" class without inheriting from list.


UPD

Since the question has been misunderstood a few times, I try to stress more exactly what the problem is. It's not about why the list does not cope with multiplication by float and the error is raised.

It is about why in the first case (when S is NOT inherited from list) the multiplication is performed by the object "c" and in the second case (when S is inherited from list) the multiplication is delegated to "s".

Apparently, the method c.__mul__ does some checks which pass in the first case, but fail in the second case, so that s.__rmul__ is called. The question is essentially: What are those checks? (I strongly doubt that this is anything like isinstance(other, list)).

 2  102  2
1 Jan 1970

Solution

 1

As pointed out by @hpaulj (thanks a lot!) this question was already around: Array and __rmul__ operator in Python Numpy.

The mechanics of the problem was explained very well in this answer: stackoverflow.com/a/38230576/901925. However, the proposed solution of inheriting from np.ndarray is certainly a "dirty" one.

In the reply immediately after, a solution based on the function __numpy_ufunc__ is proposed. The latter however is called __array_ufunc__ in modern NumPy. This function can be just set to None in the definition of the class "S". This leads to delegation of multiplcation to s.__rmul__ without attempts to perform it via c.__mul__.

2024-07-22
Pavlo Bilous

Solution

 0

So how does NumPy distinguish between the two cases?

I don't think it's NumPy's distinction; it seems to be Python's.

Python lists define multiplication with Python ints, but not other objects:

2. * [1, 2, 3]
# TypeError: can't multiply sequence by non-int of type 'float'

The message is always the same: Python lists don't think they can be multiplied by "non-int"s.

So then the question is why Python list can by multiplied by other types of objects, such as NumPy integers, and that's because the other object has the opportunity to define multiplication, too.

The built-in definition of NumPy floats is similar to that of Python floats, which cannot by multiplied with lists. However, if we replace the definition of multiplication for NumPy floats, we can get multiplication with lists to work.

# if we override the multiplication behavior of np.float64
# it works
class T(np.float64):
    def __mul__(self, other):
        return 2 * other

t = T()
t * [1, 2, 3]
# [1, 2, 3, 1, 2, 3]

You can prevent multiplication of your class with other types of objects by raising an exception from the __mul__ (and __rmul__) method(s).

class S():
    
    def __init__(self, lst):
        self.lst = lst
        
    def __len__(self):
        return len(self.lst)
    
    def __getitem__(self, s):
        return self.lst[s]

    def __mul__(self, other):
        if not isinstance(other, int):
            message = ("can't multiply instance of 'S' by  "
                       f"non-int of type '{type(other).__name__}'")
            raise TypeError(message)
        return self.lst * other


s = S([1, 2, 3])

print(s * 2)
# [1, 2, 3, 1, 2, 3]

print(s * 2.)
# TypeError: can't multiply instance of 'S' by  non-int of type 'float'
2024-07-20
Matt Haberland