ApacheSparkで扱うobjectのSerializableの必要性について

ApacheSparkで扱うobjectのSerializableの必要性について

hiveやファイルからデータを読み込んだ直後値はRDD, Dataset, DataFrameになっていて、少ないデータに対して何回もfilter処理を行う必要がある場合に一旦collectして配列に変換しdriver内で処理したい場合もあると思うけど、データを読み込んだ際にnon-serializableなクラスに値をセットしていたらcollectで配列への変換時にエラーが発生したのでその際のメモ

例えば以下のようなjavaのクラスがあったとして、これをsparkで利用するとする

public class IdsBean {
  private int id;
  public IdsBean(int id) { /* compiled code */ }
  public int getId() { /* compiled code */ }
  public void setId(int id) { /* compiled code */ }
}
scala> val rdd = sc.makeRDD(1 to 10)
scala> val rddBean = rdd.map(new IdsBean(_))

foreachやmapの処理は行える

scala> rddBean.foreach{ row => println(row.getId) }
6
7
8
9
10
1
2
3
4
5
scala> rddBean.map(row=>row).foreach(row=>println(row.getId))
1
2
3
4
5
6
7
8
9
10

ただcollectの処理はエラーが発生する

scala> rddBean.collect
[Stage 10:>                                                         (0 + 0) / 2]17/10/07 13:43:41 ERROR executor.Executor: Exception in task 0.0 in stage 10.0 (TID 20)
java.io.NotSerializableException: dto.IdsBean
Serialization stack:
    - object not serializable (class: dto.IdsBean, value: dto.IdsBean@2872403e)
    - element of array (index: 0)
    ...

collectの処理ではRDDは各executorに分散されていてそれをdriverに集めるのだが、その際にserialize → deserializeの処理が実行されるのだが、対象のクラスがnon-serializableの場合は シリアライズできないのでcollectの処理を実行する時点でエラーが発生するようだ。

対応策として以下のように対象のクラスを継承してserializableにすることを考えたがそれではダメだった。

scala> class IdsBeanS(id: Int) extends dto.IdsBean(id) with java.io.Serializable

これについてはjava.io.Serializableのjavadocでコメントがあり、non-serializableなクラスのsubtypeをSerilizableにする場合はスーパークラスのpublic, protectedなメンバ変数がシリアライズの対象となるようだ。

To allow subtypes of non-serializable classes to be serialized, the
subtype may assume responsibility for saving and restoring the
state of the supertype's public, protected, and (if accessible)
package fields.  The subtype may assume this responsibility only if
the class it extends has an accessible no-arg constructor to
initialize the class's state.  It is an error to declare a class
Serializable if this is not the case.  The error will be detected at
runtime.

また、executor-driver間でobjectを共有する場合はSerializableでになっている必要がわかった。dirver-executor間で値を共有するタイミングだがmapやforeach、collectなどのapiを実行時のようで、シリアライズ、でシリアライズの対象になっているobjectはapiによっても違っているようだ。 例えば先ほどのnon-serializableなRDDについて以下の処理であれば問題ない。

scala> rddBean.foreach{ row => println(row.getId) }

だが、mapやforeachの処理内でnon-sirializableなクラスを使用しようとしたらエラーが発生する。mapやforeachの内部で使うobjectは各executorに送られるのでserializableである必要があるようだ。

scala> val rdd = sc.makeRDD(1 to 10)
scala> rdd.foreach{rddRow =>
     |   rddBean.foreach(beanRow => println(beanRow.getId))
     | }
17/10/07 14:09:03 ERROR executor.Executor: Exception in task 0.0 in stage 13.0 (TID 26)
org.apache.spark.SparkException: This RDD lacks a SparkContext. It could happen in the following cases:
(1) RDD transformations and actions are NOT invoked by the driver, but inside of other transformations; for example, rdd1.map(x => rdd2.values.count() * x) is invalid because the values transformation and count action cannot be performed inside of the rdd1.map transformation. For more information, see SPARK-5063.

いつexecutor間にobjectが送られるのかは知っていた方が良いけど、基本的にはSerializableなobjectを使うのが安全そう。 あと同一のexecutorにobjectが送られたとしてもtask内でシリアライズ、デシリアライズして使用するからtask間で同じobjectを参照するということはないらしい。 (基本的なことだけど、object自体が状態を持っていてスレッドセーフでなければ問題が発生するかもしれない) https://twitter.com/maropu/status/889747740858568704

Scalaでseqを操作してみる

scalaでSeqを操作してみる

まず以下のcase classがあったとし、

case class Element(id: Int, time: java.sql.Timestamp)

初期のデータとして以下を保持する

val elementSeq = Array(
  Element(1, new java.sql.Timestamp(new DateTime(2017, 8, 10, 16, 13).getMillis))
  , Element(2, new java.sql.Timestamp(new DateTime(2017, 8, 9, 11, 5).getMillis))
  , Element(3, new java.sql.Timestamp(new DateTime(2017, 5, 22, 9, 13).getMillis))
  , Element(4, new java.sql.Timestamp(new DateTime(2017, 9, 1, 22, 13).getMillis))
  , Element(5, new java.sql.Timestamp(new DateTime(2017, 7, 31, 23, 13).getMillis))
  , Element(6, new java.sql.Timestamp(new DateTime(2017, 8, 15, 12, 7).getMillis))
).toSeq

ソート

Timestampの降順でソートするのは以下のようになる。TimestampはComparableをimplementsしていないので、ソートする際はgetTimeでミリ秒に変換などする

val sortedArray = elementSeq.sortBy(-_.time.getTime)
sortedArray.foreach(println)

unionでseqを結合する。

val elementSeq2 = Array(
  Element(1, new java.sql.Timestamp(new DateTime(2017, 8, 12, 15, 11).getMillis))
  , Element(2, new java.sql.Timestamp(new DateTime(2017, 8, 7, 6, 5).getMillis))
  , Element(3, new java.sql.Timestamp(new DateTime(2017, 5, 9, 16, 13).getMillis))
  , Element(4, new java.sql.Timestamp(new DateTime(2017, 9, 1, 23, 59).getMillis))
  , Element(5, new java.sql.Timestamp(new DateTime(2017, 7, 31, 23, 12).getMillis))
  , Element(6, new java.sql.Timestamp(new DateTime(2017, 8, 14, 23, 4).getMillis))
).toSeq

val unionTimeSeq = elementSeq.map(_.time)
  .union(elementSeq2.map(_.time))
  .sortBy(_.getTime)
unionTimeSeq.foreach(println)

要素の追加とフィルター

次にstartDateで指定した日付以降でフィルターして昇順でソートするのは以下のようになる。

val startDate = new java.sql.Timestamp(new DateTime(2017, 6, 1, 0, 0).getMillis)
val timeArray = (startDate +: unionTimeSeq)
  .filter(_.getTime >= startDate.getTime)
  .sortBy(_.getTime)
timeArray.foreach(println)

指定日の直前のレコードの取り出し

日付のSeqでmapし、事前に日付の降順でソートされたSeqから指定日以前の一番新しいレコードを取り出す処理は以下のようになる。

val validRecordSeq = timeArray.map{time=>
    val record = sortedArray
      .filter(_.time.getTime <= time.getTime)
      .headOption.getOrElse(Element(0, new java.sql.Timestamp(new DateTime(2017, 6, 1, 0, 0).getMillis)))
    (time, record)
}
validRecordSeq.foreach(println)

sparkからhiveを利用してみる

spark-shellにてクラスパスを指定する

spark-shell --driver-class-path 対象クラスパス

開発時にちょっと修正後にいちいちビルドしてデプロイして実行するのが面倒なので、インタラクティブシェルにて動作を確認後、ソースに反映の流れにしたい

hive

SQLを実行してみる

パッケージのインポートからselect文実行まで 以下のテーブルを使用するものとする

show create table sample;
+-----------------------------------------------------------------+
                        createtab_stmt                          |
+-----------------------------------------------------------------+
CREATE  TABLE sample(                                           |
  id int)                                                       |
ROW FORMAT SERDE                                                |
  'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe'          |
STORED AS INPUTFORMAT                                           |
  'org.apache.hadoop.mapred.TextInputFormat'                    |
OUTPUTFORMAT                                                    |
  'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat'  |
  ...                                                           |
+-----------------------------------------------------------------+
import org.apache.spark._
import org.apache.spark.sql.hive._
val hc = new HiveContext(sc)

val select="select * from sample"
val sqlResult = hc.sql(select)
sqlResult.foreach(row=>println(row))

取得対象の絡むと型を指定する

sqlResult.foreach(row=>println(row.getAs[Int]("id")))

sqlの実行結果をmapで型にセットする

sqlResult.map(row => new IdsBean(row.getAs[Int]("id")))

この時変換先の型がserizableでないとエラーになるので、既存のjava資源でserizableをimplementしていない型にセットする場合は、 scalaの方で利用できるように拡張する必要がある

case class IdsSBean(id: Int) extends dto.IdsBean(id) with java.io.Serializable
val idsRDD = sqlResult.map(row => new IdsSBean(row.getAs[Int]("id")))

RDDから配列に変換する

idsRDD.collect

RDDからSeqに変換する

idsRDD.collect.toSeq

summarizationsパターンを試してみる

簡単な数値の集計を行ってみたいと思います。

まず動作確認に使うデータを登録します。 テーブル作成

create table numerical_input(
  user_id int
  , input int
);

動作確認に使うファイルをcsvファイルに保存してhdfsにアップロード ^Aは制御文字になっておりvimであればCtrl +V Ctrl + Aでファイルに入力できる

# vim numerical_input.txt

12345^A10
12345^A8
12345^A21
54321^A1
54321^A47
54321^A8
88888^A7
88888^A12

# hdfs dfs -put numerical_input.txt /input/

それからテーブルにデータを取り込む

load data inpath '/input/numerical_input.txt' into table numerical_input;
select * from numerical_input;

次にscalaでデータをRDDとして読み込んでみる

scala> hc.sql("select * from numerical_input")
res24: org.apache.spark.sql.DataFrame = [user_id: int, input: int]

scala> val numericalRDD = hc.sql("select * from numerical_input") map { row =>
     | (row.getAs[Int]("user_id"), row.getAs[Int]("input"), 1)
     | }

scala> numericalRDD.show
+-----+---+---+
|   _1| _2| _3|
+-----+---+---+
|12345| 10|  1|
|12345|  8|  1|
|12345| 21|  1|
|54321|  1|  1|
|54321| 47|  1|
|54321|  8|  1|
|88888|  7|  1|
|88888| 12|  1|
+-----+---+---+

Datasetのapiを実行してみる

where

scala> numericalRDD.where($"_1" > 60000).show
+-----+---+---+
|   _1| _2| _3|
+-----+---+---+
|88888|  7|  1|
|88888| 12|  1|
+-----+---+---+

sort

scala> numericalRDD.sort($"_2").show
+-----+---+---+
|   _1| _2| _3|
+-----+---+---+
|54321|  1|  1|
|88888|  7|  1|
|12345|  8|  1|
|54321|  8|  1|
|12345| 10|  1|
|88888| 12|  1|
|12345| 21|  1|
|54321| 47|  1|
+-----+---+---+

scala側でデータが読み込めるようになったのでタプルの一番目にuser_idでグルーピングを行い、タプルに2番目の要素に最大値、3番目に最小値、4番目にカウント結果が入るようにしてみる。

scala> numericalRDD.groupBy($"_1" as "user_group").agg(max($"_2"), min($"_2"), count($"_3")).show
+----------+-------+-------+---------+
|user_group|max(_2)|min(_2)|count(_3)|
+----------+-------+-------+---------+
|     54321|     47|      1|        3|
|     88888|     12|      7|        2|
|     12345|     21|      8|        3|
+----------+-------+-------+---------+

Hiveの環境構築

Hive環境構築

インストール

1.javaのインストール

7系のjavaをインストールしてパスを通しておきます。

export JAVA_HOME=/usr/local/jdk1.7.0_71
export PATH=$PATH:$JAVA_HOME/bin

2.Hadoopのインストール

hadoop version

パスの設定

export HADOOP_HOME=/usr/local/hadoop
export HADOOP_MAPRED_HOME=$HADOOP_HOME
export HADOOP_COMMON_HOME=$HADOOP_HOME
export HADOOP_HDFS_HOME=$HADOOP_HOME
export YARN_HOME=$HADOOP_HOME
export HADOOP_COMMON_LIB_NATIVE_DIR=$HADOOP_HOME/lib/native export
PATH=$PATH:$HADOOP_HOME/sbin:$HADOOP_HOME/bin

core-site.xmlにNameNodeの情報を設定する

<configuration>

   <property>
      <name>fs.default.name</name>
      <value>hdfs://localhost:9000</value>
   </property>

</configuration>

hdfs-site.xmlを編集してNameNode,DataNodeのデータ保存先を設定する

<configuration>

   <property>
      <name>dfs.replication</name>
      <value>1</value>
   </property>
   <property>
      <name>dfs.namenode.name.dir</name>
      <value>file:///home/hadoop/hadoopinfra/hdfs/namenode</value>
   </property>
   <property>
      <name>dfs.namenode.data.dir</name>
      <value>file:///home/hadoop/hadoopinfra/hdfs/datanode</value >
   </property>

</configuration>

yarn-site.xmlを編集yarn.nodemanager.aux-servicesにmapreduce_shuffleを設定する

<configuration>

   <property>
      <name>yarn.nodemanager.aux-services</name>
      <value>mapreduce_shuffle</value>
   </property>

</configuration>

map-red-site.xmlを変数しmapreduce.framework.nameにyarnを設定する

<configuration>

   <property>
      <name>mapreduce.framework.name</name>
      <value>yarn</value>
   </property>

</configuration>

NameNodeをフォーマットする

hdfs namenode -format

3.Hiveのインストール

Hiveをダウンロードする。インストール済みのhadoopにあっているバージョンを選びます。 https://hive.apache.org/downloads.html ダウンロード後に解凍します。それから、パスを通します。

export HIVE_HOME=/opt/
export PATH=$PATH:$HIVE_HOME/bin
export CLASSPATH=$CLASSPATH:$HADOOP_HOME/lib/*:.
export CLASSPATH=$CLASSPATH:$HIVE_HOME/lib/*:.

hive-env.shを有効にする

cd $HIVE_HOME/conf cp hive-env.sh.template hive-env.sh

hiveのメタ情報保存先の設定

今回はPostgreSQLにhiveのめた情報を保存するようにします。 postgresqlをインストールします。hiveインストール環境からアクセスできるようにしておきます。

yum install postgresql-server postgresql-setup initdb systemctl start postgresql systemctl enable postgresql

ドライバをhiveのlibに移動します。

wget https://jdbc.postgresql.org/download/postgresql-9.3-1103.jdbc4.jar
mv postgresql-9.3-1103.jdbc4.jar /opt/hive-0.12.0/lib/

PostgreSQLにHiveで使うユーザとDBを作成します。

createuser -P hive createdb -O hive hive

メタ情報保存に使スキーマを実行します /opt/hive-0.12.0/scripts/metastore/upgrade/postgres/hive-schema-0.12.0.postgres.sql

hive-site.xmlを編集する

<property>
  <name>javax.jdo.option.ConnectionURL</name>
  <value>jdbc:postgresql://ポスグレインストール先:5432/hive</value>
  <description>JDBC connect string for a JDBC metastore</description>
</property>

<property>
  <name>javax.jdo.option.ConnectionDriverName</name>
  <value>org.postgresql.Driver</value>
  <description></description>
</property>

<property>
  <name>javax.jdo.option.ConnectionUserName</name>
  <value>hive</value>
  <description>username to use against metastore database</description>
</property>

<property>
  <name>javax.jdo.option.ConnectionPassword</name>
  <value>hive</value>
  <description>password to use against metastore database</description>
</property>

動作確認

# hive

hive> show tables;
OK
Time taken: 3.323 seconds

Hiveserver2を起動してbeelineで接続してみる

hive-site.xmlを変数する 今回は動作確認のため認証を行わなくても接続できるようにする。

<property>
  <name>hive.server2.authentication</name>
  <value>NOSASL</value> <!-- default NONE is for SASLTransport -->
</property>

<property>
  <name>hive.server2.enable.doAs</name>
  <value>false</value> <!-- Execute query as hiveserver2 process user -->
</property>

hiveserver2を起動する

$HIVE_HOME/bin/hiveserver2 &

beelineで接続してみる

# beeline !connect jdbc:hive2://localhost:10000/default;auth=noSasl hive org.apache.hive.jdbc.HiveDriver

※接続先のポートはhive-site.xmlのhive.server2.thrift.portを確認する

hiveqlを実行してみる

データ登録の確認のためまず以下のようなテキストファイルを作成しhdfsに上げておく

1
2
3
4
5

それからbeelineで接続し以下を実行する

create table sample (
  id INT
);

load data inpath 'ファイルパス' into table sample;

select *
from sample;

インサートしたレコードはhdfs上にあるのが以下のコマンドで確認できる

hdfs dfs -ls hive-site.xmlでhive.metastore.warehouse.dirに指定しているパス/DB名

またメタ情報がPostgreSQLに保存されていることも確認できる、例えばテーブル名の情報が保存される

hive=# \d "TBLS"
 TBL_ID             | bigint                 | not null
 CREATE_TIME        | bigint                 | not null
 DB_ID              | bigint                 |
 LAST_ACCESS_TIME   | bigint                 | not null
 OWNER              | character varying(767) | default NULL::character varying
 RETENTION          | bigint                 | not null
 SD_ID              | bigint                 |
 TBL_NAME           | character varying(128) | default NULL::character varying
 TBL_TYPE           | character varying(128) | default NULL::character varying
 VIEW_EXPANDED_TEXT | text                   |
 VIEW_ORIGINAL_TEXT | text                   |

 hive=# select * from "TBLS";
      1 |  1506262292 |     1 |                0 | root  |         0 |     1 | sample   | MANAGED_TABLE |
  |

Apache SparkからHiveを利用する

spark-shellでインタラクティブシェルから実行してみる

spark-shellコマンドを実行することでインタラクティブにsparkを実行することができます。

# spark-shell

spark-shell実行時に以下のようなエラーが出た場合は、

org.apache.hadoop.hive.metastore.api.MetaException: Hive Schema version 1.2.0 does not match metastore's schema version 0.12.0 Metastore is not upgraded or corrupt

hive-site.xmlのhive.metastore.schema.verificationにfalseを指定することでうまくいくようになるかもしれないです。

<property>
  <name>hive.metastore.schema.verification</name>
  <value>false</value>
   <description>
   </description>
</property>

spark-shellではscにSparkContextがセットされている

scala> sc
res0: org.apache.spark.SparkContext = org.apache.spark.SparkContext@519e67e
import org.apache.spark._
import org.apache.spark.sql.hive._
val hc = new HiveContext(sc)



val select="select * from sample"
val sqlResult = hc.sql(select)
sqlResult.foreach(row=>println(row))

次に必要なクラスのインポート後hiveのコンテキストを初期化してみます。

scala> import org.apache.spark._
import org.apache.spark._

scala> import org.apache.spark.sql.hive._
import org.apache.spark.sql.hive._

scala> val hc = new HiveContext(sc)
warning: there was one deprecation warning; re-run with -deprecation for details
hc: org.apache.spark.sql.hive.HiveContext = org.apache.spark.sql.hive.HiveContext@8167f57

試しにselectを実行してみます。

scala> val select="select * from sample"
select: String = select * from sample

scala> val sqlResult = hc.sql(select)
sqlResult: org.apache.spark.sql.DataFrame = [id: int]

scala> sqlResult.foreach(row=>println(row))
[Stage 0:>                                                          (0 + 2) / 2]
[1]
...

pysparkから実行してみる

次にpysparkからpythonスクリプトでhiveを利用してみたいと思います。

pyspark

pysparkでもspark-shellと同様にscにSparkContenxtがセットされています。

>>> sc
<pyspark.context.SparkContext object at 0x11029d0>

HiveContextをインポートしてコンテキストを初期化します。

>>> from pyspark.sql import HiveContext
>>> sqlContext = HiveContext(sc)

それからSQLを実行してみます。

>>> sqlContext.sql("select * from sample").show()
+---+
| id|
+---+
|  1|
|  2|
|  3|
|  4|
|  5|
+---+

pysparkからでもhiveに接続してデータを取ってこれることが確認できました。

Apache Sparkのアプリをデバッグする

sparkアプリケーションのデバッグ

1.sbt assemblyでjarファイルを生成しspark-submitコマンド実行サーバにアップロードする

2.spark-submitコマンド実行サーバにポートフォワードの設定付きでssh接続する
とりあえず5039ポートを使ってみる

ssh -L 5039:remote:5039 target

3.spark-submitコマンド実行

spark-submit --master local[*] \    
--driver-java-options -agentlib:jdwp=transport=dt_socket,server=y,suspend=y,address=5039 \    
--class 実行対象クラス \    
--name アプリケーション名 jarファイル アプリの引数    

4.ローカルの開発環境でリモートデバッグ

pycharmを使ってpysparkの開発を行った際に"from pyspark.sql.functions import lit"でエラーがでたのを調べて見た

pysparkの開発を行った際に"from pyspark.sql.functions import lit"でimportできないとエラーが出たのを確認した時のメモ 実際は以下のようにpyspark.sql.functions.py内で以下のようにして動的にメソッドを追加している。

def _create_function(name, doc=""):
    """ Create a function for aggregator by name"""
    def _(col):
        sc = SparkContext._active_spark_context
        jc = getattr(sc._jvm.functions, name)(col._jc if isinstance(col, Column) else col)
        return Column(jc)
    _.__name__ = name
    _.__doc__ = doc
    return _


_functions = {
    'lit': _lit_doc,
    'col': 'Returns a :class:`Column` based on the given column name.',
    'column': 'Returns a :class:`Column` based on the given column name.',
    'asc': 'Returns a sort expression based on the ascending order of the given column name.',
    'desc': 'Returns a sort expression based on the descending order of the given column name.',

    'upper': 'Converts a string expression to upper case.',
    'lower': 'Converts a string expression to upper case.',
    'sqrt': 'Computes the square root of the specified float value.',
    'abs': 'Computes the absolute value.',

    'max': 'Aggregate function: returns the maximum value of the expression in a group.',
    'min': 'Aggregate function: returns the minimum value of the expression in a group.',
    'count': 'Aggregate function: returns the number of items in a group.',
    'sum': 'Aggregate function: returns the sum of all values in the expression.',
    'avg': 'Aggregate function: returns the average of the values in a group.',
'mean': 'Aggregate function: returns the average of the values in a group.',
    'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.',
}

for _name, _doc in _functions.items():
    globals()[_name] = since(1.3)(_create_function(_name, _doc))

ここではcreate_functionでメソッドを生成し、globals()[name]にてname(col)で関数を呼び出せるようにしている。getattrでは"sc.jvm.functions"のnameで指定した関数を呼び出せるようにしており、ここでjvm場で動いているsparkを呼び出すようにしている。pysparkではpythonのコードがjvm場で動くという分けではなくpy4jにより連携するようになっていて、その連携部分が"getattr(sc.jvm.functions, name)(col.jc if isinstance(col, Column) else col)“のようでpyspark自体についてももうちょっと調べたいと思います。

pysparkでの開発時に気になった点のメモでした。

Sparkで単体テストをしてみる

Apache Sparkで単体テストをしてみる

Intelij IDEAでsparkの単体テストを書いてみたのでメモ

build.sbtの設定を変更

まず、build.sbtに以下の設定を追加する。

parallelExecution in Test := false

“build sbt"で複数のテストが同時に動いた場合に発生するSparkContext周りのエラーを防ぐのに必要なようである。

テストを書いてみる

まず、以下のようにcsvをDataFrameとして読み込んでデータを取得するclassのテストを書く場合

package intoroduction.spark.dataframe

import org.apache.spark.sql.SQLContext
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.sql.types.DataType._
import org.apache.spark.sql.types.IntegerType

case class Dessert(menuId: String, name: String, price: Int, kcal: Int)

class DesertFrame(sc: SparkContext, sqlContext: SQLContext, filePath: String) {
  import sqlContext.implicits._
  lazy val dessertDF = {
    val dessertRDD = sc.textFile(filePath)
    sc.textFile(filePath)
    // データフレームとして読み込む
    dessertRDD.map { record =>
      val splitRecord = record.split(",")
      val menuId = splitRecord(0)
      val name = splitRecord(1)
      val price = splitRecord(2).toInt
      val kcal = splitRecord(3).toInt
      Dessert(menuId, name, price, kcal)
    }.toDF
  }
  dessertDF.createOrReplaceTempView("desert_table")


  def findByMenuId(menuId: String) = {
    dessertDF.where(dessertDF("menuId") === menuId)
  }
}


object DesertFrame {

  def main(args: Array[String]): Unit ={

    val conf = new SparkConf().setAppName("DesertFrame").setMaster("local[*]")
    val sc = new SparkContext(conf)
    val sqlContext = new SQLContext(sc)
    import sqlContext.implicits._

    val filePath = "src/test/resources/data/dessert-menu.csv"
    val desertFrame = new DesertFrame(sc, sqlContext, filePath)

    val d19DF = desertFrame.findByMenuId("D-19").head
    print(d19DF)

  }
}

上記のDesertFrameのテストを書く場合は以下のようになる。

package intoroduction.spark.dataframe

import org.apache.spark.sql.SQLContext
import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.{BeforeAndAfterAll, FunSuite}

class DessertFrameTest extends FunSuite with BeforeAndAfterAll{
  private var sparkConf: SparkConf = _
  private var sc: SparkContext = _
  private var sqlContext: SQLContext = _

  override def beforeAll() {
    print("before...")
    sparkConf = new SparkConf().setAppName("DessertFrameTest").setMaster("local")
    sc = new SparkContext(sparkConf)
    sqlContext = new SQLContext(sc)
  }

  override def afterAll() {
    sc.stop()
    print("after...")
  }

  test("dessert_frame"){
    val filePath = "src/test/resources/data/dessert-menu.csv"
    val desertFrame = new DesertFrame(sc, sqlContext, filePath)

    val d19DF = desertFrame.findByMenuId("D-19").head
    assert(d19DF.get(0) == "D-19")
    assert(d19DF.get(1) == "キャラメルロール")
    assert(d19DF.get(2) == 370)
    assert(d19DF.get(3) == 230)
  }
}

ここでは"SparkConf().setAppName(“DessertFrameTest”).setMaster(“local”)“と指定しており、ローカルの環境で動かすことができるようになりテストで使うデータを"src/test/resources/data/dessert-menu.csv"にしているのでテストデータもそのままgitで管理できるようになっている。

テスト実行

あとは"sbt test"か"sbt test:testOnly クラス指定"でテストを実行できるはずである。