Source code for pyspark.sql.dataframe

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import sys
import warnings
import random

if sys.version >= '3':
    basestring = unicode = str
    long = int
    from functools import reduce
else:
    from itertools import imap as map

from pyspark import since
from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer
from pyspark.storagelevel import StorageLevel
from pyspark.traceback_utils import SCCallSiteSync
from pyspark.sql.types import _parse_datatype_json_string
from pyspark.sql.column import Column, _to_seq, _to_list, _to_java_column
from pyspark.sql.readwriter import DataFrameWriter
from pyspark.sql.types import *

__all__ = ["DataFrame", "SchemaRDD", "DataFrameNaFunctions", "DataFrameStatFunctions"]


[docs]class DataFrame(object): """A distributed collection of data grouped into named columns. A :class:`DataFrame` is equivalent to a relational table in Spark SQL, and can be created using various functions in :class:`SQLContext`:: people = sqlContext.read.parquet("...") Once created, it can be manipulated using the various domain-specific-language (DSL) functions defined in: :class:`DataFrame`, :class:`Column`. To select a column from the data frame, use the apply method:: ageCol = people.age A more concrete example:: # To create DataFrame using SQLContext people = sqlContext.read.parquet("...") department = sqlContext.read.parquet("...") people.filter(people.age > 30).join(department, people.deptId == department.id)) \ .groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"}) .. note:: Experimental .. versionadded:: 1.3 """ def __init__(self, jdf, sql_ctx): self._jdf = jdf self.sql_ctx = sql_ctx self._sc = sql_ctx and sql_ctx._sc self.is_cached = False self._schema = None # initialized lazily self._lazy_rdd = None @property @since(1.3) def rdd(self): """Returns the content as an :class:`pyspark.RDD` of :class:`Row`. """ if self._lazy_rdd is None: jrdd = self._jdf.javaToPython() self._lazy_rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer())) return self._lazy_rdd @property @since("1.3.1") def na(self): """Returns a :class:`DataFrameNaFunctions` for handling missing values. """ return DataFrameNaFunctions(self) @property @since(1.4) def stat(self): """Returns a :class:`DataFrameStatFunctions` for statistic functions. """ return DataFrameStatFunctions(self) @ignore_unicode_prefix @since(1.3)
[docs] def toJSON(self, use_unicode=True): """Converts a :class:`DataFrame` into a :class:`RDD` of string. Each row is turned into a JSON document as one element in the returned RDD. >>> df.toJSON().first() u'{"age":2,"name":"Alice"}' """ rdd = self._jdf.toJSON() return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode))
[docs] def saveAsParquetFile(self, path): """Saves the contents as a Parquet file, preserving the schema. .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.parquet` instead. """ warnings.warn("saveAsParquetFile is deprecated. Use write.parquet() instead.") self._jdf.saveAsParquetFile(path)
@since(1.3)
[docs] def registerTempTable(self, name): """Registers this RDD as a temporary table using the given name. The lifetime of this temporary table is tied to the :class:`SQLContext` that was used to create this :class:`DataFrame`. >>> df.registerTempTable("people") >>> df2 = sqlContext.sql("select * from people") >>> sorted(df.collect()) == sorted(df2.collect()) True """ self._jdf.registerTempTable(name)
[docs] def registerAsTable(self, name): """ .. note:: Deprecated in 1.4, use :func:`registerTempTable` instead. """ warnings.warn("Use registerTempTable instead of registerAsTable.") self.registerTempTable(name)
[docs] def insertInto(self, tableName, overwrite=False): """Inserts the contents of this :class:`DataFrame` into the specified table. .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.insertInto` instead. """ warnings.warn("insertInto is deprecated. Use write.insertInto() instead.") self.write.insertInto(tableName, overwrite)
[docs] def saveAsTable(self, tableName, source=None, mode="error", **options): """Saves the contents of this :class:`DataFrame` to a data source as a table. .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.saveAsTable` instead. """ warnings.warn("insertInto is deprecated. Use write.saveAsTable() instead.") self.write.saveAsTable(tableName, source, mode, **options)
@since(1.3)
[docs] def save(self, path=None, source=None, mode="error", **options): """Saves the contents of the :class:`DataFrame` to a data source. .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.save` instead. """ warnings.warn("insertInto is deprecated. Use write.save() instead.") return self.write.save(path, source, mode, **options)
@property @since(1.4) def write(self): """ Interface for saving the content of the :class:`DataFrame` out into external storage. :return: :class:`DataFrameWriter` """ return DataFrameWriter(self) @property @since(1.3) def schema(self): """Returns the schema of this :class:`DataFrame` as a :class:`types.StructType`. >>> df.schema StructType(List(StructField(age,IntegerType,true),StructField(name,StringType,true))) """ if self._schema is None: try: self._schema = _parse_datatype_json_string(self._jdf.schema().json()) except AttributeError as e: raise Exception( "Unable to parse datatype from schema. %s" % e) return self._schema @since(1.3)
[docs] def printSchema(self): """Prints out the schema in the tree format. >>> df.printSchema() root |-- age: integer (nullable = true) |-- name: string (nullable = true) <BLANKLINE> """ print(self._jdf.schema().treeString())
@since(1.3)
[docs] def explain(self, extended=False): """Prints the (logical and physical) plans to the console for debugging purpose. :param extended: boolean, default ``False``. If ``False``, prints only the physical plan. >>> df.explain() == Physical Plan == Scan ExistingRDD[age#0,name#1] >>> df.explain(True) == Parsed Logical Plan == ... == Analyzed Logical Plan == ... == Optimized Logical Plan == ... == Physical Plan == ... """ if extended: print(self._jdf.queryExecution().toString()) else: print(self._jdf.queryExecution().simpleString())
@since(1.3)
[docs] def isLocal(self): """Returns ``True`` if the :func:`collect` and :func:`take` methods can be run locally (without any Spark executors). """ return self._jdf.isLocal()
@since(1.3)
[docs] def show(self, n=20, truncate=True): """Prints the first ``n`` rows to the console. :param n: Number of rows to show. :param truncate: Whether truncate long strings and align cells right. >>> df DataFrame[age: int, name: string] >>> df.show() +---+-----+ |age| name| +---+-----+ | 2|Alice| | 5| Bob| +---+-----+ """ print(self._jdf.showString(n, truncate))
def __repr__(self): return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) @since(1.3)
[docs] def count(self): """Returns the number of rows in this :class:`DataFrame`. >>> df.count() 2 """ return int(self._jdf.count())
@ignore_unicode_prefix @since(1.3)
[docs] def collect(self): """Returns all the records as a list of :class:`Row`. >>> df.collect() [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ with SCCallSiteSync(self._sc) as css: port = self._jdf.collectToPython() return list(_load_from_socket(port, BatchedSerializer(PickleSerializer())))
@ignore_unicode_prefix @since(1.3)
[docs] def limit(self, num): """Limits the result count to the number specified. >>> df.limit(1).collect() [Row(age=2, name=u'Alice')] >>> df.limit(0).collect() [] """ jdf = self._jdf.limit(num) return DataFrame(jdf, self.sql_ctx)
@ignore_unicode_prefix @since(1.3)
[docs] def take(self, num): """Returns the first ``num`` rows as a :class:`list` of :class:`Row`. >>> df.take(2) [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ with SCCallSiteSync(self._sc) as css: port = self._sc._jvm.org.apache.spark.sql.execution.EvaluatePython.takeAndServe( self._jdf, num) return list(_load_from_socket(port, BatchedSerializer(PickleSerializer())))
@ignore_unicode_prefix @since(1.3)
[docs] def map(self, f): """ Returns a new :class:`RDD` by applying a the ``f`` function to each :class:`Row`. This is a shorthand for ``df.rdd.map()``. >>> df.map(lambda p: p.name).collect() [u'Alice', u'Bob'] """ return self.rdd.map(f)
@ignore_unicode_prefix @since(1.3)
[docs] def flatMap(self, f): """ Returns a new :class:`RDD` by first applying the ``f`` function to each :class:`Row`, and then flattening the results. This is a shorthand for ``df.rdd.flatMap()``. >>> df.flatMap(lambda p: p.name).collect() [u'A', u'l', u'i', u'c', u'e', u'B', u'o', u'b'] """ return self.rdd.flatMap(f)
@since(1.3)
[docs] def mapPartitions(self, f, preservesPartitioning=False): """Returns a new :class:`RDD` by applying the ``f`` function to each partition. This is a shorthand for ``df.rdd.mapPartitions()``. >>> rdd = sc.parallelize([1, 2, 3, 4], 4) >>> def f(iterator): yield 1 >>> rdd.mapPartitions(f).sum() 4 """ return self.rdd.mapPartitions(f, preservesPartitioning)
@since(1.3)
[docs] def foreach(self, f): """Applies the ``f`` function to all :class:`Row` of this :class:`DataFrame`. This is a shorthand for ``df.rdd.foreach()``. >>> def f(person): ... print(person.name) >>> df.foreach(f) """ return self.rdd.foreach(f)
@since(1.3)
[docs] def foreachPartition(self, f): """Applies the ``f`` function to each partition of this :class:`DataFrame`. This a shorthand for ``df.rdd.foreachPartition()``. >>> def f(people): ... for person in people: ... print(person.name) >>> df.foreachPartition(f) """ return self.rdd.foreachPartition(f)
@since(1.3)
[docs] def cache(self): """ Persists with the default storage level (C{MEMORY_ONLY_SER}). """ self.is_cached = True self._jdf.cache() return self
@since(1.3)
[docs] def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER): """Sets the storage level to persist its values across operations after the first time it is computed. This can only be used to assign a new storage level if the RDD does not have a storage level set yet. If no storage level is specified defaults to (C{MEMORY_ONLY_SER}). """ self.is_cached = True javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel) self._jdf.persist(javaStorageLevel) return self
@since(1.3)
[docs] def unpersist(self, blocking=True): """Marks the :class:`DataFrame` as non-persistent, and remove all blocks for it from memory and disk. """ self.is_cached = False self._jdf.unpersist(blocking) return self
@since(1.4)
[docs] def coalesce(self, numPartitions): """ Returns a new :class:`DataFrame` that has exactly `numPartitions` partitions. Similar to coalesce defined on an :class:`RDD`, this operation results in a narrow dependency, e.g. if you go from 1000 partitions to 100 partitions, there will not be a shuffle, instead each of the 100 new partitions will claim 10 of the current partitions. >>> df.coalesce(1).rdd.getNumPartitions() 1 """ return DataFrame(self._jdf.coalesce(numPartitions), self.sql_ctx)
@since(1.3) def repartition(self, numPartitions): """Returns a new :class:`DataFrame` that has exactly ``numPartitions`` partitions. >>> df.repartition(10).rdd.getNumPartitions() 10 """ return DataFrame(self._jdf.repartition(numPartitions), self.sql_ctx) @since(1.3)
[docs] def repartition(self, numPartitions, *cols): """ Returns a new :class:`DataFrame` partitioned by the given partitioning expressions. The resulting DataFrame is hash partitioned. ``numPartitions`` can be an int to specify the target number of partitions or a Column. If it is a Column, it will be used as the first partitioning column. If not specified, the default number of partitions is used. .. versionchanged:: 1.6 Added optional arguments to specify the partitioning columns. Also made numPartitions optional if partitioning columns are specified. >>> df.repartition(10).rdd.getNumPartitions() 10 >>> data = df.unionAll(df).repartition("age") >>> data.show() +---+-----+ |age| name| +---+-----+ | 2|Alice| | 2|Alice| | 5| Bob| | 5| Bob| +---+-----+ >>> data = data.repartition(7, "age") >>> data.show() +---+-----+ |age| name| +---+-----+ | 5| Bob| | 5| Bob| | 2|Alice| | 2|Alice| +---+-----+ >>> data.rdd.getNumPartitions() 7 >>> data = data.repartition("name", "age") >>> data.show() +---+-----+ |age| name| +---+-----+ | 5| Bob| | 5| Bob| | 2|Alice| | 2|Alice| +---+-----+ """ if isinstance(numPartitions, int): if len(cols) == 0: return DataFrame(self._jdf.repartition(numPartitions), self.sql_ctx) else: return DataFrame( self._jdf.repartition(numPartitions, self._jcols(*cols)), self.sql_ctx) elif isinstance(numPartitions, (basestring, Column)): cols = (numPartitions, ) + cols return DataFrame(self._jdf.repartition(self._jcols(*cols)), self.sql_ctx) else: raise TypeError("numPartitions should be an int or Column")
@since(1.3)
[docs] def distinct(self): """Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`. >>> df.distinct().count() 2 """ return DataFrame(self._jdf.distinct(), self.sql_ctx)
@since(1.3)
[docs] def sample(self, withReplacement, fraction, seed=None): """Returns a sampled subset of this :class:`DataFrame`. >>> df.sample(False, 0.5, 42).count() 2 """ assert fraction >= 0.0, "Negative fraction value: %s" % fraction seed = seed if seed is not None else random.randint(0, sys.maxsize) rdd = self._jdf.sample(withReplacement, fraction, long(seed)) return DataFrame(rdd, self.sql_ctx)
@since(1.5)
[docs] def sampleBy(self, col, fractions, seed=None): """ Returns a stratified sample without replacement based on the fraction given on each stratum. :param col: column that defines strata :param fractions: sampling fraction for each stratum. If a stratum is not specified, we treat its fraction as zero. :param seed: random seed :return: a new DataFrame that represents the stratified sample >>> from pyspark.sql.functions import col >>> dataset = sqlContext.range(0, 100).select((col("id") % 3).alias("key")) >>> sampled = dataset.sampleBy("key", fractions={0: 0.1, 1: 0.2}, seed=0) >>> sampled.groupBy("key").count().orderBy("key").show() +---+-----+ |key|count| +---+-----+ | 0| 5| | 1| 9| +---+-----+ """ if not isinstance(col, str): raise ValueError("col must be a string, but got %r" % type(col)) if not isinstance(fractions, dict): raise ValueError("fractions must be a dict but got %r" % type(fractions)) for k, v in fractions.items(): if not isinstance(k, (float, int, long, basestring)): raise ValueError("key must be float, int, long, or string, but got %r" % type(k)) fractions[k] = float(v) seed = seed if seed is not None else random.randint(0, sys.maxsize) return DataFrame(self._jdf.stat().sampleBy(col, self._jmap(fractions), seed), self.sql_ctx)
@since(1.4)
[docs] def randomSplit(self, weights, seed=None): """Randomly splits this :class:`DataFrame` with the provided weights. :param weights: list of doubles as weights with which to split the DataFrame. Weights will be normalized if they don't sum up to 1.0. :param seed: The seed for sampling. >>> splits = df4.randomSplit([1.0, 2.0], 24) >>> splits[0].count() 1 >>> splits[1].count() 3 """ for w in weights: if w < 0.0: raise ValueError("Weights must be positive. Found weight value: %s" % w) seed = seed if seed is not None else random.randint(0, sys.maxsize) rdd_array = self._jdf.randomSplit(_to_list(self.sql_ctx._sc, weights), long(seed)) return [DataFrame(rdd, self.sql_ctx) for rdd in rdd_array]
@property @since(1.3) def dtypes(self): """Returns all column names and their data types as a list. >>> df.dtypes [('age', 'int'), ('name', 'string')] """ return [(str(f.name), f.dataType.simpleString()) for f in self.schema.fields] @property @since(1.3) def columns(self): """Returns all column names as a list. >>> df.columns ['age', 'name'] """ return [f.name for f in self.schema.fields] @ignore_unicode_prefix @since(1.3)
[docs] def alias(self, alias): """Returns a new :class:`DataFrame` with an alias set. >>> from pyspark.sql.functions import * >>> df_as1 = df.alias("df_as1") >>> df_as2 = df.alias("df_as2") >>> joined_df = df_as1.join(df_as2, col("df_as1.name") == col("df_as2.name"), 'inner') >>> joined_df.select(col("df_as1.name"), col("df_as2.name"), col("df_as2.age")).collect() [Row(name=u'Alice', name=u'Alice', age=2), Row(name=u'Bob', name=u'Bob', age=5)] """ assert isinstance(alias, basestring), "alias should be a string" return DataFrame(getattr(self._jdf, "as")(alias), self.sql_ctx)
@ignore_unicode_prefix @since(1.3)
[docs] def join(self, other, on=None, how=None): """Joins with another :class:`DataFrame`, using the given join expression. The following performs a full outer join between ``df1`` and ``df2``. :param other: Right side of the join :param on: a string for join column name, a list of column names, , a join expression (Column) or a list of Columns. If `on` is a string or a list of string indicating the name of the join column(s), the column(s) must exist on both sides, and this performs an inner equi-join. :param how: str, default 'inner'. One of `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`. >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect() [Row(name=None, height=80), Row(name=u'Alice', height=None), Row(name=u'Bob', height=85)] >>> cond = [df.name == df3.name, df.age == df3.age] >>> df.join(df3, cond, 'outer').select(df.name, df3.age).collect() [Row(name=u'Bob', age=5), Row(name=u'Alice', age=2)] >>> df.join(df2, 'name').select(df.name, df2.height).collect() [Row(name=u'Bob', height=85)] >>> df.join(df4, ['name', 'age']).select(df.name, df.age).collect() [Row(name=u'Bob', age=5)] """ if on is not None and not isinstance(on, list): on = [on] if on is None or len(on) == 0: jdf = self._jdf.join(other._jdf) elif isinstance(on[0], basestring): if how is None: jdf = self._jdf.join(other._jdf, self._jseq(on), "inner") else: assert isinstance(how, basestring), "how should be basestring" jdf = self._jdf.join(other._jdf, self._jseq(on), how) else: assert isinstance(on[0], Column), "on should be Column or list of Column" if len(on) > 1: on = reduce(lambda x, y: x.__and__(y), on) else: on = on[0] if how is None: jdf = self._jdf.join(other._jdf, on._jc, "inner") else: assert isinstance(how, basestring), "how should be basestring" jdf = self._jdf.join(other._jdf, on._jc, how) return DataFrame(jdf, self.sql_ctx)
@since(1.6)
[docs] def sortWithinPartitions(self, *cols, **kwargs): """Returns a new :class:`DataFrame` with each partition sorted by the specified column(s). :param cols: list of :class:`Column` or column names to sort by. :param ascending: boolean or list of boolean (default True). Sort ascending vs. descending. Specify list for multiple sort orders. If a list is specified, length of the list must equal length of the `cols`. >>> df.sortWithinPartitions("age", ascending=False).show() +---+-----+ |age| name| +---+-----+ | 2|Alice| | 5| Bob| +---+-----+ """ jdf = self._jdf.sortWithinPartitions(self._sort_cols(cols, kwargs)) return DataFrame(jdf, self.sql_ctx)
@ignore_unicode_prefix @since(1.3)
[docs] def sort(self, *cols, **kwargs): """Returns a new :class:`DataFrame` sorted by the specified column(s). :param cols: list of :class:`Column` or column names to sort by. :param ascending: boolean or list of boolean (default True). Sort ascending vs. descending. Specify list for multiple sort orders. If a list is specified, length of the list must equal length of the `cols`. >>> df.sort(df.age.desc()).collect() [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] >>> df.sort("age", ascending=False).collect() [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] >>> df.orderBy(df.age.desc()).collect() [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] >>> from pyspark.sql.functions import * >>> df.sort(asc("age")).collect() [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] >>> df.orderBy(desc("age"), "name").collect() [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] >>> df.orderBy(["age", "name"], ascending=[0, 1]).collect() [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] """ jdf = self._jdf.sort(self._sort_cols(cols, kwargs)) return DataFrame(jdf, self.sql_ctx)
orderBy = sort def _jseq(self, cols, converter=None): """Return a JVM Seq of Columns from a list of Column or names""" return _to_seq(self.sql_ctx._sc, cols, converter) def _jmap(self, jm): """Return a JVM Scala Map from a dict""" return _to_scala_map(self.sql_ctx._sc, jm) def _jcols(self, *cols): """Return a JVM Seq of Columns from a list of Column or column names If `cols` has only one list in it, cols[0] will be used as the list. """ if len(cols) == 1 and isinstance(cols[0], list): cols = cols[0] return self._jseq(cols, _to_java_column) def _sort_cols(self, cols, kwargs): """ Return a JVM Seq of Columns that describes the sort order """ if not cols: raise ValueError("should sort by at least one column") if len(cols) == 1 and isinstance(cols[0], list): cols = cols[0] jcols = [_to_java_column(c) for c in cols] ascending = kwargs.get('ascending', True) if isinstance(ascending, (bool, int)): if not ascending: jcols = [jc.desc() for jc in jcols] elif isinstance(ascending, list): jcols = [jc if asc else jc.desc() for asc, jc in zip(ascending, jcols)] else: raise TypeError("ascending can only be boolean or list, but got %s" % type(ascending)) return self._jseq(jcols) @since("1.3.1")
[docs] def describe(self, *cols): """Computes statistics for numeric columns. This include count, mean, stddev, min, and max. If no columns are given, this function computes statistics for all numerical columns. .. note:: This function is meant for exploratory data analysis, as we make no \ guarantee about the backward compatibility of the schema of the resulting DataFrame. >>> df.describe().show() +-------+------------------+ |summary| age| +-------+------------------+ | count| 2| | mean| 3.5| | stddev|2.1213203435596424| | min| 2| | max| 5| +-------+------------------+ >>> df.describe(['age', 'name']).show() +-------+------------------+-----+ |summary| age| name| +-------+------------------+-----+ | count| 2| 2| | mean| 3.5| null| | stddev|2.1213203435596424| null| | min| 2|Alice| | max| 5| Bob| +-------+------------------+-----+ """ if len(cols) == 1 and isinstance(cols[0], list): cols = cols[0] jdf = self._jdf.describe(self._jseq(cols)) return DataFrame(jdf, self.sql_ctx)
@ignore_unicode_prefix @since(1.3)
[docs] def head(self, n=None): """Returns the first ``n`` rows. :param n: int, default 1. Number of rows to return. :return: If n is greater than 1, return a list of :class:`Row`. If n is 1, return a single Row. >>> df.head() Row(age=2, name=u'Alice') >>> df.head(1) [Row(age=2, name=u'Alice')] """ if n is None: rs = self.head(1) return rs[0] if rs else None return self.take(n)
@ignore_unicode_prefix @since(1.3)
[docs] def first(self): """Returns the first row as a :class:`Row`. >>> df.first() Row(age=2, name=u'Alice') """ return self.head()
@ignore_unicode_prefix @since(1.3) def __getitem__(self, item): """Returns the column as a :class:`Column`. >>> df.select(df['age']).collect() [Row(age=2), Row(age=5)] >>> df[ ["name", "age"]].collect() [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] >>> df[ df.age > 3 ].collect() [Row(age=5, name=u'Bob')] >>> df[df[0] > 3].collect() [Row(age=5, name=u'Bob')] """ if isinstance(item, basestring): jc = self._jdf.apply(item) return Column(jc) elif isinstance(item, Column): return self.filter(item) elif isinstance(item, (list, tuple)): return self.select(*item) elif isinstance(item, int): jc = self._jdf.apply(self.columns[item]) return Column(jc) else: raise TypeError("unexpected item type: %s" % type(item)) @since(1.3) def __getattr__(self, name): """Returns the :class:`Column` denoted by ``name``. >>> df.select(df.age).collect() [Row(age=2), Row(age=5)] """ if name not in self.columns: raise AttributeError( "'%s' object has no attribute '%s'" % (self.__class__.__name__, name)) jc = self._jdf.apply(name) return Column(jc) @ignore_unicode_prefix @since(1.3)
[docs] def select(self, *cols): """Projects a set of expressions and returns a new :class:`DataFrame`. :param cols: list of column names (string) or expressions (:class:`Column`). If one of the column names is '*', that column is expanded to include all columns in the current DataFrame. >>> df.select('*').collect() [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] >>> df.select('name', 'age').collect() [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] >>> df.select(df.name, (df.age + 10).alias('age')).collect() [Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)] """ jdf = self._jdf.select(self._jcols(*cols)) return DataFrame(jdf, self.sql_ctx)
@since(1.3)
[docs] def selectExpr(self, *expr): """Projects a set of SQL expressions and returns a new :class:`DataFrame`. This is a variant of :func:`select` that accepts SQL expressions. >>> df.selectExpr("age * 2", "abs(age)").collect() [Row((age * 2)=4, abs(age)=2), Row((age * 2)=10, abs(age)=5)] """ if len(expr) == 1 and isinstance(expr[0], list): expr = expr[0] jdf = self._jdf.selectExpr(self._jseq(expr)) return DataFrame(jdf, self.sql_ctx)
@ignore_unicode_prefix @since(1.3)
[docs] def filter(self, condition): """Filters rows using the given condition. :func:`where` is an alias for :func:`filter`. :param condition: a :class:`Column` of :class:`types.BooleanType` or a string of SQL expression. >>> df.filter(df.age > 3).collect() [Row(age=5, name=u'Bob')] >>> df.where(df.age == 2).collect() [Row(age=2, name=u'Alice')] >>> df.filter("age > 3").collect() [Row(age=5, name=u'Bob')] >>> df.where("age = 2").collect() [Row(age=2, name=u'Alice')] """ if isinstance(condition, basestring): jdf = self._jdf.filter(condition) elif isinstance(condition, Column): jdf = self._jdf.filter(condition._jc) else: raise TypeError("condition should be string or Column") return DataFrame(jdf, self.sql_ctx)
where = filter @ignore_unicode_prefix @since(1.3)
[docs] def groupBy(self, *cols): """Groups the :class:`DataFrame` using the specified columns, so we can run aggregation on them. See :class:`GroupedData` for all the available aggregate functions. :func:`groupby` is an alias for :func:`groupBy`. :param cols: list of columns to group by. Each element should be a column name (string) or an expression (:class:`Column`). >>> df.groupBy().avg().collect() [Row(avg(age)=3.5)] >>> df.groupBy('name').agg({'age': 'mean'}).collect() [Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)] >>> df.groupBy(df.name).avg().collect() [Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)] >>> df.groupBy(['name', df.age]).count().collect() [Row(name=u'Bob', age=5, count=1), Row(name=u'Alice', age=2, count=1)] """ jgd = self._jdf.groupBy(self._jcols(*cols)) from pyspark.sql.group import GroupedData return GroupedData(jgd, self.sql_ctx)
@since(1.4)
[docs] def rollup(self, *cols): """ Create a multi-dimensional rollup for the current :class:`DataFrame` using the specified columns, so we can run aggregation on them. >>> df.rollup('name', df.age).count().show() +-----+----+-----+ | name| age|count| +-----+----+-----+ |Alice|null| 1| | Bob| 5| 1| | Bob|null| 1| | null|null| 2| |Alice| 2| 1| +-----+----+-----+ """ jgd = self._jdf.rollup(self._jcols(*cols)) from pyspark.sql.group import GroupedData return GroupedData(jgd, self.sql_ctx)
@since(1.4)
[docs] def cube(self, *cols): """ Create a multi-dimensional cube for the current :class:`DataFrame` using the specified columns, so we can run aggregation on them. >>> df.cube('name', df.age).count().show() +-----+----+-----+ | name| age|count| +-----+----+-----+ | null| 2| 1| |Alice|null| 1| | Bob| 5| 1| | Bob|null| 1| | null| 5| 1| | null|null| 2| |Alice| 2| 1| +-----+----+-----+ """ jgd = self._jdf.cube(self._jcols(*cols)) from pyspark.sql.group import GroupedData return GroupedData(jgd, self.sql_ctx)
@since(1.3)
[docs] def agg(self, *exprs): """ Aggregate on the entire :class:`DataFrame` without groups (shorthand for ``df.groupBy.agg()``). >>> df.agg({"age": "max"}).collect() [Row(max(age)=5)] >>> from pyspark.sql import functions as F >>> df.agg(F.min(df.age)).collect() [Row(min(age)=2)] """ return self.groupBy().agg(*exprs)
@since(1.3)
[docs] def unionAll(self, other): """ Return a new :class:`DataFrame` containing union of rows in this frame and another frame. This is equivalent to `UNION ALL` in SQL. """ return DataFrame(self._jdf.unionAll(other._jdf), self.sql_ctx)
@since(1.3)
[docs] def intersect(self, other): """ Return a new :class:`DataFrame` containing rows only in both this frame and another frame. This is equivalent to `INTERSECT` in SQL. """ return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx)
@since(1.3)
[docs] def subtract(self, other): """ Return a new :class:`DataFrame` containing rows in this frame but not in another frame. This is equivalent to `EXCEPT` in SQL. """ return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx)
@since(1.4)
[docs] def dropDuplicates(self, subset=None): """Return a new :class:`DataFrame` with duplicate rows removed, optionally only considering certain columns. :func:`drop_duplicates` is an alias for :func:`dropDuplicates`. >>> from pyspark.sql import Row >>> df = sc.parallelize([ \ Row(name='Alice', age=5, height=80), \ Row(name='Alice', age=5, height=80), \ Row(name='Alice', age=10, height=80)]).toDF() >>> df.dropDuplicates().show() +---+------+-----+ |age|height| name| +---+------+-----+ | 5| 80|Alice| | 10| 80|Alice| +---+------+-----+ >>> df.dropDuplicates(['name', 'height']).show() +---+------+-----+ |age|height| name| +---+------+-----+ | 5| 80|Alice| +---+------+-----+ """ if subset is None: jdf = self._jdf.dropDuplicates() else: jdf = self._jdf.dropDuplicates(self._jseq(subset)) return DataFrame(jdf, self.sql_ctx)
@since("1.3.1")
[docs] def dropna(self, how='any', thresh=None, subset=None): """Returns a new :class:`DataFrame` omitting rows with null values. :func:`DataFrame.dropna` and :func:`DataFrameNaFunctions.drop` are aliases of each other. :param how: 'any' or 'all'. If 'any', drop a row if it contains any nulls. If 'all', drop a row only if all its values are null. :param thresh: int, default None If specified, drop rows that have less than `thresh` non-null values. This overwrites the `how` parameter. :param subset: optional list of column names to consider. >>> df4.na.drop().show() +---+------+-----+ |age|height| name| +---+------+-----+ | 10| 80|Alice| +---+------+-----+ """ if how is not None and how not in ['any', 'all']: raise ValueError("how ('" + how + "') should be 'any' or 'all'") if subset is None: subset = self.columns elif isinstance(subset, basestring): subset = [subset] elif not isinstance(subset, (list, tuple)): raise ValueError("subset should be a list or tuple of column names") if thresh is None: thresh = len(subset) if how == 'any' else 1 return DataFrame(self._jdf.na().drop(thresh, self._jseq(subset)), self.sql_ctx)
@since("1.3.1")
[docs] def fillna(self, value, subset=None): """Replace null values, alias for ``na.fill()``. :func:`DataFrame.fillna` and :func:`DataFrameNaFunctions.fill` are aliases of each other. :param value: int, long, float, string, or dict. Value to replace null values with. If the value is a dict, then `subset` is ignored and `value` must be a mapping from column name (string) to replacement value. The replacement value must be an int, long, float, or string. :param subset: optional list of column names to consider. Columns specified in subset that do not have matching data type are ignored. For example, if `value` is a string, and subset contains a non-string column, then the non-string column is simply ignored. >>> df4.na.fill(50).show() +---+------+-----+ |age|height| name| +---+------+-----+ | 10| 80|Alice| | 5| 50| Bob| | 50| 50| Tom| | 50| 50| null| +---+------+-----+ >>> df4.na.fill({'age': 50, 'name': 'unknown'}).show() +---+------+-------+ |age|height| name| +---+------+-------+ | 10| 80| Alice| | 5| null| Bob| | 50| null| Tom| | 50| null|unknown| +---+------+-------+ """ if not isinstance(value, (float, int, long, basestring, dict)): raise ValueError("value should be a float, int, long, string, or dict") if isinstance(value, (int, long)): value = float(value) if isinstance(value, dict): return DataFrame(self._jdf.na().fill(value), self.sql_ctx) elif subset is None: return DataFrame(self._jdf.na().fill(value), self.sql_ctx) else: if isinstance(subset, basestring): subset = [subset] elif not isinstance(subset, (list, tuple)): raise ValueError("subset should be a list or tuple of column names") return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx)
@since(1.4)
[docs] def replace(self, to_replace, value, subset=None): """Returns a new :class:`DataFrame` replacing a value with another value. :func:`DataFrame.replace` and :func:`DataFrameNaFunctions.replace` are aliases of each other. :param to_replace: int, long, float, string, or list. Value to be replaced. If the value is a dict, then `value` is ignored and `to_replace` must be a mapping from column name (string) to replacement value. The value to be replaced must be an int, long, float, or string. :param value: int, long, float, string, or list. Value to use to replace holes. The replacement value must be an int, long, float, or string. If `value` is a list or tuple, `value` should be of the same length with `to_replace`. :param subset: optional list of column names to consider. Columns specified in subset that do not have matching data type are ignored. For example, if `value` is a string, and subset contains a non-string column, then the non-string column is simply ignored. >>> df4.na.replace(10, 20).show() +----+------+-----+ | age|height| name| +----+------+-----+ | 20| 80|Alice| | 5| null| Bob| |null| null| Tom| |null| null| null| +----+------+-----+ >>> df4.na.replace(['Alice', 'Bob'], ['A', 'B'], 'name').show() +----+------+----+ | age|height|name| +----+------+----+ | 10| 80| A| | 5| null| B| |null| null| Tom| |null| null|null| +----+------+----+ """ if not isinstance(to_replace, (float, int, long, basestring, list, tuple, dict)): raise ValueError( "to_replace should be a float, int, long, string, list, tuple, or dict") if not isinstance(value, (float, int, long, basestring, list, tuple)): raise ValueError("value should be a float, int, long, string, list, or tuple") rep_dict = dict() if isinstance(to_replace, (float, int, long, basestring)): to_replace = [to_replace] if isinstance(to_replace, tuple): to_replace = list(to_replace) if isinstance(value, tuple): value = list(value) if isinstance(to_replace, list) and isinstance(value, list): if len(to_replace) != len(value): raise ValueError("to_replace and value lists should be of the same length") rep_dict = dict(zip(to_replace, value)) elif isinstance(to_replace, list) and isinstance(value, (float, int, long, basestring)): rep_dict = dict([(tr, value) for tr in to_replace]) elif isinstance(to_replace, dict): rep_dict = to_replace if subset is None: return DataFrame(self._jdf.na().replace('*', rep_dict), self.sql_ctx) elif isinstance(subset, basestring): subset = [subset] if not isinstance(subset, (list, tuple)): raise ValueError("subset should be a list or tuple of column names") return DataFrame( self._jdf.na().replace(self._jseq(subset), self._jmap(rep_dict)), self.sql_ctx)
@since(1.4)
[docs] def corr(self, col1, col2, method=None): """ Calculates the correlation of two columns of a DataFrame as a double value. Currently only supports the Pearson Correlation Coefficient. :func:`DataFrame.corr` and :func:`DataFrameStatFunctions.corr` are aliases of each other. :param col1: The name of the first column :param col2: The name of the second column :param method: The correlation method. Currently only supports "pearson" """ if not isinstance(col1, str): raise ValueError("col1 should be a string.") if not isinstance(col2, str): raise ValueError("col2 should be a string.") if not method: method = "pearson" if not method == "pearson": raise ValueError("Currently only the calculation of the Pearson Correlation " + "coefficient is supported.") return self._jdf.stat().corr(col1, col2, method)
@since(1.4)
[docs] def cov(self, col1, col2): """ Calculate the sample covariance for the given columns, specified by their names, as a double value. :func:`DataFrame.cov` and :func:`DataFrameStatFunctions.cov` are aliases. :param col1: The name of the first column :param col2: The name of the second column """ if not isinstance(col1, str): raise ValueError("col1 should be a string.") if not isinstance(col2, str): raise ValueError("col2 should be a string.") return self._jdf.stat().cov(col1, col2)
@since(1.4)
[docs] def crosstab(self, col1, col2): """ Computes a pair-wise frequency table of the given columns. Also known as a contingency table. The number of distinct values for each column should be less than 1e4. At most 1e6 non-zero pair frequencies will be returned. The first column of each row will be the distinct values of `col1` and the column names will be the distinct values of `col2`. The name of the first column will be `$col1_$col2`. Pairs that have no occurrences will have zero as their counts. :func:`DataFrame.crosstab` and :func:`DataFrameStatFunctions.crosstab` are aliases. :param col1: The name of the first column. Distinct items will make the first item of each row. :param col2: The name of the second column. Distinct items will make the column names of the DataFrame. """ if not isinstance(col1, str): raise ValueError("col1 should be a string.") if not isinstance(col2, str): raise ValueError("col2 should be a string.") return DataFrame(self._jdf.stat().crosstab(col1, col2), self.sql_ctx)
@since(1.4)
[docs] def freqItems(self, cols, support=None): """ Finding frequent items for columns, possibly with false positives. Using the frequent element count algorithm described in "http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou". :func:`DataFrame.freqItems` and :func:`DataFrameStatFunctions.freqItems` are aliases. .. note:: This function is meant for exploratory data analysis, as we make no \ guarantee about the backward compatibility of the schema of the resulting DataFrame. :param cols: Names of the columns to calculate frequent items for as a list or tuple of strings. :param support: The frequency with which to consider an item 'frequent'. Default is 1%. The support must be greater than 1e-4. """ if isinstance(cols, tuple): cols = list(cols) if not isinstance(cols, list): raise ValueError("cols must be a list or tuple of column names as strings.") if not support: support = 0.01 return DataFrame(self._jdf.stat().freqItems(_to_seq(self._sc, cols), support), self.sql_ctx)
@ignore_unicode_prefix @since(1.3)
[docs] def withColumn(self, colName, col): """ Returns a new :class:`DataFrame` by adding a column or replacing the existing column that has the same name. :param colName: string, name of the new column. :param col: a :class:`Column` expression for the new column. >>> df.withColumn('age2', df.age + 2).collect() [Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)] """ assert isinstance(col, Column), "col should be Column" return DataFrame(self._jdf.withColumn(colName, col._jc), self.sql_ctx)
@ignore_unicode_prefix @since(1.3)
[docs] def withColumnRenamed(self, existing, new): """Returns a new :class:`DataFrame` by renaming an existing column. :param existing: string, name of the existing column to rename. :param col: string, new name of the column. >>> df.withColumnRenamed('age', 'age2').collect() [Row(age2=2, name=u'Alice'), Row(age2=5, name=u'Bob')] """ return DataFrame(self._jdf.withColumnRenamed(existing, new), self.sql_ctx)
@since(1.4) @ignore_unicode_prefix
[docs] def drop(self, col): """Returns a new :class:`DataFrame` that drops the specified column. :param col: a string name of the column to drop, or a :class:`Column` to drop. >>> df.drop('age').collect() [Row(name=u'Alice'), Row(name=u'Bob')] >>> df.drop(df.age).collect() [Row(name=u'Alice'), Row(name=u'Bob')] >>> df.join(df2, df.name == df2.name, 'inner').drop(df.name).collect() [Row(age=5, height=85, name=u'Bob')] >>> df.join(df2, df.name == df2.name, 'inner').drop(df2.name).collect() [Row(age=5, name=u'Bob', height=85)] """ if isinstance(col, basestring): jdf = self._jdf.drop(col) elif isinstance(col, Column): jdf = self._jdf.drop(col._jc) else: raise TypeError("col should be a string or a Column") return DataFrame(jdf, self.sql_ctx)
@ignore_unicode_prefix
[docs] def toDF(self, *cols): """Returns a new class:`DataFrame` that with new specified column names :param cols: list of new column names (string) >>> df.toDF('f1', 'f2').collect() [Row(f1=2, f2=u'Alice'), Row(f1=5, f2=u'Bob')] """ jdf = self._jdf.toDF(self._jseq(cols)) return DataFrame(jdf, self.sql_ctx)
@since(1.3)
[docs] def toPandas(self): """Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. This is only available if Pandas is installed and available. >>> df.toPandas() # doctest: +SKIP age name 0 2 Alice 1 5 Bob """ import pandas as pd return pd.DataFrame.from_records(self.collect(), columns=self.columns)
########################################################################################## # Pandas compatibility ########################################################################################## groupby = groupBy drop_duplicates = dropDuplicates
# Having SchemaRDD for backward compatibility (for docs) class SchemaRDD(DataFrame): """SchemaRDD is deprecated, please use :class:`DataFrame`. """ def _to_scala_map(sc, jm): """ Convert a dict into a JVM Map. """ return sc._jvm.PythonUtils.toScalaMap(jm)
[docs]class DataFrameNaFunctions(object): """Functionality for working with missing data in :class:`DataFrame`. .. versionadded:: 1.4 """ def __init__(self, df): self.df = df
[docs] def drop(self, how='any', thresh=None, subset=None): return self.df.dropna(how=how, thresh=thresh, subset=subset)
drop.__doc__ = DataFrame.dropna.__doc__
[docs] def fill(self, value, subset=None): return self.df.fillna(value=value, subset=subset)
fill.__doc__ = DataFrame.fillna.__doc__
[docs] def replace(self, to_replace, value, subset=None): return self.df.replace(to_replace, value, subset)
replace.__doc__ = DataFrame.replace.__doc__
[docs]class DataFrameStatFunctions(object): """Functionality for statistic functions with :class:`DataFrame`. .. versionadded:: 1.4 """ def __init__(self, df): self.df = df
[docs] def corr(self, col1, col2, method=None): return self.df.corr(col1, col2, method)
corr.__doc__ = DataFrame.corr.__doc__
[docs] def cov(self, col1, col2): return self.df.cov(col1, col2)
cov.__doc__ = DataFrame.cov.__doc__
[docs] def crosstab(self, col1, col2): return self.df.crosstab(col1, col2)
crosstab.__doc__ = DataFrame.crosstab.__doc__
[docs] def freqItems(self, cols, support=None): return self.df.freqItems(cols, support)
freqItems.__doc__ = DataFrame.freqItems.__doc__
[docs] def sampleBy(self, col, fractions, seed=None): return self.df.sampleBy(col, fractions, seed)
sampleBy.__doc__ = DataFrame.sampleBy.__doc__
def _test(): import doctest from pyspark.context import SparkContext from pyspark.sql import Row, SQLContext import pyspark.sql.dataframe globs = pyspark.sql.dataframe.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc globs['sqlContext'] = SQLContext(sc) globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')])\ .toDF(StructType([StructField('age', IntegerType()), StructField('name', StringType())])) globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF() globs['df3'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]).toDF() globs['df4'] = sc.parallelize([Row(name='Alice', age=10, height=80), Row(name='Bob', age=5, height=None), Row(name='Tom', age=None, height=None), Row(name=None, age=None, height=None)]).toDF() (failure_count, test_count) = doctest.testmod( pyspark.sql.dataframe, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) globs['sc'].stop() if failure_count: exit(-1) if __name__ == "__main__": _test()