PR

融合変換で不要な同期を除去する

 ここまでは,単一の処理を並列化することだけを考えてきました。実際の並列化では,複数の別の処理を並列に実行することで全体の処理性能の向上を図るといったこともよく行われます。しかし,個々の処理にとって最適な並列化が,全体の処理にとって最適だとは限りません。複数の並列処理を組み合わせる際の境界部分に,不要な同期が入り込んでしまうことがあるからです。

 並列処理は複数のスレッドやCPUにまたがって行われるため,次の計算を行うために必要な処理が終了しているかどうかはわかりません。このため,前の計算処理が終了しているかどうかを確認し,前の計算処理が終わっていない場合には計算が終了するまで待つような同期処理が必要です。並列Haskellやデータ並列Haskellではこうした同期はプログラマからは隠ぺいされているものの,内部では同期処理を行っています。こうした同期処理は必ずしもプログラム全体にとって好ましい形で行われるわけではありません。

 例えば,map関数を並列処理するよう定義されたmapP関数があるとしましょう。mapP関数の定義によっては,「mapP f . mapP g」のような単純な処理でも,「mapP g」の並列処理の結果をいったん集めるために同期し,それから「mapP f」の処理を行う,という愚直な同期が行われる可能性があります。第37回で説明した中間データ構造の問題に類似した問題が発生するのです。「mapP f . mapP g」という処理は,本来は「mapP (f . g)」という中間的な同期を伴わない形で処理すべきです。

 並列Haskellを使う場合には,個々の処理での並列化を避けて,全体に影響するような処理を並列評価するように記述することで,この問題をある程度回避できます。

 例えば以下の式では,「map g xs」と,「map g xs」の結果であるysを利用する「map f ys」に対し,別々に「parListChunk n strat」を使って並列評価を行っています。これは望ましいプログラムではありません。

let ys = map g xs `using` parListChunk n strat
in  map f ys `using` parListChunk n strat

 このプログラムは以下のように書き直すべきでしょう。個々の処理を別々に並列評価するのではなく,「map g」と「map f」の両方を適用する処理全体に対して「parListChunk n strat」を使って並列評価を行うのです。

(map f . map g) xs `using` parListChunk n strat

 ただし,この方法は並列Haskell以外ではほとんど利用できません。たいていの場合,プログラムを並列化するための方法は,並列化されたプログラム自体に組み込まれているからです。並列HaskellでControl.Parallel.Strategiesモジュールの関数を利用する場合のように,プログラムの評価戦略を後から差し替えられるような構造になっているケースはあまりありません。また,並列化したい処理がいつもこのように整った形になっていることも期待できません。

 より広く利用できる解決方法は,不要な同期処理を除去するための融合変換を書き換え規則によって用意することです。ちょうど,第37回から第42回にかけて,融合変換を使って中間データ構造を除去したのと同様の手法を利用するのです。

 第37回で説明した組み合わせ爆発を防ぐため,アドホックな融合変換の定義を避けて,処理の並列化と同期を行う基本構造に対して融合変換を定義します。融合変換で不要な同期処理を除去している例として,第10回および第24回のコラムで紹介したデータ並列HaskellのDPHライブラリで使われている融合変換を見てみましょう(参考リンク)。

 DPHライブラリでは,並列処理を行うmapUP関数,filterUP関数,foldUP関数が用意されています。それぞれ,map関数,filter関数,fold関数に対応するものです(参考リンク1参考リンク2参考リンク3)。

-- | Apply a worker to all elements of a vector.
mapUP :: (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
{-# INLINE mapUP #-}
mapUP f xs 
        = splitJoinD theGang (mapD theGang (Seq.map f)) xs


-- | Keep elements that match the given predicate.
filterUP :: Unbox a => (a -> Bool) -> Vector a -> Vector a
{-# INLINE filterUP #-}
filterUP f
        = joinD  theGang unbalanced
        . mapD   theGang (Seq.filter f)
        . splitD theGang unbalanced

-- | Undirected fold.
--   Note that this function has more constraints on its parameters than the
--   standard fold function from the Haskell Prelude.
--
--   * The worker function must be associative.
--   * The provided starting element must be neutral with respect to the worker.
--     For example 0 is neutral wrt (+) and 1 is neutral wrt (*).
--
--   We need these constraints so that we can partition the fold across 
--   several threads. Each thread folds a chunk of the input vector, 
--   then we fold together all the results in the main thread.
--
foldUP  :: (Unbox a, DT a) => (a -> a -> a) -> a -> Vector a -> a
{-# INLINE foldUP #-}
foldUP f !z xs
        = foldD theGang f
                (mapD   theGang (Seq.fold f z)
                (splitD theGang unbalanced xs))

 これらの関数は,並列処理の対象であるデータを分割して各CPUコアに分配するsplitD関数,並列処理の結果のデータを集約するjoinD関数,splitDとjoinDを組み合わせたsplitJoinD関数,「joinD関数などで分配され,splitD関数などで集約されるデータ」に対して実際に並列処理を行うmapD関数などにより定義されています。

-- | Distribute an array over a 'Gang'.
--
--   NOTE: This is defined in terms of splitD_impl to avoid introducing loops
--         through RULES. Without it, splitJoinD would be a loop breaker.
splitD :: Unbox a => Gang -> Distribution -> Vector a -> Dist (Vector a)
{-# INLINE_DIST splitD #-}
splitD g _ arr = splitD_impl g arr

splitD_impl :: Unbox a => Gang -> Vector a -> Dist (Vector a)
{-# INLINE_DIST splitD_impl #-}
splitD_impl g !arr = generateD_cheap g (\i -> Seq.slice arr (idx i) (len i))
  where
    n = Seq.length arr
    !p = gangSize g
    !l = n `quotInt` p
    !m = n `remInt` p

    {-# INLINE [0] idx #-}
    idx i | i < m     = (l+1)*i
          | otherwise = l*i + m

    {-# INLINE [0] len #-}
    len i | i < m     = l+1
          | otherwise = l


-- | Join a distributed array.
--
--   NOTE: This is defined in terms of joinD_impl to avoid introducing loops
--         through RULES. Without it, splitJoinD would be a loop breaker.
joinD :: Unbox a => Gang -> Distribution -> Dist (Vector a) -> Vector a
{-# INLINE CONLIKE [1] joinD #-}
joinD g _ darr  = joinD_impl g darr

joinD_impl :: forall a. Unbox a => Gang -> Dist (Vector a) -> Vector a
{-# INLINE_DIST joinD_impl #-}
joinD_impl g !darr = checkGangD (here "joinD") g darr $
                     Seq.new n (\ma -> zipWithDST_ g (copy ma) di darr)
  where
    (!di,!n) = scanD g (+) 0 $ lengthD darr
    copy :: forall s. MVector s a -> Int -> Vector a -> DistST s ()
    copy ma i arr = stToDistST (Seq.copy (mslice i (Seq.length arr) ma) arr)


-- | Split a vector over a gang, run a distributed computation, then
--   join the pieces together again.
splitJoinD
        :: (Unbox a, Unbox b)
        => Gang
        -> (Dist (Vector a) -> Dist (Vector b))
        -> Vector a
        -> Vector b
{-# INLINE_DIST splitJoinD #-}
splitJoinD g f !xs = joinD_impl g (f (splitD_impl g xs))

-- | Map a function over a distributed value.
mapD :: (DT a, DT b) => Gang -> (a -> b) -> Dist a -> Dist b
{-# INLINE [1] mapD #-}
mapD g f !d = checkGangD (here "mapD") g d
             (runDistST g (myD d >>= return . f))

 mapUP,filterUP,foldUPの定義のうち,joinD,splitD,splitJoinDを使って定義された部分で無駄な同期処理を行っている可能性があります。例えば,「mapUP f . mapUP g」や「foldUP f x . filterUP g」といった式では,「joinD*により同期を行って集めたデータを,そのまますぐにsplitD*を使って再分配する」という無駄が発生しています。DPHライブラリでは,書き換え規則を使って,こうした無駄な同期処理を除去しています。

{-# RULES

"splitD[unbalanced]/joinD" forall g b da.
  splitD g unbalanced (joinD g b da) = da

"splitD[balanced]/joinD" forall g da.
  splitD g balanced (joinD g balanced da) = da

"splitD/splitJoinD" forall g b f xs.
  splitD g b (splitJoinD g f xs) = f (splitD g b xs)

"splitJoinD/joinD" forall g b f da.
  splitJoinD g f (joinD g b da) = joinD g b (f da)

"splitJoinD/splitJoinD" forall g f1 f2 xs.
  splitJoinD g f1 (splitJoinD g f2 xs) = splitJoinD g (f1 . f2) xs

~ 略 ~
  #-}

mapUP f . mapUP g
==> { 「.」のインライン化
      「\xs ->」の部分は便宜上省略する }
    mapUP f (mapUP g xs)
==> { mapUP関数のインライン化 }
    splitJoinD theGang (mapD theGang (Seq.map f))
        (mapUP g xs))
==> { mapUP関数のインライン化 }
    splitJoinD theGang (mapD theGang (Seq.map f))
        (splitJoinD theGang (mapD theGang (Seq.map g)) xs)
==> { "splitJoinD/splitJoinD"規則の適用 }
    splitJoinD theGang (f . g) xs

foldUP f x . filterUP g
==> { 「.」のインライン化
      「\ys ->」の部分は便宜上省略する }
    foldUP f x (filterUP g ys)
==> { foldUP関数のインライン化 }
    foldD theGang f
        (mapD   theGang (Seq.fold f x)
        (splitD theGang unbalanced 
                        (filterUP g ys)))
==> { filterUP関数のインライン化 }
    foldD theGang f
        (mapD   theGang (Seq.fold f x)
        (splitD theGang unbalanced
             (( joinD  theGang unbalanced
              . mapD   theGang (Seq.filter g)
              . splitD theGang unbalanced) ys)))
==> { 「.」のインライン化 + β-簡約 }
    foldD theGang f
        (mapD   theGang (Seq.fold f x)
        (splitD theGang unbalanced
             ( joinD  theGang unbalanced
             ( mapD   theGang (Seq.filter g)
             . splitD theGang unbalanced) ys)))
==> { "splitD[unbalanced]/joinD"規則の適用 }
    foldD theGang f
        (mapD   theGang (Seq.fold f x)
          ( mapD   theGang (Seq.filter g)
          . splitD theGang unbalanced) ys)

 並列プログラムが不要な同期によって十分に性能を発揮できない場合には,融合変換を使って不要な同期を除去することを検討してみましょう。