2008年10月17日金曜日

Scalaのリスト-高階操作-

リスト操作でよくあるパターンは、以下のようなものらしい。
  1. リストのすべての要素を何かしらの方法で変換する
  2. 何かしらの基準を満たす要素を取り出す
  3. 何かの演算子を使ってリスト要素を結合する
関数型言語では、こういうのは高階関数を使って実現する。

Map
リストのすべての要素を何かしらの方法で変換する操作は、Mapと言う。MapReduceのMapと同じだよね、たぶん。例えば、DoubleのListのすべての要素に、何かしらの数を掛け合わせたい場合を考えてみる。
def scaleList(xs: List[Double], factor: Double): List[Double] = xs match {
  case Nil => xs
  case x :: xs1 => x * factor :: scaleList(xs1, factor)
}
通常ならこう書くが、共通部分を抽出すればmap関数として定義できる。これは実際にListのメソッドとして定義されている。
abstract class List[A] { ...
  def map[B](f: A => B): List[B] = this match {
  case Nil => this
  case x :: xs => f(x) :: xs.map(f)
}
このmapメソッドを使って先ほどの関数を書き直すと、こんな風になる。簡単。
def scaleList(xs: List[Double], factor: Double) = xs map (x => x * factor)

Filter
何かしらの基準を満たすList要素を取り出す操作は、Filterと言う。例えば、IntのListから0より大きな要素だけ取り出す関数は、以下のように書ける。
def posElems(xs: List[Int]): List[Int] = xs match {
  case Nil => xs
  case x :: xs1 => if (x > 0) x :: posElems(xs1) else posElems(xs1)
}
これらFilter操作の共通的な要素を抜き出すと、このようになる。これも、Listのメソッド。
def filter(p: A => Boolean): List[A] = this match {
  case Nil => this
  case x :: xs => if (p(x)) x :: xs.filter(p) else xs.filter(p)
}
filterメソッドを使えば、postElemsは以下のように書ける。
def posElems(xs: List[Int]): List[Int] = xs filter (x => x > 0)
似たようなものとして、forallとexistsがある。
def forall(p: A => Boolean): Boolean = isEmpty || (p(head) && (tail forall p))
def exists(p: A => Boolean): Boolean = !isEmpty && (p(head) || (tail exists p))
素数かどうか判定するプログラムとか、かなり素敵。
package scala
object List { ...
  def range(from: Int, end: Int): List[Int] = if (from >= end) Nil else from :: range(from + 1, end)

def isPrime(n: Int) = List.range(2, n) forall (x => n % x != 0)

Fold/Reduce
これは、リスト内のすべての要素を何かしらの演算で結合するというもの。例えば、リスト要素をすべて足し合わせるとか、すべて掛け合わせるとか。これは、以下のように書くことができる。
def sum(xs: List[Int]): Int = xs match {
  case Nil => 0
  case y :: ys => y + sum(ys)
}
def product(xs: List[Int]): Int = xs match {
  case Nil => 1
  case y :: ys => y * product(ys)
}
こういう操作は、ListのreduceLeftメソッドがうまくやれる。

List(x1, ..., xn).reduceLeft(op) = (...(x1 op x2) op ... ) op xn

上記の関数をreduceLeftを使って書き換えると、こうなる。

def sum(xs: List[Int]) = (0 :: xs) reduceLeft {(x, y) => x + y}
def product(xs: List[Int]) = (1 :: xs) reduceLeft {(x, y) => x * y}
ScalaライブラリのreduceLeftの実装はこのようになっていて、foldLeftを使っている。
def reduceLeft(op: (A, A) => A): A = this match {
  case Nil => error("Nil.reduceLeft")
  case x :: xs => (xs foldLeft x)(op)
}
def foldLeft[B](z: B)(op: (B, A) => B): B = this match {
  case Nil => z
  case x :: xs => (xs foldLeft op(z, x))(op)
}
Foldは少し分かりづらいけど、初期値を渡せるreduceのような感じ。同じようにreduceRightやfoldRightがあるけど、適用順序が右からになるだけなので省略。foldLeft・foldRightは別名があり、それぞれ「/:」「:\」とも書ける。

Nested Mapping
例えば、与えられたnに対して 1 ≦ j < i < n かつ i + jが素数になるようなすべてのiとjの組み合わせを探したいとする。具体的に言うと、n=7の場合には、(2, 1), (3, 2), (4, 1), (4, 3), (5, 2), (6, 1), (6, 5)となる。 
これを実現するためには、まずはnより小さなすべての(i, j)の組み合わせを作り出し、それに対してFilterをかければよい。
ここでは、どうやってn以下のすべての(i, j)の組み合わせを作るかということにフォーカスするが、これは以下のようにすれば作ることができる。
List.range(1, n)
.map(i => List.range(1, i).map(x => (i, x)))
.foldRight(List[(Int, Int)]()) {(xs, ys) => xs ::: ys}
.filter(pair => isPrime(pair._1 + pair._2))
何をやっているかというと…
  1. まず最初のList.range(1, n)で1からnまでのListを作る。
  2. 次の.map(i => List.range(1, i)でList(1..n)の要素をList(1), List(1, 2), List(1, 2, 3)...というように置き換えたListを作る。この時点で、Listの中にListが入ったネスト構造になっている。
  3. 次の.map(x => (i, x))でList内のListに対してList(1) -> List((1, 1)) List(1, 2) -> List((2, 1), (2, 2))...というように、タプルに置き換えていく。この段階で、List(List((1,1)), List((2, 1), (2, 2)), List((3, 1), (3, 2), (3, 3)), ....)のような状態になっている。
  4. 最後に、foldRightを使って内部のListを連結してやる。
  5. めでたくList((1, 1), (2, 1), (2, 2), (3, 1), (3, 2), (3, 3), ....)という構造物が出来る。
という感じなのだが、こんなの自分じゃ思いつかない…やっぱこういうのは数学…というより数字に強い人が得意なんだろうか。まぁ、こんなこともできるよ、ということで。

Flattering Maps
Map+内部Listの連結というのはよく使う組み合わせらしく、既にメソッドが用意されている。それがflatMap。flatMapを使えば、先ほどのメソッドは
List.range(1, n)
.flatMap(i => List.range(1, i).map(x => (i, x)))
.filter(pair => isPrime(pair._1 + pair._2))
こうなる。map操作が、List内の各要素に対してListを返す場合に有効ということらしい。それが"よくある"と言われてもよくわからないのは、自分がまだ関数型言語に慣れていないせいなんだろうな。

配列との比較、再び
手続き型で一般的に使われる配列とは違って、Listのインデックスアクセスは(applyで可能であるにも関わらず)滅多に使わない。これは、Listのインデックスアクセスが非効率であるということに加えて、用意されているメソッド群を組み合わせて操作した方が便利だから。

確かに、ScalaのListに幾分慣れた後でJavaの(Array)ListをIteratorやindex使ってアクセスしたりFilterしたりしているのを見ると、イライラしてくる。Commons CollectionsのCollectionUtils.filterやCollectionUtils.forAllDoを使えば似たようなことはできるんだけど、どちらもList自体に変更を加えちゃうし、戻り値はvoidだしでいまいち使いづらい。

0 件のコメント: