Python学习之路32-运算符重载

《流畅的Python》笔记。

本篇是“面向对象惯用方法”的第六篇,也是最后一篇。本篇将讨论Python中的运算符重载。

1. 前言

Python中的运算符重载和C++中的运算符重载并不一样,C++中同一运算符可以有多个重载函数,Python中的运算符重载其实是实现运算符的同名特殊方法。

本篇只讨论一元运算符和中缀运算符,内容如下:

  • Python如何处理中缀运算符中不同类型的操作数;
  • 使用鸭子类型或白鹅类型处理不同类型的操作数;
  • 中缀运算符如何表明自己无法处理操作数;
  • 众多比较运算符的特殊行为;
  • 增量运算符的默认处理方式和重载方式。

不过,需要说明的是,并不是所有的运算符都能重载:

  • 不能重载内置类型的运算符;
  • 不能新建运算符,只能重载现有的;
  • isandornot不能重载。

本文中的示例延用《Python学习之路29》中的多维向量Vector

2. 一元运算符

本节主要介绍4个一元运算符,它们分别是:

  • - (__neg__):一元取负运算符,如x = 2,则-x == 2
  • +(__pos__):一元取正运算符,通常是x == +x,但也有特例;
  • ~(__invert__):对整数按位取反,定义为~x == -(x + 1)
  • abs()函数:Python语言参考手册把它也列为了一元运算符,它对应的就是之前多次用到的__abs__

在实现过程中需要遵循这些运算符的一个基本规则:始终返回一个新对象!也就是说不能修改self,要创建并返回合适类型的实例。以下补充两个Vector类的运算符重载:

1
2
3
4
5
def __neg__(self):
return Vector(-x for x in self)

def __pos__(self):
return Vector(self)

x+x何时不等?以下是两个例子:

  • 如果decimal.Decimal所在上下文的精度不同,则有可能不等,如下:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    >>> import decimal
    >>> ctx = decimal.getcontext()
    >>> ctx.prec = 40
    >>> one_third = decimal.Decimal("1") / decimal.Decimal("3")
    >>> one_third
    Decimal('0.3333333333333333333333333333333333333333')
    >>> one_third == +one_third
    True
    >>> ctx.prec = 28 # 这是默认精度
    >>> one_third == +one_third
    False
    >>> +one_third
    Decimal('0.3333333333333333333333333333')
  • collections.Counter在相加时,负值和零值计数会从结果中剔除,而一元运算符+对它来说等同于加上一个空Counter,如下:

    1
    2
    3
    4
    5
    6
    7
    >>> ct = Counter("abracadabra")
    >>> ct["r"] = -3
    >>> ct["d"] = 0
    >>> ct
    Counter({'a': 5, 'b': 2, 'c': 1, 'd': 0, 'r': -3})
    >>> +ct # 与ct不等
    Counter({'a': 5, 'b': 2, 'c': 1})

3. 重载向量加法运算符+

目前版本的Vector不支持向量相加,因为没有重载+运算符。我们的要求如下:

  • 它能实现两个Vector相加,并且两个长度不等的Vector也能相加,短的那个用0.0填充;
  • 能与任何可迭代对象相加,但当这个可迭代对象中的元素不能与浮点数做加法运算时,则抛出NotImplemented异常;
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def __add__(self, other):
try:
pairs = itertools.zip_longest(self, other, fillvalue=0.0) # 自动填充
return Vector(a + b for a, b in pairs)
except TypeError:
# 它不是一个异常类,而是一个单例值!所以用的是return,而不是raise
return NotImplemented

def __radd__(self, other): # 实现反向相加
return self + other

# 在控制台中运行的示例,省略了import语句
>>> v1 = Vector([1, 2, 3])
>>> v1 + Vector([2, 3, 4]) # 可以和同类型的相加
Vector([3.0, 5.0, 7.0])
>>> v1 + (1, 2, 3) # 和其他可迭代对象也能相加
Vector([2.0, 4.0, 6.0])
>>> v1 + (1, 2) # 长度不同也能相加
Vector([2.0, 4.0, 3.0])
>>> v1 + Vector2d(1, 2) # 由于我们之前实现的Vector2d也是可迭代对象,所以也能和Vector相加
Vector([2.0, 4.0, 3.0])
>>> (1, 2, 3) + v1 # <1> 反向也能相加,见解释
Vector([2.0, 4.0, 6.0])

解释

  • __radd____rsub__这种前面带r的方法一般被称作“反向”运算方法或“右向”运算方法,如果没有实现这种方法,上述代码<1>处的语句就会抛出TypeError

  • 对于表达式a + b来说,解释器会执行如下几步:

    • 如果a__add__方法,调用a.__add__(b)
    • 如果a.__add__(b)返回NotImplemented,或者a没有__add__方法,则检查b有没有__radd__方法,如果有,则调用b.__radd__(a)
    • 如果b.__radd__(a)返回NotImplemented,或者b没有__radd__方法,则抛出TypeError,并在错误消息中指明操作数类型不支持

    其他有反向运算方法的运算符在调用时也是上面这个逻辑。

  • __radd__等反向运算的实现通常就如上述代码这么简单暴力:直接委托给正向运算。

  • 在实现__add__时,我们并没有去判断other的类型或者它的元素的类型,而是捕获TypeError异常。这是在给other调用反向运算方法的一个机会。如果调用成功,other就能被当做另一个操作数的“同类”,这也遵循了鸭子类型精神。

4. 重载乘法运算符

4.1 重载数乘运算*

这里实现的是向量的数乘运算,我们希望任何实数都能和Vector做数乘预算(也叫做元素级乘法, elementwise multiplication),添加的两个方法如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def __mul__(self, scalar):
if isinstance(scalar, numbers.Real):
return Vector(n * scalar for n in self)
else:
return NotImplemented

def __rmul__(self, scalar):
return self * scalar

# 以下是在控制台中运行的示例
>>> v1 = Vector([1,2,3])
>>> 2 * v1
Vector([2.0, 4.0, 6.0])
>>> v1 * True # bool是int的子类
Vector([1.0, 2.0, 3.0])
>>> from fractions import Fraction
>>> v1 * Fraction(1, 3)
Vector([0.3333333333333333, 0.6666666666666666, 1.0])

解释:这里并没有像__add__中那样,采用鸭子类型技术,在__mul__中捕获TyperError;而是采用更易于理解和更合理的方式,即白鹅类型,使用isinstance()函数来判断操作数是否为实数。

4.2 重载点乘运算@

从Python3.5开始,已经支持点乘运算符@,它相应的特殊方法时__matmul__(矩阵乘法”matrix multiplication”的缩写),以下是对点乘运算的重载:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def __matmul__(self, other):
try:
return sum(a * b for a, b in zip(self, other))
except TypeError:
return NotImplemented

def __rmatmul__(self, other):
return self @ other

# 下面是它的运行示例:
>>> Vector([1, 2, 3]) @ Vector([4, 5, 6])
32
>>> [1, 2, 3] @ Vector([4, 5, 6])
32

5. 比较运算符

Python对比较运算符的处理与前文类似,不过在两个方面有重大区别:

  • 正向和反向调用使用的是同一系列方法,即没有r前缀。例如,对于==来说,正向和反向调用都是__eq__方法,只是掉换个参数;正向的__gt__方法调用的则是反向的__lt__方法,并调换参数。
  • ==!=来说,如果反向调用失败,Python会比较对象的ID,而不是抛出TypeError

5.1 重载 ==

之前版本的Vector中,__eq__的实现与行为如下:

1
2
3
4
5
6
def __eq__(self, other):
return (len(self) == len(other) and all(a == b for a, b in zip(self, other)))

# 它的行为如下:
>>> Vector([1, 2, 3]) == (1, 2, 3) # 除此之外还能和Vector与Vector2d比较
True

有时候我们并不想兼容这么多类型的操作数,但当遇到某些类型时(比如上面的元组),我们也不想武断地直接抛出TypeError,而是让另一个操作数判断这俩是否相等,于是我们将上述代码改为如下形式:

1
2
3
4
5
6
7
8
9
10
11
def __eq__(self, other):
if isinstance(other, Vector):
return (len(self) == len(other) and all(a == b for a, b in zip(self, other)))
else:
return NotImplemented

# 它的行为如下:
>>> va = Vector([1, 2, 3])
>>> t3 = (1, 2, 3)
>>> va == t3
False

以下是Vector([1, 2, 3]) == (1, 2, 3)这段代码的运行过程:

  • 为计算va == t3,Python调用Vector.__eq__(va, t3)
  • 由于t3不是Vector类,所以上述调用返回NotImplemented
  • Python得到NotImplemented结果,尝试调用tuple.__eq__(t3, va)
  • 由于tuple.__eq__(t3, va)不知道Vector是什么,因此返回NotImplemented
  • ==来说,如果反向调用也返回了NotImplemented,则最后比较对象的ID,发现两者不等,返回False

5.2 重载 !=

!=不用重载!从object继承而来的__ne__已经够用了,由于原版的__ne__是用C语言写到,下面的代码是它的Python版本:

1
2
3
4
5
6
def __ne__(self, other):
eq_result = self == other
if eq_result is NotImplemented:
return NotImplemented
else:
return not eq_result

意思就是:如果__eq__返回NotImplemented,那它也返回这个值;否则,返回__eq__结果的相反值。

6. 增量赋值运算符

其实目前版本的Vector已经支持了+=*=操作,因为我们为它实现了__add____mul__操作,当运行a += b时,会被转换成a = a + b。但也正因此,大家可以看出,这不是一个就地运算,这样的+=*=会创建新的实例。如果想实现就地预算,则需要重写以i开头的特殊方法,比如+=对应的__iadd__

由于Vector被定义为不可变类型,这里我们新建一个简单的MyList类来示范+=运算符的重载。为简答起见,以两个操作数的最小长度为准:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
>>> class MyList:
... def __init__(self, iterable):
... self._list = list(iterable)
...
... def __iadd__(self, other):
... for i in range(min(len(self._list), len(other))):
... self._list[i] += other[i]
... return self
...
>>> test = MyList(range(10))
>>> id(test)
2848410583560
>>> test
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
>>> test += range(9)
>>> id(test)
2848410583560 # ID没有改变
>>> test
[0, 2, 4, 6, 8, 10, 12, 14, 16, 9] # 确实是就地运算

其实这里只为强调一点:增量赋值特殊方法最后一定要返回self!

7. 总结

本文开篇先介绍了不能重载运算符的情况,随后依次介绍了一元运算符,中缀运算符(包括加法、乘法和比较运算)和增量运算符的重载情况。

其中需要注意NotImplemented这个值,它不是异常,而是个单例值,Python在进行中缀运算时会专门检测这个值。

期间,我们还讨论了如何处理不同类型的操作数:是按照鸭子类型技术,捕获TypeError,还是根据白鹅类型,用isinstance进行类型判断。这两种方式各有利弊:鸭子类型更灵活,但白鹅类型更能预知结果。如果选用isinstance,则不要检测具体类,而应检测抽象基类,比如numbers.Real

最后给出各运算符对应的特殊方法的表格,第一个表格是中缀运算符的名称:

运算符 正向方法 反向方法 就地方法 说明
+ __add__ __radd__ __iadd__ 加法或拼接
- __sub__ __rsub__ __isub__ 减法
* __mul__ __rmul__ __imul__ 乘法或重复复制
/ __truediv__ __rtruediv__ __itruediv__ 除法
// __floordiv__ __rfloordiv__ __ifloordiv__ 整除
% __mod__ __rmod__ __imod__ 取模
divmod() __divmod__ __rdivmod__ __idivmod__ 返回由整除的商和模构成的元组
**pow() __pow__ __rpow__ __ipow__ 幂运算
@ __matmul__ __rmatmul__ __imatmul__ 矩阵乘法
& __and__ __rand__ __iand__ 位与
| __or__ __ror__ __ior__ 位或
^ __xor__ __rxor__ __ixor__ 位异或
<< __lshift__ __rlshift__ __ilshift__ 按位左移
>> __rshift__ __rrshift__ __irshift__ 按位右移

下面这个表格是比较运算符的名称:

分组 中缀运算符 正向方法调用 反向方法调用 后备机制
相等性 a == b a.__eq__(b) b.__eq__(a) 返回id(a) == id(b)
a != b a.__ne__(b) b.__ne__(a) 返回not (a == b)
排序 a > b a.__gt__(b) b.__lt__(a) 抛出TypeError
a < b a.__lt__(b) b.__gt__(a) 抛出TypeError
a >= b a.__ge__(b) b.__le__(a) 抛出TypeError
a <= b a.__le__(b) b.__ge__(a) 抛出TypeError
VPointer wechat
欢迎大家关注我的微信公众号"代码港"~~
您的慷慨将鼓励我继续创作~~