Python演算子オーバーロード

Python演算子オーバーロード

やさPyに載せろ!!

私はやさしいPythonでPythonを勉強した身で、インターンでスキルを磨いています(まだまだだが)

でも、そのやさしいPythonに書いてないことがあります。そのうちの1つが演算子オーバーロードです。

演算子オーバーロード

さて、演算子オーバーロードの説明をします。

演算子オーバーロードとは、既存の演算子に別の機能を持たせようというものです。

友人が「演算子オーバーロードの説明ならあるよ。そう、オライリーならね」と言っていたので、調べてみたらあったので使ってみた。

class MyList(object):

    def __init__(self, li):

        self.li = li


    def __add__(self, addend):

        # 「+」演算子オーバーロード
        res = MyList(
            list(map(lambda a, b:
                a + b,
                self.li, addend.li
            ))
        )
        return res


    def __str__(self):

        # 出力をオーバーロード
        return '{}'.format(self.li)


if __name__ == '__main__':
    m0 = MyList([1, 1])
    m1 = MyList([2, 3])

    print('m0      = {}'.format(m0))
    print('     m1 = {}'.format(m0))
    print('m0 + m1 = {}'.format(m0 + m1))
m0      = [1, 1]
     m1 = [1, 1]
m0 + m1 = [3, 4]

リスト同士の足し算が実装できた!

でもこれ以上のList同士の足し算は車輪の再発明になるのでやめます。NumPyがいるし、しかもあれの場合はC/C++とFORTRANで最適化されているので、Pythonでこんなものを実装したところで意味がないのでやめました。

実際に演算子オーバーロードを応用している例

ググってみると、Pythonで用いられるNumPyのモダンな書き方として、演算子オーバーロードを用いた方法が見つかりました。出力例は以下のようになります。

import numpy as np

A = np.array([
    [ 1, 2, 3],
    [ 4, 5, 6],
    [ 7, 8, 0]
])
print('*を使った場合')
print(A * A.T)    # これだとアダマール積になる
print('@を使った場合')
print(A @ A.T)    # これで数式的なAA^Tが計算できる
*を使った場合
[[ 1  8 21]
 [ 8 25 48]
 [21 48  0]]
@を使った場合
[[ 14  32  23]
 [ 32  77  68]
 [ 23  68 113]]

これでNumPy上で掛け算をするときでも、わざわざ「np.dot」を使わなくても良くなります。打つのも短くできるし、リーダブルになるので、非常にオススメです。