-
Notifications
You must be signed in to change notification settings - Fork 104
/
Copy pathfwdense1.hs
74 lines (59 loc) · 1.97 KB
/
fwdense1.hs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
{-# LANGUAGE CPP, BangPatterns #-}
{-# OPTIONS_GHC -Wall -fno-warn-name-shadowing #-}
module Main ( main, test ) where
import System.Environment
import Data.Array.Repa
import Data.Functor.Identity
-- <<Graph
type Weight = Int
type Graph r = Array r DIM2 Weight
-- >>
-- -----------------------------------------------------------------------------
-- shortestPaths
-- <<shortestPaths
shortestPaths :: Graph U -> Graph U
shortestPaths g0 = runIdentity $ go g0 0 -- <1>
where
Z :. _ :. n = extent g0
go !g !k | k == n = return g -- <2>
| otherwise = do
g' <- computeP (fromFunction (Z:.n:.n) sp) -- <3>
go g' (k+1)
where
sp (Z:.i:.j) = min (g ! (Z:.i:.j)) (g ! (Z:.i:.k) + g ! (Z:.k:.j))
-- >>
-- -----------------------------------------------------------------------------
-- Testing
input :: [[Int]]
input = [[ 0, 999, 999, 13, 999, 999],
[999, 0, 999, 999, 4, 9],
[ 11, 999, 0, 999, 999, 999],
[999, 3, 999, 0, 999, 7],
[ 15, 5, 999, 1, 0, 999],
[ 11, 999, 999, 14, 999, 0]]
-- correct result:
result :: [[Int]]
result = [[0, 16, 999, 13, 20, 20],
[19, 0, 999, 5, 4, 9],
[11, 27, 0, 24, 31, 31],
[18, 3, 999, 0, 7, 7],
[15, 4, 999, 1, 0, 8],
[11, 17, 999, 14, 21, 0] ]
test :: Bool
test = fromAdjMatrix (shortestPaths (toAdjMatrix input)) == result
toAdjMatrix :: [[Int]] -> Graph U
toAdjMatrix xs = fromListUnboxed (Z :. k :. k) (concat xs)
where k = length xs
fromAdjMatrix :: Graph U -> [[Int]]
fromAdjMatrix m = chunk k (toList m)
where
(Z :. _ :. k) = extent m
chunk :: Int -> [a] -> [[a]]
chunk _ [] = []
chunk n xs = as : chunk n bs
where (as,bs) = splitAt n xs
main :: IO ()
main = do
[n] <- fmap (fmap read) getArgs
let g = fromListUnboxed (Z:.n:.n) [1..n^(2::Int)] :: Graph U
print (sumAllS (shortestPaths g))