diff --git a/Client/BinaryCollector.cs b/Client/BinaryCollector.cs new file mode 100644 index 0000000..b2a08e3 --- /dev/null +++ b/Client/BinaryCollector.cs @@ -0,0 +1,192 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Reflection; +using System.Runtime.InteropServices; +using System.Text; +using System.Threading.Tasks; + +namespace Client +{ + public sealed class BinaryCollector : IDisposable + { + // Collects reusable + private static readonly List> expired = new List>(); + + private static readonly List supportedTypes = new List() + { + typeof(bool), + typeof(byte), + typeof(sbyte), + typeof(char), + typeof(short), + typeof(ushort), + typeof(int), + typeof(uint), + typeof(long), + typeof(ulong), + typeof(float), + typeof(double), + typeof(decimal) + }; + + private static readonly FieldInfo + dec_lo, + dec_mid, + dec_hi, + dec_flags; + + static BinaryCollector() + { + dec_lo = typeof(decimal).GetField("lo", BindingFlags.NonPublic); + dec_mid = typeof(decimal).GetField("mid", BindingFlags.NonPublic); + dec_hi = typeof(decimal).GetField("hi", BindingFlags.NonPublic); + dec_flags = typeof(decimal).GetField("flags", BindingFlags.NonPublic); + } + + private object[] collect; + private readonly int bufferSize; + private int collectCount = 0; + + /// + /// Allocates a new binary collector. + /// + public BinaryCollector(int bufferSize) + { + this.bufferSize = bufferSize; + for (int i = expired.Count - 1; i >= 0; --i) + if (expired[i].TryGetTarget(out collect)) + { + if (collect.Length >= bufferSize) + { + expired.RemoveAt(i); // This entry he been un-expired for now + break; + } + } + else expired.RemoveAt(i); // Entry has been collected by GC + if (collect == null || collect.Length < bufferSize) + collect = new object[bufferSize]; + } + + public void Push(T b) + { + if (b is string || b.GetType().IsArray || IsSupportedType(b.GetType())) + collect[collectCount++] = b is string ? Encoding.UTF8.GetBytes(b as string) : b as object; + //else + // Debug.LogWarning("MLAPI: The type \"" + b.GetType() + "\" is not supported by the Binary Serializer. It will be ignored"); + } + + public byte[] ToArray() + { + long bitCount = 0; + for (int i = 0; i < collectCount; ++i) bitCount += GetBitCount(collect[i]); + + byte[] alloc = new byte[(bitCount / 8) + (bitCount % 8 == 0 ? 0 : 1)]; + long bitOffset = 0; + foreach (var item in collect) + Serialize(item, alloc, ref bitOffset); + + return alloc; + } + + private static void Serialize(T t, byte[] writeTo, ref long bitOffset) + { + Type type = t.GetType(); + if (type.IsArray) + { + var array = t as Array; + Serialize(array.Length, writeTo, ref bitOffset); + foreach (var element in array) + Serialize(element, writeTo, ref bitOffset); + } + else if (IsSupportedType(type)) + { + long offset = GetBitAllocation(type); + if (type == typeof(bool)) WriteBit(writeTo, t as bool? ?? false, bitOffset); + else if(type == typeof(decimal)) + { + WriteDynamic(writeTo, dec_lo.GetValue(t), 4, bitOffset); + WriteDynamic(writeTo, dec_mid.GetValue(t), 4, bitOffset + 32); + WriteDynamic(writeTo, dec_hi.GetValue(t), 4, bitOffset + 64); + WriteDynamic(writeTo, dec_flags.GetValue(t), 4, bitOffset + 96); + } + else if(type == typeof(float)) + { + + } + bitOffset += offset; + } + } + + private static long GetBitCount(T t) + { + Type type = t.GetType(); + long count = 0; + if (type.IsArray) + { + Type elementType = type.GetElementType(); + long allocSize = GetBitAllocation(elementType); + var array = t as Array; + + count += 2; // Int16 array size. Arrays shouldn't be syncing more than 65k elements + + if (allocSize != 0) // The array contents is known: compute the data size + count += allocSize * array.Length; + else // Unknown array contents type: iteratively assess serialization size + foreach (var element in t as Array) + count += GetBitCount(element); + } + else if(IsSupportedType(type)) count += GetBitAllocation(type); + //else + // Debug.LogWarning("MLAPI: The type \"" + b.GetType() + "\" is not supported by the Binary Serializer. It will be ignored"); + return count; + } + + private static void WriteBit(byte[] b, bool bit, long index) + => b[index / 8] = (byte)((b[index / 8] & (1 << (int)(index % 8))) | (bit ? 1 << (int)(index % 8) : 0)); + private static void WriteByte(byte[] b, byte value, long index) + { + int byteIndex = (int)(index / 8); + int shift = (int)(index % 8); + byte upper_mask = (byte)(0xFF << shift); + byte lower_mask = (byte)~upper_mask; + + b[byteIndex] = (byte)((b[byteIndex] & lower_mask) | (value << shift)); + if(shift != 0 && byteIndex + 1 < b.Length) + b[byteIndex + 1] = (byte)((b[byteIndex + 1] & upper_mask) | (value << (8 - shift))); + } + private static void WriteDynamic(byte[] b, dynamic value, int byteCount, long index) + { + for (int i = 0; i < byteCount; ++i) + WriteByte(b, (byte)((value >> (8 * i)) & 0xFF), index + (8 * i)); + } + + // Supported datatypes for serialization + private static bool IsSupportedType(Type t) => supportedTypes.Contains(t); + + // Specifies how many bits will be written + private static long GetBitAllocation(Type t) => + t == typeof(bool) ? 1 : + t == typeof(byte) ? 8 : + t == typeof(sbyte) ? 8 : + t == typeof(short) ? 16 : + t == typeof(char) ? 16 : + t == typeof(ushort) ? 16 : + t == typeof(int) ? 32 : + t == typeof(uint) ? 32 : + t == typeof(long) ? 64 : + t == typeof(ulong) ? 64 : + t == typeof(float) ? 32 : + t == typeof(double) ? 64 : + t == typeof(decimal) ? 128 : + 0; // Unknown type + + // Creates a weak reference to the allocated collector so that reuse may be possible + public void Dispose() + { + expired.Add(new WeakReference(collect)); + collect = null; + } + } +} diff --git a/Client/Client.csproj b/Client/Client.csproj index 9814e17..85807e1 100644 --- a/Client/Client.csproj +++ b/Client/Client.csproj @@ -34,6 +34,7 @@ + @@ -42,6 +43,7 @@ + diff --git a/Client/ConsoleForms/ConsoleController.cs b/Client/ConsoleForms/ConsoleController.cs index 771f448..f018520 100644 --- a/Client/ConsoleForms/ConsoleController.cs +++ b/Client/ConsoleForms/ConsoleController.cs @@ -54,7 +54,11 @@ namespace Client.ConsoleForms Draw(); }); - RegisterListener((w1, h1, w2, h2) => Console.Clear()); + RegisterListener((w1, h1, w2, h2) => + { + Console.BackgroundColor = ConsoleColor.Black; + Console.Clear(); + }); } public void AddView(View v, bool redraw = true) => AddView(v, LayoutMeta.Centering(v), redraw); @@ -194,12 +198,12 @@ namespace Client.ConsoleForms } } - private static void ClearRegion(Region r, ConsoleColor clearColor = ConsoleColor.Black) + public static void ClearRegion(Region r, ConsoleColor clearColor = ConsoleColor.Black) { foreach (var rect in r.SubRegions) ClearRegion(rect, clearColor); } - private static void ClearRegion(Rectangle rect, ConsoleColor clearColor = ConsoleColor.Black) + public static void ClearRegion(Rectangle rect, ConsoleColor clearColor = ConsoleColor.Black) { Console.BackgroundColor = clearColor; Console.ForegroundColor = ConsoleColor.White; diff --git a/Client/ConsoleForms/Graphics/InputView.cs b/Client/ConsoleForms/Graphics/InputView.cs index 9955adf..d744298 100644 --- a/Client/ConsoleForms/Graphics/InputView.cs +++ b/Client/ConsoleForms/Graphics/InputView.cs @@ -157,6 +157,7 @@ namespace Client.ConsoleForms.Graphics computedSize += splitInputs[i].Length; } ContentHeight += computedSize + Inputs.Length * 2; + ++ContentWidth; // Idk, it works, though... } protected override void _Draw(int left, ref int top) diff --git a/Client/ConsoleForms/Graphics/ListView.cs b/Client/ConsoleForms/Graphics/ListView.cs index ab84136..7abf0a6 100644 --- a/Client/ConsoleForms/Graphics/ListView.cs +++ b/Client/ConsoleForms/Graphics/ListView.cs @@ -57,10 +57,15 @@ namespace Client.ConsoleForms.Graphics if(view == innerViews[SelectedView]) { view.Item2.BackgroundColor = SelectBackground; - view.Item2.TextColor = SelectText; + //view.Item2.TextColor = SelectText; } + Region sub = new Region(new Rectangle(0, 0, ContentWidth, view.Item2.ContentHeight)).Subtract(view.Item2.Occlusion); - DrawView(left, ref top, view.Item2); + sub.Offset(left, top); + + ConsoleController.ClearRegion(sub, view.Item2.BackgroundColor); + + DrawView(left - 1, ref top, view.Item2); if (view == innerViews[SelectedView]) { diff --git a/Client/ConsoleForms/Graphics/TextView.cs b/Client/ConsoleForms/Graphics/TextView.cs index 2bf337f..3b749e4 100644 --- a/Client/ConsoleForms/Graphics/TextView.cs +++ b/Client/ConsoleForms/Graphics/TextView.cs @@ -36,7 +36,7 @@ namespace Client.ConsoleForms.Graphics Dirty = true; } } - public override Region Occlusion => new Region(new Rectangle(0, -1, ContentWidth + 2, ContentHeight)); + public override Region Occlusion => new Region(new Rectangle(0, DrawBorder ? -1 : 0, ContentWidth + (DrawBorder ? 2 : 0), ContentHeight)); //public char Border { get; set; } //public ConsoleColor BorderColor { get; set; } diff --git a/Client/ConsoleForms/Rectangle.cs b/Client/ConsoleForms/Rectangle.cs index a33965a..ceeded9 100644 --- a/Client/ConsoleForms/Rectangle.cs +++ b/Client/ConsoleForms/Rectangle.cs @@ -8,10 +8,10 @@ namespace Client.ConsoleForms { public class Rectangle { - public int Top { get; private set; } - public int Bottom { get; private set; } - public int Left { get; private set; } - public int Right { get; private set; } + public int Top { get; internal set; } + public int Bottom { get; internal set; } + public int Left { get; internal set; } + public int Right { get; internal set; } public Rectangle(int left, int top, int right, int bottom) { Left = left; @@ -20,15 +20,15 @@ namespace Client.ConsoleForms Bottom = bottom; } - public bool Intersects(Rectangle rect) => ((Left < rect.Right && Right >= rect.Left) || (Left <= rect.Right && Right > rect.Left)) && ((Top > rect.Bottom && Bottom <= rect.Top) || (Top >= rect.Bottom && Bottom < rect.Top)); + public bool Intersects(Rectangle rect) => ((Left < rect.Right && Right >= rect.Left) || (Left <= rect.Right && Right > rect.Left)) && ((Top < rect.Bottom && Bottom >= rect.Top) || (Top <= rect.Bottom && Bottom > rect.Top)); public bool Occludes(Rectangle rect) => Top >= rect.Top && Right >= rect.Right && Left >= rect.Left && Bottom >= rect.Bottom; public Rectangle GetIntersecting(Rectangle rect) => Intersects(rect) ? new Rectangle( - Left < rect.Right ? Left : rect.Left, - Bottom < rect.Top ? rect.Top : Top, - Left < rect.Right ? rect.Right : Right, - Bottom < rect.Top ? Bottom : rect.Bottom + Math.Max(Left, rect.Left), + Math.Max(rect.Top, Top), + Math.Min(rect.Right, Right), + Math.Min(Bottom, rect.Bottom) ) : null; @@ -42,11 +42,11 @@ namespace Client.ConsoleForms if (intersect.Left > Left) components[rectangles++] = new Rectangle(Left, Math.Max(intersect.Top, Top), intersect.Left, Math.Min(intersect.Bottom, Bottom)); if (intersect.Right < Right) - components[rectangles++] = new Rectangle(intersect.Right, Math.Max(intersect.Top, Top), Left, Math.Min(intersect.Bottom, Bottom)); + components[rectangles++] = new Rectangle(intersect.Right, Math.Max(intersect.Top, Top), Right, Math.Min(intersect.Bottom, Bottom)); if (intersect.Top > Top) - components[rectangles++] = new Rectangle(Math.Min(Left, intersect.Left), Top, Math.Max(Right, intersect.Right), intersect.Top); + components[rectangles++] = new Rectangle(Left, Top, Right, intersect.Top); if (intersect.Bottom < Bottom) - components[rectangles] = new Rectangle(Math.Min(Left, intersect.Left), intersect.Bottom, Math.Max(Right, intersect.Right), Bottom); + components[rectangles] = new Rectangle(Left, intersect.Bottom, Right, Bottom); return components; } diff --git a/Client/Context/SessionContext.cs b/Client/Context/SessionContext.cs index e1a0055..8394e8d 100644 --- a/Client/Context/SessionContext.cs +++ b/Client/Context/SessionContext.cs @@ -29,13 +29,16 @@ namespace Client public override void OnCreate() { - controller.AddView(views.GetNamed("Success")); + //controller.AddView(views.GetNamed("Success")); + controller.AddView(views.GetNamed("menu_options")); } public override void OnDestroy() { controller.CloseView(views.GetNamed("Success")); +#pragma warning disable CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed interactor.Disconnect(); +#pragma warning restore CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed } } } diff --git a/Client/Context/WelcomeContext.cs b/Client/Context/WelcomeContext.cs index 8ce522b..e42543e 100644 --- a/Client/Context/WelcomeContext.cs +++ b/Client/Context/WelcomeContext.cs @@ -167,7 +167,9 @@ namespace Client // Stop listening interactor.UnregisterListener(token); +#pragma warning disable CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed if (forceDestroy) interactor.Disconnect(); +#pragma warning restore CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed } } } diff --git a/Client/Networking.cs b/Client/Networking.cs index 1f303e2..9e60a83 100644 --- a/Client/Networking.cs +++ b/Client/Networking.cs @@ -1,6 +1,9 @@ -using System; +using Common; +using Common.Cryptography.KeyExchange; +using System; using System.Collections.Generic; using System.Linq; +using System.Numerics; using System.Text; using System.Threading.Tasks; using Tofvesson.Crypto; @@ -25,9 +28,9 @@ namespace Client if(checkIdentity) new Task(() => { - AuthenticatedKeys = NetClient.CheckServerIdentity(address, port, provider); + //AuthenticatedKeys = NetClient.CheckServerIdentity(address, port, provider); authenticating = false; - authenticated = AuthenticatedKeys != null; + authenticated = true;// AuthenticatedKeys != null; }).Start(); else { @@ -36,10 +39,7 @@ namespace Client } var addr = System.Net.IPAddress.Parse(address); client = new NetClient( - new Rijndael128( - Convert.ToBase64String(provider.GetBytes(64)), // 64-byte key (converted to base64) - Convert.ToBase64String(provider.GetBytes(64)) // 64-byte salt (converted to base64) - ), + EllipticDiffieHellman.Curve25519(EllipticDiffieHellman.Curve25519_GeneratePrivate(provider)), addr, port, MessageRecievedHandler, diff --git a/Client/Program.cs b/Client/Program.cs index 7542146..8867aad 100644 --- a/Client/Program.cs +++ b/Client/Program.cs @@ -18,6 +18,14 @@ namespace ConsoleForms // Set up timestamps in debug output DebugStream = new TimeStampWriter(DebugStream, "HH:mm:ss.fff"); + byte[] serialized; + + using (BinaryCollector collector = new BinaryCollector(1)) + { + collector.Push(5f); + serialized = collector.ToArray(); + } + Padding p = new AbsolutePadding(2, 2, 1, 1); @@ -52,6 +60,7 @@ namespace ConsoleForms } while (!info.ValidEvent || info.Event.Key != ConsoleKey.Escape); } + // Detects if a key has been hit without blocking [DllImport("msvcrt")] public static extern int _kbhit(); } diff --git a/Client/Resources/Layout/Session.xml b/Client/Resources/Layout/Session.xml index b4f9846..d6e7494 100644 --- a/Client/Resources/Layout/Session.xml +++ b/Client/Resources/Layout/Session.xml @@ -25,17 +25,17 @@ padding_top="1" padding_bottom="1"> - + - + + @string/SE_tx + - + + @string/SE_pwdu + \ No newline at end of file diff --git a/Common/Common.csproj b/Common/Common.csproj index f5547d8..17ac4f9 100644 --- a/Common/Common.csproj +++ b/Common/Common.csproj @@ -63,14 +63,20 @@ - - + + + + + + - + + + - + diff --git a/Common/AES.cs b/Common/Cryptography/AES.cs similarity index 100% rename from Common/AES.cs rename to Common/Cryptography/AES.cs diff --git a/Common/CBC.cs b/Common/Cryptography/CBC.cs similarity index 100% rename from Common/CBC.cs rename to Common/Cryptography/CBC.cs diff --git a/Common/Cryptography/EllipticCurve.cs b/Common/Cryptography/EllipticCurve.cs new file mode 100644 index 0000000..10cfbb3 --- /dev/null +++ b/Common/Cryptography/EllipticCurve.cs @@ -0,0 +1,199 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Numerics; +using System.Text; +using System.Threading.Tasks; +using Tofvesson.Crypto; + +namespace Common.Cryptography +{ + public class Point + { + public static readonly Point POINT_AT_INFINITY = new Point(); + public BigInteger X { get; private set; } + public BigInteger Y { get; private set; } + private bool pai = false; + public Point(BigInteger x, BigInteger y) + { + X = x; + Y = y; + } + private Point() { pai = true; } // Accessing corrdinates causes undocumented behaviour + public override string ToString() + { + return pai ? "(POINT_AT_INFINITY)" : "(" + X + ", " + Y + ")"; + } + } + + public class EllipticCurve + { + public enum CurveType { Weierstrass, Montgomery } + + protected readonly BigInteger a, b, modulo; + protected readonly CurveType type; + + public EllipticCurve(BigInteger a, BigInteger b, BigInteger modulo, CurveType type = CurveType.Weierstrass) + { + if ( + (type == CurveType.Weierstrass && (4 * a * a * a) + (27 * b * b) == 0) || // Unfavourable Weierstrass curves + (type == CurveType.Montgomery && b * (a * a - 4) == 0) // Unfavourable Montgomery curves + ) throw new Exception("Unfavourable curve"); + this.a = a; + this.b = b; + this.modulo = modulo; + this.type = type; + } + + public Point Add(Point p1, Point p2) + { +#if SAFE_MATH + CheckOnCurve(p1); + CheckOnCurve(p2); +#endif + + // Special cases + if (p1 == Point.POINT_AT_INFINITY && p2 == Point.POINT_AT_INFINITY) return Point.POINT_AT_INFINITY; + else if (p1 == Point.POINT_AT_INFINITY) return p2; + else if (p2 == Point.POINT_AT_INFINITY) return p1; + else if (p1.X == p2.X && p1.Y == Inverse(p2).Y) return Point.POINT_AT_INFINITY; + + BigInteger x3 = 0, y3 = 0; + if (type == CurveType.Weierstrass) + { + BigInteger slope = p1.X == p2.X && p1.Y == p2.Y ? Mod((3 * p1.X * p1.X + a) * MulInverse(2 * p1.Y)) : Mod(Mod(p2.Y - p1.Y) * MulInverse(p2.X - p1.X)); + x3 = Mod((slope * slope) - p1.X - p2.X); + y3 = Mod(-((slope * x3) + p1.Y - (slope * p1.X))); + } + else if (type == CurveType.Montgomery) + { + if ((p1.X == p2.X && p1.Y == p2.Y)) + { + BigInteger q = 3 * p1.X; + BigInteger w = q * p1.X; + + BigInteger e = 2 * a; + BigInteger r = e * p1.X; + + BigInteger t = 2 * b; + BigInteger y = t * p1.Y; + + BigInteger u = MulInverse(y); + + BigInteger o = w + e + 1; + BigInteger p = o * u; + } + BigInteger co = p1.X == p2.X && p1.Y == p2.Y ? Mod((3 * p1.X * p1.X + 2 * a * p1.X + 1) * MulInverse(2 * b * p1.Y)) : Mod(Mod(p2.Y - p1.Y) * MulInverse(p2.X - p1.X)); // Compute a commonly used coefficient + x3 = Mod(b * co * co - a - p1.X - p2.X); + y3 = Mod(((2 * p1.X + p2.X + a) * co) - (b * co * co * co) - p1.Y); + } + + return new Point(x3, y3); + } + + public Point Multiply(Point p, BigInteger scalar) + { + if (scalar <= 0) throw new Exception("Cannot multiply by a scalar which is <= 0"); + if (p == Point.POINT_AT_INFINITY) return Point.POINT_AT_INFINITY; + + Point p1 = new Point(p.X, p.Y); + long high_bit = scalar.HighestBit() - 1; + + // Double-and-add method + while (high_bit >= 0) + { + p1 = Add(p1, p1); // Double + if ((scalar.BitAt(high_bit))) + p1 = Add(p1, p); // Add + --high_bit; + } + + return p1; + } + + protected BigInteger MulInverse(BigInteger eq) => MulInverse(eq, modulo); + public static BigInteger MulInverse(BigInteger eq, BigInteger modulo) + { + eq = Mod(eq, modulo); + Stack collect = new Stack(); + BigInteger v = modulo; // Copy modulo + BigInteger m; + while ((m = v % eq) != 0) + { + collect.Push(-(v/eq)); + v = eq; + eq = m; + } + if (collect.Count == 0) return 1; + v = 1; + m = collect.Pop(); + while (collect.Count > 0) + { + eq = m; + m = v + (m * collect.Pop()); + v = eq; + } + return Mod(m, modulo); + } + + public Point Inverse(Point p) => Inverse(p, modulo); + protected static Point Inverse(Point p, BigInteger modulo) => new Point(p.X, Mod(-p.Y, modulo)); + + public bool IsOnCurve(Point p) + { + try { CheckOnCurve(p); } + catch { return false; } + return true; + } + protected void CheckOnCurve(Point p) + { + if ( + p != Point.POINT_AT_INFINITY && // The point at infinity is asserted to be on the curve + (type == CurveType.Weierstrass && Mod(p.Y * p.Y) != Mod((p.X * p.X * p.X) + (p.X * a) + b)) || // Weierstrass formula + (type == CurveType.Montgomery && Mod(b * p.Y * p.Y) != Mod((p.X * p.X * p.X) + (p.X * p.X * a) + p.X)) // Montgomery formula + ) throw new Exception("Point is not on curve"); + } + + protected BigInteger Mod(BigInteger b) => Mod(b, modulo); + + private static BigInteger Mod(BigInteger x, BigInteger m) + { + BigInteger r; ; + if (x.Abs() >= m) r = x % m; + else r = x; + return r < 0 ? r + m : r; + } + + // Efficient modular square root function + public static BigInteger ShanksTonelli(BigInteger a, BigInteger prime) + { + if (prime < 3 || ModPow(a, (prime - 1) / 2, prime) != 1) return 0; + Random rand = new Random(); + int e = 0; + while ((prime & 1) != 1) + { + prime >>= 1; + e += 1; + } + BigInteger s = prime / BigInteger.Pow(2, e); + return 0; + } + + protected static BigInteger ModPow(BigInteger x, BigInteger power, BigInteger prime) + { + BigInteger result = 1; + bool setBit = false; + while (power > 0) + { + x %= prime; + setBit = (power & 1) == 1; + power >>= 1; + if (setBit) result *= x; + x *= x; + } + + return result; + } + } + +} diff --git a/Common/Cryptography/KeyExchange/DiffieHellman.cs b/Common/Cryptography/KeyExchange/DiffieHellman.cs new file mode 100644 index 0000000..d12c210 --- /dev/null +++ b/Common/Cryptography/KeyExchange/DiffieHellman.cs @@ -0,0 +1,34 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Numerics; +using System.Text; +using System.Threading.Tasks; +using Tofvesson.Crypto; + +namespace Common.Cryptography.KeyExchange +{ + public sealed class DiffieHellman : IKeyExchange + { + private static readonly BigInteger EPHEMERAL_MAX = BigInteger.One << 2048; + private static readonly RandomProvider provider = new CryptoRandomProvider(); + private BigInteger priv, p, q; + private readonly BigInteger pub; + + public DiffieHellman(BigInteger p, BigInteger q) : this(provider.GenerateRandom(EPHEMERAL_MAX), p, q) { } + public DiffieHellman(BigInteger priv, BigInteger p, BigInteger q) + { + this.priv = priv; + this.p = p; + this.q = q; + pub = Support.ModExp(p, priv, q); + } + + public byte[] GetPublicKey() => pub.ToByteArray(); + + public byte[] GetSharedSecret(byte[] p) { + BigInteger pub = new BigInteger(p); + return (pub <= 0 ? (BigInteger) 0 : Support.ModExp(pub, priv, q)).ToByteArray(); + } + } +} diff --git a/Common/Cryptography/KeyExchange/EllipticDiffieHellman.cs b/Common/Cryptography/KeyExchange/EllipticDiffieHellman.cs new file mode 100644 index 0000000..7a29724 --- /dev/null +++ b/Common/Cryptography/KeyExchange/EllipticDiffieHellman.cs @@ -0,0 +1,177 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Numerics; +using System.Text; +using System.Threading.Tasks; +using Tofvesson.Crypto; + +namespace Common.Cryptography.KeyExchange +{ + public class EllipticDiffieHellman : IKeyExchange + { + private static readonly BigInteger c_25519_prime = (BigInteger.One << 255) - 19; + private static readonly BigInteger c_25519_order = (BigInteger.One << 252) + BigInteger.Parse("27742317777372353535851937790883648493"); // 27_742_317_777_372_353_535_851_937_790_883_648_493 + private static readonly EllipticCurve c_25519 = new EllipticCurve(486662, 1, c_25519_prime, EllipticCurve.CurveType.Montgomery); + private static readonly Point c_25519_gen = new Point(9, BigInteger.Parse("14781619447589544791020593568409986887264606134616475288964881837755586237401")); + + protected static readonly Random rand = new Random(); + + protected readonly EllipticCurve curve; + public readonly BigInteger priv; + protected readonly Point generator, pub; + + + public EllipticDiffieHellman(EllipticCurve curve, Point generator, BigInteger order, byte[] priv = null) + { + this.curve = curve; + this.generator = generator; + + // Generate private key + if (priv == null) + { + byte[] max = order.ToByteArray(); + do + { + byte[] p1 = new byte[5 /*rand.Next(max.Length) + 1*/]; + + rand.NextBytes(p1); + + if (p1.Length == max.Length) p1[p1.Length - 1] %= max[max.Length - 1]; + else p1[p1.Length - 1] &= 127; + + this.priv = new BigInteger(p1); + } while (this.priv < 2); + } + else this.priv = new BigInteger(priv); + + // Generate public key + pub = curve.Multiply(generator, this.priv); + } + + public byte[] GetPublicKey() + { + byte[] p1 = pub.X.ToByteArray(); + byte[] p2 = pub.Y.ToByteArray(); + + byte[] ser = new byte[4 + p1.Length + p2.Length]; + ser[0] = (byte)(p1.Length & 255); + ser[1] = (byte)((p1.Length >> 8) & 255); + ser[2] = (byte)((p1.Length >> 16) & 255); + ser[3] = (byte)((p1.Length >> 24) & 255); + Array.Copy(p1, 0, ser, 4, p1.Length); + Array.Copy(p2, 0, ser, 4 + p1.Length, p2.Length); + + return ser; + } + + public byte[] GetPrivateKey() => priv.ToByteArray(); + + public byte[] GetSharedSecret(byte[] pK) + { + byte[] p1 = new byte[pK[0] | (pK[1] << 8) | (pK[2] << 16) | (pK[3] << 24)]; // Reconstruct x-axis size + byte[] p2 = new byte[pK.Length - p1.Length - 4]; + Array.Copy(pK, 4, p1, 0, p1.Length); + Array.Copy(pK, 4 + p1.Length, p2, 0, p2.Length); + + Point remotePublic = new Point(new BigInteger(p1), new BigInteger(p2)); + + byte[] secret = curve.Multiply(remotePublic, priv).X.ToByteArray(); // Use the x-coordinate as the shared secret + + // SHA-1 (Common shared secret generation method) + + // Initialize buffers + uint h0 = 0x67452301; + uint h1 = 0xEFCDAB89; + uint h2 = 0x98BADCFE; + uint h3 = 0x10325476; + uint h4 = 0xC3D2E1F0; + + // Pad message + int ml = secret.Length + 1; + byte[] msg = new byte[ml + ((960 - (ml * 8 % 512)) % 512) / 8 + 8]; + Array.Copy(secret, msg, secret.Length); + msg[secret.Length] = 0x80; + long len = secret.Length * 8; + for (int i = 0; i < 8; ++i) msg[msg.Length - 1 - i] = (byte)((len >> (i * 8)) & 255); + //Support.WriteToArray(msg, message.Length * 8, msg.Length - 8); + //for (int i = 0; i <4; ++i) msg[msg.Length - 5 - i] = (byte)(((message.Length*8) >> (i * 8)) & 255); + + int chunks = msg.Length / 64; + + // Perform hashing for each 512-bit block + for (int i = 0; i < chunks; ++i) + { + + // Split block into words + uint[] w = new uint[80]; + for (int j = 0; j < 16; ++j) + w[j] |= (uint)((msg[i * 64 + j * 4] << 24) | (msg[i * 64 + j * 4 + 1] << 16) | (msg[i * 64 + j * 4 + 2] << 8) | (msg[i * 64 + j * 4 + 3] << 0)); + + // Expand words + for (int j = 16; j < 80; ++j) + w[j] = Rot(w[j - 3] ^ w[j - 8] ^ w[j - 14] ^ w[j - 16], 1); + + // Initialize chunk-hash + uint + a = h0, + b = h1, + c = h2, + d = h3, + e = h4; + + // Do hash rounds + for (int t = 0; t < 80; ++t) + { + uint tmp = ((a << 5) | (a >> (27))) + + ( // Round-function + t < 20 ? (b & c) | ((~b) & d) : + t < 40 ? b ^ c ^ d : + t < 60 ? (b & c) | (b & d) | (c & d) : + /*t<80*/ b ^ c ^ d + ) + + e + + ( // K-function + t < 20 ? 0x5A827999 : + t < 40 ? 0x6ED9EBA1 : + t < 60 ? 0x8F1BBCDC : + /*t<80*/ 0xCA62C1D6 + ) + + w[t]; + e = d; + d = c; + c = Rot(b, 30); + b = a; + a = tmp; + } + h0 += a; + h1 += b; + h2 += c; + h3 += d; + h4 += e; + } + + return WriteContiguous(new byte[20], 0, SwapEndian(h0), SwapEndian(h1), SwapEndian(h2), SwapEndian(h3), SwapEndian(h4)); + } + private static uint Rot(uint val, int by) => (val << by) | (val >> (32 - by)); + + // Swap endianness of a given integer + private static uint SwapEndian(uint value) => (uint)(((value >> 24) & (255 << 0)) | ((value >> 8) & (255 << 8)) | ((value << 8) & (255 << 16)) | ((value << 24) & (255 << 24))); + + private static byte[] WriteToArray(byte[] target, uint data, int offset) + { + for (int i = 0; i < 4; ++i) + target[i + offset] = (byte)((data >> (i * 8)) & 255); + return target; + } + + private static byte[] WriteContiguous(byte[] target, int offset, params uint[] data) + { + for (int i = 0; i < data.Length; ++i) WriteToArray(target, data[i], offset + i * 4); + return target; + } + + public static EllipticDiffieHellman Curve25519(BigInteger priv) => new EllipticDiffieHellman(c_25519, c_25519_gen, c_25519_order, priv.ToByteArray()); + public static BigInteger Curve25519_GeneratePrivate(RandomProvider provider) => Support.GenerateRandom(provider, c_25519_order - 2) + 2; + } +} diff --git a/Common/Cryptography/KeyExchange/IKeyExchange.cs b/Common/Cryptography/KeyExchange/IKeyExchange.cs new file mode 100644 index 0000000..598dd92 --- /dev/null +++ b/Common/Cryptography/KeyExchange/IKeyExchange.cs @@ -0,0 +1,15 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Numerics; +using System.Text; +using System.Threading.Tasks; + +namespace Common.Cryptography.KeyExchange +{ + public interface IKeyExchange + { + byte[] GetPublicKey(); + byte[] GetSharedSecret(byte[] pub); + } +} diff --git a/Common/RSA.cs b/Common/Cryptography/RSA.cs similarity index 100% rename from Common/RSA.cs rename to Common/Cryptography/RSA.cs diff --git a/Common/Net.cs b/Common/Net.cs deleted file mode 100644 index 846fe12..0000000 --- a/Common/Net.cs +++ /dev/null @@ -1,505 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Diagnostics; -using System.Net; -using System.Net.Sockets; -using System.Numerics; -using System.Runtime.CompilerServices; -using System.Text; -using System.Threading; -using System.Threading.Tasks; - -namespace Tofvesson.Crypto -{ - public delegate string OnMessageRecieved(string request, Dictionary associations, ref bool stayAlive); - public delegate void OnClientConnectStateChanged(NetClient client, bool connect); - public sealed class NetServer - { - private readonly short port; - private readonly object state_lock = new object(); - private readonly List clients = new List(); - private readonly OnMessageRecieved callback; - private readonly OnClientConnectStateChanged onConn; - private readonly IPAddress ipAddress; - private Socket listener; - private readonly RSA crypto; - private readonly byte[] ser_cache; - private readonly int bufSize; - - private bool state_running = false; - private Thread listenerThread; - - - public int Count - { - get - { - return clients.Count; - } - } - - public bool Running - { - get - { - lock (state_lock) return state_running; - } - - private set - { - lock (state_lock) state_running = value; - } - } - - public NetServer(RSA crypto, short port, OnMessageRecieved callback, OnClientConnectStateChanged onConn, int bufSize = 16384) - { - this.callback = callback; - this.onConn = onConn; - this.bufSize = bufSize; - this.crypto = crypto; - this.port = port; - this.ser_cache = crypto.Serialize(); // Keep this here so we don't wastefully re-serialize every time we get a new client - - IPHostEntry ipHostInfo = Dns.GetHostEntry(Dns.GetHostName()); - this.ipAddress = ipHostInfo.GetIPV4(); - if (ipAddress == null) - ipAddress = IPAddress.Parse("127.0.0.1"); // If there was no IPv4 result in dns lookup, use loopback address - } - - public void StartListening() - { - bool isAlive = false; - object lock_await = new object(); - if(!Running && (listenerThread==null || !listenerThread.IsAlive)) - { - Running = true; - listenerThread = new Thread(() => - { - - this.listener = new Socket(ipAddress.AddressFamily, SocketType.Stream, ProtocolType.Tcp) - { - Blocking = false // When calling Accept() with no queued sockets, listener throws an exception - }; - IPEndPoint localEndPoint = new IPEndPoint(ipAddress, port); - listener.Bind(localEndPoint); - listener.Listen(100); - - byte[] buffer = new byte[bufSize]; - lock (lock_await) isAlive = true; - Stopwatch limiter = new Stopwatch(); - while (Running) - { - limiter.Start(); - // Accept clients - try - { - Socket s = listener.Accept(); - s.Blocking = false; - clients.Add(new ClientStateObject(new NetClient(s, crypto, callback, onConn), buffer)); - } - catch (Exception) - { - if(clients.Count==0) - Thread.Sleep(25); // Wait a bit before trying to accept another client - } - - // Update clients - foreach (ClientStateObject cli in clients.ToArray()) - // Ensure we are still connected to client - if (!(cli.IsConnected() && !cli.Update())) - { - cli.client.onConn(cli.client, false); - clients.Remove(cli); - continue; - } - limiter.Stop(); - if (limiter.ElapsedMilliseconds < 125) Thread.Sleep(250); // If loading data wasn't heavy, take a break - limiter.Reset(); - } - }) - { - Priority = ThreadPriority.Highest, - Name = $"NetServer-${port}" - }; - listenerThread.Start(); - } - - bool rd; - do - { - Thread.Sleep(25); - lock (lock_await) rd = isAlive; - } while (!rd); - } - - public Task StopRunning() - { - Running = false; - - return new TaskFactory().StartNew(() => - { - listenerThread.Join(); - return null; - }); - } - - private class ClientStateObject - { - internal NetClient client; - private bool hasCrypto = false; // Whether or not encrypted communication has been etablished - private Queue buffer = new Queue(); // Incoming data buffer - private int expectedSize = 0; // Expected size of next message - private readonly byte[] buf; - - public ClientStateObject(NetClient client, byte[] buf) - { - this.client = client; - this.buf = buf; - } - - public bool Update() - { - bool stop = client.SyncListener(ref hasCrypto, ref expectedSize, out bool read, buffer, buf); - return stop; - } - public bool IsConnected() => client.IsConnected; - } - } - - public class NetClient - { - private static readonly RandomProvider rp = new CryptoRandomProvider(); - - // Thread state lock for primitive values - private readonly object state_lock = new object(); - - // Primitive state values - private bool state_running = false; - - // Socket event listener - private Thread eventListener; - - // Communication parameters - protected readonly Queue messageBuffer = new Queue(); - public readonly Dictionary assignedValues = new Dictionary(); - protected readonly OnMessageRecieved handler; - protected internal readonly OnClientConnectStateChanged onConn; - protected readonly IPAddress target; - protected readonly int bufSize; - protected readonly RSA decrypt; - protected internal long lastComm = DateTime.Now.Ticks; // Latest comunication event (in ticks) - public RSA RemoteCrypto { get => decrypt; } - - // Connection to peer - protected Socket Connection { get; private set; } - - // State/connection parameters - protected Rijndael128 Crypto { get; private set; } - protected GenericCBC CBC { get; private set; } - public short Port { get; } - protected bool Running - { - get - { - lock (state_lock) return state_running; - } - private set - { - lock (state_lock) state_running = value; - } - } - - protected internal bool IsConnected - { - get - { - return Connection != null && Connection.Connected && !(Connection.Poll(1, SelectMode.SelectRead) && Connection.Available == 0); - } - } - - public bool IsAlive - { - get - { - return Running || (Connection != null && Connection.Connected) || (eventListener != null && eventListener.IsAlive); - } - } - - protected bool ServerSide { get; private set; } - - - public NetClient(Rijndael128 crypto, IPAddress target, short port, OnMessageRecieved handler, OnClientConnectStateChanged onConn, int bufSize = 16384) - { -#pragma warning disable CS0618 // Type or member is obsolete - if (target.AddressFamily==AddressFamily.InterNetwork && target.Address == 16777343) -#pragma warning restore CS0618 // Type or member is obsolete - { - IPAddress addr = Dns.GetHostEntry(Dns.GetHostName()).GetIPV4(); - if (addr != null) target = addr; - } - this.target = target; - Crypto = crypto; - if(crypto!=null) CBC = new PCBC(crypto, rp); - this.bufSize = bufSize; - this.handler = handler; - this.onConn = onConn; - Port = port; - ServerSide = false; - } - - internal NetClient(Socket sock, RSA crypto, OnMessageRecieved handler, OnClientConnectStateChanged onConn) - : this(null, ((IPEndPoint)sock.RemoteEndPoint).Address, (short) ((IPEndPoint)sock.RemoteEndPoint).Port, handler, onConn, -1) - { - decrypt = crypto; - Connection = sock; - Running = true; - ServerSide = true; - - // Initiate crypto-handshake by sending public keys - Connection.Send(NetSupport.WithHeader(crypto.Serialize())); - } - - public virtual void Connect() - { - if (ServerSide) throw new SystemException("Serverside socket cannot connect to a remote peer!"); - NetSupport.DoStateCheck(IsAlive || (eventListener != null && eventListener.IsAlive), false); - Connection = new Socket(SocketType.Stream, ProtocolType.Tcp); - Connection.Connect(target, Port); - Running = true; - eventListener = new Thread(() => - { - bool cryptoEstablished = false; - int mLen = 0; - Queue ibuf = new Queue(); - byte[] buffer = new byte[bufSize]; - Stopwatch limiter = new Stopwatch(); - while (Running) - { - limiter.Start(); - if (SyncListener(ref cryptoEstablished, ref mLen, out bool _, ibuf, buffer)) - break; - if (cryptoEstablished && DateTime.Now.Ticks >= lastComm + (5 * TimeSpan.TicksPerSecond)) - try - { - Connection.Send(NetSupport.WithHeader(new byte[0])); // Send a test packet. (Will just send an empty header to the peer) - lastComm = DateTime.Now.Ticks; - } - catch - { - break; // Connection died - } - limiter.Stop(); - if (limiter.ElapsedMilliseconds < 125) Thread.Sleep(250); // If loading data wasn't heavy, take a break - limiter.Reset(); - } - if (ibuf.Count != 0) Debug.WriteLine("Client socket closed with unread data!"); - onConn(this, false); - }) - { - Priority = ThreadPriority.Highest, - Name = $"NetClient-${target}:${Port}" - }; - eventListener.Start(); - } - - protected internal bool SyncListener(ref bool cryptoEstablished, ref int mLen, out bool acceptedData, Queue ibuf, byte[] buffer) - { - if (cryptoEstablished) - { - lock (messageBuffer) - { - foreach (byte[] message in messageBuffer) Connection.Send(NetSupport.WithHeader(message)); - if(messageBuffer.Count > 0) lastComm = DateTime.Now.Ticks; - messageBuffer.Clear(); - } - } - if (acceptedData = Connection.Available > 0) - { - int read = Connection.Receive(buffer); - ibuf.EnqueueAll(buffer, 0, read); - if (read > 0) lastComm = DateTime.Now.Ticks; - } - if (mLen == 0 && ibuf.Count >= 4) - mLen = Support.ReadInt(ibuf.Dequeue(4), 0); - if (mLen != 0 && ibuf.Count >= mLen) - { - // Got a full message. Parse! - byte[] message = ibuf.Dequeue(mLen); - lastComm = DateTime.Now.Ticks; - - if (!cryptoEstablished) - { - if (ServerSide) - { - var nonceText = new string(Encoding.UTF8.GetChars(message)); - byte[] sign; - if(nonceText.StartsWith("Nonce:") && BigInteger.TryParse(nonceText.Substring(6), out BigInteger parse) && (sign=parse.ToByteArray()).Length <= 512) - { - Connection.Send(NetSupport.WithHeader(decrypt.Encrypt(parse.ToByteArray(), null, true))); - Disconnect(); - return true; - } - - if (Crypto == null) - { - byte[] m = decrypt.Decrypt(message); - if (m.Length == 0) return false; - Crypto = Rijndael128.Deserialize(m, out int _); - } - else - { - byte[] m = decrypt.Decrypt(message); - if (m.Length == 0) return false; - CBC = new PCBC(Crypto, m); - onConn(this, true); - } - } - else - { - // Reconstruct RSA object from remote public keys and use it to encrypt our serialized AES key/iv - RSA asymm = RSA.Deserialize(message, out int _); - Connection.Send(NetSupport.WithHeader(asymm.Encrypt(Crypto.Serialize()))); - Connection.Send(NetSupport.WithHeader(asymm.Encrypt(CBC.IV))); - onConn(this, true); - } - if (CBC != null) - cryptoEstablished = true; - } - else - { - // Decrypt the incoming message - byte[] read = Crypto.Decrypt(message); - - // Read the decrypted message length - int mlenInner = Support.ReadInt(read, 0); - if (mlenInner == 0) return false; // Got a ping packet - - // Send the message to the handler and get a response - bool live = true; - string response = handler(read.SubArray(4, 4+mlenInner).ToUTF8String(), assignedValues, ref live); - - // Send the response (if given one) and drop the connection if the handler tells us to - if (response != null) Connection.Send(NetSupport.WithHeader(Crypto.Encrypt(NetSupport.WithHeader(response.ToUTF8Bytes())))); - if (!live) - { - Running = false; - try - { - Connection.Close(); - } - catch (Exception) { } - return true; - } - } - - // Reset expexted message length - mLen = 0; - } - return false; - } - - /// - /// Disconnect from server - /// - /// - public virtual async Task Disconnect() - { - NetSupport.DoStateCheck(IsAlive, true); - Running = false; - - - return await new TaskFactory().StartNew(() => { eventListener.Join(); return null; }); - } - - // Methods for sending data to the server - public bool TrySend(string message) => TrySend(Encoding.UTF8.GetBytes(message)); - public bool TrySend(byte[] message) - { - try - { - Send(message); - return true; - } - catch (InvalidOperationException) { return false; } - } - public virtual void Send(string message) => Send(Encoding.UTF8.GetBytes(message)); - public virtual void Send(byte[] message) { - NetSupport.DoStateCheck(IsAlive, true); - lock (messageBuffer) messageBuffer.Enqueue(Crypto.Encrypt(NetSupport.WithHeader(message))); - } - - public static RSA CheckServerIdentity(string host, short port, RandomProvider provider, long timeout = 10000) - { - Socket sock = new Socket(SocketType.Stream, ProtocolType.Tcp) - { - ReceiveTimeout = 5000, - SendTimeout = 5000 - }; - sock.Blocking = false; - sock.Connect(host, port); - List read = new List(); - byte[] buf = new byte[1024]; - - if (!Read(sock, read, buf, timeout)) return null; - read.RemoveRange(0, 4); - RSA remote; - try - { - remote = RSA.Deserialize(read.ToArray(), out int _); - } - catch { return null; } - BigInteger cmp; - sock.Send(NetSupport.WithHeader(Encoding.UTF8.GetBytes("Nonce:"+(cmp=BigInteger.Abs(new BigInteger(provider.GetBytes(128))))))); - Thread.Sleep(250); // Give the server ample time to compute the signature - read.Clear(); - if (!Read(sock, read, buf, timeout)) return null; - read.RemoveRange(0, 4); - try - { - if (!cmp.Equals(new BigInteger(remote.Encrypt(read.ToArray())))) return null; - } - catch { return null; } - return remote; // Passed signature check - } - - private static bool Read(Socket sock, List read, byte[] buf, long timeout) - { - Stopwatch sw = new Stopwatch(); - int len = -1; - sw.Start(); - while ((len == -1 || read.Count < 4) && (sw.ElapsedTicks / 10000) < timeout) - { - if (len == -1 && read.Count > 4) - len = Support.ReadInt(read, 0); - - try - { - int r = sock.Receive(buf); - read.AddRange(buf.SubArray(0, r)); - } - catch { } - } - sw.Stop(); - return read.Count - 4 == len && len>0; - } - } - - // Helper methods. WithHeader() should really just be in Support.cs - public static class NetSupport - { - public static byte[] WithHeader(string message) => WithHeader(Encoding.UTF8.GetBytes(message)); - public static byte[] WithHeader(byte[] message) - { - byte[] nmsg = new byte[message.Length + 4]; - Support.WriteToArray(nmsg, message.Length, 0); - Array.Copy(message, 0, nmsg, 4, message.Length); - return nmsg; - } - - public static byte[] FromHeaded(byte[] msg, int offset) => msg.SubArray(offset + 4, offset + 4 + Support.ReadInt(msg, offset)); - - internal static void DoStateCheck(bool state, bool target) { - if (state != target) throw new InvalidOperationException("Bad state!"); - } - } -} diff --git a/Common/NetClient.cs b/Common/NetClient.cs new file mode 100644 index 0000000..0fcff21 --- /dev/null +++ b/Common/NetClient.cs @@ -0,0 +1,271 @@ +using Common.Cryptography.KeyExchange; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Net; +using System.Net.Sockets; +using System.Numerics; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Tofvesson.Crypto; + +namespace Common +{ + public delegate string OnMessageRecieved(string request, Dictionary associations, ref bool stayAlive); + public delegate void OnClientConnectStateChanged(NetClient client, bool connect); + + public class NetClient + { + private static readonly RandomProvider rp = new CryptoRandomProvider(); + + // Thread state lock for primitive values + private readonly object state_lock = new object(); + + // Primitive state values + private bool state_running = false; + + // Socket event listener + private Thread eventListener; + + // Communication parameters + protected readonly Queue messageBuffer = new Queue(); + public readonly Dictionary assignedValues = new Dictionary(); + protected readonly OnMessageRecieved handler; + protected internal readonly OnClientConnectStateChanged onConn; + protected readonly IPAddress target; + protected readonly int bufSize; + protected readonly IKeyExchange exchange; + protected internal long lastComm = DateTime.Now.Ticks; // Latest comunication event (in ticks) + + // Connection to peer + protected Socket Connection { get; private set; } + + // State/connection parameters + protected Rijndael128 Crypto { get; private set; } + protected GenericCBC CBC { get; private set; } + public short Port { get; } + protected bool Running + { + get + { + lock (state_lock) return state_running; + } + private set + { + lock (state_lock) state_running = value; + } + } + + protected internal bool IsConnected + { + get + { + return Connection != null && Connection.Connected && !(Connection.Poll(1, SelectMode.SelectRead) && Connection.Available == 0); + } + } + + public bool IsAlive + { + get + { + return Running || (Connection != null && Connection.Connected) || (eventListener != null && eventListener.IsAlive); + } + } + + protected bool ServerSide { get; private set; } + + + public NetClient(IKeyExchange exchange, IPAddress target, short port, OnMessageRecieved handler, OnClientConnectStateChanged onConn, int bufSize = 16384) + { +#pragma warning disable CS0618 // Type or member is obsolete + if (target.AddressFamily == AddressFamily.InterNetwork && target.Address == 16777343) +#pragma warning restore CS0618 // Type or member is obsolete + { + IPAddress addr = Dns.GetHostEntry(Dns.GetHostName()).GetIPV4(); + if (addr != null) target = addr; + } + this.target = target; + this.exchange = exchange; + this.bufSize = bufSize; + this.handler = handler; + this.onConn = onConn; + Port = port; + ServerSide = false; + } + + internal NetClient(IKeyExchange exchange, Socket sock, OnMessageRecieved handler, OnClientConnectStateChanged onConn) + : this(exchange, ((IPEndPoint)sock.RemoteEndPoint).Address, (short)((IPEndPoint)sock.RemoteEndPoint).Port, handler, onConn, -1) + { + Connection = sock; + Running = true; + ServerSide = true; + + // Initiate crypto-handshake by sending public keys + Connection.Send(NetSupport.WithHeader(exchange.GetPublicKey())); + } + + public virtual void Connect() + { + if (ServerSide) throw new SystemException("Serverside socket cannot connect to a remote peer!"); + NetSupport.DoStateCheck(IsAlive || (eventListener != null && eventListener.IsAlive), false); + Connection = new Socket(SocketType.Stream, ProtocolType.Tcp); + Connection.Connect(target, Port); + Running = true; + eventListener = new Thread(() => + { + bool cryptoEstablished = false; + int mLen = 0; + Queue ibuf = new Queue(); + byte[] buffer = new byte[bufSize]; + Stopwatch limiter = new Stopwatch(); + while (Running) + { + limiter.Start(); + if (SyncListener(ref cryptoEstablished, ref mLen, out bool _, ibuf, buffer)) + break; + if (cryptoEstablished && DateTime.Now.Ticks >= lastComm + (5 * TimeSpan.TicksPerSecond)) + try + { + Connection.Send(NetSupport.WithHeader(new byte[0])); // Send a test packet. (Will just send an empty header to the peer) + lastComm = DateTime.Now.Ticks; + } + catch + { + break; // Connection died + } + limiter.Stop(); + if (limiter.ElapsedMilliseconds < 125) Thread.Sleep(250); // If loading data wasn't heavy, take a break + limiter.Reset(); + } + if (ibuf.Count != 0) Debug.WriteLine("Client socket closed with unread data!"); + onConn(this, false); + }) + { + Priority = ThreadPriority.Highest, + Name = $"NetClient-${target}:${Port}" + }; + eventListener.Start(); + } + + protected internal bool SyncListener(ref bool cryptoEstablished, ref int mLen, out bool acceptedData, Queue ibuf, byte[] buffer) + { + if (cryptoEstablished) + { + lock (messageBuffer) + { + foreach (byte[] message in messageBuffer) Connection.Send(NetSupport.WithHeader(message)); + if (messageBuffer.Count > 0) lastComm = DateTime.Now.Ticks; + messageBuffer.Clear(); + } + } + if (acceptedData = Connection.Available > 0) + { + int read = Connection.Receive(buffer); + ibuf.EnqueueAll(buffer, 0, read); + if (read > 0) lastComm = DateTime.Now.Ticks; + } + if (mLen == 0 && ibuf.Count >= 4) + mLen = Support.ReadInt(ibuf.Dequeue(4), 0); + if (mLen != 0 && ibuf.Count >= mLen) + { + // Got a full message. Parse! + byte[] message = ibuf.Dequeue(mLen); + lastComm = DateTime.Now.Ticks; + + if (!cryptoEstablished) + { + if (!ServerSide) Connection.Send(NetSupport.WithHeader(exchange.GetPublicKey())); + if (message.Length == 0) return false; + Crypto = new Rijndael128(exchange.GetSharedSecret(message).ToHexString()); + CBC = new PCBC(Crypto, rp); + cryptoEstablished = true; + onConn(this, true); + } + else + { + // Decrypt the incoming message + byte[] read = Crypto.Decrypt(message); + + // Read the decrypted message length + int mlenInner = Support.ReadInt(read, 0); + if (mlenInner == 0) return false; // Got a ping packet + + // Send the message to the handler and get a response + bool live = true; + string response = handler(read.SubArray(4, 4 + mlenInner).ToUTF8String(), assignedValues, ref live); + + // Send the response (if given one) and drop the connection if the handler tells us to + if (response != null) Connection.Send(NetSupport.WithHeader(Crypto.Encrypt(NetSupport.WithHeader(response.ToUTF8Bytes())))); + if (!live) + { + Running = false; + try + { + Connection.Close(); + } + catch (Exception) { } + return true; + } + } + + // Reset expexted message length + mLen = 0; + } + return false; + } + + /// + /// Disconnect from server + /// + /// + public virtual async Task Disconnect() + { + NetSupport.DoStateCheck(IsAlive, true); + Running = false; + + + return await new TaskFactory().StartNew(() => { eventListener.Join(); return null; }); + } + + // Methods for sending data to the server + public bool TrySend(string message) => TrySend(Encoding.UTF8.GetBytes(message)); + public bool TrySend(byte[] message) + { + try + { + Send(message); + return true; + } + catch (InvalidOperationException) { return false; } + } + public virtual void Send(string message) => Send(Encoding.UTF8.GetBytes(message)); + public virtual void Send(byte[] message) + { + NetSupport.DoStateCheck(IsAlive, true); + lock (messageBuffer) messageBuffer.Enqueue(Crypto.Encrypt(NetSupport.WithHeader(message))); + } + + private static bool Read(Socket sock, List read, byte[] buf, long timeout) + { + Stopwatch sw = new Stopwatch(); + int len = -1; + sw.Start(); + while ((len == -1 || read.Count < 4) && (sw.ElapsedTicks / 10000) < timeout) + { + if (len == -1 && read.Count > 4) + len = Support.ReadInt(read, 0); + + try + { + int r = sock.Receive(buf); + read.AddRange(buf.SubArray(0, r)); + } + catch { } + } + sw.Stop(); + return read.Count - 4 == len && len > 0; + } + } +} diff --git a/Common/NetServer.cs b/Common/NetServer.cs new file mode 100644 index 0000000..0cd397d --- /dev/null +++ b/Common/NetServer.cs @@ -0,0 +1,165 @@ +using Common.Cryptography.KeyExchange; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Net; +using System.Net.Sockets; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Tofvesson.Crypto; + +namespace Common +{ + public sealed class NetServer + { + private readonly short port; + private readonly object state_lock = new object(); + private readonly List clients = new List(); + private readonly OnMessageRecieved callback; + private readonly OnClientConnectStateChanged onConn; + private readonly IPAddress ipAddress; + private Socket listener; + private readonly IKeyExchange exchange; + private readonly int bufSize; + + private bool state_running = false; + private Thread listenerThread; + + + public int Count + { + get + { + return clients.Count; + } + } + + public bool Running + { + get + { + lock (state_lock) return state_running; + } + + private set + { + lock (state_lock) state_running = value; + } + } + + public NetServer(IKeyExchange exchange, short port, OnMessageRecieved callback, OnClientConnectStateChanged onConn, int bufSize = 16384) + { + this.callback = callback; + this.onConn = onConn; + this.bufSize = bufSize; + this.exchange = exchange; + this.port = port; + + IPHostEntry ipHostInfo = Dns.GetHostEntry(Dns.GetHostName()); + this.ipAddress = ipHostInfo.GetIPV4(); + if (ipAddress == null) + ipAddress = IPAddress.Parse("127.0.0.1"); // If there was no IPv4 result in dns lookup, use loopback address + } + + public void StartListening() + { + bool isAlive = false; + object lock_await = new object(); + if (!Running && (listenerThread == null || !listenerThread.IsAlive)) + { + Running = true; + listenerThread = new Thread(() => + { + + this.listener = new Socket(ipAddress.AddressFamily, SocketType.Stream, ProtocolType.Tcp) + { + Blocking = false // When calling Accept() with no queued sockets, listener throws an exception + }; + IPEndPoint localEndPoint = new IPEndPoint(ipAddress, port); + listener.Bind(localEndPoint); + listener.Listen(100); + + byte[] buffer = new byte[bufSize]; + lock (lock_await) isAlive = true; + Stopwatch limiter = new Stopwatch(); + while (Running) + { + limiter.Start(); + // Accept clients + try + { + Socket s = listener.Accept(); + s.Blocking = false; + clients.Add(new ClientStateObject(new NetClient(exchange, s, callback, onConn), buffer)); + } + catch (Exception) + { + if (clients.Count == 0) + Thread.Sleep(25); // Wait a bit before trying to accept another client + } + + // Update clients + foreach (ClientStateObject cli in clients.ToArray()) + // Ensure we are still connected to client + if (!(cli.IsConnected() && !cli.Update())) + { + cli.client.onConn(cli.client, false); + clients.Remove(cli); + continue; + } + limiter.Stop(); + if (limiter.ElapsedMilliseconds < 125) Thread.Sleep(250); // If loading data wasn't heavy, take a break + limiter.Reset(); + } + }) + { + Priority = ThreadPriority.Highest, + Name = $"NetServer-${port}" + }; + listenerThread.Start(); + } + + bool rd; + do + { + Thread.Sleep(25); + lock (lock_await) rd = isAlive; + } while (!rd); + } + + public Task StopRunning() + { + Running = false; + + return new TaskFactory().StartNew(() => + { + listenerThread.Join(); + return null; + }); + } + + private class ClientStateObject + { + internal NetClient client; + private bool hasCrypto = false; // Whether or not encrypted communication has been etablished + private Queue buffer = new Queue(); // Incoming data buffer + private int expectedSize = 0; // Expected size of next message + private readonly byte[] buf; + + public ClientStateObject(NetClient client, byte[] buf) + { + this.client = client; + this.buf = buf; + } + + public bool Update() + { + bool stop = client.SyncListener(ref hasCrypto, ref expectedSize, out bool read, buffer, buf); + return stop; + } + public bool IsConnected() => client.IsConnected; + } + } +} diff --git a/Common/NetSupport.cs b/Common/NetSupport.cs new file mode 100644 index 0000000..d84bee7 --- /dev/null +++ b/Common/NetSupport.cs @@ -0,0 +1,29 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Tofvesson.Crypto; + +namespace Common +{ + // Helper methods. WithHeader() should really just be in Support.cs + public static class NetSupport + { + public static byte[] WithHeader(string message) => WithHeader(Encoding.UTF8.GetBytes(message)); + public static byte[] WithHeader(byte[] message) + { + byte[] nmsg = new byte[message.Length + 4]; + Support.WriteToArray(nmsg, message.Length, 0); + Array.Copy(message, 0, nmsg, 4, message.Length); + return nmsg; + } + + public static byte[] FromHeaded(byte[] msg, int offset) => msg.SubArray(offset + 4, offset + 4 + Support.ReadInt(msg, offset)); + + internal static void DoStateCheck(bool state, bool target) + { + if (state != target) throw new InvalidOperationException("Bad state!"); + } + } +} diff --git a/Common/Support.cs b/Common/Support.cs index d8f5124..2b7fa55 100644 --- a/Common/Support.cs +++ b/Common/Support.cs @@ -18,15 +18,6 @@ namespace Tofvesson.Crypto public static class Support { // -- Math -- - public static BigInteger Invert(BigInteger b) - { - byte[] arr = b.ToByteArray(); - for (int i = 0; i < arr.Length; ++i) arr[i] ^= 255; - BigInteger integer = new BigInteger(arr); - integer += 1; - return integer; - } - public static BigInteger ModExp(BigInteger b, BigInteger e, BigInteger m) { int count = e.ToByteArray().Length * 8; @@ -42,6 +33,36 @@ namespace Tofvesson.Crypto return result; } + public static BigInteger GenerateRandom(this RandomProvider provider, BigInteger bound) + { + byte[] b = bound.ToByteArray(); + if (b.Length == 0) return 0; + byte b1 = b[b.Length - 1]; + + provider.GetBytes(b); + b[b.Length - 1] %= b1; + return new BigInteger(b); + } + + public static long HighestBit(this BigInteger b) + { + byte[] b1 = b.ToByteArray(); + for (int i = b1.Length - 1; i >= 0; --i) + if (b1[i] != 0) + for (int j = 7; j >= 0; --j) + if ((b1[i] & (1 << j)) != 0) + return i * 8 + j; + return -1; + } + + public static bool BitAt(this BigInteger b, long idx) + { + byte[] b1 = b.ToByteArray(); + return (b1[(int)(idx / 8)] & (1 << ((int)(idx % 8)))) != 0; + } + + public static BigInteger Abs(this BigInteger b) => b < 0 ? -b : b; + /// /// Uses the fermat test a given amount of times to test whether or not a supplied interger is probably prime. /// @@ -280,6 +301,17 @@ namespace Tofvesson.Crypto return result; } + public static string ToHexString(this byte[] value, bool bigEndian = true) + { + StringBuilder builder = new StringBuilder(); + for (int i = bigEndian ? value.Length - 1 : 0; (bigEndian && i >= 0) || (!bigEndian && i < value.Length); i += bigEndian ? -1 : 1) + { + builder.Append((char)((((value[i] >> 4) < 10) ? 48 : 87) + (value[i] >> 4))); + builder.Append((char)((((value[i] & 15) < 10) ? 48 : 87) + (value[i] & 15))); + } + return builder.ToString(); + } + public static void ArrayCopy(IEnumerable source, int sourceOffset, T[] destination, int offset, int length) { for (int i = 0; i < length; ++i) destination[i + offset] = source.ElementAt(i+sourceOffset); diff --git a/Server/Program.cs b/Server/Program.cs index a511251..d3164db 100644 --- a/Server/Program.cs +++ b/Server/Program.cs @@ -1,4 +1,5 @@ using Common; +using Common.Cryptography.KeyExchange; using Server.Properties; using System; using System.Collections.Generic; @@ -24,19 +25,19 @@ namespace Server CryptoRandomProvider random = new CryptoRandomProvider(); - RSA rsa = null;// new RSA(Resources.e_0x200, Resources.n_0x200, Resources.d_0x200); - if (rsa == null) - { - Console.ForegroundColor = ConsoleColor.Red; - Console.Error.WriteLine("No RSA keys available! Server identity will not be verifiable!"); - Console.ForegroundColor = ConsoleColor.Gray; - Console.WriteLine("Generating session-specific RSA-keys..."); - rsa = new RSA(64, 8, 8, 5); - Console.WriteLine("Done!"); - } + //RSA rsa = null;// new RSA(Resources.e_0x200, Resources.n_0x200, Resources.d_0x200); + //if (rsa == null) + //{ + // Console.ForegroundColor = ConsoleColor.Red; + // Console.Error.WriteLine("No RSA keys available! Server identity will not be verifiable!"); + // Console.ForegroundColor = ConsoleColor.Gray; + // Console.WriteLine("Generating session-specific RSA-keys..."); + // rsa = new RSA(64, 8, 8, 5); + // Console.WriteLine("Done!"); + //} NetServer server = new NetServer( - rsa, + EllipticDiffieHellman.Curve25519(EllipticDiffieHellman.Curve25519_GeneratePrivate(random)), 80, (string r, Dictionary associations, ref bool s) => { diff --git a/Server/Server.csproj b/Server/Server.csproj index 3fbdbe6..1baaa49 100644 --- a/Server/Server.csproj +++ b/Server/Server.csproj @@ -34,6 +34,7 @@ +