import java.util.*
val DEBUG = false
val P = 1000000007L
data class FiniteField(val x: Long) {
operator fun plus(o: FiniteField) = FiniteField((x + o.x) % P)
operator fun minus(o: FiniteField) = FiniteField((x + P - o.x) % P)
operator fun times(o: FiniteField) = FiniteField((x * o.x) % P)
}
data class Edge(val from: Int, val to: Int, val cost: Long)
data class State(val v: Int, val cost: Long)
fun compareState(s1: State, s2: State): Int {
if (s1.cost > s2.cost) return 1
if (s1.cost < s2.cost) return -1
return 0
}
fun readIntList(s: String): List<Int> {
var values = mutableListOf<Int>()
var count = 0
var currentInt = 0
for (i in 0 until s.length) {
if (s[i] == ' ') {
values.add(currentInt)
currentInt = 0
count++
} else {
currentInt *= 10
currentInt += s[i] - '0'
}
}
values.add(currentInt)
return values
}
fun main(args: Array<String>) {
val (n, m) = readIntList(readLine()!!)
val (sRaw, tRaw) = readIntList(readLine()!!)
val s = sRaw - 1
val t = tRaw - 1
val graph = Array(n, { mutableListOf<Edge>() })
repeat(m) {
val line = readLine()!!
val (uRaw, vRaw, dInt) = readIntList(line)
val u = uRaw - 1
val v = vRaw - 1
val d = dInt.toLong()
graph[u].add(Edge(u, v, d))
graph[v].add(Edge(v, u, d))
}
// Dijkstra from s
val pq = PriorityQueue<State>(11, { s1: State, s2: State -> compareState(s1, s2) })
val dist = LongArray(n, { Long.MAX_VALUE })
dist[s] = 0
pq.add(State(s, 0L))
while (pq.isNotEmpty()) {
val state = pq.poll()!!
if (dist[state.v] < state.cost) continue
for (e in graph[state.v]) {
if (state.cost + e.cost < dist[e.to]) {
dist[e.to] = state.cost + e.cost
pq.add(State(e.to, state.cost + e.cost))
}
}
}
if (DEBUG) println("dist: ${dist.toList()}")
val sGraph = Array(n, { mutableListOf<Int>() })
val tGraph = Array(n, { mutableListOf<Int>() })
val stack = ArrayDeque<Int>()
stack.offerLast(t)
val visited = BooleanArray(n)
while (stack.isNotEmpty()) {
val v = stack.pollLast()
if (visited[v]) continue
visited[v] = true
if (DEBUG) println("v: $v")
for (e in graph[v]) {
if (dist[v] == dist[e.to] + e.cost) {
tGraph[v].add(e.to)
sGraph[e.to].add(v)
stack.offerLast(e.to)
}
}
}
// for (v in 0 until n) {
// for (e in graph[v]) {
// if (dist[v] + e.cost == dist[e.to]) {
// tGraph[e.to].add(v)
// sGraph[v].add(e.to)
// }
// }
// }
if (DEBUG) println("sGraph: ${sGraph.toList()}")
if (DEBUG) println("tGraph: ${tGraph.toList()}")
val sCount = Array(n, { FiniteField(0) })
sCount[s] = FiniteField(1)
val sSorted = dist.withIndex().sortedBy { it.value }.map { it.index }
fillCount(s, sGraph, sSorted, sCount)
if (DEBUG) println("sCount: ${sCount.toList()}")
val tCount = Array(n, { FiniteField(0) })
tCount[t] = FiniteField(1)
val tSorted = sSorted.reversed()
fillCount(t, tGraph, tSorted, tCount)
if (DEBUG) println("tCount: ${tCount.toList()}")
var ans = sCount[t] * tCount[s]
if (DEBUG) println("ans STEP 1: $ans")
val totalDist = dist[t]
// Meet on vertex
for (v in 0 until n) {
if (dist[v] * 2 == totalDist) {
ans -= sCount[v] * sCount[v] * tCount[v] * tCount[v]
}
}
if (DEBUG) println("ans STEP 2: $ans")
// Meet on edge
for (v in 0 until n) {
for (e in graph[v]) {
if (dist[v] + e.cost == dist[e.to] && dist[v] * 2 < totalDist && dist[e.to] * 2 > totalDist) {
ans -= sCount[v] * sCount[v] * tCount[e.to] * tCount[e.to]
}
}
}
if (DEBUG) println("ans STEP 3: $ans")
println(ans)
}
fun fillCount(r: Int, graph: Array<MutableList<Int>>, tSorted: List<Int>, count: Array<FiniteField>) {
for (v in tSorted) {
for (u in graph[v]) {
count[u] = (count[u] + count[v])
}
}
}
Main.kt:141:15: warning: parameter 'r' is never used
fun fillCount(r: Int, graph: Array<MutableList<Int>>, tSorted: List<Int>, count: Array<FiniteField>) {
^