SCombのコードを読み込んでみる

SCombとはScalaで作られたパーサコンビネータライブラリで実装はこちらになります。ソースコードの行数が600行行程度で読みやすいので、scalaのパーサコンビネータの実装の学習のために手を動かしながらソースコードを読み込んでみたいと思います。

まずはテキストの判定をできるようにする。

いきなり全体のソースを読み込むのは大変なので、最小の機能を動かす部分だけ確認してみたいと思います。ということで、テキストの判定を行えるようにするために必要な部分だけ抽出すると以下のようになりました。

Location.scala

case class Location(line: Int, column: Int)

SCombinator.scala

import scala.collection.mutable

trait SCombinator { self =>
  type P[+T] = Parser[T]

  protected var input: String = ""

  protected var recent: Option[Failure] = None

  private[this] val DefaultLabel: String = "fail"

  protected final def isEOF(index: Int): Boolean = index >= input.length

  protected final def current(index: Int): String = input.substring(index)

  protected val locations: mutable.Map[Int, Location] = mutable.Map[Int, Location]()

  sealed abstract class ParseResult[+T] {
    def index: Int
    def value: Option[T]
  }

  case class Success[+T](semanticValue: T, override val index: Int) extends ParseResult[T] {
    override def value: Option[T] = Some(semanticValue)
  }
  sealed abstract class ParseNonSuccess extends ParseResult[Nothing] {
    def message: String
  }
  case class Error(override val message: String, override val index: Int) extends ParseNonSuccess {
    override def value: Option[Nothing] = None
  }
  case class Failure(override val message: String, override val index: Int, label: String) extends ParseNonSuccess {
    self.recent match {
      case None => self.recent = Some(this)
      case Some(failure) if index >= failure.index => self.recent = Some(this)
      case _ => // Do nothing
    }
    override def value: Option[Nothing] = None
  }
  object Failure {
    def apply(message: String, index: Int): Failure = Failure(message, index, DefaultLabel)
  }

  final def parsePartial[R](rule: Parser[R], input: String): ParseResult[R] = synchronized {
    this.input = input
    this.recent = None
    this.locations.clear()
    calculateLocations()
    rule(0) match {
      case s@Success(_, _) => s
      case f@Failure(_, _, label) =>
        if(label == DefaultLabel) {
          this.recent.get
        } else {
          f
        }
      case f@Error(_, _) =>
        f
    }
  }

  protected final def calculateLocations(): Unit = {
    var i: Int = 0
    var line: Int = 1
    var column: Int = 1
    val chars = input.toCharArray
    while(i < chars.length) {
      val ch = chars(i)
      ch match {
        case '\n' =>
          locations(i) = Location(line, column)
          line += 1
          column = 1
          i += 1
        case '\r' =>
          if(i == chars.length - 1) {
            locations(i) = Location(line, column)
            line += 1
            column = 1
            i += 1
          } else {
            locations(i) = Location(line, column)
            if(chars(i + 1) == '\n') {
              locations(i + 1) = Location(line, column + 1)
              line += 1
              column = 1
              i += 2
            } else {
              line += 1
              column = 1
              i += 1
            }
          }
        case _ =>
          locations(i) = Location(line, column)
          column += 1
          i += 1
      }
    }
    locations(i) = Location(line, column)
  }

  def parserOf[T](function: Int => ParseResult[T]): Parser[T] = new Parser[T] {
    override def apply(index: Int): ParseResult[T] = function(index)
  }

  abstract class Parser[+T] extends (Int => ParseResult[T]) {

    def apply(index: Int): ParseResult[T]

  }

  final def string(literal: String): Parser[String] = parserOf{index =>
    if(literal.length > 0 && isEOF(index)) {
      Failure(s"expected:`${literal}` actual:EOF", index)
    } else if(current(index).startsWith(literal)) {
      Success(literal, index + literal.length)
    } else {
      Failure(s"expected:`${literal}` actual:`${current(index)(0)}`", index)
    }
  }

  final def $(literal: String): Parser[String] = string(literal)

}

パース対象の文字列の座標を表すのには Locationクラス を使用しています。それからパーサコンビネーターはトレイトなので、利用するときは object P1 extends SCombinator とSCombinatorを継承しています。パース対象の文字列や現在のパース結果、パース完了しているかどうか、現在のパース対象の座標など管理しています。

  protected var input: String = ""

  protected var recent: Option[Failure] = None

  private[this] val DefaultLabel: String = "fail"

  protected final def isEOF(index: Int): Boolean = index >= input.length

  protected final def current(index: Int): String = input.substring(index)

  protected val locations: mutable.Map[Int, Location] = mutable.Map[Int, Location]()

パース結果を表すために、ParseResult, Success, ParseNonSuccess, Error, Failureを用意しています。ParseNonSuccessではトレイトのメンバ変数にアクセスするため、 trait SCombinator { self => としてselfでアクセスできるようにしています。

  sealed abstract class ParseResult[+T] {
    def index: Int
    def value: Option[T]
  }

  case class Success[+T](semanticValue: T, override val index: Int) extends ParseResult[T] {
    override def value: Option[T] = Some(semanticValue)
  }
  sealed abstract class ParseNonSuccess extends ParseResult[Nothing] {
    def message: String
  }
  case class Error(override val message: String, override val index: Int) extends ParseNonSuccess {
    override def value: Option[Nothing] = None
  }
  case class Failure(override val message: String, override val index: Int, label: String) extends ParseNonSuccess {
    self.recent match {
      case None => self.recent = Some(this)
      case Some(failure) if index >= failure.index => self.recent = Some(this)
      case _ => // Do nothing
    }
    override def value: Option[Nothing] = None
  }
  object Failure {
    def apply(message: String, index: Int): Failure = Failure(message, index, DefaultLabel)
  }

パーサーに対してパース対象の文字の座標を与えると、事前にパースコンビネータに渡していた関数を実行してParseResult型を返すようにします。

  abstract class Parser[+T] extends (Int => ParseResult[T]) {

    def apply(index: Int): ParseResult[T]

  }

  def parserOf[T](function: Int => ParseResult[T]): Parser[T] = new Parser[T] {
    override def apply(index: Int): ParseResult[T] = function(index)
  }

文字列をパースするためのパーサを初期化するための関数は以下のようになります。

  final def string(literal: String): Parser[String] = parserOf{index =>
    if(literal.length > 0 && isEOF(index)) {
      Failure(s"expected:`${literal}` actual:EOF", index)
    } else if(current(index).startsWith(literal)) {
      Success(literal, index + literal.length)
    } else {
      Failure(s"expected:`${literal}` actual:`${current(index)(0)}`", index)
    }
  }

  final def $(literal: String): Parser[String] = string(literal)

それから、実際にパースを実行する部分とパース実行前のパース対象の文字座標情報を計算するのが以下の部分になります。パースをするときはtrait内のメンバ変数に対してアクセス、更新しておりパース中に別の文字列のパースが始まると困るのでsynchronizedをつけています。

  final def parsePartial[R](rule: Parser[R], input: String): ParseResult[R] = synchronized {
    this.input = input
    this.recent = None
    this.locations.clear()
    calculateLocations()
    rule(0) match {
      case s@Success(_, _) => s
      case f@Failure(_, _, label) =>
        if(label == DefaultLabel) {
          this.recent.get
        } else {
          f
        }
      case f@Error(_, _) =>
        f
    }
  }

  protected final def calculateLocations(): Unit = {
    var i: Int = 0
    var line: Int = 1
    var column: Int = 1
    val chars = input.toCharArray
    while(i < chars.length) {
      val ch = chars(i)
      ch match {
        case '\n' =>
          locations(i) = Location(line, column)
          line += 1
          column = 1
          i += 1
        case '\r' =>
          if(i == chars.length - 1) {
            locations(i) = Location(line, column)
            line += 1
            column = 1
            i += 1
          } else {
            locations(i) = Location(line, column)
            if(chars(i + 1) == '\n') {
              locations(i + 1) = Location(line, column + 1)
              line += 1
              column = 1
              i += 2
            } else {
              line += 1
              column = 1
              i += 1
            }
          }
        case _ =>
          locations(i) = Location(line, column)
          column += 1
          i += 1
      }
    }
    locations(i) = Location(line, column)
  }

それから type P[+T] = Parser[T] と型エイリアスを定義しています。 テキスト判定のパーサの動作確認は、以下のテストコードで確認できます。

class MySpec extends FunSpec with DiagrammedAssertions {
  import jp.co.teruuu.scomb._


  object P1 extends SCombinator {
    def root: P[String] = $("")
  }
  object P2 extends SCombinator {
    def root: P[String] = $("H")
  }
  object P3 extends SCombinator {
    def root: P[String] = $("Hello")
  }


  describe("test strParse") {
    it("""$("") and string("") always succeed""") {
      P1.parsePartial(P1.root, "") match {
        case P1.Success(v, index) =>
          assert("" == v)
          assert(0 == index)
        case _ =>
          assert(false)
      }
    }

    it("""$("H") succeed for string starts with 'H'""") {
      P2.parsePartial(P2.root, "H") match {
        case P2.Success(v, index) =>
          assert("H" == v)
          assert(1 == index)
        case _ =>
          assert(false)
      }
      P2.parsePartial(P2.root, "Hello") match {
        case P2.Success(v, index) =>
          assert("H" == v)
          assert(1 == index)
        case _ =>
          assert(false)
      }
      P2.parsePartial(P2.root, "I") match {
        case P2.Success(_, _) =>
          assert(false)
        case n:P2.ParseNonSuccess =>
          assert(0 == n.index)
          assert(None == n.value)
      }
    }
  }

  it("""$("Hello") succeed for string starts with 'Hello'""") {
    P3.parsePartial(P3.root, "Hello") match {
      case P3.Success(v, index) =>
        assert("Hello" == v)
        assert(5 == index)
      case _ => assert(false)
    }
    P3.parsePartial(P3.root, "H") match {
      case P3.Failure(message, index, label) =>
        assert(index == 0)
      case _ => assert(false)
    }
  }

}

複数のパーサを連結する

次に複数のパーサを連結できるようにします。パース結果の値としてトレイト内に以下のクラスを定義します。

case class ~[+A, +B](a: A, b: B)

それから Parser クラスに対して以下の~メソッドを追加します。~メソッドは2つのパースが成功した場合に~型の値を返すようにしています。

  abstract class Parser[+T] extends (Int => ParseResult[T]) {
    def apply(index: Int): ParseResult[T]

    def ~[U](right: Parser[U]) : Parser[T ~ U] = parserOf{index =>
      this(index) match {
        case Success(value1, next1) =>
          right(next1) match {
            case Success(value2, next2) =>
              Success(new ~(value1, value2), next2)
            case failure@Failure(_, _, _) =>
              failure
            case fatal@Error(_, _) =>
              fatal
          }
        case failure@Failure(_, _, _) =>
          failure
        case fatal@Error(_, _) =>
          fatal
      }
    }
  }

これに対して、以下のテストコードで動作確認ができますがトレイトの外側では~型のメンバにアクセスできないのでパース後の値は分からない状態です。

  object P4 extends SCombinator {
    def hello: P[String] = $("Hello")
    def world: P[String] = $("World")
    def root: P[String ~ String] = hello ~ world
  }

  it("""$("Hello") ~ $("World") succeed for string starts with 'HelloWorld'""") {
    P4.parsePartial(P4.root, "HelloWorld") match {
      case P4.Success(_, index) =>
        assert(index == 10)
      case _ => assert(false)
    }
  }

パース後の結果を扱いやすいように変換する

先ほどのパースの結果を扱いやすくするため、パースの結果を変換できるようにしてみます。パース後の結果の変換のために Parser クラスにmapメソッドを定義します。

    def map[U](function: T => U): Parser[U] = parserOf{index =>
      this(index) match {
        case Success(value, next) => Success(function(value), next)
        case failure@Failure(_, _, _) => failure
        case fatal@Error(_, _) => fatal
      }
    }
    
    def ^^[U](function: T => U): Parser[U] = map(function)

それから、以下のようにパース結果の値を変換することで評価できるようになることを確認できます。

  object P5 extends SCombinator {
    def hello: P[String] = $("Hello")
    def world: P[String] = $("World")
    def root: P[String] = (hello ~ world) ^^ {
      case (a ~ b) => a + b
    }
  }
  P5.parsePartial(P5.root, "Hello") match {
    case P5.Failure(_, index, _) =>
      assert(index == 5)
    case _ => assert(false)
  }  

0回以上の繰り返しをパースできるようにする

0回以上の繰り返しをパースできるようにするため Parseクラス に以下の関数を定義しています。

def * : Parser[List[T]] = parserOf{index =>
    def repeat(index: Int): ParseResult[List[T]] = this(index) match {
    case Success(value, next1) =>
        repeat(next1) match {
        case Success(result, next2) =>
            Success(value::result, next2)
        case r => throw new RuntimeException("cannot be " + r)
        }
    case Failure(message, next, DefaultLabel) =>
        Success(Nil, index)
    case failure@Failure(message, next, label) =>
        failure
    case f@Error(_, _) =>
        f
    }
    repeat(index) match {
    case r@Success(_, _) => r
    case r:ParseNonSuccess => r
    }
}

この関数を使うことで以下のようにスペースの繰り返しもパースできるようになります。

  object P6 extends SCombinator {
    def hello: P[String] = $("Hello")
    def space: P[List[String]] = $(" ").*
    def world: P[String] = $("World")
    def root: P[String] = (hello ~ space ~ world) ^^ {
      case (a ~ b ~ c) => a + b.mkString("") + c
    }
  }

  it("""$("Hello") ~ $("World")3 succeed for string starts with 'HelloWorld'""") {
    P6.parsePartial(P6.root, "Hello  World") match {
      case P6.Success(v, index) =>
        assert(v == "Hello  World")
        assert(index == 12)
      case _ => assert(false)
    }

    P6.parsePartial(P6.root, "HelloWorld") match {
      case P6.Success(v, index) =>
        assert(v == "HelloWorld")
        assert(index == 10)
      case _ =>
        assert(false)
    }
  }

複数のパーサの連続をfor式で扱えるようにする

先ほどまでは複数のパーサの連続は~メソッドを使うようにしていたのですが、連続するパーサが多くなってくると分かりづらくなってきそうなのでfor式で記述できるようにします。 flatMapは以下のようにパース後の値を受け取ったらパーサを返すように定義してあげればよいので

  object P7 extends SCombinator {
    def hello: P[String] = $("Hello")
    def space: P[List[String]] = $(" ").*
    def world: P[String] = $("World")

    def root: P[String] =
      hello.flatMap(a =>
        space.flatMap(b =>
          world.map(c =>
            a + b.mkString("") + c)))

Parserクラスに対して以下のようにflatMapを定義します。

def flatMap[U](function: T => Parser[U]): Parser[U] = parserOf{index =>
    this(index) match {
    case Success(value, next) =>
        function.apply(value).apply(next)
    case failure@Failure(_, _, _) =>
        failure
    case fatal@Error(_, _)=>
        fatal
    }
}

実際のfor式の利用は以下のようになります

  object P7 extends SCombinator {
    def hello: P[String] = $("Hello")
    def space: P[List[String]] = $(" ").*
    def world: P[String] = $("World")

    def root: P[String] = for{
      a <- hello
      b <- space
      c <- world
    } yield a + b.mkString("") + c
  }

長くなったのでここで終了しますが、元のコードも読みやすく大まかな実装の方針も把握できたと思うのでどうにか読み進めて行けそうです。