Skip to content

schema_flat

laktory.spark.dataframe.schema_flat ¤

FUNCTION DESCRIPTION
schema_flat

Returns a flattened list of columns

Functions¤

schema_flat ¤

schema_flat(df)

Returns a flattened list of columns

PARAMETER DESCRIPTION
df

Input DataFrame

TYPE: DataFrame

RETURNS DESCRIPTION
list[str]

List of columns

Examples:

import laktory  # noqa: F401
import pyspark.sql.types as T

schema = T.StructType(
    [
        T.StructField("indexx", T.IntegerType()),
        T.StructField(
            "stock",
            T.StructType(
                [
                    T.StructField("symbol", T.StringType()),
                    T.StructField("name", T.StringType()),
                ]
            ),
        ),
        T.StructField(
            "prices",
            T.ArrayType(
                T.StructType(
                    [
                        T.StructField("open", T.IntegerType()),
                        T.StructField("close", T.IntegerType()),
                    ]
                )
            ),
        ),
    ]
)

data = [
    (
        1,
        {"symbol": "AAPL", "name": "Apple"},
        [{"open": 1, "close": 2}, {"open": 1, "close": 2}],
    ),
    (
        2,
        {"symbol": "MSFT", "name": "Microsoft"},
        [{"open": 1, "close": 2}, {"open": 1, "close": 2}],
    ),
    (
        3,
        {"symbol": "GOOGL", "name": "Google"},
        [{"open": 1, "close": 2}, {"open": 1, "close": 2}],
    ),
]

df = spark.createDataFrame(data, schema=schema)
print(df.laktory.schema_flat())
'''
[
    'indexx',
    'stock',
    'stock.symbol',
    'stock.name',
    'prices',
    'prices[*].open',
    'prices[*].close',
]
'''
Source code in laktory/spark/dataframe/schema_flat.py
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
def schema_flat(df: DataFrame) -> list[str]:
    """
    Returns a flattened list of columns

    Parameters
    ----------
    df:
        Input DataFrame

    Returns
    -------
    :
        List of columns

    Examples
    --------
    ```py
    import laktory  # noqa: F401
    import pyspark.sql.types as T

    schema = T.StructType(
        [
            T.StructField("indexx", T.IntegerType()),
            T.StructField(
                "stock",
                T.StructType(
                    [
                        T.StructField("symbol", T.StringType()),
                        T.StructField("name", T.StringType()),
                    ]
                ),
            ),
            T.StructField(
                "prices",
                T.ArrayType(
                    T.StructType(
                        [
                            T.StructField("open", T.IntegerType()),
                            T.StructField("close", T.IntegerType()),
                        ]
                    )
                ),
            ),
        ]
    )

    data = [
        (
            1,
            {"symbol": "AAPL", "name": "Apple"},
            [{"open": 1, "close": 2}, {"open": 1, "close": 2}],
        ),
        (
            2,
            {"symbol": "MSFT", "name": "Microsoft"},
            [{"open": 1, "close": 2}, {"open": 1, "close": 2}],
        ),
        (
            3,
            {"symbol": "GOOGL", "name": "Google"},
            [{"open": 1, "close": 2}, {"open": 1, "close": 2}],
        ),
    ]

    df = spark.createDataFrame(data, schema=schema)
    print(df.laktory.schema_flat())
    '''
    [
        'indexx',
        'stock',
        'stock.symbol',
        'stock.name',
        'prices',
        'prices[*].open',
        'prices[*].close',
    ]
    '''
    ```
    """

    def get_fields(json_schema):
        field_names = []
        for f in json_schema.get("fields", []):
            f_name = f["name"]
            f_type = f["type"]
            if isinstance(f_type, dict):
                if f_type["type"] == "array":
                    e_type = f_type["elementType"]
                    field_names += [f_name]
                    if isinstance(e_type, dict):
                        _field_names = get_fields(f_type["elementType"])
                        field_names += [f"{f_name}[*].{v}" for v in _field_names]
                elif f_type["type"] == "struct":
                    _field_names = get_fields(f["type"])
                    field_names += [f_name]
                    field_names += [f"{f_name}.{v}" for v in _field_names]
                elif f_type["type"] == "map":
                    field_names += [f_name]
                else:
                    raise ValueError(f_type["type"])
            else:
                field_names += [f_name]
        return field_names

    return get_fields(json.loads(df.schema.json()))