Python正确重载运算符的方法示例详解

时间:2022-11-17 23:23:06

前言

说到运算符重载相信大家都不陌生,运算符重载的作用是让用户定义的对象使用中缀运算符(如 + 和 |)或一元运算符(如 - 和 ~)。说得宽泛一些,在 Python 中,函数调用(())、属性访问(.)和元素访问 / 切片([])也是运算符。

我们为 Vector 类简略实现了几个运算符。__add__ 和 __mul__ 方法是为了展示如何使用特殊方法重载运算符,不过有些小问题被我们忽视了。此外,我们定义的Vector2d.__eq__ 方法认为 Vector(3, 4) == [3, 4] 是真的(True),这可能并不合理。下面来一起看看详细的介绍吧。

运算符重载基础

在某些圈子中,运算符重载的名声并不好。这个语言特性可能(已经)被滥用,让程序员困惑,导致缺陷和意料之外的性能瓶颈。但是,如果使用得当,API 会变得好用,代码会变得易于阅读。Python 施加了一些限制,做好了灵活性、可用性和安全性方面的平衡:

  • 不能重载内置类型的运算符
  • 不能新建运算符,只能重载现有的
  • 某些运算符不能重载——is、and、or 和 not(不过位运算符
  • &、| 和 ~ 可以)

前面的博文已经为 Vector 定义了一个中缀运算符,即 ==,这个运算符由__eq__ 方法支持。我们将改进 __eq__ 方法的实现,更好地处理不是Vector 实例的操作数。然而,在运算符重载方面,众多比较运算符(==、!=、>、<、>=、<=)是特例,因此我们首先将在 Vector 中重载四个算术运算符:一元运算符 - 和 +,以及中缀运算符 + 和 *。

一元运算符

  -(__neg__)

    一元取负算术运算符。如果 x 是 -2,那么 -x == 2。

  +(__pos__)

    一元取正算术运算符。通常,x == +x,但也有一些例外。如果好奇,请阅读“x 和 +x 何时不相等”附注栏。

  ~(__invert__)

    对整数按位取反,定义为 ~x == -(x+1)。如果 x 是 2,那么 ~x== -3。

支持一元运算符很简单,只需实现相应的特殊方法。这些特殊方法只有一个参数,self。然后,使用符合所在类的逻辑实现。不过,要遵守运算符的一个基本规则:始终返回一个新对象。也就是说,不能修改self,要创建并返回合适类型的新实例。

对 - 和 + 来说,结果可能是与 self 同属一类的实例。多数时候,+ 最好返回 self 的副本。abs(...) 的结果应该是一个标量。但是对 ~ 来说,很难说什么结果是合理的,因为可能不是处理整数的位,例如在ORM 中,SQL WHERE 子句应该返回反集。

?
1
2
3
4
5
6
7
8
def __abs__(self):
  return math.sqrt(sum(x * x for x in self))
 
 def __neg__(self):
  return Vector(-x for x in self)   #为了计算 -v,构建一个新 Vector 实例,把 self 的每个分量都取反
 
 def __pos__(self):
  return Vector(self)      #为了计算 +v,构建一个新 Vector 实例,传入 self 的各个分量

x 和 +x 何时不相等

每个人都觉得 x == +x,而且在 Python 中,几乎所有情况下都是这样。但是,我在标准库中找到两例 x != +x 的情况。

第一例与 decimal.Decimal 类有关。如果 x 是 Decimal 实例,在算术运算的上下文中创建,然后在不同的上下文中计算 +x,那么 x!= +x。例如,x 所在的上下文使用某个精度,而计算 +x 时,精度变了,例如下面的

算术运算上下文的精度变化可能导致 x 不等于 +x

?
1
2
3
4
5
6
7
8
9
10
>>> import decimal
>>> ctx = decimal.getcontext()                  #获取当前全局算术运算符的上下文引用
>>> ctx.prec = 40                          #把算术运算上下文的精度设为40
>>> one_third = decimal.Decimal('1') / decimal.Decimal('3') #使用当前精度计算1/3
>>> one_third
Decimal('0.3333333333333333333333333333333333333333')     #查看结果,小数点后的40个数字
>>> one_third == +one_third                    #one_third = +one_thied返回TRUE
True
>>> ctx.prec = 28                          #把精度降为28
>>> one_third == +one_third                    #one_third = +one_thied返回FalseFalse >>> +one_third Decimal('0.3333333333333333333333333333')   #查看+one_third,小术后的28位数字

虽然每个 +one_third 表达式都会使用 one_third 的值创建一个新 Decimal 实例,但是会使用当前算术运算上下文的精度。

x != +x 的第二例在 collections.Counter 的文档中(https://docs.python.org/3/library/collections.html#collections.Counter)。类实现了几个算术运算符,例如中缀运算符 +,作用是把两个Counter 实例的计数器加在一起。然而,从实用角度出发,Counter 相加时,负值和零值计数会从结果中剔除。而一元运算符 + 等同于加上一个空 Counter,因此它产生一个新的Counter 且仅保留大于零的计数器。

  一元运算符 + 得到一个新 Counter 实例,但是没有零值和负值计数器

?
1
2
3
4
5
6
7
8
>>> from collections import Counter
>>> ct = Counter('abracadabra')
>>> ct['r'] = -3
>>> ct['d'] = 0
>>> ct
Counter({'a': 5, 'r': -3, 'b': 2, 'c': 1, 'd': 0})
>>> +ct
Counter({'a': 5, 'b': 2, 'c': 1})

重载向量加法运算符+

两个欧几里得向量加在一起得到的是一个新向量,它的各个分量是两个向量中相应的分量之和。比如说:

?
1
2
3
4
5
6
>>> v1 = Vector([3, 4, 5])
>>> v2 = Vector([6, 7, 8])
>>> v1 + v2
Vector([9.0, 11.0, 13.0])
>>> v1 + v2 == Vector([3+6, 4+7, 5+8])
True

确定这些基本的要求之后,__add__ 方法的实现短小精悍, 如下

?
1
2
3
def __add__(self, other):
 pairs = itertools.zip_longest(self, other, fillvalue=0.0)   #生成一个元祖,a来自self,b来自other,如果两个长度不够,通过fillvalue设置的补全值自动补全短的
 return Vector(a + b for a, b in pairs)        #使用生成器表达式计算pairs中的各个元素的和

还可以把Vector 加到元组或任何生成数字的可迭代对象上:

?
1
2
3
4
5
6
7
8
# 在Vector类中定义
 
 def __add__(self, other):
  pairs = itertools.zip_longest(self, other, fillvalue=0.0)   #生成一个元祖,a来自self,b来自other,如果两个长度不够,通过fillvalue设置的补全值自动补全短的
  return Vector(a + b for a, b in pairs)        #使用生成器表达式计算pairs中的各个元素的和
 
 def __radd__(self, other):            #会直接委托给__add__
  return self + other

__radd__ 通常就这么简单:直接调用适当的运算符,在这里就是委托__add__。任何可交换的运算符都能这么做。处理数字和向量时,+ 可以交换,但是拼接序列时不行。

重载标量乘法运算符*

Vector([1, 2, 3]) * x 是什么意思?如果 x 是数字,就是计算标量积(scalar product),结果是一个新 Vector 实例,各个分量都会乘以x——这也叫元素级乘法(elementwise multiplication)。

?
1
2
3
4
5
>>> v1 = Vector([1, 2, 3])
>>> v1 * 10
Vector([10.0, 20.0, 30.0])
>>> 11 * v1
Vector([11.0, 22.0, 33.0])

涉及 Vector 操作数的积还有一种,叫两个向量的点积(dotproduct);如果把一个向量看作 1×N 矩阵,把另一个向量看作 N×1 矩阵,那么就是矩阵乘法。NumPy 等库目前的做法是,不重载这两种意义的 *,只用 * 计算标量积。例如,在 NumPy 中,点积使用numpy.dot() 函数计算。

回到标量积的话题。我们依然先实现最简可用的 __mul__ 和 __rmul__方法:

?
1
2
3
4
5
6
7
8
def __mul__(self, scalar):
  if isinstance(scalar, numbers.Real):
   return Vector(n * scalar for n in self)
  else:
   return NotImplemented
 
 def __rmul__(self, scalar):
  return self * scalar

这两个方法确实可用,但是提供不兼容的操作数时会出问题。scalar参数的值要是数字,与浮点数相乘得到的积是另一个浮点数(因为Vector 类在内部使用浮点数数组)。因此,不能使用复数,但可以是int、bool(int 的子类),甚至 fractions.Fraction 实例等标量。

提供了点积所需的 @ 记号(例如,a @ b 是 a 和 b 的点积)。@ 运算符由特殊方法 __matmul__、__rmatmul__ 和__imatmul__ 提供支持,名称取自“matrix multiplication”(矩阵乘法)

?
1
2
3
4
5
6
7
8
9
10
>>> va = Vector([1, 2, 3])
>>> vz = Vector([5, 6, 7])
>>> va @ vz == 38.0 # 1*5 + 2*6 + 3*7
True
>>> [10, 20, 30] @ vz
380.0
>>> va @ 3
Traceback (most recent call last):
...
TypeError: unsupported operand type(s) for @: 'Vector' and 'int'

下面是相应特殊方法的代码:

?
1
2
3
4
5
6
7
8
9
10
>>> va = Vector([1, 2, 3])
>>> vz = Vector([5, 6, 7])
>>> va @ vz == 38.0 # 1*5 + 2*6 + 3*7
True
>>> [10, 20, 30] @ vz
380.0
>>> va @ 3
Traceback (most recent call last):
...
TypeError: unsupported operand type(s) for @: 'Vector' and 'int'

众多比较运算符

Python 解释器对众多比较运算符(==、!=、>、<、>=、<=)的处理与前文类似,不过在两个方面有重大区别。

  • 正向和反向调用使用的是同一系列方法。例如,对 == 来说,正向和反向调用都是 __eq__ 方法,只是把参数对调了;而正向的 __gt__ 方法调用的是反向的 __lt__方法,并把参数对调。
  • 对 == 和 != 来说,如果反向调用失败,Python 会比较对象的 ID,而不抛出 TypeError。

众多比较运算符:正向方法返回NotImplemented的话,调用反向方法

 


分组

 

中缀运算符

 

正向方法调用

 

反向方法调用

 

后备机制

 

相等性

 

a == b

 

a.__eq__(b)

 

b.__eq__(a)

 

返回 id(a) == id(b)

 

 

a != b

 

a.__ne__(b)

 

b.__ne__(a)

 

返回 not (a == b)

 

排序

 

a > b

 

a.__gt__(b)

 

b.__lt__(a)

 

抛出 TypeError

 

 

a < b

 

a.__lt__(b)

 

b.__gt__(a)

 

抛出 TypeError

 

 

a >= b

 

a.__ge__(b)

 

b.__le__(a)

 

抛出 TypeError

 

 

a <= b

 

a.__le__(b)

 

b.__ge__(a)

 

抛出T ypeError

 

看下面的

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from array import array
import reprlib
import math
import numbers
import functools
import operator
import itertools
 
 
class Vector:
 typecode = 'd'
 
 def __init__(self, components):
  self._components = array(self.typecode, components)
 
 def __iter__(self):
  return iter(self._components)
 
 def __repr__(self):
  components = reprlib.repr(self._components)
  components = components[components.find('['):-1]
  return 'Vector({})'.format(components)
 
 def __str__(self):
  return str(tuple(self))
 
 def __bytes__(self):
  return (bytes([ord(self.typecode)]) + bytes(self._components))
 
 def __eq__(self, other):
  return (len(self) == len(other) and all(a == b for a, b in zip(self, other)))
 
 def __hash__(self):
  hashes = map(hash, self._components)
  return functools.reduce(operator.xor, hashes, 0)
 
 def __add__(self, other):
  pairs = itertools.zip_longest(self, other, fillvalue=0.0)   #生成一个元祖,a来自self,b来自other,如果两个长度不够,通过fillvalue设置的补全值自动补全短的
  return Vector(a + b for a, b in pairs)        #使用生成器表达式计算pairs中的各个元素的和
 
 def __radd__(self, other):            #会直接委托给__add__
  return self + other
 
 def __mul__(self, scalar):
  if isinstance(scalar, numbers.Real):
   return Vector(n * scalar for n in self)
  else:
   return NotImplemented
 
 def __rmul__(self, scalar):
  return self * scalar
 
 def __matmul__(self, other):
  try:
   return sum(a * b for a, b in zip(self, other))
  except TypeError:
   return NotImplemented
 
 def __rmatmul__(self, other):
  return self @ other
 
 def __abs__(self):
  return math.sqrt(sum(x * x for x in self))
 
 def __neg__(self):
  return Vector(-x for x in self)   #为了计算 -v,构建一个新 Vector 实例,把 self 的每个分量都取反
 
 def __pos__(self):
  return Vector(self)       #为了计算 +v,构建一个新 Vector 实例,传入 self 的各个分量
 
 def __bool__(self):
  return bool(abs(self))
 
 def __len__(self):
  return len(self._components)
 
 def __getitem__(self, index):
  cls = type(self)
 
  if isinstance(index, slice):
   return cls(self._components[index])
  elif isinstance(index, numbers.Integral):
   return self._components[index]
  else:
   msg = '{.__name__} indices must be integers'
   raise TypeError(msg.format(cls))
 
 shorcut_names = 'xyzt'
 
 def __getattr__(self, name):
  cls = type(self)
 
  if len(name) == 1:
   pos = cls.shorcut_names.find(name)
   if 0 <= pos < len(self._components):
    return self._components[pos]
  msg = '{.__name__!r} object has no attribute {!r}'
  raise AttributeError(msg.format(cls, name))
 
 def angle(self, n):
  r = math.sqrt(sum(x * x for x in self[n:]))
  a = math.atan2(r, self[n-1])
  if (n == len(self) - 1 ) and (self[-1] < 0):
   return math.pi * 2 - a
  else:
   return a
 
 def angles(self):
  return (self.angle(n) for n in range(1, len(self)))
 
 def __format__(self, fmt_spec=''):
  if fmt_spec.endswith('h'):
   fmt_spec = fmt_spec[:-1]
   coords = itertools.chain([abs(self)], self.angles())
   outer_fmt = '<{}>'
  else:
   coords = self
   outer_fmt = '({})'
  components = (format(c, fmt_spec) for c in coords)
  return outer_fmt.format(', '.join(components))
 
 @classmethod
 def frombytes(cls, octets):
  typecode = chr(octets[0])
  memv = memoryview(octets[1:]).cast(typecode)
  return cls(memv)
 
va = Vector([1.0, 2.0, 3.0])
vb = Vector(range(1, 4))
print('va == vb:', va == vb)     #两个具有相同数值分量的 Vector 实例是相等的
t3 = (1, 2, 3)
print('va == t3:', va == t3)
 
print('[1, 2] == (1, 2):', [1, 2] == (1, 2))

上面代码执行返回的结果为:

?
1
2
3
va == vb: True
va == t3: True
[1, 2] == (1, 2): False

从 Python 自身来找线索,我们发现 [1,2] == (1, 2) 的结果是False。因此,我们要保守一点,做些类型检查。如果第二个操作数是Vector 实例(或者 Vector 子类的实例),那么就使用 __eq__ 方法的当前逻辑。否则,返回 NotImplemented,让 Python 处理。

vector_v8.py:改进 Vector 类的 __eq__ 方法

?
1
2
3
4
5
def __eq__(self, other):
 if isinstance(other, Vector):          #判断对比的是否和Vector同属一个实例
  return (len(self) == len(other) and all(a == b for a, b in zip(self, other)))
 else:
  return NotImplemented           #否则,返回NotImplemented

改进以后的代码执行结果:

?
1
2
3
4
5
6
7
>>> va = Vector([1.0, 2.0, 3.0])
>>> vb = Vector(range(1, 4))
>>> va == vb
True
>>> t3 = (1, 2, 3)
>>> va == t3
False

增量赋值运算符

  Vector 类已经支持增量赋值运算符 += 和 *= 了,示例如下

  增量赋值不会修改不可变目标,而是新建实例,然后重新绑定

?
1
2
3
4
5
6
7
8
9
10
11
12
13
>>> v1 = Vector([1, 2, 3])
>>> v1_alias = v1             # 复制一份,供后面审查Vector([1, 2, 3])对象
>>> id(v1)                 # 记住一开始绑定给v1的Vector实例的ID
>>> v1 += Vector([4, 5, 6])       # 增量加法运算
>>> v1                    # 结果与预期相符
Vector([5.0, 7.0, 9.0])
>>> id(v1)                 # 但是创建了新的Vector实例
>>> v1_alias                # 审查v1_alias,确认原来的Vector实例没被修改
Vector([1.0, 2.0, 3.0])
>>> v1 *= 11                # 增量乘法运算
>>> v1                   # 同样,结果与预期相符,但是创建了新的Vector实例
Vector([55.0, 77.0, 99.0])
>>> id(v1)

完整代码:

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from array import array
import reprlib
import math
import numbers
import functools
import operator
import itertools
 
 
class Vector:
 typecode = 'd'
 
 def __init__(self, components):
  self._components = array(self.typecode, components)
 
 def __iter__(self):
  return iter(self._components)
 
 def __repr__(self):
  components = reprlib.repr(self._components)
  components = components[components.find('['):-1]
  return 'Vector({})'.format(components)
 
 def __str__(self):
  return str(tuple(self))
 
 def __bytes__(self):
  return (bytes([ord(self.typecode)]) + bytes(self._components))
 
 def __eq__(self, other):
  if isinstance(other, Vector):         
   return (len(self) == len(other) and all(a == b for a, b in zip(self, other)))
  else:
   return NotImplemented         
 
 def __hash__(self):
  hashes = map(hash, self._components)
  return functools.reduce(operator.xor, hashes, 0)
 
 def __add__(self, other):
  pairs = itertools.zip_longest(self, other, fillvalue=0.0)  
  return Vector(a + b for a, b in pairs)       
 
 def __radd__(self, other):           
  return self + other
 
 def __mul__(self, scalar):
  if isinstance(scalar, numbers.Real):
   return Vector(n * scalar for n in self)
  else:
   return NotImplemented
 
 def __rmul__(self, scalar):
  return self * scalar
 
 def __matmul__(self, other):
  try:
   return sum(a * b for a, b in zip(self, other))
  except TypeError:
   return NotImplemented
 
 def __rmatmul__(self, other):
  return self @ other
 
 def __abs__(self):
  return math.sqrt(sum(x * x for x in self))
 
 def __neg__(self):
  return Vector(-x for x in self)  
 
 def __pos__(self):
  return Vector(self)      
 
 def __bool__(self):
  return bool(abs(self))
 
 def __len__(self):
  return len(self._components)
 
 def __getitem__(self, index):
  cls = type(self)
 
  if isinstance(index, slice):
   return cls(self._components[index])
  elif isinstance(index, numbers.Integral):
   return self._components[index]
  else:
   msg = '{.__name__} indices must be integers'
   raise TypeError(msg.format(cls))
 
 shorcut_names = 'xyzt'
 
 def __getattr__(self, name):
  cls = type(self)
 
  if len(name) == 1:
   pos = cls.shorcut_names.find(name)
   if 0 <= pos < len(self._components):
    return self._components[pos]
  msg = '{.__name__!r} object has no attribute {!r}'
  raise AttributeError(msg.format(cls, name))
 
 def angle(self, n):
  r = math.sqrt(sum(x * x for x in self[n:]))
  a = math.atan2(r, self[n-1])
  if (n == len(self) - 1 ) and (self[-1] < 0):
   return math.pi * 2 - a
  else:
   return a
 
 def angles(self):
  return (self.angle(n) for n in range(1, len(self)))
 
 def __format__(self, fmt_spec=''):
  if fmt_spec.endswith('h'):
   fmt_spec = fmt_spec[:-1]
   coords = itertools.chain([abs(self)], self.angles())
   outer_fmt = '<{}>'
  else:
   coords = self
   outer_fmt = '({})'
  components = (format(c, fmt_spec) for c in coords)
  return outer_fmt.format(', '.join(components))
 
 @classmethod
 def frombytes(cls, octets):
  typecode = chr(octets[0])
  memv = memoryview(octets[1:]).cast(typecode)
  return cls(memv)

总结

以上就是这篇文章的全部内容了,希望本文的内容对大家的学习或者工作能带来一定的帮助,如果有疑问大家可以留言交流,谢谢大家对服务器之家的支持。

原文链接:http://www.cnblogs.com/demon89/p/7422454.html