Fix various bugs

This commit is contained in:
Gabriel Tofvesson 2019-06-04 00:04:27 +02:00
parent ba427e0a6f
commit b635b80187
4 changed files with 158 additions and 67 deletions

11
.idea/artifacts/Bungee_jar.xml generated Normal file
View File

@ -0,0 +1,11 @@
<component name="ArtifactManager">
<artifact type="jar" name="Bungee:jar">
<output-path>$PROJECT_DIR$/out/artifacts/Bungee_jar</output-path>
<root id="archive" name="Bungee.jar">
<element id="module-output" name="Bungee" />
<element id="extracted-dir" path="$KOTLIN_BUNDLED$/lib/kotlin-stdlib.jar" path-in-jar="/" />
<element id="extracted-dir" path="$KOTLIN_BUNDLED$/lib/kotlin-reflect.jar" path-in-jar="/" />
<element id="extracted-dir" path="$KOTLIN_BUNDLED$/lib/kotlin-test.jar" path-in-jar="/" />
</root>
</artifact>
</component>

View File

@ -1,31 +0,0 @@
import dev.w1zzrd.bungee.BungeeRTCPRouter
import dev.w1zzrd.bungee.BungeeRTCPServer
import java.io.File
import java.net.InetAddress
import java.nio.file.Files
import java.nio.file.Path
import java.security.KeyFactory
import java.security.spec.PKCS8EncodedKeySpec
import java.security.spec.X509EncodedKeySpec
fun main(args: Array<String>){
val privKey = KeyFactory.getInstance("RSA")
.generatePrivate(PKCS8EncodedKeySpec(Files.readAllBytes(Path.of("./private_key.der"))))
val pubKey = KeyFactory.getInstance("RSA")
.generatePublic(X509EncodedKeySpec(Files.readAllBytes(Path.of("./public_key.der"))))
Thread(Runnable {
BungeeRTCPRouter(InetAddress.getByName("0.0.0.0"), 6969, pubKey).listen()
}).start()
Thread.sleep(20)
BungeeRTCPServer(
InetAddress.getByName("0.0.0.0"),
6969,
InetAddress.getByName("192.168.1.145"),
25565,
privKey
).start()
}

View File

@ -1,9 +1,13 @@
package dev.w1zzrd.bungee package dev.w1zzrd.bungee
import java.io.File
import java.net.* import java.net.*
import java.nio.ByteBuffer import java.nio.ByteBuffer
import java.security.KeyFactory
import java.security.PublicKey import java.security.PublicKey
import java.security.Signature import java.security.Signature
import java.security.spec.X509EncodedKeySpec
import java.util.*
import java.util.concurrent.ThreadLocalRandom import java.util.concurrent.ThreadLocalRandom
// TODO: Inherit BungeeServer // TODO: Inherit BungeeServer
@ -12,17 +16,29 @@ import java.util.concurrent.ThreadLocalRandom
class BungeeRTCPRouter( class BungeeRTCPRouter(
private val listenAddr: InetAddress, private val listenAddr: InetAddress,
private val port: Int, private val port: Int,
private val routePK: PublicKey private val routePK: PublicKey,
var verbose: Boolean = true
){ ){
constructor(listenAddr: InetAddress, port: Int, keyName: String, verbose: Boolean = true):
this(
listenAddr,
port,
KeyFactory.getInstance("RSA")
.generatePublic(X509EncodedKeySpec(File(keyName).readBytes())),
verbose
)
// Map a client to a unique (long) id // Map a client to a unique (long) id
private val clients = HashMap<Socket, Long>() private val clients = HashMap<Socket, Long>()
private val router = ServerSocket() private var router = ServerSocket()
private lateinit var routeSocket: Socket private lateinit var routeSocket: Socket
private var alive = false private var alive = false
private var canStart = true private var canStart = true
private val headerBuffer = ByteBuffer.wrap(ByteArray(13)) // [ID (byte)][(CUID) long][DLEN (int)] private val headerBuffer = ByteBuffer.wrap(ByteArray(13)) // [ID (byte)][(CUID) long][DLEN (int)]
fun listen(){
fun listen() = try{
if(!canStart) throw IllegalStateException("Already started/stopped") if(!canStart) throw IllegalStateException("Already started/stopped")
canStart = false canStart = false
alive = true alive = true
@ -47,12 +63,23 @@ class BungeeRTCPRouter(
readBytes = 0 readBytes = 0
} }
var timeout = -1L
fun status(pref: String, msg: String) = if(verbose) println("$pref: $msg") else Unit
fun info(msg: String) = status("INFO", msg)
fun fail(msg: String) = status("FAIL", msg)
fun success(msg: String) = status("SUCCESS", msg)
while(true){ while(true){
if(tryRoute == null){ if(tryRoute == null){
tryRoute = router.tryAccept() tryRoute = router.tryAccept()
if(tryRoute != null){ if(tryRoute != null){
timeout = System.currentTimeMillis() + 2000L
info("Got RTCP candidate: "+(tryRoute!!.remoteSocketAddress))
rand.nextBytes(checkBytes) rand.nextBytes(checkBytes)
try{ try{
info("Sending stage 1: ${Arrays.toString(checkBytes)}")
// Send the bytes to be signed to remove host // Send the bytes to be signed to remove host
tryRoute!!.getOutputStream().write(checkBytes) tryRoute!!.getOutputStream().write(checkBytes)
}catch (e: Throwable){ }catch (e: Throwable){
@ -62,8 +89,12 @@ class BungeeRTCPRouter(
}else continue }else continue
} }
if(tryRoute!!.isClosed || !tryRoute!!.isConnected){ // Auth timeout
val timedOut = (timeout > 0 && timeout < System.currentTimeMillis())
if(tryRoute!!.isClosed || !tryRoute!!.isConnected || timedOut){
disconnectRouteServer() disconnectRouteServer()
fail(if(timedOut) "Candidate timed out!" else "Candidate disconnected!")
timeout = -1L
continue continue
} }
try { try {
@ -71,6 +102,7 @@ class BungeeRTCPRouter(
if (read.available() > 0) { if (read.available() > 0) {
if(read.available() + readBytes > fromClients.size){ if(read.available() + readBytes > fromClients.size){
disconnectRouteServer() disconnectRouteServer()
fail("Candidate sent too much data!")
continue continue
} }
readBytes += read.read(fromClients, readBytes, fromClients.size - readBytes) readBytes += read.read(fromClients, readBytes, fromClients.size - readBytes)
@ -82,11 +114,21 @@ class BungeeRTCPRouter(
} }
// We have a client. Let's check if they can authenticate // We have a client. Let's check if they can authenticate
if(readBytes >= 4 && wrappedClientBuffer.getInt(0) == 0x13376969){ // Tell router that you would like to authenticate if(readBytes >= 4){ // Tell router that you would like to authenticate
if(wrappedClientBuffer.getInt(0) != 0x13376969){
disconnectRouteServer()
fail("Candidate sent improper header")
continue
}
info("Got valid header")
if(readBytes >= (4 + 4)){ if(readBytes >= (4 + 4)){
val signedDataLength = wrappedClientBuffer.getInt(4) val signedDataLength = wrappedClientBuffer.getInt(4)
if(readBytes >= (4 + 4 + signedDataLength)){ if(readBytes >= (4 + 4 + signedDataLength)){
info("Checking signature...")
// We have the signed data; let's verify its integrity // We have the signed data; let's verify its integrity
val sig = Signature.getInstance("NONEwithRSA") // Raw bytes signed with RSA ;) val sig = Signature.getInstance("NONEwithRSA") // Raw bytes signed with RSA ;)
sig.initVerify(routePK) sig.initVerify(routePK)
@ -94,11 +136,12 @@ class BungeeRTCPRouter(
if(sig.verify(fromClients, 4 + 4, signedDataLength)){ if(sig.verify(fromClients, 4 + 4, signedDataLength)){
// We have a verified remote route! :D // We have a verified remote route! :D
routeSocket = tryRoute!! routeSocket = tryRoute!!
println("RTCP server verified!") success("Candidate RTCP server verified!")
break break
}else{ }else{
// Verification failed :( // Verification failed :(
disconnectRouteServer() disconnectRouteServer()
fail("Candidate RTCP server failed verification step!")
continue continue
} }
} }
@ -160,9 +203,11 @@ class BungeeRTCPRouter(
if(routeStream.available() > 0){ if(routeStream.available() > 0){
val read = routeStream.read(fromRoute, routeBytes, fromRoute.size - routeBytes) val read = routeStream.read(fromRoute, routeBytes, fromRoute.size - routeBytes)
var parsed = 0 var parsed = 0
parseLoop@while((routeBytes + read) - parsed > 9){ parseLoop@while((routeBytes + read) - parsed > 0){
when(fromRoute[parsed]){ when(fromRoute[parsed]){
0.toByte() -> { 0.toByte() -> {
if((routeBytes + read) - parsed < 10) break@parseLoop
// Parse data packet // Parse data packet
if((routeBytes + read) - parsed < 13) break@parseLoop // Not enough data if((routeBytes + read) - parsed < 13) break@parseLoop // Not enough data
@ -182,6 +227,8 @@ class BungeeRTCPRouter(
} }
1.toByte() -> { 1.toByte() -> {
if((routeBytes + read) - parsed < 10) break@parseLoop
// Handle disconnection // Handle disconnection
val uid = wrappedRouteBuffer.getLong(parsed + 1) val uid = wrappedRouteBuffer.getLong(parsed + 1)
if(clients.values.contains(uid)){ if(clients.values.contains(uid)){
@ -193,6 +240,12 @@ class BungeeRTCPRouter(
} }
parsed += 9 parsed += 9
} }
2.toByte() -> {
for(client in clients) client.key.forceClose()
clients.clear()
break@acceptLoop
}
} }
} }
@ -200,6 +253,11 @@ class BungeeRTCPRouter(
routeBytes = (routeBytes + read) - parsed // Amount of unread bytes after parsing routeBytes = (routeBytes + read) - parsed // Amount of unread bytes after parsing
} }
} }
}catch(e: Exception){
e.printStackTrace()
}finally{
try{ router.close() }catch(e: Exception){}
router = ServerSocket()
} }
private fun ServerSocket.tryAccept() = try{ accept() }catch(e: SocketTimeoutException){ null } private fun ServerSocket.tryAccept() = try{ accept() }catch(e: SocketTimeoutException){ null }

View File

@ -1,12 +1,16 @@
package dev.w1zzrd.bungee package dev.w1zzrd.bungee
import java.net.InetAddress import java.io.File
import java.net.InetSocketAddress import java.net.*
import java.net.ServerSocket
import java.net.Socket
import java.nio.ByteBuffer import java.nio.ByteBuffer
import java.nio.file.Files
import java.nio.file.Path
import java.security.KeyFactory
import java.security.PrivateKey import java.security.PrivateKey
import java.security.Signature import java.security.Signature
import java.security.spec.PKCS8EncodedKeySpec
import java.security.spec.X509EncodedKeySpec
import java.util.concurrent.atomic.AtomicBoolean
// TODO: Inherit BungeeServer // TODO: Inherit BungeeServer
// Private key used to authenticate against router // Private key used to authenticate against router
@ -14,22 +18,39 @@ class BungeeRTCPServer(
private val routerAddr: InetAddress, private val routerAddr: InetAddress,
private val routerPort: Int, private val routerPort: Int,
private val routeTo: InetAddress, private val routeTo: InetAddress,
private val routePort: Int, private var routePort: Int,
private val privateKey: PrivateKey private val privateKey: PrivateKey
){ ){
constructor(routerAddr: InetAddress, routerPort: Int, routeTo: InetAddress, routePort: Int, keyName: String):
this(
routerAddr,
routerPort,
routeTo,
routePort,
KeyFactory.getInstance("RSA")
.generatePrivate(PKCS8EncodedKeySpec(File(keyName).readBytes()))
)
// A map of a client UID to the "virtual client" (a socket from this server to the provided route endpoint) // A map of a client UID to the "virtual client" (a socket from this server to the provided route endpoint)
private val vClients = HashMap<Long, Socket>() private val vClients = HashMap<Long, Socket>()
private var canStart = true private val canStart = AtomicBoolean(true)
private val serverSocket = Socket() private var serverSocket = Socket()
private val buffer = ByteArray(BUFFER_SIZE) private val buffer = ByteArray(BUFFER_SIZE)
private val wrappedServerBuffer = ByteBuffer.wrap(buffer) private val wrappedServerBuffer = ByteBuffer.wrap(buffer)
private val clientBuffer = ByteArray(BUFFER_SIZE) private val clientBuffer = ByteArray(BUFFER_SIZE)
private var alive = false private val alive = AtomicBoolean(false)
private val headerBuffer = ByteBuffer.wrap(ByteArray(13)) private val headerBuffer = ByteBuffer.wrap(ByteArray(13))
fun start(){ fun start(){
if(!canStart) throw IllegalStateException("Already started/stopped") synchronized(canStart) {
canStart = false if (!canStart.get()) return@start
canStart.set(false)
}
println("Starting RTCP server")
serverSocket.connect(InetSocketAddress(routerAddr, routerPort)) serverSocket.connect(InetSocketAddress(routerAddr, routerPort))
@ -37,8 +58,13 @@ class BungeeRTCPServer(
val read = serverSocket.getInputStream() val read = serverSocket.getInputStream()
val write = serverSocket.getOutputStream() val write = serverSocket.getOutputStream()
var readCount = 0 var readCount = 0
while(readCount < 256) readCount += read.read(buffer, readCount, buffer.size - readCount) try {
while (readCount < 256) readCount += read.read(buffer, readCount, buffer.size - readCount)
}catch(e: Exception){
println("Encountered an error when authenticating")
stop()
return
}
val sig = Signature.getInstance("NONEwithRSA") val sig = Signature.getInstance("NONEwithRSA")
sig.initSign(privateKey) sig.initSign(privateKey)
sig.update(buffer, 0, 256) sig.update(buffer, 0, 256)
@ -50,26 +76,29 @@ class BungeeRTCPServer(
write.write(buffer, 0, 8 + signLen) write.write(buffer, 0, 8 + signLen)
var bufferBytes = 0 var bufferBytes = 0
alive = true synchronized(alive){alive.set(true)}
while(alive){ while(synchronized(alive){alive.get()}){
if(read.available() > 0) if(read.available() > 0)
bufferBytes += read.read(buffer, bufferBytes, buffer.size - bufferBytes) bufferBytes += read.read(buffer, bufferBytes, buffer.size - bufferBytes)
var parsed = 0 var parsed = 0
parseLoop@while(bufferBytes - parsed > 9){ parseLoop@while((bufferBytes - parsed) > 9){
val action = wrappedServerBuffer.get(parsed)
val uid = wrappedServerBuffer.getLong(parsed + 1) val uid = wrappedServerBuffer.getLong(parsed + 1)
when(buffer[parsed]){ when(action){
0.toByte() -> { // New client
// New client 0.toByte() -> vClients[uid] = Socket(routeTo, routePort, InetAddress.getByName("localhost"), 0)
vClients[uid] = Socket(routeTo, routePort)
}
1.toByte() -> { 1.toByte() -> {
// Data from client // Data from client
if(bufferBytes - parsed > 13){ if((bufferBytes - parsed) > 13){
// Get packet size
val dLen = wrappedServerBuffer.getInt(parsed + 9) val dLen = wrappedServerBuffer.getInt(parsed + 9)
if(bufferBytes < parsed + dLen) break@parseLoop // Not enough data
// Check if entire packet has been received yet
if((bufferBytes - parsed - 13) < dLen) break@parseLoop // Not enough data
try { try {
// Send data to server // Send data to server
vClients[uid]?.getOutputStream()?.write(buffer, parsed + 13, dLen) vClients[uid]?.getOutputStream()?.write(buffer, parsed + 13, dLen)
@ -81,18 +110,21 @@ class BungeeRTCPServer(
}else break@parseLoop // Not enough data }else break@parseLoop // Not enough data
} }
2.toByte() -> { // Remote disconnection
// Remote disconnection 2.toByte() -> vClients.remove(uid)?.forceClose()
vClients[uid]?.forceClose()
vClients.remove(uid)
}
} }
parsed += 9 parsed += 9
} }
System.arraycopy(buffer, parsed, buffer, 0, bufferBytes - parsed) try{
bufferBytes -= parsed if(parsed > bufferBytes) println("Packet read overflow (by ${parsed - bufferBytes} bytes) detected!")
System.arraycopy(buffer, Math.min(bufferBytes, parsed), buffer, 0, Math.max(0, bufferBytes - parsed))
bufferBytes = Math.max(0, bufferBytes - parsed)
}catch(e: Exception){
println("bufferBytes: $bufferBytes\nparsed: $parsed\nlength: ${buffer.size}\n")
throw e
}
// Accept data from route endpoint // Accept data from route endpoint
@ -124,4 +156,25 @@ class BungeeRTCPServer(
fun sendMessageToRouter(data: ByteArray, off: Int, len: Int){ fun sendMessageToRouter(data: ByteArray, off: Int, len: Int){
serverSocket.getOutputStream().write(data, off, len) serverSocket.getOutputStream().write(data, off, len)
} }
fun stop(newPort: Int = routePort) = synchronized(canStart){
synchronized(alive) {
if (alive.get()) {
try {
sendMessageToRouter(byteArrayOf(2), 0, 1)
} catch (e: Exception) {
} finally {
try {
serverSocket.forceClose()
} catch (e: Exception) {
}
}
}
alive.set(false)
canStart.set(true)
routePort = newPort
serverSocket = Socket()
println("RTCP server Stopped")
}
}
} }