diff --git a/.gitignore b/.gitignore
index 855557650f..fb6fe2a03c 100644
--- a/.gitignore
+++ b/.gitignore
@@ -159,3 +159,5 @@ gradle-app.setting
/wai2k/
.idea/modules.xml
+
+assets/models/
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
index cfe33df0cb..bdf373599f 100644
--- a/.idea/inspectionProfiles/Project_Default.xml
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -3,7 +3,7 @@
-
+
diff --git a/build.gradle.kts b/build.gradle.kts
index 47ed62ba03..558fd66884 100644
--- a/build.gradle.kts
+++ b/build.gradle.kts
@@ -23,6 +23,8 @@ import org.jetbrains.kotlin.gradle.tasks.KotlinCompile
import java.nio.file.Files
import java.nio.file.Paths
import java.nio.file.StandardOpenOption
+import java.security.MessageDigest
+import javax.xml.bind.DatatypeConverter
plugins {
java
@@ -60,6 +62,8 @@ dependencies {
implementation("org.controlsfx", "controlsfx", "8.40.14")
implementation("org.reflections", "reflections", "0.9.12")
implementation("com.squareup.okhttp3:okhttp:4.7.2")
+ implementation("ai.djl.pytorch:pytorch-engine:0.8.0")
+ implementation("ai.djl.pytorch:pytorch-native-auto:1.6.0")
implementation("net.sourceforge.tess4j", "tess4j", "4.5.1") {
exclude("org.ghost4j")
@@ -98,11 +102,16 @@ tasks {
dependencies {
include { it.moduleGroup.startsWith("com.waicool20") }
}
+ doLast { md5sum(archiveFile.get()) }
}
}
+task("prepare-deploy") {
+ dependsOn("build", "deps-list", "packAssets")
+}
+
task("deps-list") {
- val file = Paths.get("$projectDir/build/dependencies.txt")
+ val file = Paths.get("$buildDir/deploy/dependencies.txt")
doFirst {
if (Files.notExists(file)) {
Files.createDirectories(file.parent)
@@ -141,7 +150,28 @@ task("versioning") {
task("packAssets") {
archiveFileName.set("assets.zip")
- destinationDirectory.set(file("$buildDir"))
+ destinationDirectory.set(file("$buildDir/deploy/"))
+
+ from(projectDir)
+ include("/assets/**")
+ exclude("/assets/models/**")
+ doLast { md5sum(archiveFile.get()) }
+}
+
+task("packModels") {
+ archiveFileName.set("models.zip")
+ destinationDirectory.set(file("$buildDir/deploy/"))
+
+ from(projectDir)
+ include("/assets/models/**")
+ doLast { md5sum(archiveFile.get()) }
+}
- from("$projectDir/assets")
+fun md5sum(file: RegularFile) {
+ val path = file.asFile.toPath()
+ val md5File = Paths.get("$path.md5")
+ val md5sum = MessageDigest.getInstance("MD5")
+ .digest(Files.readAllBytes(path))
+ .let { DatatypeConverter.printHexBinary(it) }
+ Files.write(md5File, md5sum.toByteArray())
}
\ No newline at end of file
diff --git a/launcher/build.gradle.kts b/launcher/build.gradle.kts
index cbf92dcf2a..ed796284a0 100644
--- a/launcher/build.gradle.kts
+++ b/launcher/build.gradle.kts
@@ -25,6 +25,10 @@
import com.github.jengelman.gradle.plugins.shadow.tasks.ShadowJar
import org.jetbrains.kotlin.gradle.plugin.KotlinPluginWrapper
import org.jetbrains.kotlin.gradle.tasks.KotlinCompile
+import java.nio.file.Files
+import java.nio.file.Paths
+import java.security.MessageDigest
+import javax.xml.bind.DatatypeConverter
plugins {
java
@@ -68,5 +72,15 @@ tasks {
archiveClassifier.value("")
archiveVersion.value("")
exclude("kotlin/reflect/**")
+ doLast { md5sum(archiveFile.get()) }
}
}
+
+fun md5sum(file: RegularFile) {
+ val path = file.asFile.toPath()
+ val md5File = Paths.get("$path.md5")
+ val md5sum = MessageDigest.getInstance("MD5")
+ .digest(Files.readAllBytes(path))
+ .let { DatatypeConverter.printHexBinary(it) }
+ Files.write(md5File, md5sum.toByteArray())
+}
diff --git a/src/main/kotlin/com/waicool20/wai2k/Wai2K.kt b/src/main/kotlin/com/waicool20/wai2k/Wai2K.kt
index 705006ac25..4b8f56ab82 100644
--- a/src/main/kotlin/com/waicool20/wai2k/Wai2K.kt
+++ b/src/main/kotlin/com/waicool20/wai2k/Wai2K.kt
@@ -48,6 +48,9 @@ class Wai2K : App(Wai2KWorkspace::class) {
} catch (t: Throwable) {
logger.warn("Could not set locale to C, application may crash if using tesseract 4.0+")
}
+ // Set inference thread count, anything above 8 seems to be ignored
+ val cores = Runtime.getRuntime().availableProcessors().coerceAtMost(8)
+ CLib.setEnv("OMP_NUM_THREADS", cores.toString(), true)
}
private fun isRunningJar(): Boolean {
diff --git a/src/main/kotlin/com/waicool20/wai2k/script/modules/CombatSimModule.kt b/src/main/kotlin/com/waicool20/wai2k/script/modules/CombatSimModule.kt
index ffed112247..bcbbe47418 100644
--- a/src/main/kotlin/com/waicool20/wai2k/script/modules/CombatSimModule.kt
+++ b/src/main/kotlin/com/waicool20/wai2k/script/modules/CombatSimModule.kt
@@ -64,7 +64,7 @@ class CombatSimModule(
private val mapRunner = object : AbsoluteMapRunner(scriptRunner, region, config, profile) {
override val isCorpseDraggingMap = false
- override suspend fun execute() {
+ override suspend fun begin() {
if (profile.combatSimulation.neuralFragment == Level.OFF) return
if (OffsetDateTime.now(ZoneOffset.ofHours(-8)).dayOfWeek !in dataSimDays) {
diff --git a/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/HomographyMapRunner.kt b/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/HomographyMapRunner.kt
index d19b1a9364..3460ee3574 100644
--- a/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/HomographyMapRunner.kt
+++ b/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/HomographyMapRunner.kt
@@ -19,22 +19,25 @@
package com.waicool20.wai2k.script.modules.combat
-import boofcv.struct.image.GrayF32
+import ai.djl.metric.Metrics
+import ai.djl.modality.cv.Image
+import ai.djl.modality.cv.ImageFactory
import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper
import com.fasterxml.jackson.module.kotlin.readValue
import com.waicool20.cvauto.android.AndroidRegion
-import com.waicool20.cvauto.util.homography
import com.waicool20.cvauto.util.transformRect
import com.waicool20.wai2k.config.Wai2KConfig
import com.waicool20.wai2k.config.Wai2KProfile
-import com.waicool20.wai2k.script.NodeNotFoundException
import com.waicool20.wai2k.script.ScriptRunner
-import com.waicool20.wai2k.util.extractNodes
+import com.waicool20.wai2k.util.ai.MatchingModel
+import com.waicool20.wai2k.util.ai.MatchingTranslator
import com.waicool20.waicoolutils.logging.loggerFor
import georegression.struct.homography.Homography2D_F64
-import kotlinx.coroutines.*
+import kotlinx.coroutines.Dispatchers
+import kotlinx.coroutines.async
+import kotlinx.coroutines.delay
+import kotlinx.coroutines.runBlocking
import java.nio.file.Files
-import javax.imageio.ImageIO
import kotlin.math.max
import kotlin.math.pow
import kotlin.random.Random
@@ -54,24 +57,30 @@ abstract class HomographyMapRunner(
private const val minScroll = 75
/**
- * Difference theresholds
+ * Difference thresholds
*/
private const val maxMapDiff = 80.0
private const val maxSideDiff = 5.0
}
+ private val metrics = Metrics()
+
+ private val predictor by lazy {
+ val model = MatchingModel(
+ config.assetsDirectory.resolve("models/SuperPoint.pt"),
+ config.assetsDirectory.resolve("models/SuperGlue.pt")
+ )
+ model.newPredictor(MatchingTranslator(480, 360)).also { it.setMetrics(metrics) }
+ }
+
/**
* Map homography cache
*/
private var mapH: Homography2D_F64? = null
- protected open val extractBlueNodes: Boolean = true
- protected open val extractWhiteNodes: Boolean = false
- protected open val extractYellowNodes: Boolean = true
-
final override val nodes: List
- val fullMap: GrayF32
+ val fullMap: Image
init {
val n = async(Dispatchers.IO) {
@@ -83,41 +92,41 @@ abstract class HomographyMapRunner(
}
}
val fm = async(Dispatchers.IO) {
- val path = config.assetsDirectory.resolve("$PREFIX/map.png")
- if (Files.exists(path)) {
- ImageIO.read(path.toFile()).extractNodes(extractBlueNodes, extractWhiteNodes, extractYellowNodes)
- } else {
- GrayF32()
- }
+ ImageFactory.getInstance().fromFile(config.assetsDirectory.resolve("$PREFIX/map.png"))
}
nodes = runBlocking { n.await() }
fullMap = runBlocking { fm.await() }
}
+ override suspend fun cleanup() {
+ mapH = null
+ }
+
override suspend fun MapNode.findRegion(): AndroidRegion {
val window = mapRunnerRegions.window
- var h: Homography2D_F64? = null
- while (h == null) {
- h = try {
- mapH
- ?: fullMap.homography(window.capture().extractNodes(extractBlueNodes, extractWhiteNodes, extractYellowNodes))
- } catch (e: IllegalStateException) {
- continue
- }
+
+ val h = mapH ?: run {
+ logger.info("Finding map transformation")
+ val prediction = predictor.predict(fullMap to ImageFactory.getInstance().fromImage(window.capture()))
+ logger.debug("Homography prediction metrics:")
+ logger.debug("Preprocess: ${metrics.latestMetric("Preprocess").value.toLong() / 1000000} ms")
+ logger.debug("Inference: ${metrics.latestMetric("Inference").value.toLong() / 1000000} ms")
+ logger.debug("Postprocess: ${metrics.latestMetric("Postprocess").value.toLong() / 1000000} ms")
+ logger.debug("Total: ${metrics.latestMetric("Total").value.toLong() / 1000000} ms")
+ mapH = prediction
+ prediction
}
suspend fun retry(): AndroidRegion {
- if (Random.nextBoolean()) {
- logger.info("Zoom out")
- region.pinch(
- Random.nextInt(500, 700),
- Random.nextInt(300, 400),
- 0.0,
- 500
- )
- delay(1000)
- }
+ logger.info("Zoom out")
+ region.pinch(
+ Random.nextInt(500, 700),
+ Random.nextInt(300, 400),
+ 0.0,
+ 500
+ )
+ delay(1000)
mapH = null
return findRegion()
}
@@ -192,28 +201,6 @@ abstract class HomographyMapRunner(
delay(200)
return findRegion()
}
-
- while (isActive) {
- val targets = mutableListOf>()
- val img = roi.capture().extractNodes()
- for (y in 0 until img.height) {
- var index = img.startIndex + y * img.stride
- for (x in 0 until img.width) {
- if (img.data[index++] >= 175) targets += x to y
- }
- }
- yield()
- if (targets.isEmpty()) {
- logger.debug("No targets found, retry")
- if (Random.nextBoolean()) continue else return retry()
- }
- logger.debug("${targets.size} target candidates for node $this")
- val target = targets.random().let { (cX, cY) ->
- region.subRegionAs(roi.x + cX, roi.y + cY, 5, 5)
- }
- logger.debug("Node target: (x=${target.x},y=${target.y})")
- return target
- }
- throw NodeNotFoundException(this)
+ return roi
}
}
\ No newline at end of file
diff --git a/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/MapRunner.kt b/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/MapRunner.kt
index 91aee6b4ea..ee7223dd9b 100644
--- a/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/MapRunner.kt
+++ b/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/MapRunner.kt
@@ -116,9 +116,22 @@ abstract class MapRunner(
abstract val isCorpseDraggingMap: Boolean
/**
- * Main execution function that is executed when map is entered
+ * Main run function that goes through whole life cycle of MapRunner
*/
- abstract suspend fun execute()
+ suspend fun execute() {
+ begin()
+ cleanup()
+ }
+
+ /**
+ * Function that is executed when map is entered
+ */
+ abstract suspend fun begin()
+
+ /**
+ * Cleanup function run after execute()
+ */
+ open suspend fun cleanup() = Unit
/**
* Executes when entering a battle
diff --git a/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map0_2.kt b/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map0_2.kt
index 818a945e24..71c7e08fc4 100644
--- a/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map0_2.kt
+++ b/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map0_2.kt
@@ -25,7 +25,10 @@ import com.waicool20.wai2k.config.Wai2KProfile
import com.waicool20.wai2k.script.ScriptRunner
import com.waicool20.wai2k.script.modules.combat.HomographyMapRunner
import com.waicool20.waicoolutils.logging.loggerFor
+import kotlinx.coroutines.delay
import kotlinx.coroutines.yield
+import kotlin.math.roundToLong
+import kotlin.random.Random
class Map0_2(
scriptRunner: ScriptRunner,
@@ -36,8 +39,17 @@ class Map0_2(
private val logger = loggerFor()
override val isCorpseDraggingMap = true
- override suspend fun execute() {
- nodes[2].findRegion() // Try focus boss node to get map centered
+ override suspend fun begin() {
+ if (gameState.requiresMapInit) {
+ logger.info("Zoom out")
+ region.pinch(
+ Random.nextInt(900, 1000),
+ Random.nextInt(300, 400),
+ 0.0,
+ 1000)
+ delay((900 * gameState.delayCoefficient).roundToLong()) //Wait to settle
+ gameState.requiresMapInit = false
+ }
val rEchelons = deployEchelons(nodes[14], nodes[13])
mapRunnerRegions.startOperation.click(); yield()
waitForGNKSplash()
diff --git a/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map0_2_EX.kt b/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map0_2_EX.kt
index 6b58b58a99..7daca659e9 100644
--- a/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map0_2_EX.kt
+++ b/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map0_2_EX.kt
@@ -40,7 +40,7 @@ class Map0_2_EX(
private val logger = loggerFor()
override val isCorpseDraggingMap = true
- override suspend fun execute() {
+ override suspend fun begin() {
if (gameState.requiresMapInit) {
// Check to see if its already good
diff --git a/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map10_4E.kt b/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map10_4E.kt
index b5c3f2f0d6..7cbdbc1f07 100644
--- a/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map10_4E.kt
+++ b/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map10_4E.kt
@@ -38,7 +38,7 @@ class Map10_4E(
private val logger = loggerFor()
override val isCorpseDraggingMap = false
- override suspend fun execute() {
+ override suspend fun begin() {
if (gameState.requiresMapInit) {
logger.info("Zoom out")
repeat(2) {
diff --git a/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map10_4E_Drag.kt b/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map10_4E_Drag.kt
index 314874cb91..c31da4f5e0 100644
--- a/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map10_4E_Drag.kt
+++ b/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map10_4E_Drag.kt
@@ -43,7 +43,7 @@ class Map10_4E_Drag(
private val logger = loggerFor()
override val isCorpseDraggingMap = true
- override suspend fun execute() {
+ override suspend fun begin() {
// Mostly empty region to the left
val r = region.subRegionAs(300, 500, 150, 8)
diff --git a/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map11_5.kt b/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map11_5.kt
index 183c5afe70..e96d2193d3 100644
--- a/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map11_5.kt
+++ b/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map11_5.kt
@@ -38,7 +38,7 @@ class Map11_5(
private val logger = loggerFor()
override val isCorpseDraggingMap = true
- override suspend fun execute() {
+ override suspend fun begin() {
// No need to zoom, delay for map lag
delay((1000 * gameState.delayCoefficient).roundToLong())
val rEchelons = deployEchelons(nodes[1], nodes[0])
diff --git a/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map1_6.kt b/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map1_6.kt
index a40207e526..ac1a26be4e 100644
--- a/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map1_6.kt
+++ b/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map1_6.kt
@@ -40,7 +40,7 @@ class Map1_6(
private val logger = loggerFor()
override val isCorpseDraggingMap = false
- override suspend fun execute() {
+ override suspend fun begin() {
logger.info("Zoom out")
region.pinch(
Random.nextInt(700, 800),
diff --git a/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map2_6.kt b/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map2_6.kt
index e044b71840..143e31e296 100644
--- a/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map2_6.kt
+++ b/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map2_6.kt
@@ -39,7 +39,7 @@ class Map2_6(
private val logger = loggerFor()
override val isCorpseDraggingMap = false
- override suspend fun execute() {
+ override suspend fun begin() {
logger.info("Zoom out")
region.pinch(
Random.nextInt(700, 800),
diff --git a/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map3_4E.kt b/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map3_4E.kt
index b5a25169a1..981c859c2d 100644
--- a/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map3_4E.kt
+++ b/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map3_4E.kt
@@ -38,7 +38,7 @@ class Map3_4E(
private val logger = loggerFor()
override val isCorpseDraggingMap = false
- override suspend fun execute() {
+ override suspend fun begin() {
logger.info("Zoom out")
region.pinch(
Random.nextInt(700, 800),
diff --git a/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map3_6.kt b/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map3_6.kt
index fd5841f4e9..de97e10ee5 100644
--- a/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map3_6.kt
+++ b/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map3_6.kt
@@ -39,7 +39,7 @@ class Map3_6(
private val logger = loggerFor()
override val isCorpseDraggingMap = false
- override suspend fun execute() {
+ override suspend fun begin() {
logger.info("Zoom out")
region.pinch(
Random.nextInt(700, 800),
diff --git a/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map4_6.kt b/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map4_6.kt
index e76a6bf63d..c9a31cede8 100644
--- a/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map4_6.kt
+++ b/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map4_6.kt
@@ -38,7 +38,7 @@ class Map4_6(
private val logger = loggerFor()
override val isCorpseDraggingMap = false
- override suspend fun execute() {
+ override suspend fun begin() {
if (gameState.requiresMapInit) {
logger.info("Zoom out")
repeat(2) {
diff --git a/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map4_6_Data.kt b/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map4_6_Data.kt
index d102e45197..8baf782ec7 100644
--- a/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map4_6_Data.kt
+++ b/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map4_6_Data.kt
@@ -42,7 +42,7 @@ class Map4_6_Data(
//Allow interruption of waiting for turn if necessary
private var combatComplete = false
- override suspend fun execute() {
+ override suspend fun begin() {
if (gameState.requiresMapInit) {
logger.info("Zoom out")
repeat(2) {
diff --git a/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map5_6.kt b/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map5_6.kt
index c4fba3c437..bdea64ea90 100644
--- a/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map5_6.kt
+++ b/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map5_6.kt
@@ -38,7 +38,7 @@ class Map5_6(
private val logger = loggerFor()
override val isCorpseDraggingMap = false
- override suspend fun execute() {
+ override suspend fun begin() {
if (gameState.requiresMapInit) {
logger.info("Zoom out")
repeat(2) {
diff --git a/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map6_4N.kt b/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map6_4N.kt
index 685ffacd0a..8966f1102a 100644
--- a/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map6_4N.kt
+++ b/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map6_4N.kt
@@ -38,7 +38,7 @@ class Map6_4N(
private val logger = loggerFor()
override val isCorpseDraggingMap = false
- override suspend fun execute() {
+ override suspend fun begin() {
if (gameState.requiresMapInit) {
logger.info("Zoom out")
zoom()
diff --git a/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map6_6.kt b/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map6_6.kt
index fde6014794..a6ab9992a1 100644
--- a/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map6_6.kt
+++ b/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map6_6.kt
@@ -38,7 +38,7 @@ class Map6_6(
private val logger = loggerFor()
override val isCorpseDraggingMap = false
- override suspend fun execute() {
+ override suspend fun begin() {
if (gameState.requiresMapInit) {
logger.info("Zoom out")
repeat(2) {
diff --git a/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map8_1N.kt b/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map8_1N.kt
index 3b3ccd424d..226b1a0a27 100644
--- a/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map8_1N.kt
+++ b/src/main/kotlin/com/waicool20/wai2k/script/modules/combat/maps/Map8_1N.kt
@@ -38,7 +38,7 @@ class Map8_1N(
private val logger = loggerFor()
override val isCorpseDraggingMap = true
- override suspend fun execute() {
+ override suspend fun begin() {
if (gameState.requiresMapInit) {
logger.info("Zoom out")
repeat(2) {
diff --git a/src/main/kotlin/com/waicool20/wai2k/util/ImageUtils.kt b/src/main/kotlin/com/waicool20/wai2k/util/ImageUtils.kt
index d085799895..b3d0195d74 100644
--- a/src/main/kotlin/com/waicool20/wai2k/util/ImageUtils.kt
+++ b/src/main/kotlin/com/waicool20/wai2k/util/ImageUtils.kt
@@ -49,13 +49,17 @@ fun BufferedImage.extractNodes(
val yellowNodes = hsv.clone().apply { hsvFilter(hueRange = 40..50, satRange = 12..100) }.getBand(2)
img += yellowNodes
}
+
+ return img.binarizeImage(0.75)
+}
+
+fun GrayF32.binarizeImage(threshold: Double = 0.4): GrayF32 {
for (y in 0 until height) {
- var index = img.startIndex + y * img.stride
+ var index = startIndex + y * stride
for (x in 0 until width) {
- img.data[index] = if (img.data[index] >= 175) 255f else 0f
+ data[index] = if (data[index] >= 255 * threshold) 255f else 0f
index++
}
}
-
- return img
+ return this
}
\ No newline at end of file
diff --git a/src/main/kotlin/com/waicool20/wai2k/util/ai/GFLObject.kt b/src/main/kotlin/com/waicool20/wai2k/util/ai/GFLObject.kt
new file mode 100644
index 0000000000..cb7b77ad61
--- /dev/null
+++ b/src/main/kotlin/com/waicool20/wai2k/util/ai/GFLObject.kt
@@ -0,0 +1,73 @@
+/*
+ * GPLv3 License
+ *
+ * Copyright (c) WAI2K by waicool20
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ */
+
+package com.waicool20.wai2k.util.ai
+
+import ai.djl.modality.cv.output.DetectedObjects
+import ai.djl.modality.cv.output.Rectangle
+
+sealed class GFLObject {
+ abstract val probability: Double
+ abstract val bbox: Rectangle
+
+ open class Node(override val probability: Double, override val bbox: Rectangle) : GFLObject()
+ class CommandPost(override val probability: Double, override val bbox: Rectangle) : Node(probability, bbox)
+ class Heliport(override val probability: Double, override val bbox: Rectangle) : Node(probability, bbox)
+ class SupplyCrate(override val probability: Double, override val bbox: Rectangle) : Node(probability, bbox)
+ class Radar(override val probability: Double, override val bbox: Rectangle) : Node(probability, bbox)
+
+ abstract class Unit : GFLObject()
+
+ open class Enemy(override val probability: Double, override val bbox: Rectangle) : Unit()
+ class SangvisFerri(override val probability: Double, override val bbox: Rectangle) : Enemy(probability, bbox)
+ class Military(override val probability: Double, override val bbox: Rectangle) : Enemy(probability, bbox)
+ class Paradeus(override val probability: Double, override val bbox: Rectangle) : Enemy(probability, bbox)
+
+ open class Friendly(override val probability: Double, override val bbox: Rectangle) : Unit()
+
+ companion object {
+ /**
+ * All possible values of [GFLObject], must be declared in the order the model was created with
+ */
+ val values = listOf(
+ Node::class,
+ CommandPost::class,
+ Heliport::class,
+ Enemy::class,
+ SangvisFerri::class,
+ Military::class,
+ Radar::class,
+ Paradeus::class,
+ SupplyCrate::class,
+ Friendly::class
+ )
+ }
+
+ override fun toString(): String {
+ return this::class.simpleName?.replace(Regex("(.)([A-Z])"), "$1 $2") ?: "Unknown"
+ }
+}
+
+fun List.toDetectedObjects(): DetectedObjects {
+ return DetectedObjects(
+ map { it.toString() },
+ map { it.probability },
+ map { it.bbox }
+ )
+}
\ No newline at end of file
diff --git a/src/main/kotlin/com/waicool20/wai2k/util/ai/MatchingModel.kt b/src/main/kotlin/com/waicool20/wai2k/util/ai/MatchingModel.kt
new file mode 100644
index 0000000000..533eed1c9c
--- /dev/null
+++ b/src/main/kotlin/com/waicool20/wai2k/util/ai/MatchingModel.kt
@@ -0,0 +1,73 @@
+/*
+ * GPLv3 License
+ *
+ * Copyright (c) WAI2K by waicool20
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ */
+
+package com.waicool20.wai2k.util.ai
+
+import ai.djl.Device
+import ai.djl.Model
+import ai.djl.ndarray.NDList
+import ai.djl.ndarray.NDManager
+import ai.djl.ndarray.types.Shape
+import ai.djl.nn.AbstractBlock
+import ai.djl.pytorch.engine.PtEngine
+import ai.djl.training.ParameterStore
+import ai.djl.util.PairList
+import java.nio.file.Path
+
+class MatchingModel(
+ superpointModelPath: Path,
+ superglueModelPath: Path
+) : Model by engine.newModel("MatchingModel", Device.defaultDevice()) {
+ companion object {
+ private val engine by lazy { PtEngine.getInstance() }
+ }
+
+ private val superpoint = ModelLoader.loadModel(superpointModelPath)
+ private val superglue = ModelLoader.loadModel(superglueModelPath)
+
+ init {
+ block = object : AbstractBlock(2) {
+ override fun forward(parameterStore: ParameterStore, inputs: NDList, training: Boolean, params: PairList?): NDList {
+ val (img0, img1) = inputs
+ val pred0 = superpoint.block.forward(parameterStore, NDList(img0), training, params)
+ val pred1 = superpoint.block.forward(parameterStore, NDList(img1), training, params)
+ val data = NDList().apply {
+ add(img0)
+ addAll(pred0)
+ add(img1)
+ addAll(pred1)
+ }
+ return NDList().apply {
+ addAll(pred0)
+ addAll(pred1)
+ addAll(superglue.block.forward(parameterStore, data, training, params))
+ }
+ }
+
+ override fun getOutputShapes(manager: NDManager, inputShapes: Array): Array {
+ return emptyArray()
+ }
+ }
+ }
+
+ override fun close() {
+ superpoint.close()
+ superglue.close()
+ }
+}
diff --git a/src/main/kotlin/com/waicool20/wai2k/util/ai/MatchingTranslator.kt b/src/main/kotlin/com/waicool20/wai2k/util/ai/MatchingTranslator.kt
new file mode 100644
index 0000000000..cc58d69df8
--- /dev/null
+++ b/src/main/kotlin/com/waicool20/wai2k/util/ai/MatchingTranslator.kt
@@ -0,0 +1,109 @@
+/*
+ * GPLv3 License
+ *
+ * Copyright (c) WAI2K by waicool20
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ */
+
+package com.waicool20.wai2k.util.ai
+
+import ai.djl.modality.cv.Image
+import ai.djl.modality.cv.transform.Resize
+import ai.djl.ndarray.NDList
+import ai.djl.ndarray.index.NDIndex
+import ai.djl.translate.Batchifier
+import ai.djl.translate.Pipeline
+import ai.djl.translate.Translator
+import ai.djl.translate.TranslatorContext
+import boofcv.struct.geo.AssociatedPair
+import com.waicool20.cvauto.util.wrapper.Config
+import com.waicool20.cvauto.util.wrapper.KFactoryMultiViewRobust
+import georegression.struct.homography.Homography2D_F64
+import georegression.struct.point.Point2D_F64
+
+class MatchingTranslator(
+ private val resizeWidth: Int = 640,
+ private val resizeHeight: Int = 480
+) : Translator, Homography2D_F64> {
+
+ private var img0Width = -1
+ private var img0Height = -1
+ private var img1Width = -1
+ private var img1Height = -1
+
+ override fun getBatchifier() = Batchifier.STACK
+
+ override fun getPipeline() = Pipeline(Resize(resizeWidth, resizeHeight), TransposeNormalizeTransform())
+
+ override fun processInput(ctx: TranslatorContext, input: Pair): NDList {
+ val (img0, img1) = input
+ img0Width = img0.width
+ img0Height = img0.height
+ img1Width = img1.width
+ img1Height = img1.height
+
+ val img0array = img0.toNDArray(ctx.ndManager, Image.Flag.GRAYSCALE)
+ val img1array = img1.toNDArray(ctx.ndManager, Image.Flag.GRAYSCALE)
+ return NDList().apply {
+ addAll(pipeline.transform(NDList(img0array)))
+ addAll(pipeline.transform(NDList(img1array)))
+ }
+ }
+
+ override fun processOutput(ctx: TranslatorContext, list: NDList): Homography2D_F64 {
+ var kpts0 = list[0]
+ val kpts1Temp = list[3]
+ var matches = list[6]
+ var conf = list[8]
+
+ val keepMask = matches.gt(-1)
+
+ matches = matches.booleanMask(keepMask)
+ conf = conf.booleanMask(keepMask)
+
+ kpts0 = kpts0.booleanMask(keepMask.stack(keepMask, 1), 1)
+ kpts0 = kpts0.reshape(kpts0.shape[0] / 2, 2)
+
+ var kpts1 = kpts0.zerosLike()
+
+ for ((i, j) in matches.toLongArray().withIndex()) {
+ kpts1.set(NDIndex(i.toLong()), kpts1Temp.get(j))
+ }
+
+ val scales0 = ctx.ndManager.create(floatArrayOf(img0Width.toFloat() / resizeWidth, img0Height.toFloat() / resizeHeight))
+ val scales1 = ctx.ndManager.create(floatArrayOf(img1Width.toFloat() / resizeWidth, img1Height.toFloat() / resizeHeight))
+
+ kpts0 = kpts0 * scales0
+ kpts1 = kpts1 * scales1
+
+ val points1 = kpts0.toPoint2D().map { Point2D_F64(it.x, it.y) }
+ val points2 = kpts1.toPoint2D().map { Point2D_F64(it.x, it.y) }
+ val confArr = conf.toFloatArray()
+
+ val pairs = mutableListOf()
+ for (i in points1.indices) {
+ if (confArr[i] > 0.45) {
+ pairs.add(AssociatedPair(points1[i], points2[i], false))
+ }
+ }
+
+ val modelMatcher = KFactoryMultiViewRobust.homographyRansac(
+ Config.Ransac(60, 3.0)
+ )
+
+ check(modelMatcher.process(pairs)) { "Model matching failed!" }
+ return modelMatcher.modelParameters.copy()
+ }
+}
\ No newline at end of file
diff --git a/src/main/kotlin/com/waicool20/wai2k/util/ai/ModelLoader.kt b/src/main/kotlin/com/waicool20/wai2k/util/ai/ModelLoader.kt
new file mode 100644
index 0000000000..49929e3b2b
--- /dev/null
+++ b/src/main/kotlin/com/waicool20/wai2k/util/ai/ModelLoader.kt
@@ -0,0 +1,40 @@
+/*
+ * GPLv3 License
+ *
+ * Copyright (c) WAI2K by waicool20
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ */
+
+package com.waicool20.wai2k.util.ai
+
+import ai.djl.Device
+import ai.djl.Model
+import ai.djl.pytorch.engine.PtEngine
+import java.nio.file.Files
+import java.nio.file.Path
+
+object ModelLoader {
+ val engine by lazy { PtEngine.getInstance() }
+
+ fun loadModel(path: Path): Model {
+ require(Files.isRegularFile(path)) { "Must be path to model file" }
+ require(Files.exists(path)) { "Model file does not exist" }
+ require("$path".endsWith(".pt")) { "Model must have .pt extension" }
+ val name = "${path.fileName}".dropLastWhile { it != '.' }.dropLast(1)
+ return engine.newModel(name, Device.cpu()).apply {
+ load(path.parent, name)
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/main/kotlin/com/waicool20/wai2k/util/ai/NDArrayExtensions.kt b/src/main/kotlin/com/waicool20/wai2k/util/ai/NDArrayExtensions.kt
new file mode 100644
index 0000000000..9155cd9c94
--- /dev/null
+++ b/src/main/kotlin/com/waicool20/wai2k/util/ai/NDArrayExtensions.kt
@@ -0,0 +1,101 @@
+/*
+ * GPLv3 License
+ *
+ * Copyright (c) WAI2K by waicool20
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ */
+
+package com.waicool20.wai2k.util.ai
+
+import ai.djl.modality.cv.output.Point
+import ai.djl.ndarray.NDArray
+import ai.djl.ndarray.NDManager
+import ai.djl.ndarray.types.DataType
+
+operator fun NDArray.plus(array: NDArray): NDArray = add(array)
+operator fun NDArray.plus(number: Number): NDArray = add(number)
+operator fun NDArray.minus(array: NDArray): NDArray = sub(array)
+operator fun NDArray.minus(number: Number): NDArray = sub(number)
+operator fun NDArray.times(array: NDArray): NDArray = mul(array)
+operator fun NDArray.times(number: Number): NDArray = mul(number)
+
+operator fun NDArray.unaryMinus(): NDArray = neg()
+
+operator fun Number.plus(array: NDArray): NDArray = array + this
+operator fun Number.minus(array: NDArray): NDArray = array.neg() + this
+operator fun Number.times(array: NDArray): NDArray = array * this
+operator fun Number.div(array: NDArray): NDArray = array.pow(-1) * this
+
+operator fun NDArray.plusAssign(array: NDArray) {
+ addi(array)
+}
+
+operator fun NDArray.plusAssign(number: Number) {
+ addi(number)
+}
+
+operator fun NDArray.minusAssign(array: NDArray) {
+ subi(array)
+}
+
+operator fun NDArray.minusAssign(number: Number) {
+ subi(number)
+}
+
+operator fun NDArray.timesAssign(array: NDArray) {
+ muli(array)
+}
+
+operator fun NDArray.timesAssign(number: Number) {
+ muli(number)
+}
+
+operator fun NDArray.divAssign(array: NDArray) {
+ divi(array)
+}
+
+operator fun NDArray.divAssign(number: Number) {
+ divi(number)
+}
+
+fun NDArray.sum(vararg axes: Int) = sum(axes)
+
+fun NDArray.diag() = manager.eye(shape[0].toInt()) * tile(shape[0]).reshape(shape[0], shape[0])
+
+fun NDArray._trace() = (manager.eye(shape.dimension()) * this).sum()
+fun NDArray._stack(repeats: Long) = tile(repeats).reshape(repeats, shape[0])
+
+fun NDArray.toPoint2D(): List {
+ return when (dataType) {
+ DataType.FLOAT32 -> {
+ val array = toFloatArray()
+ List((size() / 2).toInt()) { i -> Point(array[i * 2].toDouble(), array[i * 2 + 1].toDouble()) }
+ }
+ DataType.FLOAT64 -> {
+ val array = toDoubleArray()
+ List((size() / 2).toInt()) { i -> Point(array[i * 2], array[i * 2 + 1]) }
+ }
+ else -> error("Unsupported data type: $dataType")
+ }
+}
+
+fun List.toNDArray(): NDArray {
+ val array = DoubleArray(size * 2)
+ for (i in indices) {
+ array[i * 2] = get(i).x
+ array[i * 2 + 1] = get(i).y
+ }
+ return NDManager.newBaseManager().create(array).reshape(size.toLong(), 2)
+}
\ No newline at end of file
diff --git a/src/main/kotlin/com/waicool20/wai2k/util/ai/PredictorExtensions.kt b/src/main/kotlin/com/waicool20/wai2k/util/ai/PredictorExtensions.kt
new file mode 100644
index 0000000000..370c1aa235
--- /dev/null
+++ b/src/main/kotlin/com/waicool20/wai2k/util/ai/PredictorExtensions.kt
@@ -0,0 +1,34 @@
+/*
+ * GPLv3 License
+ *
+ * Copyright (c) WAI2K by waicool20
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ */
+
+package com.waicool20.wai2k.util.ai
+
+import ai.djl.inference.Predictor
+import ai.djl.modality.cv.Image
+import ai.djl.modality.cv.ImageFactory
+import com.waicool20.cvauto.core.Region
+import java.awt.image.BufferedImage
+
+fun Predictor.predict(image: BufferedImage): T {
+ return predict(ImageFactory.getInstance().fromImage(image))
+}
+
+fun Predictor.predict(region: Region<*>): T {
+ return predict(ImageFactory.getInstance().fromImage(region.capture()))
+}
\ No newline at end of file
diff --git a/src/main/kotlin/com/waicool20/wai2k/util/ai/ShapeExtensions.kt b/src/main/kotlin/com/waicool20/wai2k/util/ai/ShapeExtensions.kt
new file mode 100644
index 0000000000..7ee7dd2d66
--- /dev/null
+++ b/src/main/kotlin/com/waicool20/wai2k/util/ai/ShapeExtensions.kt
@@ -0,0 +1,33 @@
+/*
+ * GPLv3 License
+ *
+ * Copyright (c) WAI2K by waicool20
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ */
+
+package com.waicool20.wai2k.util.ai
+
+import ai.djl.ndarray.types.Shape
+
+operator fun Shape.component1() = get(0)
+operator fun Shape.component2() = get(1)
+operator fun Shape.component3() = get(2)
+operator fun Shape.component4() = get(3)
+operator fun Shape.component5() = get(4)
+operator fun Shape.component6() = get(5)
+operator fun Shape.component7() = get(6)
+operator fun Shape.component8() = get(7)
+operator fun Shape.component9() = get(8)
+operator fun Shape.component10() = get(9)
diff --git a/src/main/kotlin/com/waicool20/wai2k/util/ai/TransposeNormalizeTransform.kt b/src/main/kotlin/com/waicool20/wai2k/util/ai/TransposeNormalizeTransform.kt
new file mode 100644
index 0000000000..326b5a609f
--- /dev/null
+++ b/src/main/kotlin/com/waicool20/wai2k/util/ai/TransposeNormalizeTransform.kt
@@ -0,0 +1,29 @@
+/*
+ * GPLv3 License
+ *
+ * Copyright (c) WAI2K by waicool20
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ */
+
+package com.waicool20.wai2k.util.ai
+
+import ai.djl.ndarray.NDArray
+import ai.djl.translate.Transform
+
+class TransposeNormalizeTransform: Transform {
+ override fun transform(array: NDArray): NDArray {
+ return array.transpose(2, 0, 1).div(255f)
+ }
+}
\ No newline at end of file
diff --git a/src/main/kotlin/com/waicool20/wai2k/util/ai/YoloTranslator.kt b/src/main/kotlin/com/waicool20/wai2k/util/ai/YoloTranslator.kt
new file mode 100644
index 0000000000..e4d779840e
--- /dev/null
+++ b/src/main/kotlin/com/waicool20/wai2k/util/ai/YoloTranslator.kt
@@ -0,0 +1,170 @@
+/*
+ * GPLv3 License
+ *
+ * Copyright (c) WAI2K by waicool20
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ */
+
+package com.waicool20.wai2k.util.ai
+
+import ai.djl.Model
+import ai.djl.modality.cv.Image
+import ai.djl.modality.cv.ImageFactory
+import ai.djl.modality.cv.output.Rectangle
+import ai.djl.modality.cv.translator.BaseImageTranslator
+import ai.djl.modality.cv.util.NDImageUtils
+import ai.djl.ndarray.NDArray
+import ai.djl.ndarray.NDList
+import ai.djl.ndarray.index.NDIndex
+import ai.djl.ndarray.types.Shape
+import ai.djl.translate.Pipeline
+import ai.djl.translate.Transform
+import ai.djl.translate.TranslatorContext
+import com.waicool20.waicoolutils.createCompatibleCopy
+import com.waicool20.waicoolutils.pad
+import java.awt.Color
+import java.awt.Graphics2D
+import java.awt.RenderingHints
+import java.awt.image.BufferedImage
+import java.awt.image.ImageObserver
+import java.util.concurrent.CountDownLatch
+import kotlin.math.roundToInt
+import kotlin.reflect.full.isSubclassOf
+import kotlin.reflect.full.isSuperclassOf
+import kotlin.reflect.full.primaryConstructor
+
+class YoloTranslator(
+ model: Model,
+ val threshold: Double,
+ val iouThreshold: Double = 0.4
+) : BaseImageTranslator>(
+ Builder().setPipeline(Pipeline(TransposeNormalizeTransform()))
+) {
+ private val size = model.getProperty("InputSize")?.toInt()
+ ?: error("Model property 'InputSize' must be set")
+
+ private class Builder : BaseImageTranslator.BaseBuilder() {
+ override fun self() = this
+ }
+
+ private var imageWidth = -1.0
+ private var imageHeight = -1.0
+
+ override fun processInput(ctx: TranslatorContext, input: Image): NDList {
+ imageWidth = input.width.toDouble()
+ imageHeight = input.height.toDouble()
+
+ val inputImage= input.wrappedImage as BufferedImage
+ val networkInput = inputImage.createCompatibleCopy(size, size)
+ val g = (networkInput.graphics as Graphics2D).apply {
+ setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR)
+ paint = Color.BLACK
+ fillRect(0, 0, size, size)
+ }
+
+ if (input.width < size && input.height < size) {
+ g.drawImage(inputImage, 0, 0, null)
+ } else {
+ val width = input.width.toDouble()
+ val height = input.height.toDouble()
+ when {
+ width > height -> {
+ val newHeight = (size * (height / width)).roundToInt()
+ g.drawImage(inputImage, 0, 0, size, newHeight, null)
+ }
+ height < width -> {
+ val newWidth = (size * (width / height)).roundToInt()
+ g.drawImage(inputImage, 0, 0, newWidth, size, null)
+ }
+ width == height -> {
+ g.drawImage(inputImage, 0, 0, size, size, null)
+ }
+ }
+ }
+ g.dispose()
+ return super.processInput(ctx, ImageFactory.getInstance().fromImage(networkInput))
+ }
+
+ override fun processOutput(ctx: TranslatorContext, list: NDList): List {
+ var output = list[0]
+ val inputArraySize = list[1].shape[1] * 32
+ val mask = output[NDIndex(":, 4")].gte(threshold).repeat(15).reshape(output.shape)
+ output = output.booleanMask(mask)
+ output = output.reshape(output.shape[0] / 15, 15)
+ val objects = mutableListOf()
+ for (i in 0 until output.shape[0]) {
+ // Array format is x1, y1, x2, y2, conf, cls
+ val detection = output[i].toFloatArray()
+ val (centerX, centerY, w, h, p) = detection
+ var x = (centerX - w / 2).toDouble() / inputArraySize
+ var y = (centerY - h / 2).toDouble() / inputArraySize
+ var width = w.toDouble() / inputArraySize
+ var height = h.toDouble() / inputArraySize
+
+ if (imageWidth < size && imageHeight < size) {
+ x *= size / imageWidth
+ y *= size / imageHeight
+ width *= size / imageWidth
+ height *= size / imageHeight
+ } else {
+ when {
+ imageWidth > imageHeight -> {
+ val scale = imageWidth / imageHeight
+ y *= scale
+ height *= scale
+ }
+ imageWidth < imageHeight -> {
+ val scale = imageHeight / imageWidth
+ x *= scale
+ width *= scale
+ }
+ imageWidth == imageHeight -> Unit // Do Nothing
+ }
+ }
+
+ x = x.coerceIn(0.0, 1.0)
+ y = y.coerceIn(0.0, 1.0)
+ width = width.coerceIn(0.0, 1.0)
+ height = height.coerceIn(0.0, 1.0)
+
+ val c = detection.slice(5..detection.lastIndex)
+ val cMaxIdx = c.indexOf(c.max())
+ val obj = try {
+ GFLObject.values[cMaxIdx].primaryConstructor?.call(p.toDouble(), Rectangle(x, y, width, height))
+ } catch (e: Exception) {
+ null
+ }
+ if (obj != null) objects.add(obj)
+ }
+ return nms(objects)
+ }
+
+ fun nms(boxes: List): List {
+ val input = boxes.toMutableList()
+ val output = mutableListOf()
+ while (input.isNotEmpty()) {
+ val best = input.maxBy { it.probability } ?: continue
+ input.remove(best)
+ input.removeAll {
+ (it::class == best::class ||
+ it::class.isSubclassOf(best::class) ||
+ it::class.isSuperclassOf(best::class)) &&
+ it.bbox.getIoU(best.bbox) >= iouThreshold
+ }
+ output.add(best)
+ }
+ return output
+ }
+}
diff --git a/src/main/kotlin/com/waicool20/wai2k/util/cpd/AffineCPDRegistration.kt b/src/main/kotlin/com/waicool20/wai2k/util/cpd/AffineCPDRegistration.kt
new file mode 100644
index 0000000000..8ae978c44e
--- /dev/null
+++ b/src/main/kotlin/com/waicool20/wai2k/util/cpd/AffineCPDRegistration.kt
@@ -0,0 +1,77 @@
+/*
+ * GPLv3 License
+ *
+ * Copyright (c) WAI2K by waicool20
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ */
+
+package com.waicool20.wai2k.util.cpd
+
+import ai.djl.ndarray.NDArray
+import ai.djl.ndarray.types.Shape
+import com.waicool20.wai2k.util.ai.*
+import kotlin.math.abs
+import kotlin.math.ln
+
+/**
+ * Coherent Point Drift Algorithm, this estimates and registers a source point set to a target
+ * point set. This class provides affine transformation registration.
+ *
+ * Magical code ported from: https://github.com/siavashk/pycpd/blob/master/pycpd/affine_registration.py
+ */
+class AffineCPDRegistration(
+ X: NDArray,
+ Y: NDArray,
+ initB: NDArray? = null,
+ initT: NDArray? = null
+) : CPDRegistration(X, Y) {
+ private var B = initB ?: manager.eye(D.toInt())
+ private var t = initT ?: manager.zeros(Shape(1, D))
+ private lateinit var X_hat: NDArray
+ private lateinit var Y_hat: NDArray
+ private lateinit var A: NDArray
+ private lateinit var YPY: NDArray
+
+ override fun updateTransform() {
+ val muX = P.dot(target).sum(0) / Np
+ val muY = P.transpose().dot(source).sum(0) / Np
+
+ X_hat = target - muX._stack(N)
+ Y_hat = source - muY._stack(M)
+ A = X_hat.transpose().dot(P.transpose()).dot(Y_hat)
+ YPY = Y_hat.transpose().dot(P1.diag()).dot(Y_hat)
+ B = solve(YPY.transpose(), A.transpose())
+ t = muX.transpose() - (B.transpose().matMul(muY.transpose()))
+ }
+
+ override fun transformPointCloud() {
+ TY = source.dot(B) + t.tile(M).reshape(M, D)
+ }
+
+ override fun updateVariance() {
+ val qprev = error
+ // Cant use .trace(), not implemented in PtNDArray
+ val trAB = A.dot(B)._trace().getDouble()
+ val xPx = Pt1.transpose().dot((X_hat * X_hat).sum(1)).getDouble()
+ val trBYPYP = B.dot(YPY).dot(B)._trace().getDouble()
+ error = (xPx - 2 * trAB + trBYPYP) / (2 * sigma2) + D * Np / 2 * ln(sigma2)
+ diff = abs(error - qprev)
+ sigma2 = (xPx - trAB) / (Np * D)
+
+ if (sigma2 <= 0) {
+ sigma2 = tolerance / 10
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/main/kotlin/com/waicool20/wai2k/util/cpd/CPDRegistration.kt b/src/main/kotlin/com/waicool20/wai2k/util/cpd/CPDRegistration.kt
new file mode 100644
index 0000000000..ab77e4ddd4
--- /dev/null
+++ b/src/main/kotlin/com/waicool20/wai2k/util/cpd/CPDRegistration.kt
@@ -0,0 +1,131 @@
+/*
+ * GPLv3 License
+ *
+ * Copyright (c) WAI2K by waicool20
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ */
+
+package com.waicool20.wai2k.util.cpd
+
+import ai.djl.ndarray.NDArray
+import ai.djl.ndarray.NDManager
+import ai.djl.ndarray.types.Shape
+import com.waicool20.wai2k.util.ai.*
+import kotlin.math.PI
+import kotlin.math.pow
+
+/**
+ * Coherent Point Drift Algorithm, this estimates and registers a source point set to a target
+ * point set. This is the base class where all transformation types inherit from.
+ *
+ * Magical code ported from: https://github.com/siavashk/pycpd/blob/master/pycpd/emregistration.py
+ */
+abstract class CPDRegistration(
+ val target: NDArray,
+ val source: NDArray,
+ initSigma2: Double? = null,
+ val tolerance: Double = 0.001,
+ private val w: Double = 0.0
+) : Iterator, Sequence {
+ protected val manager = NDManager.newBaseManager()
+
+ init {
+ require(target.shape.dimension() == 2) { "Target point cloud must be 2D NDArray" }
+ require(source.shape.dimension() == 2) { "Source point cloud must be 2D NDArray" }
+ if (initSigma2 != null) {
+ require(initSigma2 > 0) { "Expected positive sigma2 value: ${this.sigma2}" }
+ }
+ require(tolerance > 0) { "Tolerance must be larger than 0: $tolerance" }
+ require(w in 0.0..1.0) { "w must be within 0 until 1: $w" }
+ }
+
+ protected var sigma2 = initSigma2 ?: run {
+ val (N, D) = target.shape
+ val (M, _) = source.shape
+ val diff = target.expandDims(0) - source.expandDims(1)
+ val err = diff.square()
+ (err.sum() / (D * M * N)).getDouble()
+ }
+
+ var TY = source
+ protected set
+ var error = Double.POSITIVE_INFINITY
+ protected set
+
+ protected val N = target.shape[0]
+ protected val D = target.shape[1]
+ protected val M = source.shape[0]
+
+ protected var diff = Double.POSITIVE_INFINITY
+ protected var P = manager.zeros(Shape(M, N))
+ protected var Pt1 = manager.zeros(Shape(N))
+ protected var P1 = manager.zeros(Shape(M))
+ protected var Np = 0.0
+
+ protected abstract fun updateTransform()
+ protected abstract fun transformPointCloud()
+ protected abstract fun updateVariance()
+
+ override fun iterator(): Iterator {
+ return this
+ }
+
+ override fun hasNext() = true
+
+ override fun next(): CPDRegistration {
+ expectation()
+ maximization()
+ return this
+ }
+
+ fun getTranslatedPoints() = TY.toPoint2D()
+
+ private fun expectation() {
+ P = (target.expandDims(0) - TY.expandDims(1)).square().sum(2)
+
+ val c = run {
+ var c = (2 * PI * sigma2).pow(D / 2.0)
+ c = c * w / (1 - w)
+ c * M / N
+ }
+
+ P = (-P / (2 * sigma2)).exp()
+ val den = P.sum(0)._stack(M) + Double.MIN_VALUE + c
+
+ P = P / den
+ Pt1 = P.sum(0)
+ P1 = P.sum(1)
+ Np = P1.sum().getDouble()
+ }
+
+ private fun maximization() {
+ updateTransform()
+ transformPointCloud()
+ updateVariance()
+ }
+
+ protected fun solve(a: NDArray, b: NDArray): NDArray {
+ val A = a.toDoubleArray()
+ val detA = A[0] * A[3] - A[1] * A[2]
+ val temp = A[0]
+ A[0] = A[3]
+ A[1] = -A[1]
+ A[2] = -A[2]
+ A[3] = temp
+ a.set(A)
+ val invA = manager.create(A, a.shape) / detA
+ return invA.dot(b)
+ }
+}
\ No newline at end of file
diff --git a/src/main/kotlin/com/waicool20/wai2k/views/DebugView.kt b/src/main/kotlin/com/waicool20/wai2k/views/DebugView.kt
index 6409e5cc8f..1313c7fd5c 100644
--- a/src/main/kotlin/com/waicool20/wai2k/views/DebugView.kt
+++ b/src/main/kotlin/com/waicool20/wai2k/views/DebugView.kt
@@ -19,16 +19,18 @@
package com.waicool20.wai2k.views
+import ai.djl.modality.cv.ImageFactory
import com.waicool20.cvauto.android.ADB
import com.waicool20.cvauto.android.AndroidDevice
import com.waicool20.cvauto.core.Region
import com.waicool20.cvauto.core.template.FileTemplate
-import com.waicool20.cvauto.util.asBufferedImage
import com.waicool20.cvauto.util.asGrayF32
import com.waicool20.wai2k.config.Wai2KContext
import com.waicool20.wai2k.script.ScriptRunner
import com.waicool20.wai2k.util.Ocr
-import com.waicool20.wai2k.util.extractNodes
+import com.waicool20.wai2k.util.ai.ModelLoader
+import com.waicool20.wai2k.util.ai.YoloTranslator
+import com.waicool20.wai2k.util.ai.toDetectedObjects
import com.waicool20.wai2k.util.useCharFilter
import com.waicool20.waicoolutils.javafx.CoroutineScopeView
import com.waicool20.waicoolutils.javafx.addListener
@@ -38,16 +40,25 @@ import javafx.scene.control.*
import javafx.scene.control.SpinnerValueFactory.IntegerSpinnerValueFactory
import javafx.scene.image.ImageView
import javafx.scene.layout.VBox
+import javafx.stage.DirectoryChooser
import javafx.stage.FileChooser
import kotlinx.coroutines.*
import kotlinx.coroutines.javafx.JavaFx
import net.sourceforge.tess4j.ITesseract
import tornadofx.*
+import java.awt.image.BufferedImage
import java.nio.file.Files
import java.nio.file.Paths
+import javax.xml.parsers.DocumentBuilderFactory
+import javax.xml.transform.OutputKeys
+import javax.xml.transform.TransformerFactory
+import javax.xml.transform.dom.DOMSource
+import javax.xml.transform.stream.StreamResult
+import kotlin.streams.asSequence
import kotlin.time.ExperimentalTime
import kotlin.time.measureTimedValue
+
class DebugView : CoroutineScopeView() {
override val root: VBox by fxml("/views/debug.fxml")
private val openButton: Button by fxid()
@@ -62,9 +73,8 @@ class DebugView : CoroutineScopeView() {
private val ocrImageView: ImageView by fxid()
private val OCRButton: Button by fxid()
private val resetOCRButton: Button by fxid()
- private val filterBlueCheckBox: CheckBox by fxid()
- private val filterWhiteCheckBox: CheckBox by fxid()
- private val filterYellowCheckBox: CheckBox by fxid()
+ private val annotateSetButton: Button by fxid()
+ private val saveAnnotationsCheckBox: CheckBox by fxid()
private val useLSTMCheckBox: CheckBox by fxid()
private val filterCheckBox: CheckBox by fxid()
@@ -81,6 +91,16 @@ class DebugView : CoroutineScopeView() {
private val logger = loggerFor()
+ private val predictor by lazy {
+ try {
+ val model = ModelLoader.loadModel(wai2KContext.wai2KConfig.assetsDirectory.resolve("models/gfl.pt"))
+ model.setProperty("InputSize", "640")
+ model.newPredictor(YoloTranslator(model, 0.6))
+ } catch (e: Exception) {
+ null
+ }
+ }
+
init {
title = "WAI2K - Debugging tools"
}
@@ -93,6 +113,7 @@ class DebugView : CoroutineScopeView() {
assetOCRButton.setOnAction { doAssetOCR() }
OCRButton.setOnAction { doOCR() }
resetOCRButton.setOnAction { createNewRenderJob() }
+ annotateSetButton.setOnAction { annotateSet() }
}
private fun uiSetup() {
@@ -137,20 +158,18 @@ class DebugView : CoroutineScopeView() {
}
}
while (isActive) {
- var image = device.screens[0].capture().let {
- if (wSpinner.value > 0 && hSpinner.value > 0) {
- it.getSubimage(xSpinner.value, ySpinner.value, wSpinner.value, hSpinner.value)
- } else it
- }
- if (filterBlueCheckBox.isSelected || filterWhiteCheckBox.isSelected || filterYellowCheckBox.isSelected) {
- image = image.extractNodes(
- includeBlue = filterBlueCheckBox.isSelected,
- includeWhite = filterWhiteCheckBox.isSelected,
- includeYellow = filterYellowCheckBox.isSelected
- ).asBufferedImage()
- }
- withContext(Dispatchers.JavaFx) {
- ocrImageView.image = SwingFXUtils.toFXImage(image, null)
+ val predictor = this@DebugView.predictor
+ if (predictor == null) {
+ withContext(Dispatchers.JavaFx) {
+ ocrImageView.image = SwingFXUtils.toFXImage(device.screens[0].capture(), null)
+ }
+ } else {
+ val image = ImageFactory.getInstance().fromImage(device.screens[0].capture())
+ val objects = predictor.predict(image)
+ image.drawBoundingBoxes(objects.toDetectedObjects())
+ withContext(Dispatchers.JavaFx) {
+ ocrImageView.image = SwingFXUtils.toFXImage(image.wrappedImage as BufferedImage, null)
+ }
}
}
}
@@ -231,7 +250,6 @@ class DebugView : CoroutineScopeView() {
}
logger.info("Result: \n${getOCR().doOCR(image)}\n----------")
}
-
}
}
@@ -246,4 +264,66 @@ class DebugView : CoroutineScopeView() {
}
return ocr
}
+
+ private fun annotateSet() {
+ launch(Dispatchers.IO) {
+ val predictor = predictor ?: return@launch
+
+ val dir = withContext(Dispatchers.JavaFx) {
+ DirectoryChooser().apply {
+ title = "Annotate which directory?"
+ }.showDialog(null)?.toPath()
+ } ?: return@launch
+
+ val output = dir.resolve("out")
+ logger.info("Annotating images in $dir")
+ Files.createDirectories(output)
+
+ val doc = DocumentBuilderFactory.newInstance().newDocumentBuilder().newDocument()
+ val root = doc.createElement("annotations").also { doc.appendChild(it) }
+ val version = doc.createElement("version").apply {
+ appendChild(doc.createTextNode("1.1"))
+ }
+ root.appendChild(version)
+
+ Files.walk(dir).asSequence()
+ .filterNot { it.parent.endsWith("out") }
+ .filter { "$it".endsWith(".png", true) || "$it".endsWith(".jpg", true) }
+ .sorted()
+ .forEachIndexed { i, path ->
+ val image = ImageFactory.getInstance().fromFile(path)
+ val objects = predictor.predict(image)
+ val imageNode = doc.createElement("image").apply {
+ setAttribute("id", "$i")
+ setAttribute("name", "${dir.parent.relativize(path)}")
+ setAttribute("width", "${image.width}")
+ setAttribute("height", "${image.height}")
+ }
+ objects.forEach { obj ->
+ val bbox = obj.bbox
+ doc.createElement("box").apply {
+ setAttribute("label", "$obj")
+ setAttribute("occluded", "0")
+ setAttribute("xtl", "${bbox.x * image.width}")
+ setAttribute("ytl", "${bbox.y * image.height}")
+ setAttribute("xbr", "${(bbox.x + bbox.width) * image.width}")
+ setAttribute("ybr", "${(bbox.y + bbox.height) * image.height}")
+ }.also { imageNode.appendChild(it) }
+ }
+ root.appendChild(imageNode)
+ if (saveAnnotationsCheckBox.isSelected) {
+ image.drawBoundingBoxes(objects.toDetectedObjects())
+ image.save(Files.newOutputStream(output.resolve(path.fileName)), "png")
+ }
+ logger.info("Image: $path\n$objects")
+ }
+
+ TransformerFactory.newInstance().newTransformer().apply {
+ setOutputProperty(OutputKeys.INDENT, "yes")
+ setOutputProperty("{http://xml.apache.org/xslt}indent-amount", "2");
+ }.transform(DOMSource(doc), StreamResult(Files.newOutputStream(output.resolve("annotations.xml"))))
+
+ logger.info("All annotations done")
+ }
+ }
}
\ No newline at end of file
diff --git a/src/main/kotlin/com/waicool20/wai2k/views/LoaderView.kt b/src/main/kotlin/com/waicool20/wai2k/views/LoaderView.kt
index cba1145564..513533a39c 100644
--- a/src/main/kotlin/com/waicool20/wai2k/views/LoaderView.kt
+++ b/src/main/kotlin/com/waicool20/wai2k/views/LoaderView.kt
@@ -19,6 +19,7 @@
package com.waicool20.wai2k.views
+import ai.djl.Device
import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper
import com.fasterxml.jackson.module.kotlin.readValue
import com.waicool20.cvauto.android.ADB
@@ -29,6 +30,7 @@ import com.waicool20.wai2k.config.Wai2KContext
import com.waicool20.wai2k.config.Wai2KProfile
import com.waicool20.wai2k.script.ScriptContext
import com.waicool20.wai2k.script.ScriptRunner
+import com.waicool20.wai2k.util.ai.ModelLoader
import com.waicool20.waicoolutils.javafx.CoroutineScopeView
import com.waicool20.waicoolutils.logging.LoggingEventBus
import com.waicool20.waicoolutils.logging.loggerFor
@@ -86,6 +88,8 @@ class LoaderView : CoroutineScopeView() {
loadWai2KProfile()
loadScriptRunner()
FileTemplate.checkPaths.add(wai2KConfig.assetsDirectory)
+ logger.info("Loading detection model...")
+ ModelLoader.engine.newModel("Loading", Device.defaultDevice()).close()
closeAndShowMainApp()
}
diff --git a/src/main/resources/views/debug.fxml b/src/main/resources/views/debug.fxml
index d638a54752..89b60dc174 100644
--- a/src/main/resources/views/debug.fxml
+++ b/src/main/resources/views/debug.fxml
@@ -34,80 +34,78 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
-
-
+
+
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+