diff --git a/MLAPI/Data/NetworkedClient.cs b/MLAPI/Data/NetworkedClient.cs index 6d1dbd2..135e778 100644 --- a/MLAPI/Data/NetworkedClient.cs +++ b/MLAPI/Data/NetworkedClient.cs @@ -8,5 +8,6 @@ namespace MLAPI public int ClientId; public GameObject PlayerObject; public List OwnedObjects = new List(); + public byte[] AesKey; } } diff --git a/MLAPI/Data/NetworkingConfiguration.cs b/MLAPI/Data/NetworkingConfiguration.cs index fe969cb..072543d 100644 --- a/MLAPI/Data/NetworkingConfiguration.cs +++ b/MLAPI/Data/NetworkingConfiguration.cs @@ -13,6 +13,7 @@ namespace MLAPI public List MessageTypes = new List(); public List PassthroughMessageTypes = new List(); internal HashSet RegisteredPassthroughMessageTypes = new HashSet(); + public HashSet EncryptedChannels = new HashSet(); public List RegisteredScenes = new List(); public int MessageBufferSize = 65535; public int ReceiveTickrate = 64; @@ -28,11 +29,8 @@ namespace MLAPI public byte[] ConnectionData = new byte[0]; public float SecondsHistory = 5; public bool HandleObjectSpawning = true; - //TODO - public bool CompressMessages = false; - //Should only be used for dedicated servers and will require the servers RSA keypair being hard coded into clients in order to exchange a AES key - //TODO - public bool EncryptMessages = false; + + public bool EnableEncryption = true; public bool AllowPassthroughMessages = true; public bool EnableSceneSwitching = false; @@ -72,8 +70,7 @@ namespace MLAPI } } writer.Write(HandleObjectSpawning); - writer.Write(CompressMessages); - writer.Write(EncryptMessages); + writer.Write(EnableEncryption); writer.Write(AllowPassthroughMessages); writer.Write(EnableSceneSwitching); } diff --git a/MLAPI/MLAPI.csproj b/MLAPI/MLAPI.csproj index 446cc8e..bd86046 100644 --- a/MLAPI/MLAPI.csproj +++ b/MLAPI/MLAPI.csproj @@ -65,6 +65,9 @@ + + + diff --git a/MLAPI/MonoBehaviours/Core/NetworkingManager.cs b/MLAPI/MonoBehaviours/Core/NetworkingManager.cs index 7d6ff02..96ea2bd 100644 --- a/MLAPI/MonoBehaviours/Core/NetworkingManager.cs +++ b/MLAPI/MonoBehaviours/Core/NetworkingManager.cs @@ -49,6 +49,10 @@ namespace MLAPI public NetworkingConfiguration NetworkConfig; + private EllipticDiffieHellman clientDiffieHellman; + private Dictionary diffieHellmanPublicKeys; + private byte[] clientAesKey; + private void OnValidate() { if (SpawnablePrefabs != null) @@ -86,6 +90,7 @@ namespace MLAPI pendingClients = new HashSet(); connectedClients = new Dictionary(); messageBuffer = new byte[NetworkConfig.MessageBufferSize]; + diffieHellmanPublicKeys = new Dictionary(); MessageManager.channels = new Dictionary(); MessageManager.messageTypes = new Dictionary(); MessageManager.messageCallbacks = new Dictionary>>(); @@ -372,15 +377,29 @@ namespace MLAPI } else { + byte[] diffiePublic = new byte[0]; + if(NetworkConfig.EnableEncryption) + { + clientDiffieHellman = new EllipticDiffieHellman(EllipticDiffieHellman.DEFAULT_CURVE, EllipticDiffieHellman.DEFAULT_GENERATOR, EllipticDiffieHellman.DEFAULT_ORDER); + diffiePublic = clientDiffieHellman.GetPublicKey(); + } + int sizeOfStream = 32; if (NetworkConfig.ConnectionApproval) sizeOfStream += 2 + NetworkConfig.ConnectionData.Length; + if (NetworkConfig.EnableEncryption) + sizeOfStream += 2 + diffiePublic.Length; using (MemoryStream writeStream = new MemoryStream(sizeOfStream)) { using (BinaryWriter writer = new BinaryWriter(writeStream)) { writer.Write(NetworkConfig.GetConfig()); + if (NetworkConfig.EnableEncryption) + { + writer.Write((ushort)diffiePublic.Length); + writer.Write(diffiePublic); + } if (NetworkConfig.ConnectionApproval) { writer.Write((ushort)NetworkConfig.ConnectionData.Length); @@ -469,6 +488,14 @@ namespace MLAPI ushort bytesToRead = reader.ReadUInt16(); byte[] incommingData = reader.ReadBytes(bytesToRead); + if(NetworkConfig.EncryptedChannels.Contains(channelId)) + { + //Encrypted message + if (isServer) + incommingData = CryptographyHelper.Decrypt(incommingData, connectedClients[clientId].AesKey); + else + incommingData = CryptographyHelper.Decrypt(incommingData, clientAesKey); + } if (isServer && isPassthrough && !NetworkConfig.RegisteredPassthroughMessageTypes.Contains(messageType)) { @@ -550,6 +577,18 @@ namespace MLAPI DisconnectClient(clientId); return; } + byte[] aesKey = new byte[0]; + if(NetworkConfig.EnableEncryption) + { + ushort diffiePublicSize = reader.ReadUInt16(); + byte[] diffiePublic = reader.ReadBytes(diffiePublicSize); + diffieHellmanPublicKeys.Add(clientId, diffiePublic); + /* + EllipticDiffieHellman diffieHellman = new EllipticDiffieHellman(EllipticDiffieHellman.DEFAULT_CURVE, EllipticDiffieHellman.DEFAULT_GENERATOR, EllipticDiffieHellman.DEFAULT_ORDER); + aesKey = diffieHellman.GetSharedSecret(diffiePublic); + */ + + } if (NetworkConfig.ConnectionApproval) { ushort bufferSize = messageReader.ReadUInt16(); @@ -578,6 +617,12 @@ namespace MLAPI sceneIndex = messageReader.ReadUInt32(); } + if (NetworkConfig.EnableEncryption) + { + ushort keyLength = reader.ReadUInt16(); + clientAesKey = clientDiffieHellman.GetSharedSecret(reader.ReadBytes(keyLength)); + } + float netTime = messageReader.ReadSingle(); int remoteStamp = messageReader.ReadInt32(); int msDelay = NetworkTransport.GetRemoteDelayTimeMS(hostId, clientId, remoteStamp, out error); @@ -894,8 +939,18 @@ namespace MLAPI writer.Write(orderId.Value); writer.Write(true); writer.Write(sourceId); - writer.Write((ushort)data.Length); - writer.Write(data); + if(NetworkConfig.EncryptedChannels.Contains(channelId)) + { + //Encrypted message + byte[] encrypted = CryptographyHelper.Encrypt(data, connectedClients[targetId].AesKey); + writer.Write((ushort)encrypted.Length); + writer.Write(encrypted); + } + else + { + writer.Write((ushort)data.Length); + writer.Write(data); + } } NetworkTransport.QueueMessageForSending(hostId, targetId, channelId, stream.GetBuffer(), sizeOfStream, out error); } @@ -944,8 +999,25 @@ namespace MLAPI writer.Write(isPassthrough); if (isPassthrough) writer.Write(clientId); - writer.Write((ushort)data.Length); - writer.Write(data); + + if (NetworkConfig.EncryptedChannels.Contains(MessageManager.channels[channelName])) + { + //This is an encrypted message. + byte[] encrypted; + if (isServer) + encrypted = CryptographyHelper.Encrypt(data, connectedClients[clientId].AesKey); + else + encrypted = CryptographyHelper.Encrypt(data, clientAesKey); + + writer.Write((ushort)encrypted.Length); + writer.Write(encrypted); + } + else + { + //Send in plaintext. + writer.Write((ushort)data.Length); + writer.Write(data); + } } if (isPassthrough) clientId = serverClientId; @@ -958,7 +1030,12 @@ namespace MLAPI internal void Send(int[] clientIds, string messageType, string channelName, byte[] data, uint? networkId = null, ushort? orderId = null) { - int sizeOfStream = 6; + if (NetworkConfig.EncryptedChannels.Contains(MessageManager.channels[channelName])) + { + Debug.LogWarning("MLAPI: Cannot send messages over encrypted channel to multiple clients."); + return; + } + int sizeOfStream = 6; if (networkId != null) sizeOfStream += 4; if (orderId != null) @@ -1000,6 +1077,12 @@ namespace MLAPI internal void Send(List clientIds, string messageType, string channelName, byte[] data, uint? networkId = null, ushort? orderId = null) { + if (NetworkConfig.EncryptedChannels.Contains(MessageManager.channels[channelName])) + { + Debug.LogWarning("MLAPI: Cannot send messages over encrypted channel to multiple clients."); + return; + } + //2 bytes for messageType, 2 bytes for buffer length and one byte for target bool int sizeOfStream = 6; if (networkId != null) @@ -1043,6 +1126,12 @@ namespace MLAPI internal void Send(string messageType, string channelName, byte[] data, uint? networkId = null, ushort? orderId = null) { + if (NetworkConfig.EncryptedChannels.Contains(MessageManager.channels[channelName])) + { + Debug.LogWarning("MLAPI: Cannot send messages over encrypted channel to multiple clients."); + return; + } + //2 bytes for messageType, 2 bytes for buffer length and one byte for target bool int sizeOfStream = 6; if (networkId != null) @@ -1087,6 +1176,12 @@ namespace MLAPI internal void Send(string messageType, string channelName, byte[] data, int clientIdToIgnore, uint? networkId = null, ushort? orderId = null) { + if (NetworkConfig.EncryptedChannels.Contains(MessageManager.channels[channelName])) + { + Debug.LogWarning("MLAPI: Cannot send messages over encrypted channel to multiple clients."); + return; + } + //2 bytes for messageType, 2 bytes for buffer length and one byte for target bool int sizeOfStream = 5; if (networkId != null) @@ -1134,10 +1229,16 @@ namespace MLAPI { if (!isServer) return; + if (pendingClients.Contains(clientId)) pendingClients.Remove(clientId); + if (connectedClients.ContainsKey(clientId)) connectedClients.Remove(clientId); + + if (diffieHellmanPublicKeys.ContainsKey(clientId)) + diffieHellmanPublicKeys.Remove(clientId); + NetworkTransport.Disconnect(hostId, clientId, out error); } @@ -1180,9 +1281,23 @@ namespace MLAPI //Inform new client it got approved if (pendingClients.Contains(clientId)) pendingClients.Remove(clientId); + + byte[] aesKey = new byte[0]; + byte[] publicKey = new byte[0]; + if (NetworkConfig.EnableEncryption) + { + EllipticDiffieHellman diffieHellman = new EllipticDiffieHellman(EllipticDiffieHellman.DEFAULT_CURVE, EllipticDiffieHellman.DEFAULT_GENERATOR, EllipticDiffieHellman.DEFAULT_ORDER); + aesKey = diffieHellman.GetSharedSecret(diffieHellmanPublicKeys[clientId]); + publicKey = diffieHellman.GetPublicKey(); + + if (diffieHellmanPublicKeys.ContainsKey(clientId)) + diffieHellmanPublicKeys.Remove(clientId); + } + NetworkedClient client = new NetworkedClient() { - ClientId = clientId + ClientId = clientId, + AesKey = aesKey }; connectedClients.Add(clientId, client); @@ -1193,7 +1308,6 @@ namespace MLAPI connectedClients[clientId].PlayerObject = go; } - int sizeOfStream = 16 + ((connectedClients.Count - 1) * 4); int amountOfObjectsToSend = SpawnManager.spawnedObjects.Values.Count(x => x.ServerOnly == false); @@ -1203,6 +1317,10 @@ namespace MLAPI sizeOfStream += 4; sizeOfStream += 14 * amountOfObjectsToSend; } + if(NetworkConfig.EnableEncryption) + { + sizeOfStream += 2 + publicKey.Length; + } if(NetworkConfig.EnableSceneSwitching) { sizeOfStream += 4; @@ -1217,8 +1335,16 @@ namespace MLAPI { writer.Write(NetworkSceneManager.CurrentSceneIndex); } + + if(NetworkConfig.EnableEncryption) + { + writer.Write((ushort)publicKey.Length); + writer.Write(publicKey); + } + writer.Write(NetworkTime); writer.Write(NetworkTransport.GetNetworkTimestamp()); + writer.Write(connectedClients.Count - 1); foreach (KeyValuePair item in connectedClients) { @@ -1284,6 +1410,10 @@ namespace MLAPI { if (pendingClients.Contains(clientId)) pendingClients.Remove(clientId); + + if (diffieHellmanPublicKeys.ContainsKey(clientId)) + diffieHellmanPublicKeys.Remove(clientId); + NetworkTransport.Disconnect(hostId, clientId, out error); } } diff --git a/MLAPI/NetworkingManagerComponents/CryptographyHelper.cs b/MLAPI/NetworkingManagerComponents/CryptographyHelper.cs new file mode 100644 index 0000000..02344e8 --- /dev/null +++ b/MLAPI/NetworkingManagerComponents/CryptographyHelper.cs @@ -0,0 +1,50 @@ +using System; +using System.Security.Cryptography; +using System.IO; + +namespace MLAPI.NetworkingManagerComponents +{ + public static class CryptographyHelper + { + public static byte[] Decrypt(byte[] encryptedBuffer, byte[] key) + { + byte[] iv = new byte[16]; + Array.Copy(encryptedBuffer, 0, iv, 0, 16); + + using (MemoryStream stream = new MemoryStream()) + { + using (RijndaelManaged aes = new RijndaelManaged()) + { + aes.IV = iv; + aes.Key = key; + using (CryptoStream cs = new CryptoStream(stream, aes.CreateDecryptor(), CryptoStreamMode.Write)) + { + cs.Write(encryptedBuffer, 16, encryptedBuffer.Length - 16); + } + return stream.ToArray(); + } + } + } + + public static byte[] Encrypt(byte[] clearBuffer, byte[] key) + { + using (MemoryStream stream = new MemoryStream()) + { + using (RijndaelManaged aes = new RijndaelManaged()) + { + aes.Key = key; + aes.GenerateIV(); + using (CryptoStream cs = new CryptoStream(stream, aes.CreateEncryptor(), CryptoStreamMode.Write)) + { + cs.Write(clearBuffer, 0, clearBuffer.Length); + } + byte[] encrypted = stream.ToArray(); + byte[] final = new byte[encrypted.Length + 16]; + Array.Copy(aes.IV, final, 16); + Array.Copy(encrypted, 0, final, 16, encrypted.Length); + return final; + } + } + } + } +} diff --git a/MLAPI/NetworkingManagerComponents/DiffieHellman.cs b/MLAPI/NetworkingManagerComponents/DiffieHellman.cs index 2e675b5..d7885f9 100644 --- a/MLAPI/NetworkingManagerComponents/DiffieHellman.cs +++ b/MLAPI/NetworkingManagerComponents/DiffieHellman.cs @@ -3,7 +3,7 @@ using IntXLib; using System.Text; using System.Security.Cryptography; -namespace ECDH +namespace MLAPI.NetworkingManagerComponents { public class EllipticDiffieHellman { diff --git a/MLAPI/NetworkingManagerComponents/EllipticCurve.cs b/MLAPI/NetworkingManagerComponents/EllipticCurve.cs index 3a37efe..3be71c8 100644 --- a/MLAPI/NetworkingManagerComponents/EllipticCurve.cs +++ b/MLAPI/NetworkingManagerComponents/EllipticCurve.cs @@ -2,7 +2,7 @@ using System.Collections.Generic; using IntXLib; -namespace ECDH +namespace MLAPI.NetworkingManagerComponents { public class CurvePoint {