pycharmを使ってpysparkの開発を行った際に"from pyspark.sql.functions import lit"でエラーがでたのを調べて見た

pysparkの開発を行った際に"from pyspark.sql.functions import lit"でimportできないとエラーが出たのを確認した時のメモ 実際は以下のようにpyspark.sql.functions.py内で以下のようにして動的にメソッドを追加している。

def _create_function(name, doc=""):
    """ Create a function for aggregator by name"""
    def _(col):
        sc = SparkContext._active_spark_context
        jc = getattr(sc._jvm.functions, name)(col._jc if isinstance(col, Column) else col)
        return Column(jc)
    _.__name__ = name
    _.__doc__ = doc
    return _


_functions = {
    'lit': _lit_doc,
    'col': 'Returns a :class:`Column` based on the given column name.',
    'column': 'Returns a :class:`Column` based on the given column name.',
    'asc': 'Returns a sort expression based on the ascending order of the given column name.',
    'desc': 'Returns a sort expression based on the descending order of the given column name.',

    'upper': 'Converts a string expression to upper case.',
    'lower': 'Converts a string expression to upper case.',
    'sqrt': 'Computes the square root of the specified float value.',
    'abs': 'Computes the absolute value.',

    'max': 'Aggregate function: returns the maximum value of the expression in a group.',
    'min': 'Aggregate function: returns the minimum value of the expression in a group.',
    'count': 'Aggregate function: returns the number of items in a group.',
    'sum': 'Aggregate function: returns the sum of all values in the expression.',
    'avg': 'Aggregate function: returns the average of the values in a group.',
'mean': 'Aggregate function: returns the average of the values in a group.',
    'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.',
}

for _name, _doc in _functions.items():
    globals()[_name] = since(1.3)(_create_function(_name, _doc))

ここではcreate_functionでメソッドを生成し、globals()[name]にてname(col)で関数を呼び出せるようにしている。getattrでは"sc.jvm.functions"のnameで指定した関数を呼び出せるようにしており、ここでjvm場で動いているsparkを呼び出すようにしている。pysparkではpythonのコードがjvm場で動くという分けではなくpy4jにより連携するようになっていて、その連携部分が"getattr(sc.jvm.functions, name)(col.jc if isinstance(col, Column) else col)“のようでpyspark自体についてももうちょっと調べたいと思います。

pysparkでの開発時に気になった点のメモでした。