diff --git a/MLAPI/Data/NetworkedClient.cs b/MLAPI/Data/NetworkedClient.cs index 6d1dbd2..135e778 100644 --- a/MLAPI/Data/NetworkedClient.cs +++ b/MLAPI/Data/NetworkedClient.cs @@ -8,5 +8,6 @@ namespace MLAPI public int ClientId; public GameObject PlayerObject; public List OwnedObjects = new List(); + public byte[] AesKey; } } diff --git a/MLAPI/Data/NetworkingConfiguration.cs b/MLAPI/Data/NetworkingConfiguration.cs index fe969cb..072543d 100644 --- a/MLAPI/Data/NetworkingConfiguration.cs +++ b/MLAPI/Data/NetworkingConfiguration.cs @@ -13,6 +13,7 @@ namespace MLAPI public List MessageTypes = new List(); public List PassthroughMessageTypes = new List(); internal HashSet RegisteredPassthroughMessageTypes = new HashSet(); + public HashSet EncryptedChannels = new HashSet(); public List RegisteredScenes = new List(); public int MessageBufferSize = 65535; public int ReceiveTickrate = 64; @@ -28,11 +29,8 @@ namespace MLAPI public byte[] ConnectionData = new byte[0]; public float SecondsHistory = 5; public bool HandleObjectSpawning = true; - //TODO - public bool CompressMessages = false; - //Should only be used for dedicated servers and will require the servers RSA keypair being hard coded into clients in order to exchange a AES key - //TODO - public bool EncryptMessages = false; + + public bool EnableEncryption = true; public bool AllowPassthroughMessages = true; public bool EnableSceneSwitching = false; @@ -72,8 +70,7 @@ namespace MLAPI } } writer.Write(HandleObjectSpawning); - writer.Write(CompressMessages); - writer.Write(EncryptMessages); + writer.Write(EnableEncryption); writer.Write(AllowPassthroughMessages); writer.Write(EnableSceneSwitching); } diff --git a/MLAPI/MLAPI.csproj b/MLAPI/MLAPI.csproj index 446cc8e..bd86046 100644 --- a/MLAPI/MLAPI.csproj +++ b/MLAPI/MLAPI.csproj @@ -65,6 +65,9 @@ + + + diff --git a/MLAPI/MonoBehaviours/Core/NetworkingManager.cs b/MLAPI/MonoBehaviours/Core/NetworkingManager.cs index 2640bff..6678155 100644 --- a/MLAPI/MonoBehaviours/Core/NetworkingManager.cs +++ b/MLAPI/MonoBehaviours/Core/NetworkingManager.cs @@ -51,6 +51,10 @@ namespace MLAPI public NetworkingConfiguration NetworkConfig; + private EllipticDiffieHellman clientDiffieHellman; + private Dictionary diffieHellmanPublicKeys; + private byte[] clientAesKey; + private void OnValidate() { if (SpawnablePrefabs != null) @@ -88,6 +92,7 @@ namespace MLAPI pendingClients = new HashSet(); connectedClients = new Dictionary(); messageBuffer = new byte[NetworkConfig.MessageBufferSize]; + diffieHellmanPublicKeys = new Dictionary(); MessageManager.channels = new Dictionary(); MessageManager.messageTypes = new Dictionary(); MessageManager.messageCallbacks = new Dictionary>>(); @@ -374,15 +379,29 @@ namespace MLAPI } else { + byte[] diffiePublic = new byte[0]; + if(NetworkConfig.EnableEncryption) + { + clientDiffieHellman = new EllipticDiffieHellman(EllipticDiffieHellman.DEFAULT_CURVE, EllipticDiffieHellman.DEFAULT_GENERATOR, EllipticDiffieHellman.DEFAULT_ORDER); + diffiePublic = clientDiffieHellman.GetPublicKey(); + } + int sizeOfStream = 32; if (NetworkConfig.ConnectionApproval) sizeOfStream += 2 + NetworkConfig.ConnectionData.Length; + if (NetworkConfig.EnableEncryption) + sizeOfStream += 2 + diffiePublic.Length; using (MemoryStream writeStream = new MemoryStream(sizeOfStream)) { using (BinaryWriter writer = new BinaryWriter(writeStream)) { writer.Write(NetworkConfig.GetConfig()); + if (NetworkConfig.EnableEncryption) + { + writer.Write((ushort)diffiePublic.Length); + writer.Write(diffiePublic); + } if (NetworkConfig.ConnectionApproval) { writer.Write((ushort)NetworkConfig.ConnectionData.Length); @@ -471,6 +490,14 @@ namespace MLAPI ushort bytesToRead = reader.ReadUInt16(); byte[] incommingData = reader.ReadBytes(bytesToRead); + if(NetworkConfig.EncryptedChannels.Contains(channelId)) + { + //Encrypted message + if (isServer) + incommingData = CryptographyHelper.Decrypt(incommingData, connectedClients[clientId].AesKey); + else + incommingData = CryptographyHelper.Decrypt(incommingData, clientAesKey); + } if (isServer && isPassthrough && !NetworkConfig.RegisteredPassthroughMessageTypes.Contains(messageType)) { @@ -555,6 +582,18 @@ namespace MLAPI DisconnectClient(clientId); return; } + byte[] aesKey = new byte[0]; + if(NetworkConfig.EnableEncryption) + { + ushort diffiePublicSize = reader.ReadUInt16(); + byte[] diffiePublic = reader.ReadBytes(diffiePublicSize); + diffieHellmanPublicKeys.Add(clientId, diffiePublic); + /* + EllipticDiffieHellman diffieHellman = new EllipticDiffieHellman(EllipticDiffieHellman.DEFAULT_CURVE, EllipticDiffieHellman.DEFAULT_GENERATOR, EllipticDiffieHellman.DEFAULT_ORDER); + aesKey = diffieHellman.GetSharedSecret(diffiePublic); + */ + + } if (NetworkConfig.ConnectionApproval) { ushort bufferSize = messageReader.ReadUInt16(); @@ -583,6 +622,12 @@ namespace MLAPI sceneIndex = messageReader.ReadUInt32(); } + if (NetworkConfig.EnableEncryption) + { + ushort keyLength = reader.ReadUInt16(); + clientAesKey = clientDiffieHellman.GetSharedSecret(reader.ReadBytes(keyLength)); + } + float netTime = messageReader.ReadSingle(); int remoteStamp = messageReader.ReadInt32(); int msDelay = NetworkTransport.GetRemoteDelayTimeMS(hostId, clientId, remoteStamp, out error); @@ -901,8 +946,18 @@ namespace MLAPI writer.Write(orderId.Value); writer.Write(true); writer.Write(sourceId); - writer.Write((ushort)data.Length); - writer.Write(data); + if(NetworkConfig.EncryptedChannels.Contains(channelId)) + { + //Encrypted message + byte[] encrypted = CryptographyHelper.Encrypt(data, connectedClients[targetId].AesKey); + writer.Write((ushort)encrypted.Length); + writer.Write(encrypted); + } + else + { + writer.Write((ushort)data.Length); + writer.Write(data); + } } NetworkTransport.QueueMessageForSending(hostId, targetId, channelId, stream.GetBuffer(), sizeOfStream, out error); } @@ -951,8 +1006,25 @@ namespace MLAPI writer.Write(isPassthrough); if (isPassthrough) writer.Write(clientId); - writer.Write((ushort)data.Length); - writer.Write(data); + + if (NetworkConfig.EncryptedChannels.Contains(MessageManager.channels[channelName])) + { + //This is an encrypted message. + byte[] encrypted; + if (isServer) + encrypted = CryptographyHelper.Encrypt(data, connectedClients[clientId].AesKey); + else + encrypted = CryptographyHelper.Encrypt(data, clientAesKey); + + writer.Write((ushort)encrypted.Length); + writer.Write(encrypted); + } + else + { + //Send in plaintext. + writer.Write((ushort)data.Length); + writer.Write(data); + } } if (isPassthrough) clientId = serverClientId; @@ -965,7 +1037,12 @@ namespace MLAPI internal void Send(int[] clientIds, string messageType, string channelName, byte[] data, uint? networkId = null, ushort? orderId = null) { - int sizeOfStream = 6; + if (NetworkConfig.EncryptedChannels.Contains(MessageManager.channels[channelName])) + { + Debug.LogWarning("MLAPI: Cannot send messages over encrypted channel to multiple clients."); + return; + } + int sizeOfStream = 6; if (networkId != null) sizeOfStream += 4; if (orderId != null) @@ -1007,6 +1084,12 @@ namespace MLAPI internal void Send(List clientIds, string messageType, string channelName, byte[] data, uint? networkId = null, ushort? orderId = null) { + if (NetworkConfig.EncryptedChannels.Contains(MessageManager.channels[channelName])) + { + Debug.LogWarning("MLAPI: Cannot send messages over encrypted channel to multiple clients."); + return; + } + //2 bytes for messageType, 2 bytes for buffer length and one byte for target bool int sizeOfStream = 6; if (networkId != null) @@ -1050,6 +1133,12 @@ namespace MLAPI internal void Send(string messageType, string channelName, byte[] data, uint? networkId = null, ushort? orderId = null) { + if (NetworkConfig.EncryptedChannels.Contains(MessageManager.channels[channelName])) + { + Debug.LogWarning("MLAPI: Cannot send messages over encrypted channel to multiple clients."); + return; + } + //2 bytes for messageType, 2 bytes for buffer length and one byte for target bool int sizeOfStream = 6; if (networkId != null) @@ -1094,6 +1183,12 @@ namespace MLAPI internal void Send(string messageType, string channelName, byte[] data, int clientIdToIgnore, uint? networkId = null, ushort? orderId = null) { + if (NetworkConfig.EncryptedChannels.Contains(MessageManager.channels[channelName])) + { + Debug.LogWarning("MLAPI: Cannot send messages over encrypted channel to multiple clients."); + return; + } + //2 bytes for messageType, 2 bytes for buffer length and one byte for target bool int sizeOfStream = 5; if (networkId != null) @@ -1142,10 +1237,16 @@ namespace MLAPI { if (!isServer) return; + if (pendingClients.Contains(clientId)) pendingClients.Remove(clientId); + if (connectedClients.ContainsKey(clientId)) connectedClients.Remove(clientId); + + if (diffieHellmanPublicKeys.ContainsKey(clientId)) + diffieHellmanPublicKeys.Remove(clientId); + NetworkTransport.Disconnect(hostId, clientId, out error); } @@ -1188,9 +1289,23 @@ namespace MLAPI //Inform new client it got approved if (pendingClients.Contains(clientId)) pendingClients.Remove(clientId); + + byte[] aesKey = new byte[0]; + byte[] publicKey = new byte[0]; + if (NetworkConfig.EnableEncryption) + { + EllipticDiffieHellman diffieHellman = new EllipticDiffieHellman(EllipticDiffieHellman.DEFAULT_CURVE, EllipticDiffieHellman.DEFAULT_GENERATOR, EllipticDiffieHellman.DEFAULT_ORDER); + aesKey = diffieHellman.GetSharedSecret(diffieHellmanPublicKeys[clientId]); + publicKey = diffieHellman.GetPublicKey(); + + if (diffieHellmanPublicKeys.ContainsKey(clientId)) + diffieHellmanPublicKeys.Remove(clientId); + } + NetworkedClient client = new NetworkedClient() { - ClientId = clientId + ClientId = clientId, + AesKey = aesKey }; connectedClients.Add(clientId, client); @@ -1201,7 +1316,6 @@ namespace MLAPI connectedClients[clientId].PlayerObject = go; } - int sizeOfStream = 16 + ((connectedClients.Count - 1) * 4); int amountOfObjectsToSend = SpawnManager.spawnedObjects.Values.Count(x => x.ServerOnly == false); @@ -1211,6 +1325,10 @@ namespace MLAPI sizeOfStream += 4; sizeOfStream += 14 * amountOfObjectsToSend; } + if(NetworkConfig.EnableEncryption) + { + sizeOfStream += 2 + publicKey.Length; + } if(NetworkConfig.EnableSceneSwitching) { sizeOfStream += 4; @@ -1225,8 +1343,16 @@ namespace MLAPI { writer.Write(NetworkSceneManager.CurrentSceneIndex); } + + if(NetworkConfig.EnableEncryption) + { + writer.Write((ushort)publicKey.Length); + writer.Write(publicKey); + } + writer.Write(NetworkTime); writer.Write(NetworkTransport.GetNetworkTimestamp()); + writer.Write(connectedClients.Count - 1); foreach (KeyValuePair item in connectedClients) { @@ -1292,6 +1418,10 @@ namespace MLAPI { if (pendingClients.Contains(clientId)) pendingClients.Remove(clientId); + + if (diffieHellmanPublicKeys.ContainsKey(clientId)) + diffieHellmanPublicKeys.Remove(clientId); + NetworkTransport.Disconnect(hostId, clientId, out error); } } diff --git a/MLAPI/NetworkingManagerComponents/CryptographyHelper.cs b/MLAPI/NetworkingManagerComponents/CryptographyHelper.cs new file mode 100644 index 0000000..02344e8 --- /dev/null +++ b/MLAPI/NetworkingManagerComponents/CryptographyHelper.cs @@ -0,0 +1,50 @@ +using System; +using System.Security.Cryptography; +using System.IO; + +namespace MLAPI.NetworkingManagerComponents +{ + public static class CryptographyHelper + { + public static byte[] Decrypt(byte[] encryptedBuffer, byte[] key) + { + byte[] iv = new byte[16]; + Array.Copy(encryptedBuffer, 0, iv, 0, 16); + + using (MemoryStream stream = new MemoryStream()) + { + using (RijndaelManaged aes = new RijndaelManaged()) + { + aes.IV = iv; + aes.Key = key; + using (CryptoStream cs = new CryptoStream(stream, aes.CreateDecryptor(), CryptoStreamMode.Write)) + { + cs.Write(encryptedBuffer, 16, encryptedBuffer.Length - 16); + } + return stream.ToArray(); + } + } + } + + public static byte[] Encrypt(byte[] clearBuffer, byte[] key) + { + using (MemoryStream stream = new MemoryStream()) + { + using (RijndaelManaged aes = new RijndaelManaged()) + { + aes.Key = key; + aes.GenerateIV(); + using (CryptoStream cs = new CryptoStream(stream, aes.CreateEncryptor(), CryptoStreamMode.Write)) + { + cs.Write(clearBuffer, 0, clearBuffer.Length); + } + byte[] encrypted = stream.ToArray(); + byte[] final = new byte[encrypted.Length + 16]; + Array.Copy(aes.IV, final, 16); + Array.Copy(encrypted, 0, final, 16, encrypted.Length); + return final; + } + } + } + } +} diff --git a/MLAPI/NetworkingManagerComponents/DiffieHellman.cs b/MLAPI/NetworkingManagerComponents/DiffieHellman.cs index 2e675b5..d7885f9 100644 --- a/MLAPI/NetworkingManagerComponents/DiffieHellman.cs +++ b/MLAPI/NetworkingManagerComponents/DiffieHellman.cs @@ -3,7 +3,7 @@ using IntXLib; using System.Text; using System.Security.Cryptography; -namespace ECDH +namespace MLAPI.NetworkingManagerComponents { public class EllipticDiffieHellman { diff --git a/MLAPI/NetworkingManagerComponents/EllipticCurve.cs b/MLAPI/NetworkingManagerComponents/EllipticCurve.cs index 3a37efe..3be71c8 100644 --- a/MLAPI/NetworkingManagerComponents/EllipticCurve.cs +++ b/MLAPI/NetworkingManagerComponents/EllipticCurve.cs @@ -2,7 +2,7 @@ using System.Collections.Generic; using IntXLib; -namespace ECDH +namespace MLAPI.NetworkingManagerComponents { public class CurvePoint {