Compare commits

...

1 Commits

Author SHA1 Message Date
Albin Corén
9de99a5246 Added Eliptic Curve DiffieHellman with RSA verification 2018-01-11 12:04:21 +01:00
2 changed files with 91 additions and 16 deletions

View File

@ -12,7 +12,9 @@ namespace MLAPI
public SortedDictionary<string, QosType> Channels = new SortedDictionary<string, QosType>(); public SortedDictionary<string, QosType> Channels = new SortedDictionary<string, QosType>();
public List<string> MessageTypes = new List<string>(); public List<string> MessageTypes = new List<string>();
public List<string> PassthroughMessageTypes = new List<string>(); public List<string> PassthroughMessageTypes = new List<string>();
public List<string> EncryptionMessageTypes = new List<string>();
internal HashSet<ushort> RegisteredPassthroughMessageTypes = new HashSet<ushort>(); internal HashSet<ushort> RegisteredPassthroughMessageTypes = new HashSet<ushort>();
internal HashSet<ushort> RegisteredEncryptionMessageTypes = new HashSet<ushort>();
public int MessageBufferSize = 65535; public int MessageBufferSize = 65535;
public int MaxMessagesPerFrame = 150; public int MaxMessagesPerFrame = 150;
public int MaxConnections = 100; public int MaxConnections = 100;
@ -25,10 +27,11 @@ namespace MLAPI
public bool HandleObjectSpawning = true; public bool HandleObjectSpawning = true;
//TODO //TODO
public bool CompressMessages = false; public bool CompressMessages = false;
//Should only be used for dedicated servers and will require the servers RSA keypair being hard coded into clients in order to exchange a AES key
//TODO
public bool EncryptMessages = false;
public bool AllowPassthroughMessages = true; public bool AllowPassthroughMessages = true;
public bool EnableDiffieHellman = true;
public bool DenyUntrustedKeys = false;
public RSAParameters ServerPublicKey;
public RSAParameters ServerPrivateKey;
//Cached config hash //Cached config hash
private byte[] ConfigHash = null; private byte[] ConfigHash = null;
@ -59,7 +62,7 @@ namespace MLAPI
} }
writer.Write(HandleObjectSpawning); writer.Write(HandleObjectSpawning);
writer.Write(CompressMessages); writer.Write(CompressMessages);
writer.Write(EncryptMessages); writer.Write(EnableDiffieHellman);
writer.Write(AllowPassthroughMessages); writer.Write(AllowPassthroughMessages);
} }
using(SHA256Managed sha256 = new SHA256Managed()) using(SHA256Managed sha256 = new SHA256Managed())

View File

@ -2,7 +2,9 @@
using System; using System;
using System.Collections; using System.Collections;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq;
using System.IO; using System.IO;
using System.Security.Cryptography;
using UnityEngine; using UnityEngine;
using UnityEngine.Networking; using UnityEngine.Networking;
@ -16,6 +18,9 @@ namespace MLAPI
//Client only, what my connectionId is on the server //Client only, what my connectionId is on the server
public int MyClientId; public int MyClientId;
internal Dictionary<int, NetworkedClient> connectedClients; internal Dictionary<int, NetworkedClient> connectedClients;
internal Dictionary<int, ECDiffieHellmanPublicKey> server_pendingPublicKeys;
internal ECDiffieHellmanCng client_pendingDiffieHellmanExchange;
internal Dictionary<int, byte[]> encryptionKeys;
public Dictionary<int, NetworkedClient> ConnectedClients public Dictionary<int, NetworkedClient> ConnectedClients
{ {
get get
@ -71,6 +76,9 @@ namespace MLAPI
MessageManager.reverseMessageTypes = new Dictionary<ushort, string>(); MessageManager.reverseMessageTypes = new Dictionary<ushort, string>();
SpawnManager.spawnedObjects = new Dictionary<uint, NetworkedObject>(); SpawnManager.spawnedObjects = new Dictionary<uint, NetworkedObject>();
SpawnManager.releasedNetworkObjectIds = new Stack<uint>(); SpawnManager.releasedNetworkObjectIds = new Stack<uint>();
server_pendingPublicKeys = new Dictionary<int, ECDiffieHellmanPublicKey>();
client_pendingDiffieHellmanExchange = null;
encryptionKeys = new Dictionary<int, byte[]>();
if (NetworkConfig.HandleObjectSpawning) if (NetworkConfig.HandleObjectSpawning)
{ {
NetworkedObject[] sceneObjects = FindObjectsOfType<NetworkedObject>(); NetworkedObject[] sceneObjects = FindObjectsOfType<NetworkedObject>();
@ -286,12 +294,25 @@ namespace MLAPI
if (NetworkConfig.ConnectionApproval) if (NetworkConfig.ConnectionApproval)
sizeOfStream += 2 + NetworkConfig.ConnectionData.Length; sizeOfStream += 2 + NetworkConfig.ConnectionData.Length;
byte[] clientPubKey = new byte[0];
if(NetworkConfig.EnableDiffieHellman)
{
ECDiffieHellmanCng diffieHellman = new ECDiffieHellmanCng();
client_pendingDiffieHellmanExchange = diffieHellman;
clientPubKey = diffieHellman.PublicKey.ToByteArray();
}
using (MemoryStream writeStream = new MemoryStream(sizeOfStream)) using (MemoryStream writeStream = new MemoryStream(sizeOfStream))
{ {
using (BinaryWriter writer = new BinaryWriter(writeStream)) using (BinaryWriter writer = new BinaryWriter(writeStream))
{ {
writer.Write(NetworkConfig.GetConfig()); writer.Write(NetworkConfig.GetConfig());
if(NetworkConfig.ConnectionApproval) if(NetworkConfig.EnableDiffieHellman)
{
writer.Write((ushort)clientPubKey.Length);
writer.Write(clientPubKey);
}
if (NetworkConfig.ConnectionApproval)
{ {
writer.Write((ushort)NetworkConfig.ConnectionData.Length); writer.Write((ushort)NetworkConfig.ConnectionData.Length);
writer.Write(NetworkConfig.ConnectionData); writer.Write(NetworkConfig.ConnectionData);
@ -300,6 +321,7 @@ namespace MLAPI
Send(clientId, "MLAPI_CONNECTION_REQUEST", "MLAPI_RELIABLE_FRAGMENTED_SEQUENCED", writeStream.GetBuffer()); Send(clientId, "MLAPI_CONNECTION_REQUEST", "MLAPI_RELIABLE_FRAGMENTED_SEQUENCED", writeStream.GetBuffer());
} }
} }
break; break;
case NetworkEventType.DataEvent: case NetworkEventType.DataEvent:
HandleIncomingData(clientId, messageBuffer, channelId); HandleIncomingData(clientId, messageBuffer, channelId);
@ -433,6 +455,18 @@ namespace MLAPI
DisconnectClient(clientId); DisconnectClient(clientId);
return; return;
} }
if(NetworkConfig.EnableDiffieHellman)
{
ushort pubKeyLength = messageReader.ReadUInt16();
byte[] clientPubKeyBlob = messageReader.ReadBytes(pubKeyLength);
ECDiffieHellmanPublicKey clientPubKey = ECDiffieHellmanCngPublicKey.FromByteArray(clientPubKeyBlob, CngKeyBlobFormat.EccPublicBlob);
byte[] serverPubKeyBlob = new byte[0];
using (ECDiffieHellmanCng diffieHellman = new ECDiffieHellmanCng())
{
encryptionKeys.Add(clientId, diffieHellman.DeriveKeyMaterial(clientPubKey));
server_pendingPublicKeys.Add(clientId, diffieHellman.PublicKey);
}
}
if (NetworkConfig.ConnectionApproval) if (NetworkConfig.ConnectionApproval)
{ {
ushort bufferSize = messageReader.ReadUInt16(); ushort bufferSize = messageReader.ReadUInt16();
@ -455,6 +489,31 @@ namespace MLAPI
{ {
using (BinaryReader messageReader = new BinaryReader(messageReadStream)) using (BinaryReader messageReader = new BinaryReader(messageReadStream))
{ {
if(NetworkConfig.EnableDiffieHellman)
{
ushort serverPubKeyLength = messageReader.ReadUInt16();
byte[] serverPubKeyBlob = messageReader.ReadBytes(serverPubKeyLength);
if(NetworkConfig.DenyUntrustedKeys)
{
ushort signedPubKeyLength = messageReader.ReadUInt16();
byte[] signedPubKey = messageReader.ReadBytes(signedPubKeyLength);
//TODO: NEEDS THREADING!!!!
using(RSA rsa = RSA.Create())
{
rsa.ImportParameters(NetworkConfig.ServerPublicKey);
byte[] decryptedSignedPubKey = rsa.DecryptValue(signedPubKey);
if(!serverPubKeyBlob.SequenceEqual(decryptedSignedPubKey))
{
//Man in the middle attack detected. Or the RSA keys are wrong ;)
}
}
}
ECDiffieHellmanCng diffieHellman = client_pendingDiffieHellmanExchange;
ECDiffieHellmanPublicKey serverPubKey = ECDiffieHellmanCngPublicKey.FromByteArray(serverPubKeyBlob, CngKeyBlobFormat.EccPublicBlob);
encryptionKeys.Add(serverClientId, diffieHellman.DeriveKeyMaterial(serverPubKey));
diffieHellman.Clear();
client_pendingDiffieHellmanExchange = null;
}
MyClientId = messageReader.ReadInt32(); MyClientId = messageReader.ReadInt32();
connectedClients.Add(MyClientId, new NetworkedClient() { ClientId = MyClientId }); connectedClients.Add(MyClientId, new NetworkedClient() { ClientId = MyClientId });
int clientCount = messageReader.ReadInt32(); int clientCount = messageReader.ReadInt32();
@ -825,6 +884,10 @@ namespace MLAPI
{ {
if (pendingClients.Contains(clientId)) if (pendingClients.Contains(clientId))
pendingClients.Remove(clientId); pendingClients.Remove(clientId);
if (encryptionKeys.ContainsKey(clientId))
encryptionKeys.Remove(clientId);
if (server_pendingPublicKeys.ContainsKey(clientId))
server_pendingPublicKeys.Remove(clientId);
if (connectedClients.ContainsKey(clientId)) if (connectedClients.ContainsKey(clientId))
{ {
if(NetworkConfig.HandleObjectSpawning) if(NetworkConfig.HandleObjectSpawning)
@ -873,8 +936,6 @@ namespace MLAPI
connectedClients[clientId].PlayerObject = go; connectedClients[clientId].PlayerObject = go;
} }
int sizeOfStream = 4 + 4 + ((connectedClients.Count - 1) * 4);
int amountOfObjectsToSend = 0; int amountOfObjectsToSend = 0;
foreach (KeyValuePair<uint, NetworkedObject> pair in SpawnManager.spawnedObjects) foreach (KeyValuePair<uint, NetworkedObject> pair in SpawnManager.spawnedObjects)
{ {
@ -883,16 +944,28 @@ namespace MLAPI
else else
amountOfObjectsToSend++; amountOfObjectsToSend++;
} }
if(NetworkConfig.HandleObjectSpawning)
{
sizeOfStream += 4;
sizeOfStream += 13 * amountOfObjectsToSend;
}
using (MemoryStream writeStream = new MemoryStream(sizeOfStream)) using (MemoryStream writeStream = new MemoryStream())
{ {
using (BinaryWriter writer = new BinaryWriter(writeStream)) using (BinaryWriter writer = new BinaryWriter(writeStream))
{ {
if(NetworkConfig.EnableDiffieHellman)
{
byte[] serverPubKey = server_pendingPublicKeys[clientId].ToByteArray();
writer.Write((ushort)serverPubKey.Length);
writer.Write(serverPubKey);
if(NetworkConfig.DenyUntrustedKeys)
{
//NEEDS TO BE THREADED!!!
using (RSA rsa = RSA.Create())
{
rsa.ImportParameters(NetworkConfig.ServerPrivateKey);
byte[] signedPubKey = rsa.EncryptValue(serverPubKey);
writer.Write((ushort)signedPubKey.Length);
writer.Write(signedPubKey);
}
}
}
writer.Write(clientId); writer.Write(clientId);
writer.Write(connectedClients.Count - 1); writer.Write(connectedClients.Count - 1);
foreach (KeyValuePair<int, NetworkedClient> item in connectedClients) foreach (KeyValuePair<int, NetworkedClient> item in connectedClients)
@ -917,15 +990,14 @@ namespace MLAPI
} }
} }
} }
Send(clientId, "MLAPI_CONNECTION_APPROVED", "MLAPI_RELIABLE_FRAGMENTED_SEQUENCED", writeStream.GetBuffer()); Send(clientId, "MLAPI_CONNECTION_APPROVED", "MLAPI_RELIABLE_FRAGMENTED_SEQUENCED", writeStream.ToArray());
} }
//Inform old clients of the new player //Inform old clients of the new player
int sizeOfStream = 4;
if(NetworkConfig.HandleObjectSpawning) if(NetworkConfig.HandleObjectSpawning)
sizeOfStream = 13; sizeOfStream = 13;
else
sizeOfStream = 4;
using (MemoryStream stream = new MemoryStream(sizeOfStream)) using (MemoryStream stream = new MemoryStream(sizeOfStream))
{ {