From 9de99a52460209387bc5c4f283d8b1ae887e8ec5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Albin=20Cor=C3=A9n?= <2108U9@gmail.com> Date: Thu, 11 Jan 2018 12:04:21 +0100 Subject: [PATCH] Added Eliptic Curve DiffieHellman with RSA verification --- MLAPI/Data/NetworkingConfiguration.cs | 11 ++- .../MonoBehaviours/Core/NetworkingManager.cs | 96 ++++++++++++++++--- 2 files changed, 91 insertions(+), 16 deletions(-) diff --git a/MLAPI/Data/NetworkingConfiguration.cs b/MLAPI/Data/NetworkingConfiguration.cs index fb3b439..c04022f 100644 --- a/MLAPI/Data/NetworkingConfiguration.cs +++ b/MLAPI/Data/NetworkingConfiguration.cs @@ -12,7 +12,9 @@ namespace MLAPI public SortedDictionary Channels = new SortedDictionary(); public List MessageTypes = new List(); public List PassthroughMessageTypes = new List(); + public List EncryptionMessageTypes = new List(); internal HashSet RegisteredPassthroughMessageTypes = new HashSet(); + internal HashSet RegisteredEncryptionMessageTypes = new HashSet(); public int MessageBufferSize = 65535; public int MaxMessagesPerFrame = 150; public int MaxConnections = 100; @@ -25,10 +27,11 @@ namespace MLAPI public bool HandleObjectSpawning = true; //TODO public bool CompressMessages = false; - //Should only be used for dedicated servers and will require the servers RSA keypair being hard coded into clients in order to exchange a AES key - //TODO - public bool EncryptMessages = false; public bool AllowPassthroughMessages = true; + public bool EnableDiffieHellman = true; + public bool DenyUntrustedKeys = false; + public RSAParameters ServerPublicKey; + public RSAParameters ServerPrivateKey; //Cached config hash private byte[] ConfigHash = null; @@ -59,7 +62,7 @@ namespace MLAPI } writer.Write(HandleObjectSpawning); writer.Write(CompressMessages); - writer.Write(EncryptMessages); + writer.Write(EnableDiffieHellman); writer.Write(AllowPassthroughMessages); } using(SHA256Managed sha256 = new SHA256Managed()) diff --git a/MLAPI/MonoBehaviours/Core/NetworkingManager.cs b/MLAPI/MonoBehaviours/Core/NetworkingManager.cs index 0c54a5a..a31e61a 100644 --- a/MLAPI/MonoBehaviours/Core/NetworkingManager.cs +++ b/MLAPI/MonoBehaviours/Core/NetworkingManager.cs @@ -2,7 +2,9 @@ using System; using System.Collections; using System.Collections.Generic; +using System.Linq; using System.IO; +using System.Security.Cryptography; using UnityEngine; using UnityEngine.Networking; @@ -16,6 +18,9 @@ namespace MLAPI //Client only, what my connectionId is on the server public int MyClientId; internal Dictionary connectedClients; + internal Dictionary server_pendingPublicKeys; + internal ECDiffieHellmanCng client_pendingDiffieHellmanExchange; + internal Dictionary encryptionKeys; public Dictionary ConnectedClients { get @@ -71,6 +76,9 @@ namespace MLAPI MessageManager.reverseMessageTypes = new Dictionary(); SpawnManager.spawnedObjects = new Dictionary(); SpawnManager.releasedNetworkObjectIds = new Stack(); + server_pendingPublicKeys = new Dictionary(); + client_pendingDiffieHellmanExchange = null; + encryptionKeys = new Dictionary(); if (NetworkConfig.HandleObjectSpawning) { NetworkedObject[] sceneObjects = FindObjectsOfType(); @@ -286,12 +294,25 @@ namespace MLAPI if (NetworkConfig.ConnectionApproval) sizeOfStream += 2 + NetworkConfig.ConnectionData.Length; + byte[] clientPubKey = new byte[0]; + if(NetworkConfig.EnableDiffieHellman) + { + ECDiffieHellmanCng diffieHellman = new ECDiffieHellmanCng(); + client_pendingDiffieHellmanExchange = diffieHellman; + clientPubKey = diffieHellman.PublicKey.ToByteArray(); + } + using (MemoryStream writeStream = new MemoryStream(sizeOfStream)) { using (BinaryWriter writer = new BinaryWriter(writeStream)) { writer.Write(NetworkConfig.GetConfig()); - if(NetworkConfig.ConnectionApproval) + if(NetworkConfig.EnableDiffieHellman) + { + writer.Write((ushort)clientPubKey.Length); + writer.Write(clientPubKey); + } + if (NetworkConfig.ConnectionApproval) { writer.Write((ushort)NetworkConfig.ConnectionData.Length); writer.Write(NetworkConfig.ConnectionData); @@ -300,6 +321,7 @@ namespace MLAPI Send(clientId, "MLAPI_CONNECTION_REQUEST", "MLAPI_RELIABLE_FRAGMENTED_SEQUENCED", writeStream.GetBuffer()); } } + break; case NetworkEventType.DataEvent: HandleIncomingData(clientId, messageBuffer, channelId); @@ -433,6 +455,18 @@ namespace MLAPI DisconnectClient(clientId); return; } + if(NetworkConfig.EnableDiffieHellman) + { + ushort pubKeyLength = messageReader.ReadUInt16(); + byte[] clientPubKeyBlob = messageReader.ReadBytes(pubKeyLength); + ECDiffieHellmanPublicKey clientPubKey = ECDiffieHellmanCngPublicKey.FromByteArray(clientPubKeyBlob, CngKeyBlobFormat.EccPublicBlob); + byte[] serverPubKeyBlob = new byte[0]; + using (ECDiffieHellmanCng diffieHellman = new ECDiffieHellmanCng()) + { + encryptionKeys.Add(clientId, diffieHellman.DeriveKeyMaterial(clientPubKey)); + server_pendingPublicKeys.Add(clientId, diffieHellman.PublicKey); + } + } if (NetworkConfig.ConnectionApproval) { ushort bufferSize = messageReader.ReadUInt16(); @@ -455,6 +489,31 @@ namespace MLAPI { using (BinaryReader messageReader = new BinaryReader(messageReadStream)) { + if(NetworkConfig.EnableDiffieHellman) + { + ushort serverPubKeyLength = messageReader.ReadUInt16(); + byte[] serverPubKeyBlob = messageReader.ReadBytes(serverPubKeyLength); + if(NetworkConfig.DenyUntrustedKeys) + { + ushort signedPubKeyLength = messageReader.ReadUInt16(); + byte[] signedPubKey = messageReader.ReadBytes(signedPubKeyLength); + //TODO: NEEDS THREADING!!!! + using(RSA rsa = RSA.Create()) + { + rsa.ImportParameters(NetworkConfig.ServerPublicKey); + byte[] decryptedSignedPubKey = rsa.DecryptValue(signedPubKey); + if(!serverPubKeyBlob.SequenceEqual(decryptedSignedPubKey)) + { + //Man in the middle attack detected. Or the RSA keys are wrong ;) + } + } + } + ECDiffieHellmanCng diffieHellman = client_pendingDiffieHellmanExchange; + ECDiffieHellmanPublicKey serverPubKey = ECDiffieHellmanCngPublicKey.FromByteArray(serverPubKeyBlob, CngKeyBlobFormat.EccPublicBlob); + encryptionKeys.Add(serverClientId, diffieHellman.DeriveKeyMaterial(serverPubKey)); + diffieHellman.Clear(); + client_pendingDiffieHellmanExchange = null; + } MyClientId = messageReader.ReadInt32(); connectedClients.Add(MyClientId, new NetworkedClient() { ClientId = MyClientId }); int clientCount = messageReader.ReadInt32(); @@ -825,6 +884,10 @@ namespace MLAPI { if (pendingClients.Contains(clientId)) pendingClients.Remove(clientId); + if (encryptionKeys.ContainsKey(clientId)) + encryptionKeys.Remove(clientId); + if (server_pendingPublicKeys.ContainsKey(clientId)) + server_pendingPublicKeys.Remove(clientId); if (connectedClients.ContainsKey(clientId)) { if(NetworkConfig.HandleObjectSpawning) @@ -873,8 +936,6 @@ namespace MLAPI connectedClients[clientId].PlayerObject = go; } - - int sizeOfStream = 4 + 4 + ((connectedClients.Count - 1) * 4); int amountOfObjectsToSend = 0; foreach (KeyValuePair pair in SpawnManager.spawnedObjects) { @@ -883,16 +944,28 @@ namespace MLAPI else amountOfObjectsToSend++; } - if(NetworkConfig.HandleObjectSpawning) - { - sizeOfStream += 4; - sizeOfStream += 13 * amountOfObjectsToSend; - } - using (MemoryStream writeStream = new MemoryStream(sizeOfStream)) + using (MemoryStream writeStream = new MemoryStream()) { using (BinaryWriter writer = new BinaryWriter(writeStream)) { + if(NetworkConfig.EnableDiffieHellman) + { + byte[] serverPubKey = server_pendingPublicKeys[clientId].ToByteArray(); + writer.Write((ushort)serverPubKey.Length); + writer.Write(serverPubKey); + if(NetworkConfig.DenyUntrustedKeys) + { + //NEEDS TO BE THREADED!!! + using (RSA rsa = RSA.Create()) + { + rsa.ImportParameters(NetworkConfig.ServerPrivateKey); + byte[] signedPubKey = rsa.EncryptValue(serverPubKey); + writer.Write((ushort)signedPubKey.Length); + writer.Write(signedPubKey); + } + } + } writer.Write(clientId); writer.Write(connectedClients.Count - 1); foreach (KeyValuePair item in connectedClients) @@ -917,15 +990,14 @@ namespace MLAPI } } } - Send(clientId, "MLAPI_CONNECTION_APPROVED", "MLAPI_RELIABLE_FRAGMENTED_SEQUENCED", writeStream.GetBuffer()); + Send(clientId, "MLAPI_CONNECTION_APPROVED", "MLAPI_RELIABLE_FRAGMENTED_SEQUENCED", writeStream.ToArray()); } //Inform old clients of the new player + int sizeOfStream = 4; if(NetworkConfig.HandleObjectSpawning) sizeOfStream = 13; - else - sizeOfStream = 4; using (MemoryStream stream = new MemoryStream(sizeOfStream)) {