Sparkで単体テストをしてみる

Apache Sparkで単体テストをしてみる

Intelij IDEAでsparkの単体テストを書いてみたのでメモ

build.sbtの設定を変更

まず、build.sbtに以下の設定を追加する。

parallelExecution in Test := false

“build sbt"で複数のテストが同時に動いた場合に発生するSparkContext周りのエラーを防ぐのに必要なようである。

テストを書いてみる

まず、以下のようにcsvをDataFrameとして読み込んでデータを取得するclassのテストを書く場合

package intoroduction.spark.dataframe

import org.apache.spark.sql.SQLContext
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.sql.types.DataType._
import org.apache.spark.sql.types.IntegerType

case class Dessert(menuId: String, name: String, price: Int, kcal: Int)

class DesertFrame(sc: SparkContext, sqlContext: SQLContext, filePath: String) {
  import sqlContext.implicits._
  lazy val dessertDF = {
    val dessertRDD = sc.textFile(filePath)
    sc.textFile(filePath)
    // データフレームとして読み込む
    dessertRDD.map { record =>
      val splitRecord = record.split(",")
      val menuId = splitRecord(0)
      val name = splitRecord(1)
      val price = splitRecord(2).toInt
      val kcal = splitRecord(3).toInt
      Dessert(menuId, name, price, kcal)
    }.toDF
  }
  dessertDF.createOrReplaceTempView("desert_table")


  def findByMenuId(menuId: String) = {
    dessertDF.where(dessertDF("menuId") === menuId)
  }
}


object DesertFrame {

  def main(args: Array[String]): Unit ={

    val conf = new SparkConf().setAppName("DesertFrame").setMaster("local[*]")
    val sc = new SparkContext(conf)
    val sqlContext = new SQLContext(sc)
    import sqlContext.implicits._

    val filePath = "src/test/resources/data/dessert-menu.csv"
    val desertFrame = new DesertFrame(sc, sqlContext, filePath)

    val d19DF = desertFrame.findByMenuId("D-19").head
    print(d19DF)

  }
}

上記のDesertFrameのテストを書く場合は以下のようになる。

package intoroduction.spark.dataframe

import org.apache.spark.sql.SQLContext
import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.{BeforeAndAfterAll, FunSuite}

class DessertFrameTest extends FunSuite with BeforeAndAfterAll{
  private var sparkConf: SparkConf = _
  private var sc: SparkContext = _
  private var sqlContext: SQLContext = _

  override def beforeAll() {
    print("before...")
    sparkConf = new SparkConf().setAppName("DessertFrameTest").setMaster("local")
    sc = new SparkContext(sparkConf)
    sqlContext = new SQLContext(sc)
  }

  override def afterAll() {
    sc.stop()
    print("after...")
  }

  test("dessert_frame"){
    val filePath = "src/test/resources/data/dessert-menu.csv"
    val desertFrame = new DesertFrame(sc, sqlContext, filePath)

    val d19DF = desertFrame.findByMenuId("D-19").head
    assert(d19DF.get(0) == "D-19")
    assert(d19DF.get(1) == "キャラメルロール")
    assert(d19DF.get(2) == 370)
    assert(d19DF.get(3) == 230)
  }
}

ここでは"SparkConf().setAppName(“DessertFrameTest”).setMaster(“local”)“と指定しており、ローカルの環境で動かすことができるようになりテストで使うデータを"src/test/resources/data/dessert-menu.csv"にしているのでテストデータもそのままgitで管理できるようになっている。

テスト実行

あとは"sbt test"か"sbt test:testOnly クラス指定"でテストを実行できるはずである。