Fixed some issues with asynchronous code

Updated some old code to use new methods
ECDH implementation now uses BitWriter/BitReader to serialize/deserialize
Added error handling to NetClient for predictable cases of error
Fixed regular SHA1 implementation
Partially remade optimized SHA1 implementation as a hybrid implementation between minimum allocation and minimum processor overhead
Fixed how Databse manages serialization/deserialization of Users
Updated Output class to support any type
Output now supports overwritable lines
Added OutputFormatter to simplify creating-column output
Sessions keys are no longer invalidated when client-server connection is closed
Added command system to server for easier administration
This commit is contained in:
Gabriel Tofvesson 2018-04-26 23:51:41 +02:00
parent 100f5a32be
commit 98a6557000
12 changed files with 303 additions and 79 deletions

View File

@ -44,14 +44,13 @@ namespace Client
{
// Authenticate against server here
Show("AuthWait");
Task<Promise> prom = interactor.Authenticate(i.Inputs[0].Text, i.Inputs[1].Text);
if(!prom.IsCompleted) prom.RunSynchronously();
promise = prom.Result;
promise = Promise.AwaitPromise(interactor.Authenticate(i.Inputs[0].Text, i.Inputs[1].Text));
//promise = prom.Result;
promise.Subscribe =
response =>
{
Hide("AuthWait");
if (response.Value.Equals("ERROR"))
if (response.Value.StartsWith("ERROR") || response.Value.Equals("False")) // Auth failure or general error
Show("AuthError");
else
{
@ -122,7 +121,7 @@ namespace Client
else Show("EmptyFieldError");
};
((InputView)views.GetNamed("Register")).InputListener = (v, c, i) =>
GetView<InputView>("Register").InputListener = (v, c, i) =>
{
c.BackgroundColor = v.DefaultBackgroundColor;
c.SelectBackgroundColor = v.DefaultSelectBackgroundColor;
@ -132,10 +131,11 @@ namespace Client
public override void OnCreate()
{
token = interactor.RegisterListener((c, s) =>
{
if(!s) controller.Popup("The connection to the server was severed! ", 4500, ConsoleColor.DarkRed, () => manager.LoadContext(new NetContext(manager)));
});
// This was set up back when the connection was persistent
//token = interactor.RegisterListener((c, s) =>
//{
// if(!s) controller.Popup("The connection to the server was severed! ", 4500, ConsoleColor.DarkRed, () => manager.LoadContext(new NetContext(manager)));
//});
// Add the initial view
Show("WelcomeScreen");

View File

@ -262,7 +262,8 @@ namespace Client
}
public static Promise AwaitPromise(Task<Promise> p)
{
if (!p.IsCompleted) p.RunSynchronously();
//if (!p.IsCompleted) p.RunSynchronously();
p.Wait();
return p.Result;
}
}

View File

@ -4,6 +4,7 @@ using System.Linq;
using System.Numerics;
using System.Text;
using System.Threading.Tasks;
using Tofvesson.Common;
using Tofvesson.Crypto;
namespace Common.Cryptography.KeyExchange
@ -51,30 +52,25 @@ namespace Common.Cryptography.KeyExchange
public byte[] GetPublicKey()
{
byte[] p1 = pub.X.ToByteArray();
byte[] p2 = pub.Y.ToByteArray();
byte[] ser = new byte[4 + p1.Length + p2.Length];
ser[0] = (byte)(p1.Length & 255);
ser[1] = (byte)((p1.Length >> 8) & 255);
ser[2] = (byte)((p1.Length >> 16) & 255);
ser[3] = (byte)((p1.Length >> 24) & 255);
Array.Copy(p1, 0, ser, 4, p1.Length);
Array.Copy(p2, 0, ser, 4 + p1.Length, p2.Length);
return ser;
using (BitWriter writer = new BitWriter())
{
writer.WriteByteArray(pub.X.ToByteArray());
writer.WriteByteArray(pub.Y.ToByteArray(), true);
return writer.Finalize();
}
}
public byte[] GetPrivateKey() => priv.ToByteArray();
public byte[] GetSharedSecret(byte[] pK)
{
byte[] p1 = new byte[pK[0] | (pK[1] << 8) | (pK[2] << 16) | (pK[3] << 24)]; // Reconstruct x-axis size
byte[] p2 = new byte[pK.Length - p1.Length - 4];
Array.Copy(pK, 4, p1, 0, p1.Length);
Array.Copy(pK, 4 + p1.Length, p2, 0, p2.Length);
BitReader reader = new BitReader(pK);
Point remotePublic = new Point(new BigInteger(p1), new BigInteger(p2));
byte[] x = reader.ReadByteArray();
Point remotePublic = new Point(
new BigInteger(x),
new BigInteger(reader.ReadByteArray(pK.Length - BinaryHelpers.VarIntSize(x.Length) - x.Length))
);
return curve.Multiply(remotePublic, priv).X.ToByteArray(); // Use the x-coordinate as the shared secret
}

View File

@ -168,7 +168,19 @@ namespace Common
if (read > 0) lastComm = DateTime.Now.Ticks;
}
if (mLen == 0 && BinaryHelpers.TryReadVarInt(ibuf, 0, out mLen))
{
ibuf.Dequeue(BinaryHelpers.VarIntSize(mLen));
if(mLen > 65535) // Problematic message size. Just drop connection
{
Running = false;
try
{
Connection.Close();
}
catch { }
return true;
}
}
if (mLen != 0 && ibuf.Count >= mLen)
{
// Got a full message. Parse!
@ -179,7 +191,20 @@ namespace Common
{
if (!ServerSide) Connection.Send(NetSupport.WithHeader(exchange.GetPublicKey()));
if (message.Length == 0) return false;
Crypto = new Rijndael128(exchange.GetSharedSecret(message).ToHexString());
try
{
Crypto = new Rijndael128(exchange.GetSharedSecret(message).ToHexString());
}
catch
{
Running = false;
try
{
Connection.Close();
}
catch { }
return true;
}
CBC = new PCBC(Crypto, rp);
cryptoEstablished = true;
onConn(this, true);
@ -187,7 +212,21 @@ namespace Common
else
{
// Decrypt the incoming message
byte[] read = Crypto.Decrypt(message);
byte[] read;
try
{
read = Crypto.Decrypt(message);
}
catch // Presumably, something weird happened that wasn't expected. Just drop it...
{
Running = false;
try
{
Connection.Close();
}
catch { }
return true;
}
// Read the decrypted message length
int mlenInner = (int) BinaryHelpers.ReadVarInt(read, 0);

View File

@ -31,15 +31,15 @@ namespace Tofvesson.Crypto
//for (int i = 0; i <4; ++i) msg[msg.Length - 5 - i] = (byte)(((message.Length*8) >> (i * 8)) & 255);
int chunks = msg.Length / 64;
// Split block into words (allocated out here to prevent massive garbage buildup)
uint[] w = new uint[80];
// Perform hashing for each 512-bit block
for (int i = 0; i<chunks; ++i)
{
// Split block into words (allocated out here to prevent massive garbage buildup)
uint[] w = new uint[80];
// Compute initial source data from padded message
for(int j = 0; j<16; ++j)
for (int j = 0; j<16; ++j)
w[j] |= (uint) ((msg[i * 64 + j * 4] << 24) | (msg[i * 64 + j * 4 + 1] << 16) | (msg[i * 64 + j * 4 + 2] << 8) | (msg[i * 64 + j * 4 + 3] << 0));
// Expand words
@ -81,6 +81,7 @@ namespace Tofvesson.Crypto
public uint i0, i1, i2, i3, i4;
public byte Get(int idx) => (byte)((idx < 4 ? i0 : idx < 8 ? i1 : idx < 12 ? i2 : idx < 16 ? i3 : i4)>>(8*(idx%4)));
}
private static readonly uint[] block = new uint[80];
public static SHA1Result SHA1_Opt(byte[] message)
{
SHA1Result result = new SHA1Result
@ -114,18 +115,26 @@ namespace Tofvesson.Crypto
int chunks = max / 64;
// Replaces the recurring allocation of 80 uints
uint ComputeIndex(int block, int idx)
/*uint ComputeIndex(int block, int idx)
{
if (idx < 16)
return (uint)((GetMsg(block * 64 + idx * 4) << 24) | (GetMsg(block * 64 + idx * 4 + 1) << 16) | (GetMsg(block * 64 + idx * 4 + 2) << 8) | (GetMsg(block * 64 + idx * 4 + 3) << 0));
else
return Rot(ComputeIndex(block, idx - 3) ^ ComputeIndex(block, idx - 8) ^ ComputeIndex(block, idx - 14) ^ ComputeIndex(block, idx - 16), 1);
}
}*/
// Perform hashing for each 512-bit block
for (int i = 0; i < chunks; ++i)
{
// Compute initial source data from padded message
for (int j = 0; j < 16; ++j)
block[j] = (uint)((GetMsg(i * 64 + j * 4) << 24) | (GetMsg(i * 64 + j * 4 + 1) << 16) | (GetMsg(i * 64 + j * 4 + 2) << 8) | (GetMsg(i * 64 + j * 4 + 3) << 0));
// Expand words
for (int j = 16; j < 80; ++j)
block[j] = Rot(block[j - 3] ^ block[j - 8] ^ block[j - 14] ^ block[j - 16], 1);
// Initialize chunk-hash
uint
a = result.i0,
@ -137,7 +146,7 @@ namespace Tofvesson.Crypto
// Do hash rounds
for (int t = 0; t < 80; ++t)
{
uint tmp = Rot(a, 5) + func(t, b, c, d) + e + K(t) + ComputeIndex(i, t);
uint tmp = Rot(a, 5) + func(t, b, c, d) + e + K(t) + block[i];
e = d;
d = c;
c = Rot(b, 30);

View File

@ -397,6 +397,7 @@ namespace Tofvesson.Crypto
return array;
}
public static bool EqualsIgnoreCase(this string s, string s1) => s.ToLower().Equals(s1.ToLower());
public static string ToUTF8String(this byte[] b) => new string(Encoding.UTF8.GetChars(b));
public static byte[] ToUTF8Bytes(this string s) => Encoding.UTF8.GetBytes(s);

View File

@ -50,7 +50,6 @@ namespace Server
public void AddUser(User entry) => AddUser(entry, true);
private void AddUser(User entry, bool withFlush)
{
entry = ToEncoded(entry);
for (int i = 0; i < loadedUsers.Count; ++i)
if (entry.Equals(loadedUsers[i]))
loadedUsers[i] = entry;
@ -117,7 +116,7 @@ namespace Server
wroteNode = false;
if (reader.Name.Equals("User"))
{
User u = User.Parse(ReadEntry(reader), this);
User u = FromEncoded(User.Parse(ReadEntry(reader), this));
if (u != null)
{
bool shouldWrite = true;
@ -176,6 +175,7 @@ namespace Server
private static void WriteUser(XmlWriter writer, User u)
{
u = ToEncoded(u);
writer.WriteStartElement("User");
if (u.IsAdministrator) writer.WriteAttributeString("admin", "", "true");
writer.WriteElementString("Name", u.Name);
@ -238,16 +238,15 @@ namespace Server
public User FirstUser(Predicate<User> p)
{
if (p == null) return null; // Done to conveniently handle system insertions
User u;
foreach (var entry in loadedUsers)
if (p(u=FromEncoded(entry)))
return u;
if (p(entry))
return entry;
foreach (var entry in changeList)
if (p(u=FromEncoded(entry)))
if (p(entry))
{
if (!loadedUsers.Contains(entry)) loadedUsers.Add(u);
return u;
if (!loadedUsers.Contains(entry)) loadedUsers.Add(entry);
return entry;
}
using (var reader = XmlReader.Create(DatabaseName))
@ -257,8 +256,8 @@ namespace Server
{
if (reader.Name.Equals("User"))
{
User n = User.Parse(ReadEntry(reader), this);
if (n != null && p(FromEncoded(n)))
User n = FromEncoded(User.Parse(ReadEntry(reader), this));
if (n != null && p(n))
{
if (!loadedUsers.Contains(n)) loadedUsers.Add(n);
return n;
@ -288,12 +287,12 @@ namespace Server
Transaction tx = new Transaction(from == null ? "System" : from.Name, to.Name, amount, message, fromAccount, toAccount);
toAcc.History.Add(tx);
toAcc.balance += amount;
AddUser(to);
AddUser(to, false);
if (from != null)
{
fromAcc.History.Add(tx);
fromAcc.balance -= amount;
AddUser(from);
AddUser(from, false);
}
return true;
}
@ -301,20 +300,23 @@ namespace Server
public User[] Users(Predicate<User> p)
{
List<User> l = new List<User>();
User u;
foreach (var entry in changeList)
if (p(u=FromEncoded(entry)))
if (p(entry))
l.Add(entry);
foreach(var entry in loadedUsers)
if (!l.Contains(entry) && p(entry))
l.Add(entry);
using (var reader = XmlReader.Create(DatabaseName))
{
if (!Traverse(reader, MasterEntry)) return null;
while (SkipSpaces(reader) && reader.NodeType != XmlNodeType.EndElement)
while (((reader.NodeType==XmlNodeType.Element && reader.Name.Equals("User")) || SkipSpaces(reader)) && reader.NodeType != XmlNodeType.EndElement)
{
if (reader.NodeType == XmlNodeType.EndElement) break;
User e = User.Parse(ReadEntry(reader), this);
if (e!=null && p(e=FromEncoded(e))) l.Add(e);
if (e!=null && !l.Contains(e = FromEncoded(e)) && p(e)) l.Add(e);
}
}
return l.ToArray();

View File

@ -1,38 +1,81 @@
using System;
using System.Collections.Generic;
using Common;
using System;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace Server
{
public static class Output
{
public static void WriteLine(string message, bool error = false)
// Fancy timestamped output
private static readonly TextWriter stampedError = new TimeStampWriter(Console.Error, "HH:mm:ss.fff");
private static readonly TextWriter stampedOutput = new TimeStampWriter(Console.Out, "HH:mm:ss.fff");
private static bool overwrite = false;
public static Action OnNewLine { get; set; }
public static void WriteLine(object message, bool error = false, bool timeStamp = true)
{
if (error) Error(message);
else Info(message);
if (error) Error(message, true, timeStamp);
else Info(message, true, timeStamp);
}
public static void Write(string message, bool error)
public static void Write(object message, bool error = false, bool timeStamp = true)
{
if (error) Error(message, false);
else Info(message, false);
if (error) Error(message, false, timeStamp);
else Info(message, false, timeStamp);
}
public static void Positive(string message, bool newline = true) => Write(message, ConsoleColor.DarkGreen, ConsoleColor.Black, newline, Console.Out);
public static void Info(string message, bool newline = true) => Write(message, ConsoleColor.Gray, ConsoleColor.Black, newline, Console.Out);
public static void Error(string message, bool newline = true) => Write(message, ConsoleColor.Gray, ConsoleColor.Black, newline, Console.Out);
public static void Fatal(string message, bool newline = true) => Write(message, ConsoleColor.Gray, ConsoleColor.Black, newline, Console.Error);
public static void WriteOverwritable(string message)
{
Info(message, false, false);
overwrite = true;
}
public static void Raw(object message, bool newline = true) => Info(message, newline, false);
public static void RawErr(object message, bool newline = true) => Error(message, newline, false);
public static void Positive(object message, bool newline = true, bool timeStamp = true) =>
Write(message == null ? "null" : message.ToString(), ConsoleColor.DarkGreen, ConsoleColor.Black, newline, timeStamp ? stampedOutput : Console.Out);
public static void Info(object message, bool newline = true, bool timeStamp = true) =>
Write(message == null ? "null" : message.ToString(), ConsoleColor.Gray, ConsoleColor.Black, newline, timeStamp ? stampedOutput : Console.Out);
public static void Error(object message, bool newline = true, bool timeStamp = true) =>
Write(message == null ? "null" : message.ToString(), ConsoleColor.Red, ConsoleColor.Black, newline, timeStamp ? stampedOutput : Console.Out);
public static void Fatal(object message, bool newline = true, bool timeStamp = true) =>
Write(message == null ? "null" : message.ToString(), ConsoleColor.Red, ConsoleColor.White, newline, timeStamp ? stampedError : Console.Error);
private static void Write(string message, ConsoleColor f, ConsoleColor b, bool newline, TextWriter writer)
{
if (overwrite) ClearLine();
overwrite = false;
ConsoleColor f1 = Console.ForegroundColor, b1 = Console.BackgroundColor;
Console.ForegroundColor = f;
Console.BackgroundColor = b;
writer.Write(message);
if (newline) writer.WriteLine();
if (newline)
{
writer.WriteLine();
OnNewLine?.Invoke();
}
Console.ForegroundColor = f1;
Console.BackgroundColor = b1;
}
public static string ReadLine()
{
string s = Console.ReadLine();
overwrite = false;
OnNewLine?.Invoke();
return s;
}
private static void ClearLine(int from = 0)
{
from = Math.Min(from, Console.WindowWidth);
int y = Console.CursorTop;
Console.SetCursorPosition(from, y);
char[] msg = new char[Console.WindowWidth - from];
for (int i = 0; i < msg.Length; ++i) msg[i] = ' ';
Console.Write(new string(msg));
Console.SetCursorPosition(from, y);
}
}
}

58
Server/OutputFormatter.cs Normal file
View File

@ -0,0 +1,58 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace Server
{
public sealed class OutputFormatter
{
private readonly List<Tuple<string, string>> lines = new List<Tuple<string, string>>();
private int leftLen = 0;
private readonly int minPad;
private readonly string prepend, delimiter, postpad, trail;
public OutputFormatter(int minPad = 1, string prepend = "", string delimiter = "", string postpad = "", string trail = "")
{
this.prepend = prepend;
this.delimiter = delimiter;
this.postpad = postpad;
this.trail = trail;
this.minPad = Math.Abs(minPad);
}
public OutputFormatter Append(string key, string value)
{
lines.Add(new Tuple<string, string>(key, value));
leftLen = Math.Max(key.Length + minPad, leftLen);
return this;
}
public string GetString()
{
StringBuilder builder = new StringBuilder();
foreach (var line in lines)
builder
.Append(prepend)
.Append(line.Item1)
.Append(delimiter)
.Append(Pad(line.Item1, leftLen))
.Append(postpad)
.Append(line.Item2)
.Append(trail)
.Append('\n');
builder.Length -= 1;
return builder.ToString();
}
private static string Pad(string msg, int length)
{
if (msg.Length >= length) return "";
char[] c = new char[length - msg.Length];
for (int i = 0; i < c.Length; ++i) c[i] = ' ';
return new string(c);
}
}
}

View File

@ -17,10 +17,6 @@ namespace Server
private const string VERBOSE_RESPONSE = "@string/REMOTE_";
public static void Main(string[] args)
{
// Set up fancy output
Console.SetError(new TimeStampWriter(Console.Error, "HH:mm:ss.fff"));
Console.SetOut(new TimeStampWriter(Console.Out, "HH:mm:ss.fff"));
// Create a client session manager and allow sessions to remain valid for up to 5 minutes of inactivity (300 seconds)
SessionManager manager = new SessionManager(300 * TimeSpan.TicksPerSecond, 20);
@ -180,7 +176,7 @@ namespace Server
return GenerateResponse(id, "ERROR");
}
user.accounts.Add(new Database.Account(user, 0, name));
db.AddUser(user); // Notify database of the update
db.UpdateUser(user); // Notify database of the update
return GenerateResponse(id, true);
}
case "Account_Transaction_Create":
@ -302,12 +298,88 @@ namespace Server
(c, b) => // Called every time a client connects or disconnects (conn + dc with every command/request)
{
// Output.Info($"Client has {(b ? "C" : "Disc")}onnected");
if(!b && c.assignedValues.ContainsKey("session"))
manager.Expire(c.assignedValues["session"]);
//if(!b && c.assignedValues.ContainsKey("session"))
// manager.Expire(c.assignedValues["session"]);
});
server.StartListening();
Console.ReadLine();
string commands =
new OutputFormatter(4, " ", "", "- ")
.Append("help", "Show this help menu")
.Append("stop", "Stop server")
.Append("sessions", "Show active client sessions")
.Append("list {admin}", "Show registered users. Add \"admin\" to only list admins")
.Append("admin [user] {true/false}", "Show or set admin status for a user")
.GetString();
Output.OnNewLine = () => Output.WriteOverwritable(">> ");
Output.OnNewLine();
// Server command loop
while (true)
{
string cmd = Output.ReadLine();
string[] parts = cmd.Split();
if (cmd.EqualsIgnoreCase("stop")) break;
else if (cmd.EqualsIgnoreCase("sessions"))
{
StringBuilder builder = new StringBuilder();
manager.Update(); // Ensure that we don't show expired sessions (artifacts exist until it is necessary to remove them)
foreach (var session in manager.Sessions)
builder.Append(session.user.Name).Append(" : ").Append(session.sessionID).Append('\n');
if (builder.Length == 0) builder.Append("There are no active sessions at the moment");
else builder.Length = builder.Length - 1;
Output.Raw(builder);
}
else if (parts[0].EqualsIgnoreCase("admin"))
{
if (parts.Length == 1) Output.Raw("Usage: admin [username] {true/false}");
else if (parts.Length == 2)
{
Database.User user = db.GetUser(parts[1]);
if (user == null) Output.RawErr($"User \"{parts[1]}\" could not be found in the databse!");
else Output.Raw(user.IsAdministrator);
}
else if (parts.Length == 3)
{
Database.User user = db.GetUser(parts[1]);
if (user == null) Output.RawErr($"User \"{parts[1]}\" could not be found in the databse!");
else if (!bool.TryParse(parts[2].ToLower(), out bool admin)) Output.RawErr($"Could not interpret \"{parts[2]}\"");
else
{
if (user.IsAdministrator == admin) Output.Info("The given administrator state was already set");
else if (admin) Output.Raw("User is now an administrator");
else Output.Raw("User is no longer an administrator");
user.IsAdministrator = admin;
db.AddUser(user);
}
}
else Output.RawErr("Too many parameters!");
}
else if (parts[0].EqualsIgnoreCase("list"))
{
if (parts.Length > 2) Output.RawErr("Too many parameters!");
else
{
bool filter = parts.Length > 1, filterAdmin = filter && parts[1].EqualsIgnoreCase("admin");
StringBuilder builder = new StringBuilder();
foreach (var user in db.Users(u => !filter || (filterAdmin && u.IsAdministrator)))
builder.Append(user.Name).Append('\n');
if (builder.Length != 0)
{
builder.Length = builder.Length - 1;
Output.Raw(builder);
}
}
}
else if (cmd.EqualsIgnoreCase("help"))
{
Output.Raw("Available commands:\n" + commands);
}
else if (cmd.Length != 0) Output.RawErr("Unknown command. Use command \"help\" to view available commands");
}
server.StopRunning();
}

View File

@ -45,6 +45,7 @@
<ItemGroup>
<Compile Include="Database.cs" />
<Compile Include="Output.cs" />
<Compile Include="OutputFormatter.cs" />
<Compile Include="Program.cs" />
<Compile Include="Properties\AssemblyInfo.cs" />
<Compile Include="Properties\Resources.Designer.cs">

View File

@ -11,6 +11,8 @@ namespace Server
private readonly long timeout;
private readonly int sidLength;
public List<Session> Sessions { get => sessions; }
public SessionManager(long timeout, int sidLength = 10)
{
this.timeout = timeout;