重载 __eq__ 以返回自定义对象
Posted
技术标签:
【中文标题】重载 __eq__ 以返回自定义对象【英文标题】:Overloading __eq__ to return custom objects 【发布时间】:2014-08-17 12:31:20 【问题描述】:我正在用 Python 编写 DSL,我想重载运算符以便能够轻松地编写我的 DSL 表达式。例如,我想写Var("a") + Var("b")
并得到Add(Var("a"), Var("b"))
的等效表示。为此,我重载了__add__
方法,它适用于这个方法。
尽管如此,我尝试重载__eq__
方法以实现类似的效果:我想编写Var("a") == Var("b")
并获得Eq(Var("a"), Var("b"))
的等效表示。通过重载__eq__
方法,返回Eq
的实例,我实现了我的目标。但是在重载__eq__
方法时,显然会干扰标准Python的行为,比如Var("b") in [Var("a")]
返回True
。
有没有办法实现我的目标,即可以写Var("a") == Var("b")
并得到Eq(Var("a"), Var("b"))
,但仍然可以写if Var("a") == Var("b"): blablabla
或将表达式放入内置容器等?
编辑
我尝试实现Eq
类的__bool__
方法,它似乎可以工作(见以下代码)。是我遗漏了什么还是可行的解决方案?
class Expr:
def __add__(self, other):
return Add(self, other)
def __eq__(self, other):
return Eq(self, other)
def __repr__(self):
return str(self)
def __add__(self, other):
return Add(self, other)
def __ne__(self, other):
return Neq(self, other)
class Var(Expr):
def __init__(self, name):
self.name = name
def __str__(self):
return "Var(" + str(self.name) + ")"
def equals(self, other):
if type(self) is type(other):
return self.name == other.name
else:
return False
def __hash__(self):
return 17 + 23 * hash(self.name)
class Add(Expr):
def __init__(self, left, right):
self.left = left
self.right = right
def __str__(self):
return "Add(" + str(self.left) + ", " + str(self.right) + ")"
def equals(self, other):
if type(self) is type(other):
return ( ( self.left.equals(other.left) and
self.right.equals(other.right) ) or
( self.left.equals(other.right) and
self.right.equals(other.left) ) )
else:
return False
def __hash__(self):
return (17 + 23 * hash("+") +
23 * 23 * hash(self.left) + 23 * 23 * hash(self.right))
class Eq(Expr):
def __init__(self, left, right):
self.left = left
self.right = right
def __str__(self):
return "Eq(" + str(self.left) + ", " + str(self.right) + ")"
def equals(self, other):
if type(self) is type(other):
return ( ( self.left.equals(other.left) and
self.right.equals(other.right) ) or
( self.left.equals(other.right) and
self.right.equals(other.left) ) )
else:
return False
def __bool__(self):
return self.left.equals(self.right)
def __hash__(self):
return (17 + 23 * hash("==") +
23 * 23 * hash(self.left) + 23 * 23 * hash(self.right))
class Neq(Expr):
def __init__(self, left, right):
self.left = left
self.right = right
def __str__(self):
return "Neq(" + str(self.left) + ", " + str(self.right) + ")"
def equals(self, other):
if type(self) is type(other):
return ( ( not self.left.equals(other.left) or
not self.right.equals(other.right) ) and
( not self.left.equals(other.right) or
not self.right.equals(other.left) ) )
else:
return False
def __bool__(self):
return not self.left.equals(self.right)
def __hash__(self):
return (17 + 23 * hash("!=") +
23 * 23 * hash(self.left) + 23 * 23 * hash(self.right))
a = Var("a")
aa = Var("a")
b = Var("b")
c = Var("c")
print("a + b", "=>", a + b) # a + b => Add(Var(a), Var(b))
print("a == b", "=>", a == b) # a == b => Eq(Var(a), Var(b))
print("a != b", "=>", a != b) # a != b => Neq(Var(a), Var(b))
print("a if a == b else b", "=>", a if a == b else b)
# a if a == b else b => Var(b)
print("a if a == aa else b", "=>", a if a == aa else b)
# a if a == aa else b => Var(a)
l = [a, a+b]
print("l", "=>", l) # l => [Var(a), Add(Var(a), Var(b))]
print("b in l", "=>", b in l) # b in l => False
print("a in l", "=>", a in l) # a in l => True
print("aa in l", "=>", aa in l) # aa in l => True
print("a+b in l", "=>", a+b in l) # a+b in l => True
print("b+a in l", "=>", b+a in l) # b+a in l => True
print("a+c in l", "=>", a+c in l) # a+c in l => False
if a == b:
print("a == b is True")
else:
print("a == b is False") # a == b is False
if a == aa:
print("a == aa is True") # a == aa is True
else:
print("a == aa is False")
if a != b:
print("a != b is True") # a != b is True
else:
print("a != b is False")
if a != aa:
print("a != aa is True")
else:
print("a != aa is False") # a != aa is False
if a == b or a == aa:
print("a == b or a == aa is True") # a == b or a == aa is True
else:
print("a == b or a == aa is False")
if a == aa and a == b:
print("a == aa and a == b is True")
else:
print("a == aa and a == b is False") # a == aa and a == b is False
if not a == aa:
print("not a == aa is True")
else:
print("not a == aa is False") # not a == aa is False
if not a == b:
print("not a == b is True") # not a == b is True
else:
print("not a == b is False")
if a == 3:
print("a == 3 is True")
else:
print("a == 3 is False") # a == 3 is False
if a != 3:
print("a != 3 is True") # a != 3 is True
else:
print("a != 3 is False")
if 3 == a:
print("3 == a is True")
else:
print("3 == a is False") # 3 == a is False
if 3 != a:
print("3 != a is True") # 3 != a is True
else:
print("3 != a is False")
if a == 'a':
print("a == 'a' is True")
else:
print("a == 'a' is False") # a == 'a' is False
if a != 'a':
print("a != 'a' is True") # a != 'a' is True
else:
print("a != 'a' is False")
if 'a' == a:
print("'a' == a is True")
else:
print("'a' == a is False") # 'a' == a is False
if 'a' != a:
print("'a' != a is True") # 'a' != a is True
else:
print("'a' != a is False")
s = a
print("s", "=>", s) # s => Var(a)
print("a in s", "=>", a in s) # a in s => True
print("b in s", "=>", b in s) # b in s => False
print("aa in s", "=>", aa in s) # aa in s => True
d = a: 1, b: 2
print("d", "=>", d) # d => Var(b): 2, Var(a): 1
print("d[a]", "=>", d[a]) # d[a] => 1
print("d[b]", "=>", d[b]) # d[b] => 2
print("c in d", "=>", c in d) # c in d => False
print("aa in d", "=>", aa in d) # aa in d => True
print("d[aa]", "=>", d[aa]) # d[aa] => 1
【问题讨论】:
请注意numpy
有同样的问题。他们解决了让==
返回一个数组,为了避免混淆,他们实现了__bool__
以便它引发异常。如果你想检查一个布尔值,你必须明确说出你想要什么(例如:(a == b).all()
来比较相等元素)。
@Bakuriu 的方法是我推荐的。类似地用于此答案:***.com/a/9504358/416467
【参考方案1】:
不,你不能。您必须选择一种行为或另一种行为。使用 .__eq__()
方法的上下文无法(可靠地)检测到。
如果两者都需要,则需要使用不同的运算符或方法来表示 DSL 行为。
【讨论】:
以上是关于重载 __eq__ 以返回自定义对象的主要内容,如果未能解决你的问题,请参考以下文章