Scalaで作って理解するモナド

Haskellでのモナド型クラスが有名で、Haskell創始者のひとりであるフィリップ・ワドラーはモナドについて"自己関手の圏におけるモノイド対象"と答えているらしい。そのほかにも調べているとモナドは箱に入った値を返す関数を箱の中の値に適用するといったものもあり、大体が数学の専門的な内容であったり抽象的な表現だったりするの勉強を始めたばかりの人には分かりづらい印象です。ScalaではOptionやFutureがよく使われるモナドだと思います。自分がScalaを勉強し始めたころはOptionについてはそんなに抵抗がありませんでしたが、Webアプリを作るときとかに出てくるFutureはfor式などで出てきてややこしかった気がします。

今回は自分でモナドインスタンスを実装してみてfor式に対する抵抗がなくなることを目指したいと思います。

モナドの定義

ScalaHaskellモナド型クラスのようなtraitがあってそれを実装すればよいというわけではないので、Scala関数型デザインの内容を参考に実装してみたいと思います。

Functorトレイト

MonadトレイトはFunctorトレイトを拡張しているので、まずはFunctorトレイトを作成します。

trait Functor[F[_]] {
  def map[A, B](fa: F[A])(f: A => B): F[B]
}

Functorトレイトではmap関数一つのみ定義されていまして、内容としては F[A]型A型を受け取ってB型を返す関数 の2つを受け取って F[B]型 を返すといったものになっています。これに対して自作のオプション型がファンクターとなるようにしてみます。まず自作の型を以下のように定義します。

sealed trait MyOption[+A] {}
case class MySome[+A](get: A) extends MyOption[A]
case object MyNone extends MyOption[Nothing]

これに対してFunctorのインスタンスを実装してみます。

val optionFuncotor = new Functor[MyOption] {
  def map[A, B](fa: MyOption[A])(f: A => B): MyOption[B] = fa match {
    case MyNone => MyNone
    case MySome(a) => MySome(f(a))
  }
}

内容としては単純でmap関数の引数の MyNone であれば常に MyNone を返し、 MySome[A]型 であれば受け取った関数をMySomeの中身に適用して結果をMySomeでくるむといったものになります。 動きを見ていると大丈夫そうな感じです。

println(
  optionFuncotor.map(MySome(1))
  (a => a + 2)
)

MySome(3)

MonadはFunctorを拡張したものになりますが、それではFunctorになにが足りないかといいますと複数の値を扱うとき想定外の結果になるといったものがあります。例えば以下を実行してみたとき

println(
  optionFuncotor.map(MySome(1))
  (a =>
    optionFuncotor.map(MySome(2))
    (b =>
      optionFuncotor.map(MySome(3))
      (c => a + b + c)
    ))
)

MySome(MySome(MySome(6)))

結果として MySome(6) を期待すると思いますが結果は MySome(MySome(MySome(6))) になってしまいました。

Monadトレイト

それではFunctorで問題となっていた複数の値を扱うときの動きを改善するためのMonadトレイトを実装してみたいと思います。Monadトレイトの定義は以下のようになります。

trait Monad[F[_]] extends Functor[F] {
  def unit[A](a: => A): F[A]
  def flatMap[A, B](ma: F[A])(f: A => F[B]): F[B]
  def map[A, B](ma: F[A])(f: A => B): F[B] =
    flatMap(ma)(a => unit(f(a)))
}

MonadoトレイトはFunctorトレイトを拡張したものになっていまして unit関数flatMap関数 が新たに追加されています。Functorの方にあった map 関数については flatMap を使ったデフォルト定義がされています。まず unit関数 ですが内容としては引数で受け取った => A 型F で包むといっただけのものです。このF型の部分は先ほどのMyOption型のようなもので、引数として2や3を受け取ったらMySome(2)、MySome(3)を返すだけです。次に flatMap関数 ですが、これがFunctorでの問題を改善するものになっていまして、ないよとしては引数として受け取る関数が A => B 型 から A => F[B] 型 に代わっています。 このMonadトレイトを先ほどのMyOption型で実装すると以下のようになります。

val optionMonad = new Monad[MyOption] {
  def unit[A](a: => A) = MySome(a)
  override def flatMap[A, B](ma: MyOption[A])(f: A => MyOption[B]): MyOption[B] =
    ma match {
      case MyNone => MyNone
      case a: MySome[A] => f(a.a)
    }
}

それでは、Functorで問題であった複数の値を使った時の動きを見てみます。

println(
  optionMonad.flatMap(MySome(1))
  (a =>
    optionMonad.flatMap(MySome(2))
    (b =>
      optionMonad.map(MySome(3))
      (c => a + b + c)
    ))
)

MySome(6)

結果としては想定通り MySome(6) となりました。ここでのポイントとして最後のみ map を使い、それ以外は flatMap を使っているというものです。FunctorとMonadインスタンスを実際に実装して並べてみるとMonadの方が処理をつないでいくときに向いているのが実感てきたかと思います。

for式

モナドの方が処理をつないでいきやすいとしましたが、このままであれば optionMonad.flatMap がネストして続いていって大分見づらくなります。Scalaではfor式を使うことで見やすく書くことができます。ただScala自体にはMonadのトレイトが存在するわけでもなく対象のクラスがmap, flatMapを実装されていることでfor式で書くことが出来るようになります。それでは、今まで使っていたMyOptionに対してmap, flatMapを実装してみます。

object MyOption {
  def apply[A](x: A): MyOption[A] = if (x == null) MyNone else MySome(x)
  def unit[A](a: => A) = apply(a)
  def empty[A]: MyOption[A] = MyNone
}
sealed abstract class MyOption[+A] {
  def get: A
  def isEmpty: Boolean = this eq MyNone
  def isDefined: Boolean = !isEmpty
  def map[B](f: A => B): MyOption[B] =
    if (isEmpty) MyNone else MySome(f(this.get))
  def flatMap[B](f: A => MyOption[B]): MyOption[B] =
    if (isEmpty) MyNone else f(this.get)
}
final case class MySome[+A](value: A) extends MyOption[A] {
  def get: A = value
}
case object MyNone extends MyOption[Nothing] {
  def get: Nothing = throw new NoSuchElementException("None.get")
}

今までであれば、以下のような書き方をしていました。

println(MyOption(1).flatMap(a =>
  MyOption(2).flatMap(b =>
    MyOption(3).map(c =>
      a + b + c
    ))))

単純にクラスの関数にmap, flatMapがあるのでモナドインスタンス経由で関数を呼び出すのに比べて見やすくなっていますが、ここではfor式を使って更に見やすい書き方を確認してみます。

println(for {
  a <- MyOption(1)
  b <- MyOption(2)
  c <- MyOption(3)
} yield a + b + c)

内容は同じもので、for式の方も実際にはflatMap, map関数の呼び出しに変換されます。比べてみるとfor式の方が見やすいことがわかります。これでfor式がどんな風に動いているかイメージできたかと思います。今回はOptionだけでしたがFutureであれば、モナドにより複数の処理を合成したりしやすいといったイメージがわきそうです。

自分でモナドを実装するときはモナド則を満たしているかのチェックも必要だと思いますが、今回は省略しておきます。

scalaで使われるモナド

これまではOptionについてばかり説明していましたが、それだけではMonadを使うことで何がうれしいのかあまり実感できないと思います。ここではScalaで使われるモナドの名前だけ出しておきます。

これらを利用することで処理の詳細を分離することができ、柔軟な実装にも役に立つとのことらしいです。