diff --git a/Client/Context/WelcomeContext.cs b/Client/Context/WelcomeContext.cs index 064dc9d..68a87fc 100644 --- a/Client/Context/WelcomeContext.cs +++ b/Client/Context/WelcomeContext.cs @@ -44,14 +44,13 @@ namespace Client { // Authenticate against server here Show("AuthWait"); - Task prom = interactor.Authenticate(i.Inputs[0].Text, i.Inputs[1].Text); - if(!prom.IsCompleted) prom.RunSynchronously(); - promise = prom.Result; + promise = Promise.AwaitPromise(interactor.Authenticate(i.Inputs[0].Text, i.Inputs[1].Text)); + //promise = prom.Result; promise.Subscribe = response => { Hide("AuthWait"); - if (response.Value.Equals("ERROR")) + if (response.Value.StartsWith("ERROR") || response.Value.Equals("False")) // Auth failure or general error Show("AuthError"); else { @@ -122,7 +121,7 @@ namespace Client else Show("EmptyFieldError"); }; - ((InputView)views.GetNamed("Register")).InputListener = (v, c, i) => + GetView("Register").InputListener = (v, c, i) => { c.BackgroundColor = v.DefaultBackgroundColor; c.SelectBackgroundColor = v.DefaultSelectBackgroundColor; @@ -132,10 +131,11 @@ namespace Client public override void OnCreate() { - token = interactor.RegisterListener((c, s) => - { - if(!s) controller.Popup("The connection to the server was severed! ", 4500, ConsoleColor.DarkRed, () => manager.LoadContext(new NetContext(manager))); - }); + // This was set up back when the connection was persistent + //token = interactor.RegisterListener((c, s) => + //{ + // if(!s) controller.Popup("The connection to the server was severed! ", 4500, ConsoleColor.DarkRed, () => manager.LoadContext(new NetContext(manager))); + //}); // Add the initial view Show("WelcomeScreen"); diff --git a/Client/Networking.cs b/Client/Networking.cs index 4b3b2c0..0c699ca 100644 --- a/Client/Networking.cs +++ b/Client/Networking.cs @@ -262,7 +262,8 @@ namespace Client } public static Promise AwaitPromise(Task p) { - if (!p.IsCompleted) p.RunSynchronously(); + //if (!p.IsCompleted) p.RunSynchronously(); + p.Wait(); return p.Result; } } diff --git a/Common/Cryptography/KeyExchange/EllipticDiffieHellman.cs b/Common/Cryptography/KeyExchange/EllipticDiffieHellman.cs index 285c6c8..99e1153 100644 --- a/Common/Cryptography/KeyExchange/EllipticDiffieHellman.cs +++ b/Common/Cryptography/KeyExchange/EllipticDiffieHellman.cs @@ -4,6 +4,7 @@ using System.Linq; using System.Numerics; using System.Text; using System.Threading.Tasks; +using Tofvesson.Common; using Tofvesson.Crypto; namespace Common.Cryptography.KeyExchange @@ -51,30 +52,25 @@ namespace Common.Cryptography.KeyExchange public byte[] GetPublicKey() { - byte[] p1 = pub.X.ToByteArray(); - byte[] p2 = pub.Y.ToByteArray(); - - byte[] ser = new byte[4 + p1.Length + p2.Length]; - ser[0] = (byte)(p1.Length & 255); - ser[1] = (byte)((p1.Length >> 8) & 255); - ser[2] = (byte)((p1.Length >> 16) & 255); - ser[3] = (byte)((p1.Length >> 24) & 255); - Array.Copy(p1, 0, ser, 4, p1.Length); - Array.Copy(p2, 0, ser, 4 + p1.Length, p2.Length); - - return ser; + using (BitWriter writer = new BitWriter()) + { + writer.WriteByteArray(pub.X.ToByteArray()); + writer.WriteByteArray(pub.Y.ToByteArray(), true); + return writer.Finalize(); + } } public byte[] GetPrivateKey() => priv.ToByteArray(); public byte[] GetSharedSecret(byte[] pK) { - byte[] p1 = new byte[pK[0] | (pK[1] << 8) | (pK[2] << 16) | (pK[3] << 24)]; // Reconstruct x-axis size - byte[] p2 = new byte[pK.Length - p1.Length - 4]; - Array.Copy(pK, 4, p1, 0, p1.Length); - Array.Copy(pK, 4 + p1.Length, p2, 0, p2.Length); + BitReader reader = new BitReader(pK); - Point remotePublic = new Point(new BigInteger(p1), new BigInteger(p2)); + byte[] x = reader.ReadByteArray(); + Point remotePublic = new Point( + new BigInteger(x), + new BigInteger(reader.ReadByteArray(pK.Length - BinaryHelpers.VarIntSize(x.Length) - x.Length)) + ); return curve.Multiply(remotePublic, priv).X.ToByteArray(); // Use the x-coordinate as the shared secret } diff --git a/Common/NetClient.cs b/Common/NetClient.cs index 4cc0369..c2c917d 100644 --- a/Common/NetClient.cs +++ b/Common/NetClient.cs @@ -168,7 +168,19 @@ namespace Common if (read > 0) lastComm = DateTime.Now.Ticks; } if (mLen == 0 && BinaryHelpers.TryReadVarInt(ibuf, 0, out mLen)) + { ibuf.Dequeue(BinaryHelpers.VarIntSize(mLen)); + if(mLen > 65535) // Problematic message size. Just drop connection + { + Running = false; + try + { + Connection.Close(); + } + catch { } + return true; + } + } if (mLen != 0 && ibuf.Count >= mLen) { // Got a full message. Parse! @@ -179,7 +191,20 @@ namespace Common { if (!ServerSide) Connection.Send(NetSupport.WithHeader(exchange.GetPublicKey())); if (message.Length == 0) return false; - Crypto = new Rijndael128(exchange.GetSharedSecret(message).ToHexString()); + try + { + Crypto = new Rijndael128(exchange.GetSharedSecret(message).ToHexString()); + } + catch + { + Running = false; + try + { + Connection.Close(); + } + catch { } + return true; + } CBC = new PCBC(Crypto, rp); cryptoEstablished = true; onConn(this, true); @@ -187,7 +212,21 @@ namespace Common else { // Decrypt the incoming message - byte[] read = Crypto.Decrypt(message); + byte[] read; + try + { + read = Crypto.Decrypt(message); + } + catch // Presumably, something weird happened that wasn't expected. Just drop it... + { + Running = false; + try + { + Connection.Close(); + } + catch { } + return true; + } // Read the decrypted message length int mlenInner = (int) BinaryHelpers.ReadVarInt(read, 0); diff --git a/Common/SHA.cs b/Common/SHA.cs index b9aa4dd..ec0ce9c 100644 --- a/Common/SHA.cs +++ b/Common/SHA.cs @@ -31,15 +31,15 @@ namespace Tofvesson.Crypto //for (int i = 0; i <4; ++i) msg[msg.Length - 5 - i] = (byte)(((message.Length*8) >> (i * 8)) & 255); int chunks = msg.Length / 64; - - // Split block into words (allocated out here to prevent massive garbage buildup) - uint[] w = new uint[80]; - + // Perform hashing for each 512-bit block for (int i = 0; i (byte)((idx < 4 ? i0 : idx < 8 ? i1 : idx < 12 ? i2 : idx < 16 ? i3 : i4)>>(8*(idx%4))); } + private static readonly uint[] block = new uint[80]; public static SHA1Result SHA1_Opt(byte[] message) { SHA1Result result = new SHA1Result @@ -114,18 +115,26 @@ namespace Tofvesson.Crypto int chunks = max / 64; // Replaces the recurring allocation of 80 uints - uint ComputeIndex(int block, int idx) + /*uint ComputeIndex(int block, int idx) { if (idx < 16) return (uint)((GetMsg(block * 64 + idx * 4) << 24) | (GetMsg(block * 64 + idx * 4 + 1) << 16) | (GetMsg(block * 64 + idx * 4 + 2) << 8) | (GetMsg(block * 64 + idx * 4 + 3) << 0)); else return Rot(ComputeIndex(block, idx - 3) ^ ComputeIndex(block, idx - 8) ^ ComputeIndex(block, idx - 14) ^ ComputeIndex(block, idx - 16), 1); - } + }*/ // Perform hashing for each 512-bit block for (int i = 0; i < chunks; ++i) { + // Compute initial source data from padded message + for (int j = 0; j < 16; ++j) + block[j] = (uint)((GetMsg(i * 64 + j * 4) << 24) | (GetMsg(i * 64 + j * 4 + 1) << 16) | (GetMsg(i * 64 + j * 4 + 2) << 8) | (GetMsg(i * 64 + j * 4 + 3) << 0)); + + // Expand words + for (int j = 16; j < 80; ++j) + block[j] = Rot(block[j - 3] ^ block[j - 8] ^ block[j - 14] ^ block[j - 16], 1); + // Initialize chunk-hash uint a = result.i0, @@ -137,7 +146,7 @@ namespace Tofvesson.Crypto // Do hash rounds for (int t = 0; t < 80; ++t) { - uint tmp = Rot(a, 5) + func(t, b, c, d) + e + K(t) + ComputeIndex(i, t); + uint tmp = Rot(a, 5) + func(t, b, c, d) + e + K(t) + block[i]; e = d; d = c; c = Rot(b, 30); diff --git a/Common/Support.cs b/Common/Support.cs index 69206c9..97d66df 100644 --- a/Common/Support.cs +++ b/Common/Support.cs @@ -397,6 +397,7 @@ namespace Tofvesson.Crypto return array; } + public static bool EqualsIgnoreCase(this string s, string s1) => s.ToLower().Equals(s1.ToLower()); public static string ToUTF8String(this byte[] b) => new string(Encoding.UTF8.GetChars(b)); public static byte[] ToUTF8Bytes(this string s) => Encoding.UTF8.GetBytes(s); diff --git a/Server/Database.cs b/Server/Database.cs index 480d6af..d39039f 100644 --- a/Server/Database.cs +++ b/Server/Database.cs @@ -50,7 +50,6 @@ namespace Server public void AddUser(User entry) => AddUser(entry, true); private void AddUser(User entry, bool withFlush) { - entry = ToEncoded(entry); for (int i = 0; i < loadedUsers.Count; ++i) if (entry.Equals(loadedUsers[i])) loadedUsers[i] = entry; @@ -117,7 +116,7 @@ namespace Server wroteNode = false; if (reader.Name.Equals("User")) { - User u = User.Parse(ReadEntry(reader), this); + User u = FromEncoded(User.Parse(ReadEntry(reader), this)); if (u != null) { bool shouldWrite = true; @@ -176,6 +175,7 @@ namespace Server private static void WriteUser(XmlWriter writer, User u) { + u = ToEncoded(u); writer.WriteStartElement("User"); if (u.IsAdministrator) writer.WriteAttributeString("admin", "", "true"); writer.WriteElementString("Name", u.Name); @@ -238,16 +238,15 @@ namespace Server public User FirstUser(Predicate p) { if (p == null) return null; // Done to conveniently handle system insertions - User u; foreach (var entry in loadedUsers) - if (p(u=FromEncoded(entry))) - return u; + if (p(entry)) + return entry; foreach (var entry in changeList) - if (p(u=FromEncoded(entry))) + if (p(entry)) { - if (!loadedUsers.Contains(entry)) loadedUsers.Add(u); - return u; + if (!loadedUsers.Contains(entry)) loadedUsers.Add(entry); + return entry; } using (var reader = XmlReader.Create(DatabaseName)) @@ -257,8 +256,8 @@ namespace Server { if (reader.Name.Equals("User")) { - User n = User.Parse(ReadEntry(reader), this); - if (n != null && p(FromEncoded(n))) + User n = FromEncoded(User.Parse(ReadEntry(reader), this)); + if (n != null && p(n)) { if (!loadedUsers.Contains(n)) loadedUsers.Add(n); return n; @@ -288,12 +287,12 @@ namespace Server Transaction tx = new Transaction(from == null ? "System" : from.Name, to.Name, amount, message, fromAccount, toAccount); toAcc.History.Add(tx); toAcc.balance += amount; - AddUser(to); + AddUser(to, false); if (from != null) { fromAcc.History.Add(tx); fromAcc.balance -= amount; - AddUser(from); + AddUser(from, false); } return true; } @@ -301,20 +300,23 @@ namespace Server public User[] Users(Predicate p) { List l = new List(); - User u; foreach (var entry in changeList) - if (p(u=FromEncoded(entry))) + if (p(entry)) + l.Add(entry); + + foreach(var entry in loadedUsers) + if (!l.Contains(entry) && p(entry)) l.Add(entry); using (var reader = XmlReader.Create(DatabaseName)) { if (!Traverse(reader, MasterEntry)) return null; - - while (SkipSpaces(reader) && reader.NodeType != XmlNodeType.EndElement) + + while (((reader.NodeType==XmlNodeType.Element && reader.Name.Equals("User")) || SkipSpaces(reader)) && reader.NodeType != XmlNodeType.EndElement) { if (reader.NodeType == XmlNodeType.EndElement) break; User e = User.Parse(ReadEntry(reader), this); - if (e!=null && p(e=FromEncoded(e))) l.Add(e); + if (e!=null && !l.Contains(e = FromEncoded(e)) && p(e)) l.Add(e); } } return l.ToArray(); diff --git a/Server/Output.cs b/Server/Output.cs index 7ce345c..ed9c5fb 100644 --- a/Server/Output.cs +++ b/Server/Output.cs @@ -1,38 +1,81 @@ -using System; -using System.Collections.Generic; +using Common; +using System; using System.IO; -using System.Linq; -using System.Text; -using System.Threading.Tasks; namespace Server { public static class Output { - public static void WriteLine(string message, bool error = false) + // Fancy timestamped output + private static readonly TextWriter stampedError = new TimeStampWriter(Console.Error, "HH:mm:ss.fff"); + private static readonly TextWriter stampedOutput = new TimeStampWriter(Console.Out, "HH:mm:ss.fff"); + private static bool overwrite = false; + + public static Action OnNewLine { get; set; } + + public static void WriteLine(object message, bool error = false, bool timeStamp = true) { - if (error) Error(message); - else Info(message); + if (error) Error(message, true, timeStamp); + else Info(message, true, timeStamp); } - public static void Write(string message, bool error) + public static void Write(object message, bool error = false, bool timeStamp = true) { - if (error) Error(message, false); - else Info(message, false); + if (error) Error(message, false, timeStamp); + else Info(message, false, timeStamp); } - public static void Positive(string message, bool newline = true) => Write(message, ConsoleColor.DarkGreen, ConsoleColor.Black, newline, Console.Out); - public static void Info(string message, bool newline = true) => Write(message, ConsoleColor.Gray, ConsoleColor.Black, newline, Console.Out); - public static void Error(string message, bool newline = true) => Write(message, ConsoleColor.Gray, ConsoleColor.Black, newline, Console.Out); - public static void Fatal(string message, bool newline = true) => Write(message, ConsoleColor.Gray, ConsoleColor.Black, newline, Console.Error); + + public static void WriteOverwritable(string message) + { + Info(message, false, false); + overwrite = true; + } + + public static void Raw(object message, bool newline = true) => Info(message, newline, false); + public static void RawErr(object message, bool newline = true) => Error(message, newline, false); + + public static void Positive(object message, bool newline = true, bool timeStamp = true) => + Write(message == null ? "null" : message.ToString(), ConsoleColor.DarkGreen, ConsoleColor.Black, newline, timeStamp ? stampedOutput : Console.Out); + public static void Info(object message, bool newline = true, bool timeStamp = true) => + Write(message == null ? "null" : message.ToString(), ConsoleColor.Gray, ConsoleColor.Black, newline, timeStamp ? stampedOutput : Console.Out); + public static void Error(object message, bool newline = true, bool timeStamp = true) => + Write(message == null ? "null" : message.ToString(), ConsoleColor.Red, ConsoleColor.Black, newline, timeStamp ? stampedOutput : Console.Out); + public static void Fatal(object message, bool newline = true, bool timeStamp = true) => + Write(message == null ? "null" : message.ToString(), ConsoleColor.Red, ConsoleColor.White, newline, timeStamp ? stampedError : Console.Error); private static void Write(string message, ConsoleColor f, ConsoleColor b, bool newline, TextWriter writer) { + if (overwrite) ClearLine(); + overwrite = false; ConsoleColor f1 = Console.ForegroundColor, b1 = Console.BackgroundColor; Console.ForegroundColor = f; Console.BackgroundColor = b; writer.Write(message); - if (newline) writer.WriteLine(); + if (newline) + { + writer.WriteLine(); + OnNewLine?.Invoke(); + } Console.ForegroundColor = f1; Console.BackgroundColor = b1; } + + public static string ReadLine() + { + string s = Console.ReadLine(); + overwrite = false; + OnNewLine?.Invoke(); + return s; + } + + private static void ClearLine(int from = 0) + { + from = Math.Min(from, Console.WindowWidth); + int y = Console.CursorTop; + Console.SetCursorPosition(from, y); + char[] msg = new char[Console.WindowWidth - from]; + for (int i = 0; i < msg.Length; ++i) msg[i] = ' '; + Console.Write(new string(msg)); + Console.SetCursorPosition(from, y); + } } } diff --git a/Server/OutputFormatter.cs b/Server/OutputFormatter.cs new file mode 100644 index 0000000..ef7d7f8 --- /dev/null +++ b/Server/OutputFormatter.cs @@ -0,0 +1,58 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Server +{ + public sealed class OutputFormatter + { + private readonly List> lines = new List>(); + private int leftLen = 0; + private readonly int minPad; + private readonly string prepend, delimiter, postpad, trail; + + + public OutputFormatter(int minPad = 1, string prepend = "", string delimiter = "", string postpad = "", string trail = "") + { + this.prepend = prepend; + this.delimiter = delimiter; + this.postpad = postpad; + this.trail = trail; + this.minPad = Math.Abs(minPad); + } + + public OutputFormatter Append(string key, string value) + { + lines.Add(new Tuple(key, value)); + leftLen = Math.Max(key.Length + minPad, leftLen); + return this; + } + + public string GetString() + { + StringBuilder builder = new StringBuilder(); + foreach (var line in lines) + builder + .Append(prepend) + .Append(line.Item1) + .Append(delimiter) + .Append(Pad(line.Item1, leftLen)) + .Append(postpad) + .Append(line.Item2) + .Append(trail) + .Append('\n'); + builder.Length -= 1; + return builder.ToString(); + } + + private static string Pad(string msg, int length) + { + if (msg.Length >= length) return ""; + char[] c = new char[length - msg.Length]; + for (int i = 0; i < c.Length; ++i) c[i] = ' '; + return new string(c); + } + } +} diff --git a/Server/Program.cs b/Server/Program.cs index cdae912..47f1ffb 100644 --- a/Server/Program.cs +++ b/Server/Program.cs @@ -17,10 +17,6 @@ namespace Server private const string VERBOSE_RESPONSE = "@string/REMOTE_"; public static void Main(string[] args) { - // Set up fancy output - Console.SetError(new TimeStampWriter(Console.Error, "HH:mm:ss.fff")); - Console.SetOut(new TimeStampWriter(Console.Out, "HH:mm:ss.fff")); - // Create a client session manager and allow sessions to remain valid for up to 5 minutes of inactivity (300 seconds) SessionManager manager = new SessionManager(300 * TimeSpan.TicksPerSecond, 20); @@ -180,7 +176,7 @@ namespace Server return GenerateResponse(id, "ERROR"); } user.accounts.Add(new Database.Account(user, 0, name)); - db.AddUser(user); // Notify database of the update + db.UpdateUser(user); // Notify database of the update return GenerateResponse(id, true); } case "Account_Transaction_Create": @@ -302,12 +298,88 @@ namespace Server (c, b) => // Called every time a client connects or disconnects (conn + dc with every command/request) { // Output.Info($"Client has {(b ? "C" : "Disc")}onnected"); - if(!b && c.assignedValues.ContainsKey("session")) - manager.Expire(c.assignedValues["session"]); + //if(!b && c.assignedValues.ContainsKey("session")) + // manager.Expire(c.assignedValues["session"]); }); server.StartListening(); - - Console.ReadLine(); + + + string commands = + new OutputFormatter(4, " ", "", "- ") + .Append("help", "Show this help menu") + .Append("stop", "Stop server") + .Append("sessions", "Show active client sessions") + .Append("list {admin}", "Show registered users. Add \"admin\" to only list admins") + .Append("admin [user] {true/false}", "Show or set admin status for a user") + .GetString(); + + Output.OnNewLine = () => Output.WriteOverwritable(">> "); + Output.OnNewLine(); + // Server command loop + while (true) + { + string cmd = Output.ReadLine(); + string[] parts = cmd.Split(); + + if (cmd.EqualsIgnoreCase("stop")) break; + else if (cmd.EqualsIgnoreCase("sessions")) + { + StringBuilder builder = new StringBuilder(); + manager.Update(); // Ensure that we don't show expired sessions (artifacts exist until it is necessary to remove them) + foreach (var session in manager.Sessions) + builder.Append(session.user.Name).Append(" : ").Append(session.sessionID).Append('\n'); + if (builder.Length == 0) builder.Append("There are no active sessions at the moment"); + else builder.Length = builder.Length - 1; + Output.Raw(builder); + } + else if (parts[0].EqualsIgnoreCase("admin")) + { + if (parts.Length == 1) Output.Raw("Usage: admin [username] {true/false}"); + else if (parts.Length == 2) + { + Database.User user = db.GetUser(parts[1]); + if (user == null) Output.RawErr($"User \"{parts[1]}\" could not be found in the databse!"); + else Output.Raw(user.IsAdministrator); + } + else if (parts.Length == 3) + { + Database.User user = db.GetUser(parts[1]); + if (user == null) Output.RawErr($"User \"{parts[1]}\" could not be found in the databse!"); + else if (!bool.TryParse(parts[2].ToLower(), out bool admin)) Output.RawErr($"Could not interpret \"{parts[2]}\""); + else + { + if (user.IsAdministrator == admin) Output.Info("The given administrator state was already set"); + else if (admin) Output.Raw("User is now an administrator"); + else Output.Raw("User is no longer an administrator"); + user.IsAdministrator = admin; + db.AddUser(user); + } + } + else Output.RawErr("Too many parameters!"); + } + else if (parts[0].EqualsIgnoreCase("list")) + { + if (parts.Length > 2) Output.RawErr("Too many parameters!"); + else + { + bool filter = parts.Length > 1, filterAdmin = filter && parts[1].EqualsIgnoreCase("admin"); + + StringBuilder builder = new StringBuilder(); + foreach (var user in db.Users(u => !filter || (filterAdmin && u.IsAdministrator))) + builder.Append(user.Name).Append('\n'); + if (builder.Length != 0) + { + builder.Length = builder.Length - 1; + Output.Raw(builder); + } + } + } + else if (cmd.EqualsIgnoreCase("help")) + { + Output.Raw("Available commands:\n" + commands); + } + else if (cmd.Length != 0) Output.RawErr("Unknown command. Use command \"help\" to view available commands"); + } server.StopRunning(); } diff --git a/Server/Server.csproj b/Server/Server.csproj index 73fdd3d..058329b 100644 --- a/Server/Server.csproj +++ b/Server/Server.csproj @@ -45,6 +45,7 @@ + diff --git a/Server/SessionManager.cs b/Server/SessionManager.cs index 268ebc2..7d823d1 100644 --- a/Server/SessionManager.cs +++ b/Server/SessionManager.cs @@ -11,6 +11,8 @@ namespace Server private readonly long timeout; private readonly int sidLength; + public List Sessions { get => sessions; } + public SessionManager(long timeout, int sidLength = 10) { this.timeout = timeout;