From 52b8e30437e235817ed534dec860e781bb0468c0 Mon Sep 17 00:00:00 2001 From: vnugent Date: Sat, 14 Jan 2023 16:24:28 -0500 Subject: MemoryUtil native integer size update + tests --- lib/Hashing.Portable/src/Argon2/VnArgon2.cs | 24 +- .../src/IdentityUtility/HashingExtensions.cs | 57 +- .../src/IdentityUtility/JsonWebKey.cs | 5 +- .../src/IdentityUtility/JsonWebToken.cs | 61 +- lib/Hashing.Portable/src/ManagedHash.cs | 23 +- lib/Hashing.Portable/src/RandomHash.cs | 143 +++- lib/Net.Http/src/ConnectionInfo.cs | 166 ---- lib/Net.Http/src/Core/ConnectionInfo.cs | 166 ++++ lib/Net.Http/src/Core/Request/HttpRequest.cs | 5 +- lib/Net.Http/src/Helpers/CoreBufferHelpers.cs | 14 +- lib/Net.Messaging.FBM/src/Client/FBMRequest.cs | 5 +- .../src/Accounts/AccountManager.cs | 872 ------------------- .../src/Accounts/AccountUtils.cs | 922 +++++++++++++++++++++ lib/Plugins.Essentials/src/Accounts/INonce.cs | 42 - .../src/Accounts/ISecretProvider.cs | 49 ++ .../src/Accounts/NonceExtensions.cs | 75 ++ .../src/Accounts/PasswordHashing.cs | 50 +- .../src/Extensions/JsonResponse.cs | 9 +- lib/Plugins.Essentials/src/HttpEntity.cs | 8 + lib/Plugins.Essentials/src/Sessions/SessionInfo.cs | 2 +- lib/Utils/src/Extensions/IoExtensions.cs | 2 +- lib/Utils/src/Extensions/MemoryExtensions.cs | 150 ++-- lib/Utils/src/Extensions/VnStringExtensions.cs | 87 +- lib/Utils/src/IO/InMemoryTemplate.cs | 2 +- lib/Utils/src/IO/VnMemoryStream.cs | 124 +-- lib/Utils/src/Memory/IMemoryHandle.cs | 8 +- lib/Utils/src/Memory/IUnmangedHeap.cs | 4 +- lib/Utils/src/Memory/Memory.cs | 456 ---------- lib/Utils/src/Memory/MemoryHandle.cs | 43 +- lib/Utils/src/Memory/MemoryUtil.cs | 603 ++++++++++++++ lib/Utils/src/Memory/PrivateBuffersMemoryPool.cs | 2 +- lib/Utils/src/Memory/PrivateHeap.cs | 184 ---- lib/Utils/src/Memory/PrivateStringManager.cs | 4 +- lib/Utils/src/Memory/ProcessHeap.cs | 28 +- lib/Utils/src/Memory/RpMallocPrivateHeap.cs | 57 +- lib/Utils/src/Memory/SubSequence.cs | 29 +- lib/Utils/src/Memory/SysBufferMemoryManager.cs | 6 +- lib/Utils/src/Memory/UnmanagedHeapBase.cs | 16 +- lib/Utils/src/Memory/UnsafeMemoryHandle.cs | 17 +- lib/Utils/src/Memory/VnString.cs | 73 +- lib/Utils/src/Memory/VnTable.cs | 68 +- lib/Utils/src/Memory/VnTempBuffer.cs | 18 +- lib/Utils/src/Memory/Win32PrivateHeap.cs | 191 +++++ lib/Utils/src/VnEncoding.cs | 10 +- lib/Utils/tests/Memory/MemoryHandleTest.cs | 25 +- lib/Utils/tests/Memory/MemoryTests.cs | 244 ------ lib/Utils/tests/Memory/MemoryUtilTests.cs | 333 ++++++++ lib/Utils/tests/Memory/VnTableTests.cs | 51 +- lib/Utils/tests/VnEncodingTests.cs | 6 +- 49 files changed, 3065 insertions(+), 2474 deletions(-) delete mode 100644 lib/Net.Http/src/ConnectionInfo.cs create mode 100644 lib/Net.Http/src/Core/ConnectionInfo.cs delete mode 100644 lib/Plugins.Essentials/src/Accounts/AccountManager.cs create mode 100644 lib/Plugins.Essentials/src/Accounts/AccountUtils.cs create mode 100644 lib/Plugins.Essentials/src/Accounts/ISecretProvider.cs create mode 100644 lib/Plugins.Essentials/src/Accounts/NonceExtensions.cs delete mode 100644 lib/Utils/src/Memory/Memory.cs create mode 100644 lib/Utils/src/Memory/MemoryUtil.cs delete mode 100644 lib/Utils/src/Memory/PrivateHeap.cs create mode 100644 lib/Utils/src/Memory/Win32PrivateHeap.cs delete mode 100644 lib/Utils/tests/Memory/MemoryTests.cs create mode 100644 lib/Utils/tests/Memory/MemoryUtilTests.cs diff --git a/lib/Hashing.Portable/src/Argon2/VnArgon2.cs b/lib/Hashing.Portable/src/Argon2/VnArgon2.cs index 01cfe74..7b467ba 100644 --- a/lib/Hashing.Portable/src/Argon2/VnArgon2.cs +++ b/lib/Hashing.Portable/src/Argon2/VnArgon2.cs @@ -51,7 +51,7 @@ namespace VNLib.Hashing public const string ARGON2_DEFUALT_LIB_NAME = "Argon2"; private static readonly Encoding LocEncoding = Encoding.Unicode; - private static readonly Lazy _heap = new (Memory.InitializeNewHeapForProcess, LazyThreadSafetyMode.PublicationOnly); + private static readonly Lazy _heap = new (MemoryUtil.InitializeNewHeapForProcess, LazyThreadSafetyMode.PublicationOnly); private static readonly Lazy _nativeLibrary = new(LoadNativeLib, LazyThreadSafetyMode.PublicationOnly); @@ -137,16 +137,22 @@ namespace VNLib.Hashing { //Get bytes count int saltbytes = LocEncoding.GetByteCount(salt); + //Get bytes count for password int passBytes = LocEncoding.GetByteCount(password); + //Alloc memory for salt using MemoryHandle buffer = PwHeap.Alloc(saltbytes + passBytes, true); + Span saltBuffer = buffer.AsSpan(0, saltbytes); Span passBuffer = buffer.AsSpan(passBytes); + //Encode salt with span the same size of the salt _ = LocEncoding.GetBytes(salt, saltBuffer); + //Encode password, create a new span to make sure its proper size _ = LocEncoding.GetBytes(password, passBuffer); + //Hash return Hash2id(passBuffer, saltBuffer, secret, timeCost, memCost, parallelism, hashLen); } @@ -170,10 +176,13 @@ namespace VNLib.Hashing { //Get bytes count int passBytes = LocEncoding.GetByteCount(password); + //Alloc memory for password using MemoryHandle pwdHandle = PwHeap.Alloc(passBytes, true); + //Encode password, create a new span to make sure its proper size _ = LocEncoding.GetBytes(password, pwdHandle); + //Hash return Hash2id(pwdHandle.Span, salt, secret, timeCost, memCost, parallelism, hashLen); } @@ -197,12 +206,16 @@ namespace VNLib.Hashing string hash, salts; //Alloc data for hash output using MemoryHandle hashHandle = PwHeap.Alloc(hashLen, true); + //hash the password Hash2id(password, salt, secret, hashHandle.Span, timeCost, memCost, parallelism); + //Encode hash hash = Convert.ToBase64String(hashHandle.Span); + //encode salt salts = Convert.ToBase64String(salt); + //Encode salt in base64 return $"${ID_MODE},v={(int)Argon2_version.VERSION_13},m={memCost},t={timeCost},p={parallelism},s={salts}${hash}"; } @@ -277,6 +290,7 @@ namespace VNLib.Hashing { //Alloc data for hash output using MemoryHandle outputHandle = PwHeap.Alloc(hashBytes.Length, true); + //Get pointers fixed (byte* secretptr = secret, pwd = rawPass, slptr = salt) { @@ -333,6 +347,7 @@ namespace VNLib.Hashing { throw new VnArgon2PasswordFormatException("The hash argument supplied is not a valid format and cannot be decoded"); } + Argon2PasswordEntry entry; try { @@ -343,12 +358,15 @@ namespace VNLib.Hashing { throw new VnArgon2PasswordFormatException("Password format was not recoverable", ex); } + //Calculate base64 buffer sizes int passBase64BufSize = Base64.GetMaxDecodedFromUtf8Length(entry.Hash.Length); int saltBase64BufSize = Base64.GetMaxDecodedFromUtf8Length(entry.Salt.Length); int rawPassLen = LocEncoding.GetByteCount(rawPass); + //Alloc buffer for decoded data - using MemoryHandle rawBufferHandle = Memory.Shared.Alloc(passBase64BufSize + saltBase64BufSize + rawPassLen, true); + using MemoryHandle rawBufferHandle = MemoryUtil.Shared.Alloc(passBase64BufSize + saltBase64BufSize + rawPassLen, true); + //Split buffers Span saltBuf = rawBufferHandle.Span[..saltBase64BufSize]; Span passBuf = rawBufferHandle.AsSpan(saltBase64BufSize, passBase64BufSize); @@ -362,6 +380,7 @@ namespace VNLib.Hashing //Resize pass buff passBuf = passBuf[..actualHashLen]; } + //Decode salt { if (!Convert.TryFromBase64Chars(entry.Salt, saltBuf, out int actualSaltLen)) @@ -371,6 +390,7 @@ namespace VNLib.Hashing //Resize salt buff saltBuf = saltBuf[..actualSaltLen]; } + //encode password bytes rawPassLen = LocEncoding.GetBytes(rawPass, rawPassBuf); //Verify password diff --git a/lib/Hashing.Portable/src/IdentityUtility/HashingExtensions.cs b/lib/Hashing.Portable/src/IdentityUtility/HashingExtensions.cs index f36b151..5ff37e8 100644 --- a/lib/Hashing.Portable/src/IdentityUtility/HashingExtensions.cs +++ b/lib/Hashing.Portable/src/IdentityUtility/HashingExtensions.cs @@ -45,27 +45,39 @@ namespace VNLib.Hashing.IdentityUtility /// The data to compute the hash of /// The used to encode the character buffer /// The base64 UTF8 string of the computed hash of the specified data - public static string ComputeBase64Hash(this HMAC hmac, ReadOnlySpan data, Encoding encoding = null) + /// + /// + public static string ComputeBase64Hash(this HMAC hmac, ReadOnlySpan data, Encoding? encoding = null) { + _ = hmac ?? throw new ArgumentNullException(nameof(hmac)); + encoding ??= Encoding.UTF8; + //Calc hashsize to alloc buffer int hashBufSize = (hmac.HashSize / 8); + //Calc buffer size int encBufSize = encoding.GetByteCount(data); + //Alloc buffer for encoding data - using UnsafeMemoryHandle buffer = Memory.UnsafeAlloc(encBufSize + hashBufSize); + using UnsafeMemoryHandle buffer = MemoryUtil.UnsafeAlloc(encBufSize + hashBufSize); + Span encBuffer = buffer.Span[0..encBufSize]; Span hashBuffer = buffer.Span[encBufSize..]; + //Encode data _ = encoding.GetBytes(data, encBuffer); + //compute hash if (!hmac.TryComputeHash(encBuffer, hashBuffer, out int hashBytesWritten)) { - throw new OutOfMemoryException("Hash buffer size was too small"); + throw new InternalBufferTooSmallException("Hash buffer size was too small"); } + //Convert to base64 string return Convert.ToBase64String(hashBuffer[..hashBytesWritten]); } + /// /// Computes the hash of the raw data and compares the computed hash against /// the specified base64hash @@ -77,36 +89,48 @@ namespace VNLib.Hashing.IdentityUtility /// A value indicating if the hash values match /// /// - public static bool VerifyBase64Hash(this HMAC hmac, ReadOnlySpan base64Hmac, ReadOnlySpan raw, Encoding encoding = null) + /// + public static bool VerifyBase64Hash(this HMAC hmac, ReadOnlySpan base64Hmac, ReadOnlySpan raw, Encoding? encoding = null) { _ = hmac ?? throw new ArgumentNullException(nameof(hmac)); + if (raw.IsEmpty) { throw new ArgumentException("Raw data buffer must not be empty", nameof(raw)); } + if (base64Hmac.IsEmpty) { throw new ArgumentException("Hmac buffer must not be empty", nameof(base64Hmac)); } + encoding ??= Encoding.UTF8; + //Calc buffer size int rawDataBufSize = encoding.GetByteCount(raw); + //Calc base64 buffer size int base64BufSize = base64Hmac.Length; + //Alloc buffer for encoding and raw data - using UnsafeMemoryHandle buffer = Memory.UnsafeAlloc(rawDataBufSize + base64BufSize, true); + using UnsafeMemoryHandle buffer = MemoryUtil.UnsafeAlloc(rawDataBufSize + base64BufSize, true); + Span rawDataBuf = buffer.Span[0..rawDataBufSize]; Span base64Buf = buffer.Span[rawDataBufSize..]; + //encode _ = encoding.GetBytes(raw, rawDataBuf); + //Convert to binary if(!Convert.TryFromBase64Chars(base64Hmac, base64Buf, out int base64Converted)) { - throw new OutOfMemoryException("Base64 buffer too small"); + throw new InternalBufferTooSmallException("Base64 buffer too small"); } + //Compare hash buffers return hmac.VerifyHash(base64Buf[0..base64Converted], rawDataBuf); } + /// /// Computes the hash of the raw data and compares the computed hash against /// the specified hash @@ -117,25 +141,33 @@ namespace VNLib.Hashing.IdentityUtility /// A value indicating if the hash values match /// /// + /// public static bool VerifyHash(this HMAC hmac, ReadOnlySpan hash, ReadOnlySpan raw) { + _ = hmac ?? throw new ArgumentNullException(nameof(hmac)); + if (raw.IsEmpty) { throw new ArgumentException("Raw data buffer must not be empty", nameof(raw)); } + if (hash.IsEmpty) { throw new ArgumentException("Hash buffer must not be empty", nameof(hash)); } + //Calc hashsize to alloc buffer int hashBufSize = hmac.HashSize / 8; + //Alloc buffer for hash - using UnsafeMemoryHandle buffer = Memory.UnsafeAlloc(hashBufSize); + using UnsafeMemoryHandle buffer = MemoryUtil.UnsafeAlloc(hashBufSize); + //compute hash if (!hmac.TryComputeHash(raw, buffer, out int hashBytesWritten)) { - throw new OutOfMemoryException("Hash buffer size was too small"); + throw new InternalBufferTooSmallException("Hash buffer size was too small"); } + //Compare hash buffers return CryptographicOperations.FixedTimeEquals(buffer.Span[0..hashBytesWritten], hash); } @@ -153,17 +185,22 @@ namespace VNLib.Hashing.IdentityUtility /// /// /// - public static ERRNO TryEncrypt(this RSA alg, ReadOnlySpan data, in Span output, RSAEncryptionPadding padding, Encoding enc = null) + public static ERRNO TryEncrypt(this RSA alg, ReadOnlySpan data, in Span output, RSAEncryptionPadding padding, Encoding? enc = null) { _ = alg ?? throw new ArgumentNullException(nameof(alg)); + //Default to UTF8 encoding enc ??= Encoding.UTF8; + //Alloc decode buffer int buffSize = enc.GetByteCount(data); + //Alloc buffer - using UnsafeMemoryHandle buffer = Memory.UnsafeAlloc(buffSize, true); + using UnsafeMemoryHandle buffer = MemoryUtil.UnsafeAlloc(buffSize, true); + //Encode data int converted = enc.GetBytes(data, buffer); + //Try encrypt return !alg.TryEncrypt(buffer.Span, output, padding, out int bytesWritten) ? ERRNO.E_FAIL : (ERRNO)bytesWritten; } diff --git a/lib/Hashing.Portable/src/IdentityUtility/JsonWebKey.cs b/lib/Hashing.Portable/src/IdentityUtility/JsonWebKey.cs index 54098c2..9076e5b 100644 --- a/lib/Hashing.Portable/src/IdentityUtility/JsonWebKey.cs +++ b/lib/Hashing.Portable/src/IdentityUtility/JsonWebKey.cs @@ -449,10 +449,13 @@ namespace VNLib.Hashing.IdentityUtility { return null; } + //bin buffer for temp decoding - using UnsafeMemoryHandle binBuffer = Memory.UnsafeAlloc(base64.Length + 16, false); + using UnsafeMemoryHandle binBuffer = MemoryUtil.UnsafeAlloc(base64.Length + 16, false); + //base64url decode ERRNO count = VnEncoding.Base64UrlDecode(base64, binBuffer.Span); + //Return buffer or null if failed return count ? binBuffer.AsSpan(0, count).ToArray() : null; } diff --git a/lib/Hashing.Portable/src/IdentityUtility/JsonWebToken.cs b/lib/Hashing.Portable/src/IdentityUtility/JsonWebToken.cs index 716dd4c..e3822d0 100644 --- a/lib/Hashing.Portable/src/IdentityUtility/JsonWebToken.cs +++ b/lib/Hashing.Portable/src/IdentityUtility/JsonWebToken.cs @@ -39,7 +39,7 @@ namespace VNLib.Hashing.IdentityUtility /// Provides a dynamic JSON Web Token class that will store and /// compute Base64Url encoded WebTokens /// - public class JsonWebToken : VnDisposeable, IStringSerializeable, IDisposable + public class JsonWebToken : VnDisposeable, IStringSerializeable { internal const byte SAEF_PERIOD = 0x2e; internal const byte PADDING_BYTES = 0x3d; @@ -53,15 +53,19 @@ namespace VNLib.Hashing.IdentityUtility /// /// /// - public static JsonWebToken Parse(ReadOnlySpan urlEncJwtString, IUnmangedHeap heap = null) + public static JsonWebToken Parse(ReadOnlySpan urlEncJwtString, IUnmangedHeap? heap = null) { - heap ??= Memory.Shared; + heap ??= MemoryUtil.Shared; + //Calculate the decoded size of the characters to alloc a buffer int utf8Size = Encoding.UTF8.GetByteCount(urlEncJwtString); + //Alloc bin buffer to store decode data using MemoryHandle binBuffer = heap.Alloc(utf8Size, true); + //Decode to utf8 utf8Size = Encoding.UTF8.GetBytes(urlEncJwtString, binBuffer); + //Parse and return the jwt return ParseRaw(binBuffer.Span[..utf8Size], heap); } @@ -75,14 +79,16 @@ namespace VNLib.Hashing.IdentityUtility /// /// /// - public static JsonWebToken ParseRaw(ReadOnlySpan utf8JWTData, IUnmangedHeap heap = null) + public static JsonWebToken ParseRaw(ReadOnlySpan utf8JWTData, IUnmangedHeap? heap = null) { if (utf8JWTData.IsEmpty) { throw new ArgumentException("JWT data may not be empty", nameof(utf8JWTData)); } + //Set default heap of non was specified - heap ??= Memory.Shared; + heap ??= MemoryUtil.Shared; + //Alloc the token and copy the supplied data to a new mem stream JsonWebToken jwt = new(heap, new (heap, utf8JWTData)); try @@ -144,17 +150,19 @@ namespace VNLib.Hashing.IdentityUtility { Heap = heap; DataStream = initialData; + + //Update position to the end of the initial data initialData.Position = initialData.Length; } /// /// Creates a new empty JWT instance, with an optional heap to alloc - /// buffers from. ( is used as default) + /// buffers from. ( is used as default) /// /// The to alloc buffers from - public JsonWebToken(IUnmangedHeap heap = null) + public JsonWebToken(IUnmangedHeap? heap = null) { - Heap = heap ?? Memory.Shared; + Heap = heap ?? MemoryUtil.Shared; DataStream = new(Heap, 100, true); } @@ -186,8 +194,10 @@ namespace VNLib.Hashing.IdentityUtility #endregion #region Payload + private int PayloadStart => HeaderEnd + 1; private int PayloadEnd; + /// /// The Base64URL encoded UTF8 bytes of the payload portion of the current JWT /// @@ -218,6 +228,7 @@ namespace VNLib.Hashing.IdentityUtility //Store final position PayloadEnd = ByteSize; } + /// /// Encodes the specified value and writes it to the /// internal buffer @@ -233,7 +244,7 @@ namespace VNLib.Hashing.IdentityUtility //Slice off the begiing of the buffer for the base64 encoding if(Base64.EncodeToUtf8(value, binBuffer.Span, out _, out int written) != OperationStatus.Done) { - throw new OutOfMemoryException(); + throw new InternalBufferTooSmallException("Failed to encode the specified value to base64"); } //Base64 encoded Span base64Data = binBuffer.Span[..written].Trim(PADDING_BYTES); @@ -245,8 +256,10 @@ namespace VNLib.Hashing.IdentityUtility #endregion #region Signature + private int SignatureStart => PayloadEnd + 1; private int SignatureEnd => ByteSize; + /// /// The Base64URL encoded UTF8 bytes of the signature portion of the current JWT /// @@ -266,23 +279,31 @@ namespace VNLib.Hashing.IdentityUtility public virtual void Sign(HashAlgorithm signatureAlgorithm) { Check(); + _ = signatureAlgorithm ?? throw new ArgumentNullException(nameof(signatureAlgorithm)); + //Calculate the size of the buffer to use for the current algorithm int bufferSize = signatureAlgorithm.HashSize / 8; + //Alloc buffer for signature output Span signatureBuffer = stackalloc byte[bufferSize]; + //Compute the hash of the current payload if(!signatureAlgorithm.TryComputeHash(DataBuffer, signatureBuffer, out int bytesWritten)) { - throw new OutOfMemoryException(); + throw new InternalBufferTooSmallException(); } + //Reset the stream position to the end of the payload DataStream.SetLength(PayloadEnd); + //Write leading period DataStream.WriteByte(SAEF_PERIOD); + //Write the signature data to the buffer WriteValue(signatureBuffer[..bytesWritten]); } + /// /// Use an RSA algorithm to sign the JWT message /// @@ -296,20 +317,27 @@ namespace VNLib.Hashing.IdentityUtility public virtual void Sign(RSA rsa, in HashAlgorithmName hashAlg, RSASignaturePadding padding, int hashSize) { Check(); + _ = rsa ?? throw new ArgumentNullException(nameof(rsa)); + //Calculate the size of the buffer to use for the current algorithm using UnsafeMemoryHandle sigBuffer = Heap.UnsafeAlloc(hashSize); + if(!rsa.TrySignData(HeaderAndPayload, sigBuffer.Span, hashAlg, padding, out int hashBytesWritten)) { - throw new OutOfMemoryException("Signature buffer is not large enough to store the hash"); + throw new InternalBufferTooSmallException("Signature buffer is not large enough to store the hash"); } + //Reset the stream position to the end of the payload DataStream.SetLength(PayloadEnd); + //Write leading period DataStream.WriteByte(SAEF_PERIOD); + //Write the signature data to the buffer WriteValue(sigBuffer.Span[..hashBytesWritten]); } + /// /// Use an RSA algorithm to sign the JWT message /// @@ -322,17 +350,23 @@ namespace VNLib.Hashing.IdentityUtility public virtual void Sign(ECDsa alg, in HashAlgorithmName hashAlg, int hashSize) { Check(); + _ = alg ?? throw new ArgumentNullException(nameof(alg)); + //Calculate the size of the buffer to use for the current algorithm using UnsafeMemoryHandle sigBuffer = Heap.UnsafeAlloc(hashSize); + if (!alg.TrySignData(HeaderAndPayload, sigBuffer.Span, hashAlg, out int hashBytesWritten)) { - throw new OutOfMemoryException("Signature buffer is not large enough to store the hash"); + throw new InternalBufferTooSmallException("Signature buffer is not large enough to store the hash"); } + //Reset the stream position to the end of the payload DataStream.SetLength(PayloadEnd); + //Write leading period DataStream.WriteByte(SAEF_PERIOD); + //Write the signature data to the buffer WriteValue(sigBuffer.Span[..hashBytesWritten]); } @@ -379,7 +413,6 @@ namespace VNLib.Hashing.IdentityUtility //Clear pointers, so buffer get operations just return empty instead of throwing Reset(); DataStream.Dispose(); - } - + } } } diff --git a/lib/Hashing.Portable/src/ManagedHash.cs b/lib/Hashing.Portable/src/ManagedHash.cs index 46a8cb8..b4e5f09 100644 --- a/lib/Hashing.Portable/src/ManagedHash.cs +++ b/lib/Hashing.Portable/src/ManagedHash.cs @@ -79,10 +79,13 @@ namespace VNLib.Hashing public static ERRNO ComputeHash(ReadOnlySpan data, Span buffer, HashAlg type) { int byteCount = CharEncoding.GetByteCount(data); + //Alloc buffer - using UnsafeMemoryHandle binbuf = Memory.UnsafeAlloc(byteCount, true); + using UnsafeMemoryHandle binbuf = MemoryUtil.UnsafeAlloc(byteCount, true); + //Encode data byteCount = CharEncoding.GetBytes(data, binbuf); + //hash the buffer return ComputeHash(binbuf.Span[..byteCount], buffer, type); } @@ -99,7 +102,7 @@ namespace VNLib.Hashing { int byteCount = CharEncoding.GetByteCount(data); //Alloc buffer - using UnsafeMemoryHandle binbuf = Memory.UnsafeAlloc(byteCount, true); + using UnsafeMemoryHandle binbuf = MemoryUtil.UnsafeAlloc(byteCount, true); //Encode data byteCount = CharEncoding.GetBytes(data, binbuf); //hash the buffer @@ -229,10 +232,13 @@ namespace VNLib.Hashing public static ERRNO ComputeHmac(ReadOnlySpan key, ReadOnlySpan data, Span output, HashAlg type) { int byteCount = CharEncoding.GetByteCount(data); + //Alloc buffer - using UnsafeMemoryHandle binbuf = Memory.UnsafeAlloc(byteCount, true); + using UnsafeMemoryHandle binbuf = MemoryUtil.UnsafeAlloc(byteCount, true); + //Encode data byteCount = CharEncoding.GetBytes(data, binbuf); + //hash the buffer return ComputeHmac(key, binbuf.Span[..byteCount], output, type); } @@ -249,10 +255,13 @@ namespace VNLib.Hashing public static byte[] ComputeHmac(ReadOnlySpan key, ReadOnlySpan data, HashAlg type) { int byteCount = CharEncoding.GetByteCount(data); + //Alloc buffer - using UnsafeMemoryHandle binbuf = Memory.UnsafeAlloc(byteCount, true); + using UnsafeMemoryHandle binbuf = MemoryUtil.UnsafeAlloc(byteCount, true); + //Encode data byteCount = CharEncoding.GetBytes(data, binbuf); + //hash the buffer return ComputeHmac(key, binbuf.Span[..byteCount], type); } @@ -317,12 +326,15 @@ namespace VNLib.Hashing { //Alloc hash buffer Span hashBuffer = stackalloc byte[(int)type]; + //hash the buffer ERRNO count = ComputeHmac(key, data, hashBuffer, type); + if (!count) { throw new InternalBufferTooSmallException("Failed to compute the hash of the data"); } + //Convert to hex string return mode switch { @@ -347,12 +359,15 @@ namespace VNLib.Hashing { //Alloc hash buffer Span hashBuffer = stackalloc byte[(int)type]; + //hash the buffer ERRNO count = ComputeHmac(key, data, hashBuffer, type); + if (!count) { throw new InternalBufferTooSmallException("Failed to compute the hash of the data"); } + //Convert to hex string return mode switch { diff --git a/lib/Hashing.Portable/src/RandomHash.cs b/lib/Hashing.Portable/src/RandomHash.cs index 5a4fc66..67518ad 100644 --- a/lib/Hashing.Portable/src/RandomHash.cs +++ b/lib/Hashing.Portable/src/RandomHash.cs @@ -23,6 +23,7 @@ */ using System; +using System.Runtime.CompilerServices; using System.Security.Cryptography; using VNLib.Utils; @@ -36,6 +37,8 @@ namespace VNLib.Hashing public static class RandomHash { + private const int MAX_STACK_ALLOC = 128; + /// /// Generates a cryptographic random number, computes the hash, and encodes the hash as a string. /// @@ -45,12 +48,28 @@ namespace VNLib.Hashing /// String containing hash of the random number public static string GetRandomHash(HashAlg alg, int size = 64, HashEncodingMode encoding = HashEncodingMode.Base64) { - //Get temporary buffer for storing random keys - using UnsafeMemoryHandle buffer = Memory.UnsafeAlloc(size); - //Fill with random non-zero bytes - GetRandomBytes(buffer.Span); - //Compute hash - return ManagedHash.ComputeHash(buffer.Span, alg, encoding); + if(size > MAX_STACK_ALLOC) + { + //Get temporary buffer for storing random keys + using UnsafeMemoryHandle buffer = MemoryUtil.UnsafeAlloc(size); + + //Fill with random non-zero bytes + GetRandomBytes(buffer.Span); + + //Compute hash + return ManagedHash.ComputeHash(buffer.Span, alg, encoding); + } + else + { + //Get temporary buffer for storing random keys + Span buffer = stackalloc byte[size]; + + //Fill with random non-zero bytes + GetRandomBytes(buffer); + + //Compute hash + return ManagedHash.ComputeHash(buffer, alg, encoding); + } } /// @@ -60,25 +79,27 @@ namespace VNLib.Hashing /// public static string GetGuidHash(HashAlg alg, HashEncodingMode encoding = HashEncodingMode.Base64) { - //Get temp buffer - Span buffer = stackalloc byte[16]; + //Get temp buffer, the size of the guid + Span buffer = stackalloc byte[Unsafe.SizeOf()]; + //Get a new GUID and write bytes to - if (!Guid.NewGuid().TryWriteBytes(buffer)) - { - throw new FormatException("Failed to get a guid hash"); - } - return ManagedHash.ComputeHash(buffer, alg, encoding); + return Guid.NewGuid().TryWriteBytes(buffer) + ? ManagedHash.ComputeHash(buffer, alg, encoding) + : throw new FormatException("Failed to get a guid hash"); } + /// /// Generates a secure random number and seeds a GUID object, then returns the string GUID /// /// Guid string public static Guid GetSecureGuid() { - //Get temp buffer - Span buffer = stackalloc byte[16]; + //Get temp buffer size of Guid + Span buffer = stackalloc byte[Unsafe.SizeOf()]; + //Generate non zero bytes GetRandomBytes(buffer); + //Get a GUID initialized with the key data and return the string represendation return new Guid(buffer); } @@ -90,13 +111,30 @@ namespace VNLib.Hashing /// Base64 string of the random number public static string GetRandomBase64(int size = 64) { - //Get temp buffer - using UnsafeMemoryHandle buffer = Memory.UnsafeAlloc(size); - //Generate non zero bytes - GetRandomBytes(buffer.Span); - //Convert to base 64 - return Convert.ToBase64String(buffer.Span, Base64FormattingOptions.None); + if (size > MAX_STACK_ALLOC) + { + //Get temp buffer + using UnsafeMemoryHandle buffer = MemoryUtil.UnsafeAlloc(size); + + //Generate non zero bytes + GetRandomBytes(buffer.Span); + + //Convert to base 64 + return Convert.ToBase64String(buffer.Span, Base64FormattingOptions.None); + } + else + { + //Get temp buffer + Span buffer = stackalloc byte[size]; + + //Generate non zero bytes + GetRandomBytes(buffer); + + //Convert to base 64 + return Convert.ToBase64String(buffer, Base64FormattingOptions.None); + } } + /// /// Generates a cryptographic random number and returns the hex string of that number /// @@ -104,13 +142,30 @@ namespace VNLib.Hashing /// Hex string of the random number public static string GetRandomHex(int size = 64) { - //Get temp buffer - using UnsafeMemoryHandle buffer = Memory.UnsafeAlloc(size); - //Generate non zero bytes - GetRandomBytes(buffer.Span); - //Convert to hex - return Convert.ToHexString(buffer.Span); + if (size > MAX_STACK_ALLOC) + { + //Get temp buffer + using UnsafeMemoryHandle buffer = MemoryUtil.UnsafeAlloc(size); + + //Generate non zero bytes + GetRandomBytes(buffer.Span); + + //Convert to hex + return Convert.ToHexString(buffer.Span); + } + else + { + //Get temp buffer + Span buffer = stackalloc byte[size]; + + //Generate non zero bytes + GetRandomBytes(buffer); + + //Convert to hex + return Convert.ToHexString(buffer); + } } + /// /// Generates a cryptographic random number and returns the Base32 encoded string of that number /// @@ -118,12 +173,28 @@ namespace VNLib.Hashing /// Base32 string of the random number public static string GetRandomBase32(int size = 64) { - //Get temporary buffer for storing random keys - using UnsafeMemoryHandle buffer = Memory.UnsafeAlloc(size); - //Fill with random non-zero bytes - GetRandomBytes(buffer.Span); - //Return string of encoded data - return VnEncoding.ToBase32String(buffer.Span); + if (size > MAX_STACK_ALLOC) + { + //Get temp buffer + using UnsafeMemoryHandle buffer = MemoryUtil.UnsafeAlloc(size); + + //Generate non zero bytes + GetRandomBytes(buffer.Span); + + //Convert to hex + return VnEncoding.ToBase32String(buffer.Span); + } + else + { + //Get temp buffer + Span buffer = stackalloc byte[size]; + + //Generate non zero bytes + GetRandomBytes(buffer); + + //Convert to hex + return VnEncoding.ToBase32String(buffer); + } } /// @@ -137,13 +208,11 @@ namespace VNLib.Hashing GetRandomBytes(rand); return rand; } + /// /// Fill the buffer with non-zero bytes /// /// Buffer to fill - public static void GetRandomBytes(Span data) - { - RandomNumberGenerator.Fill(data); - } + public static void GetRandomBytes(Span data) => RandomNumberGenerator.Fill(data); } } \ No newline at end of file diff --git a/lib/Net.Http/src/ConnectionInfo.cs b/lib/Net.Http/src/ConnectionInfo.cs deleted file mode 100644 index 6e1660d..0000000 --- a/lib/Net.Http/src/ConnectionInfo.cs +++ /dev/null @@ -1,166 +0,0 @@ -/* -* Copyright (c) 2022 Vaughn Nugent -* -* Library: VNLib -* Package: VNLib.Net.Http -* File: ConnectionInfo.cs -* -* ConnectionInfo.cs is part of VNLib.Net.Http which is part of the larger -* VNLib collection of libraries and utilities. -* -* VNLib.Net.Http is free software: you can redistribute it and/or modify -* it under the terms of the GNU Affero General Public License as -* published by the Free Software Foundation, either version 3 of the -* License, or (at your option) any later version. -* -* VNLib.Net.Http is distributed in the hope that it will be useful, -* but WITHOUT ANY WARRANTY; without even the implied warranty of -* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -* GNU Affero General Public License for more details. -* -* You should have received a copy of the GNU Affero General Public License -* along with this program. If not, see https://www.gnu.org/licenses/. -*/ - -using System; -using System.Net; -using System.Linq; -using System.Text; -using System.Collections.Generic; -using System.Security.Authentication; - -using VNLib.Net.Http.Core; -using VNLib.Utils.Extensions; - -namespace VNLib.Net.Http -{ - /// - internal sealed class ConnectionInfo : IConnectionInfo - { - private HttpContext Context; - - /// - public Uri RequestUri => Context.Request.Location; - /// - public string Path => RequestUri.LocalPath; - /// - public string? UserAgent => Context.Request.UserAgent; - /// - public IHeaderCollection Headers { get; private set; } - /// - public bool CrossOrigin { get; } - /// - public bool IsWebSocketRequest { get; } - /// - public ContentType ContentType => Context.Request.ContentType; - /// - public HttpMethod Method => Context.Request.Method; - /// - public HttpVersion ProtocolVersion => Context.Request.HttpVersion; - /// - public bool IsSecure => Context.Request.EncryptionVersion != SslProtocols.None; - /// - public SslProtocols SecurityProtocol => Context.Request.EncryptionVersion; - /// - public Uri? Origin => Context.Request.Origin; - /// - public Uri? Referer => Context.Request.Referrer; - /// - public Tuple? Range => Context.Request.Range; - /// - public IPEndPoint LocalEndpoint => Context.Request.LocalEndPoint; - /// - public IPEndPoint RemoteEndpoint => Context.Request.RemoteEndPoint; - /// - public Encoding Encoding => Context.ParentServer.Config.HttpEncoding; - /// - public IReadOnlyDictionary RequestCookies => Context.Request.Cookies; - /// - public IEnumerable Accept => Context.Request.Accept; - /// - public TransportSecurityInfo? TransportSecurity => Context.GetSecurityInfo(); - - /// - public bool Accepts(ContentType type) - { - //Get the content type string from he specified content type - string contentType = HttpHelpers.GetContentTypeString(type); - return Accepts(contentType); - } - /// - public bool Accepts(string contentType) - { - if (AcceptsAny()) - { - return true; - } - - //If client accepts exact requested encoding - if (Accept.Contains(contentType)) - { - return true; - } - - //Search accept types to determine if the content type is acceptable - bool accepted = Accept - .Where(ctype => - { - //Get prinary side of mime type - ReadOnlySpan primary = contentType.AsSpan().SliceBeforeParam('/'); - ReadOnlySpan ctSubType = ctype.AsSpan().SliceBeforeParam('/'); - //See if accepts any subtype, or the primary sub-type matches - return ctSubType[0] == '*' || ctSubType.Equals(primary, StringComparison.OrdinalIgnoreCase); - }).Any(); - return accepted; - } - /// - /// Determines if the connection accepts any content type - /// - /// true if the connection accepts any content typ, false otherwise - private bool AcceptsAny() - { - //Accept any if no accept header was present, or accept all value */* - return Context.Request.Accept.Count == 0 || Accept.Where(static t => t.StartsWith("*/*", StringComparison.OrdinalIgnoreCase)).Any(); - } - /// - public void SetCookie(string name, string value, string? domain, string? path, TimeSpan Expires, CookieSameSite sameSite, bool httpOnly, bool secure) - { - //Create the new cookie - HttpCookie cookie = new(name) - { - Value = value, - Domain = domain, - Path = path, - MaxAge = Expires, - //Set the session lifetime flag if the timeout is max value - IsSession = Expires == TimeSpan.MaxValue, - //If the connection is cross origin, then we need to modify the secure and samsite values - SameSite = CrossOrigin ? CookieSameSite.None : sameSite, - Secure = secure | CrossOrigin, - HttpOnly = httpOnly - }; - //Set the cookie - Context.Response.AddCookie(cookie); - } - - internal ConnectionInfo(HttpContext ctx) - { - //Create new header collection - Headers = new VnHeaderCollection(ctx); - //set co value - CrossOrigin = ctx.Request.IsCrossOrigin(); - //Set websocket status - IsWebSocketRequest = ctx.Request.IsWebSocketRequest(); - //Update the context referrence - Context = ctx; - } - -#nullable disable - internal void Clear() - { - Context = null; - (Headers as VnHeaderCollection).Clear(); - Headers = null; - } - } -} \ No newline at end of file diff --git a/lib/Net.Http/src/Core/ConnectionInfo.cs b/lib/Net.Http/src/Core/ConnectionInfo.cs new file mode 100644 index 0000000..6e1660d --- /dev/null +++ b/lib/Net.Http/src/Core/ConnectionInfo.cs @@ -0,0 +1,166 @@ +/* +* Copyright (c) 2022 Vaughn Nugent +* +* Library: VNLib +* Package: VNLib.Net.Http +* File: ConnectionInfo.cs +* +* ConnectionInfo.cs is part of VNLib.Net.Http which is part of the larger +* VNLib collection of libraries and utilities. +* +* VNLib.Net.Http is free software: you can redistribute it and/or modify +* it under the terms of the GNU Affero General Public License as +* published by the Free Software Foundation, either version 3 of the +* License, or (at your option) any later version. +* +* VNLib.Net.Http is distributed in the hope that it will be useful, +* but WITHOUT ANY WARRANTY; without even the implied warranty of +* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +* GNU Affero General Public License for more details. +* +* You should have received a copy of the GNU Affero General Public License +* along with this program. If not, see https://www.gnu.org/licenses/. +*/ + +using System; +using System.Net; +using System.Linq; +using System.Text; +using System.Collections.Generic; +using System.Security.Authentication; + +using VNLib.Net.Http.Core; +using VNLib.Utils.Extensions; + +namespace VNLib.Net.Http +{ + /// + internal sealed class ConnectionInfo : IConnectionInfo + { + private HttpContext Context; + + /// + public Uri RequestUri => Context.Request.Location; + /// + public string Path => RequestUri.LocalPath; + /// + public string? UserAgent => Context.Request.UserAgent; + /// + public IHeaderCollection Headers { get; private set; } + /// + public bool CrossOrigin { get; } + /// + public bool IsWebSocketRequest { get; } + /// + public ContentType ContentType => Context.Request.ContentType; + /// + public HttpMethod Method => Context.Request.Method; + /// + public HttpVersion ProtocolVersion => Context.Request.HttpVersion; + /// + public bool IsSecure => Context.Request.EncryptionVersion != SslProtocols.None; + /// + public SslProtocols SecurityProtocol => Context.Request.EncryptionVersion; + /// + public Uri? Origin => Context.Request.Origin; + /// + public Uri? Referer => Context.Request.Referrer; + /// + public Tuple? Range => Context.Request.Range; + /// + public IPEndPoint LocalEndpoint => Context.Request.LocalEndPoint; + /// + public IPEndPoint RemoteEndpoint => Context.Request.RemoteEndPoint; + /// + public Encoding Encoding => Context.ParentServer.Config.HttpEncoding; + /// + public IReadOnlyDictionary RequestCookies => Context.Request.Cookies; + /// + public IEnumerable Accept => Context.Request.Accept; + /// + public TransportSecurityInfo? TransportSecurity => Context.GetSecurityInfo(); + + /// + public bool Accepts(ContentType type) + { + //Get the content type string from he specified content type + string contentType = HttpHelpers.GetContentTypeString(type); + return Accepts(contentType); + } + /// + public bool Accepts(string contentType) + { + if (AcceptsAny()) + { + return true; + } + + //If client accepts exact requested encoding + if (Accept.Contains(contentType)) + { + return true; + } + + //Search accept types to determine if the content type is acceptable + bool accepted = Accept + .Where(ctype => + { + //Get prinary side of mime type + ReadOnlySpan primary = contentType.AsSpan().SliceBeforeParam('/'); + ReadOnlySpan ctSubType = ctype.AsSpan().SliceBeforeParam('/'); + //See if accepts any subtype, or the primary sub-type matches + return ctSubType[0] == '*' || ctSubType.Equals(primary, StringComparison.OrdinalIgnoreCase); + }).Any(); + return accepted; + } + /// + /// Determines if the connection accepts any content type + /// + /// true if the connection accepts any content typ, false otherwise + private bool AcceptsAny() + { + //Accept any if no accept header was present, or accept all value */* + return Context.Request.Accept.Count == 0 || Accept.Where(static t => t.StartsWith("*/*", StringComparison.OrdinalIgnoreCase)).Any(); + } + /// + public void SetCookie(string name, string value, string? domain, string? path, TimeSpan Expires, CookieSameSite sameSite, bool httpOnly, bool secure) + { + //Create the new cookie + HttpCookie cookie = new(name) + { + Value = value, + Domain = domain, + Path = path, + MaxAge = Expires, + //Set the session lifetime flag if the timeout is max value + IsSession = Expires == TimeSpan.MaxValue, + //If the connection is cross origin, then we need to modify the secure and samsite values + SameSite = CrossOrigin ? CookieSameSite.None : sameSite, + Secure = secure | CrossOrigin, + HttpOnly = httpOnly + }; + //Set the cookie + Context.Response.AddCookie(cookie); + } + + internal ConnectionInfo(HttpContext ctx) + { + //Create new header collection + Headers = new VnHeaderCollection(ctx); + //set co value + CrossOrigin = ctx.Request.IsCrossOrigin(); + //Set websocket status + IsWebSocketRequest = ctx.Request.IsWebSocketRequest(); + //Update the context referrence + Context = ctx; + } + +#nullable disable + internal void Clear() + { + Context = null; + (Headers as VnHeaderCollection).Clear(); + Headers = null; + } + } +} \ No newline at end of file diff --git a/lib/Net.Http/src/Core/Request/HttpRequest.cs b/lib/Net.Http/src/Core/Request/HttpRequest.cs index 593275d..356c3f6 100644 --- a/lib/Net.Http/src/Core/Request/HttpRequest.cs +++ b/lib/Net.Http/src/Core/Request/HttpRequest.cs @@ -137,9 +137,12 @@ namespace VNLib.Net.Http.Core public string Compile() { //Alloc char buffer for compilation - using UnsafeMemoryHandle buffer = Memory.UnsafeAlloc(16 * 1024, true); + using UnsafeMemoryHandle buffer = MemoryUtil.UnsafeAlloc(16 * 1024, true); + ForwardOnlyWriter writer = new(buffer.Span); + Compile(ref writer); + return writer.ToString(); } diff --git a/lib/Net.Http/src/Helpers/CoreBufferHelpers.cs b/lib/Net.Http/src/Helpers/CoreBufferHelpers.cs index 5cc5ed9..15c617c 100644 --- a/lib/Net.Http/src/Helpers/CoreBufferHelpers.cs +++ b/lib/Net.Http/src/Helpers/CoreBufferHelpers.cs @@ -104,7 +104,7 @@ namespace VNLib.Net.Http.Core /// public static IUnmangedHeap HttpPrivateHeap => _lazyHeap.Value; - private static readonly Lazy _lazyHeap = new(Memory.InitializeNewHeapForProcess, LazyThreadSafetyMode.PublicationOnly); + private static readonly Lazy _lazyHeap = new(MemoryUtil.InitializeNewHeapForProcess, LazyThreadSafetyMode.PublicationOnly); /// /// Alloctes an unsafe block of memory from the internal heap, or buffer pool @@ -120,11 +120,11 @@ namespace VNLib.Net.Http.Core size = (size / 4096 + 1) * 4096; //If rpmalloc lib is loaded, use it - if (Memory.IsRpMallocLoaded) + if (MemoryUtil.IsRpMallocLoaded) { - return Memory.Shared.UnsafeAlloc(size, zero); + return MemoryUtil.Shared.UnsafeAlloc(size, zero); } - else if (size > Memory.MAX_UNSAFE_POOL_SIZE) + else if (size > MemoryUtil.MAX_UNSAFE_POOL_SIZE) { return HttpPrivateHeap.UnsafeAlloc(size, zero); } @@ -140,12 +140,12 @@ namespace VNLib.Net.Http.Core size = (size / 4096 + 1) * 4096; //If rpmalloc lib is loaded, use it - if (Memory.IsRpMallocLoaded) + if (MemoryUtil.IsRpMallocLoaded) { - return Memory.Shared.DirectAlloc(size, zero); + return MemoryUtil.Shared.DirectAlloc(size, zero); } //Avoid locking in heap unless the buffer is too large to alloc array - else if (size > Memory.MAX_UNSAFE_POOL_SIZE) + else if (size > MemoryUtil.MAX_UNSAFE_POOL_SIZE) { return HttpPrivateHeap.DirectAlloc(size, zero); } diff --git a/lib/Net.Messaging.FBM/src/Client/FBMRequest.cs b/lib/Net.Messaging.FBM/src/Client/FBMRequest.cs index f02724a..0e46582 100644 --- a/lib/Net.Messaging.FBM/src/Client/FBMRequest.cs +++ b/lib/Net.Messaging.FBM/src/Client/FBMRequest.cs @@ -276,8 +276,11 @@ namespace VNLib.Net.Messaging.FBM.Client public string Compile() { int charSize = Helpers.DefaultEncoding.GetCharCount(RequestData.Span); - using UnsafeMemoryHandle buffer = Memory.UnsafeAlloc(charSize + 128); + + using UnsafeMemoryHandle buffer = MemoryUtil.UnsafeAlloc(charSize + 128); + ERRNO count = Compile(buffer.Span); + return buffer.AsSpan(0, count).ToString(); } /// diff --git a/lib/Plugins.Essentials/src/Accounts/AccountManager.cs b/lib/Plugins.Essentials/src/Accounts/AccountManager.cs deleted file mode 100644 index f148fdb..0000000 --- a/lib/Plugins.Essentials/src/Accounts/AccountManager.cs +++ /dev/null @@ -1,872 +0,0 @@ -/* -* Copyright (c) 2022 Vaughn Nugent -* -* Library: VNLib -* Package: VNLib.Plugins.Essentials -* File: AccountManager.cs -* -* AccountManager.cs is part of VNLib.Plugins.Essentials which is part of the larger -* VNLib collection of libraries and utilities. -* -* VNLib.Plugins.Essentials is free software: you can redistribute it and/or modify -* it under the terms of the GNU Affero General Public License as -* published by the Free Software Foundation, either version 3 of the -* License, or (at your option) any later version. -* -* VNLib.Plugins.Essentials is distributed in the hope that it will be useful, -* but WITHOUT ANY WARRANTY; without even the implied warranty of -* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -* GNU Affero General Public License for more details. -* -* You should have received a copy of the GNU Affero General Public License -* along with this program. If not, see https://www.gnu.org/licenses/. -*/ - -using System; -using System.IO; -using System.Text; -using System.Threading.Tasks; -using System.Security.Cryptography; -using System.Text.RegularExpressions; -using System.Runtime.CompilerServices; - -using VNLib.Hashing; -using VNLib.Net.Http; -using VNLib.Utils; -using VNLib.Utils.Memory; -using VNLib.Utils.Extensions; -using VNLib.Plugins.Essentials.Users; -using VNLib.Plugins.Essentials.Sessions; -using VNLib.Plugins.Essentials.Extensions; - - -#nullable enable - -namespace VNLib.Plugins.Essentials.Accounts -{ - - /// - /// Provides essential constants, static methods, and session/user extensions - /// to facilitate unified user-controls, athentication, and security - /// application-wide - /// - public static partial class AccountManager - { - public const int MAX_EMAIL_CHARS = 50; - public const int ID_FIELD_CHARS = 65; - public const int STREET_ADDR_CHARS = 150; - public const int MAX_LOGIN_COUNT = 10; - public const int MAX_FAILED_RESET_ATTEMPS = 5; - - /// - /// The maximum time in seconds for a login message to be considered valid - /// - public const double MAX_TIME_DIFF_SECS = 10.00; - /// - /// The size in bytes of the random passwords generated when invoking the - /// - public const int RANDOM_PASS_SIZE = 128; - /// - /// The name of the header that will identify a client's identiy - /// - public const string LOGIN_TOKEN_HEADER = "X-Web-Token"; - /// - /// The origin string of a local user account. This value will be set if an - /// account is created through the VNLib.Plugins.Essentials.Accounts library - /// - public const string LOCAL_ACCOUNT_ORIGIN = "local"; - /// - /// The size (in bytes) of the challenge secret - /// - public const int CHALLENGE_SIZE = 64; - /// - /// The size (in bytes) of the sesssion long user-password challenge - /// - public const int SESSION_CHALLENGE_SIZE = 128; - - //The buffer size to use when decoding the base64 public key from the user - private const int PUBLIC_KEY_BUFFER_SIZE = 1024; - /// - /// The name of the login cookie set when a user logs in - /// - public const string LOGIN_COOKIE_NAME = "VNLogin"; - /// - /// The name of the login client identifier cookie (cookie that is set fir client to use to determine if the user is logged in) - /// - public const string LOGIN_COOKIE_IDENTIFIER = "li"; - - private const int LOGIN_COOKIE_SIZE = 64; - - //Session entry keys - private const string BROWSER_ID_ENTRY = "acnt.bid"; - private const string CLIENT_PUB_KEY_ENTRY = "acnt.pbk"; - private const string CHALLENGE_HMAC_ENTRY = "acnt.cdig"; - private const string FAILED_LOGIN_ENTRY = "acnt.flc"; - private const string LOCAL_ACCOUNT_ENTRY = "acnt.ila"; - private const string ACC_ORIGIN_ENTRY = "__.org"; - //private const string CHALLENGE_HASH_ENTRY = "acnt.chl"; - - //Privlage masks - public const ulong READ_MSK = 0x0000000000000001L; - public const ulong DOWNLOAD_MSK = 0x0000000000000002L; - public const ulong WRITE_MSK = 0x0000000000000004L; - public const ulong DELETE_MSK = 0x0000000000000008L; - public const ulong ALLFILE_MSK = 0x000000000000000FL; - public const ulong OPTIONS_MSK = 0x000000000000FF00L; - public const ulong GROUP_MSK = 0x00000000FFFF0000L; - public const ulong LEVEL_MSK = 0x000000FF00000000L; - - public const byte OPTIONS_MSK_OFFSET = 0x08; - public const byte GROUP_MSK_OFFSET = 0x10; - public const byte LEVEL_MSK_OFFSET = 0x18; - - public const ulong MINIMUM_LEVEL = 0x0000000100000001L; - - //Timeouts - public static readonly TimeSpan LoginCookieLifespan = TimeSpan.FromHours(1); - public static readonly TimeSpan RegenIdPeriod = TimeSpan.FromMinutes(25); - - /// - /// The client data encryption padding. - /// - public static readonly RSAEncryptionPadding ClientEncryptonPadding = RSAEncryptionPadding.OaepSHA256; - - /// - /// The size (in bytes) of the web-token hash size - /// - private static readonly int TokenHashSize = (SHA384.Create().HashSize / 8); - - /// - /// Speical character regual expresion for basic checks - /// - public static readonly Regex SpecialCharacters = new(@"[\r\n\t\a\b\e\f#?!@$%^&*\+\-\~`|<>\{}]", RegexOptions.Compiled); - - #region Password/User helper extensions - - /// - /// Generates and sets a random password for the specified user account - /// - /// The configured to process the password update on - /// The user instance to update the password on - /// The instance to hash the random password with - /// Size (in bytes) of the generated random password - /// A value indicating the results of the event (number of rows affected, should evaluate to true) - /// - /// - /// - public static async Task SetRandomPasswordAsync(this PasswordHashing passHashing, IUserManager manager, IUser user, int size = RANDOM_PASS_SIZE) - { - _ = manager ?? throw new ArgumentNullException(nameof(manager)); - _ = user ?? throw new ArgumentNullException(nameof(user)); - _ = passHashing ?? throw new ArgumentNullException(nameof(passHashing)); - if (user.IsReleased) - { - throw new ObjectDisposedException("The specifed user object has been released"); - } - //Alloc a buffer - using IMemoryHandle buffer = Memory.SafeAlloc(size); - //Use the CGN to get a random set - RandomHash.GetRandomBytes(buffer.Span); - //Hash the new random password - using PrivateString passHash = passHashing.Hash(buffer.Span); - //Write the password to the user account - return await manager.UpdatePassAsync(user, passHash); - } - - - /// - /// Checks to see if the current user account was created - /// using a local account. - /// - /// - /// True if the account is a local account, false otherwise - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static bool IsLocalAccount(this IUser user) => LOCAL_ACCOUNT_ORIGIN.Equals(user.GetAccountOrigin(), StringComparison.Ordinal); - - /// - /// If this account was created by any means other than a local account creation. - /// Implementors can use this method to determine the origin of the account. - /// This field is not required - /// - /// The origin of the account - public static string GetAccountOrigin(this IUser ud) => ud[ACC_ORIGIN_ENTRY]; - /// - /// If this account was created by any means other than a local account creation. - /// Implementors can use this method to specify the origin of the account. This field is not required - /// - /// - /// Value of the account origin - public static void SetAccountOrigin(this IUser ud, string origin) => ud[ACC_ORIGIN_ENTRY] = origin; - - /// - /// Gets a random user-id generated from crypograhic random number - /// then hashed (SHA1) and returns a hexadecimal string - /// - /// The random string user-id - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static string GetRandomUserId() => RandomHash.GetRandomHash(HashAlg.SHA1, 64, HashEncodingMode.Hexadecimal); - - #endregion - - #region Client Auth Extensions - - /// - /// Runs necessary operations to grant authorization to the specified user of a given session and user with provided variables - /// - /// The connection and session to log-in - /// The message of the client to set the log-in status of - /// The user to log-in - /// The encrypted base64 token secret data to send to the client - /// - /// - public static string GenerateAuthorization(this HttpEntity ev, LoginMessage loginMessage, IUser user) - { - return GenerateAuthorization(ev, loginMessage.ClientPublicKey, loginMessage.ClientID, user); - } - - /// - /// Runs necessary operations to grant authorization to the specified user of a given session and user with provided variables - /// - /// The connection and session to log-in - /// The clients base64 public key - /// The browser/client id - /// The user to log-in - /// The encrypted base64 token secret data to send to the client - /// - /// - /// - public static string GenerateAuthorization(this HttpEntity ev, string base64PubKey, string clientId, IUser user) - { - if (!ev.Session.IsSet || ev.Session.SessionType != SessionType.Web) - { - throw new InvalidOperationException("The session is not set or the session is not a web-based session type"); - } - //derrive token from login data - TryGenerateToken(base64PubKey, out string base64ServerToken, out string base64ClientData); - //Clear flags - user.FailedLoginCount(0); - //Get the "local" account flag from the user object - bool localAccount = user.IsLocalAccount(); - //Set login cookie and session login hash - ev.SetLogin(localAccount); - //Store variables - ev.Session.UserID = user.UserID; - ev.Session.Privilages = user.Privilages; - //Store browserid/client id if specified - SetBrowserID(in ev.Session, clientId); - //Store the clients public key - SetBrowserPubKey(in ev.Session, base64PubKey); - //Set local account flag - ev.Session.HasLocalAccount(localAccount); - //Store the base64 server key to compute the hmac later - ev.Session.Token = base64ServerToken; - //Return the client encrypted data - return base64ClientData; - } - - /* - * Notes for RSA client token generator code below - * - * To log-in a client with the following API the calling code - * must have already determined that the client should be - * logged in (verified passwords or auth tokens). - * - * The client will send a LoginMessage object that will - * contain the following Information. - * - * - The clients RSA public key in base64 subject-key info format - * - The client browser's id hex string - * - The clients local-time - * - * The TryGenerateToken method, will generate a random-byte token, - * encrypt it using the clients RSA public key, return the encrypted - * token data to the client, and only the client will be able to - * decrypt the token data. - * - * The token data is also hashed with SHA-256 (for future use) and - * stored in the client's session store. The client must decrypt - * the token data, hash it, and return it as a header for verification. - * - * Ideally the client should sign the data and send the signature or - * hash back, but it wont prevent MITM, and for now I think it just - * adds extra overhead for every connection during the HttpEvent.TokenMatches() - * check extension method - */ - - private ref struct TokenGenBuffers - { - public readonly Span Buffer { private get; init; } - public readonly Span SignatureBuffer => Buffer[..64]; - - - - public int ClientPbkWritten; - public readonly Span ClientPublicKeyBuffer => Buffer.Slice(64, 1024); - public readonly ReadOnlySpan ClientPbkOutput => ClientPublicKeyBuffer[..ClientPbkWritten]; - - - - public int ClientEncBytesWritten; - public readonly Span ClientEncOutputBuffer => Buffer[(64 + 1024)..]; - public readonly ReadOnlySpan EncryptedOutput => ClientEncOutputBuffer[..ClientEncBytesWritten]; - } - - /// - /// Computes a random buffer, encrypts it with the client's public key, - /// computes the digest of that key and returns the base64 encoded strings - /// of those components - /// - /// The user's public key credential - /// The base64 encoded digest of the secret that was encrypted - /// The client's user-agent header value - /// A string representing a unique signed token for a given login context - /// - /// - private static void TryGenerateToken(string base64clientPublicKey, out string base64Digest, out string base64ClientData) - { - //Temporary work buffer - using IMemoryHandle buffer = Memory.SafeAlloc(4096, true); - /* - * Create a new token buffer for bin buffers. - * This buffer struct is used to break up - * a single block of memory into individual - * non-overlapping (important!) buffer windows - * for named purposes - */ - TokenGenBuffers tokenBuf = new() - { - Buffer = buffer.Span - }; - //Recover the clients public key from its base64 encoding - if (!Convert.TryFromBase64String(base64clientPublicKey, tokenBuf.ClientPublicKeyBuffer, out tokenBuf.ClientPbkWritten)) - { - throw new InternalBufferOverflowException("Failed to recover the clients RSA public key"); - } - /* - * Fill signature buffer with random data - * this signature will be stored and used to verify - * signed client messages. It will also be encryped - * using the clients RSA keys - */ - RandomHash.GetRandomBytes(tokenBuf.SignatureBuffer); - /* - * Setup a new RSA Crypto provider that is initialized with the clients - * supplied public key. RSA will be used to encrypt the server secret - * that only the client will be able to decrypt for the current connection - */ - using RSA rsa = RSA.Create(); - //Setup rsa from the users public key - rsa.ImportSubjectPublicKeyInfo(tokenBuf.ClientPbkOutput, out _); - //try to encypte output data - if (!rsa.TryEncrypt(tokenBuf.SignatureBuffer, tokenBuf.ClientEncOutputBuffer, RSAEncryptionPadding.OaepSHA256, out tokenBuf.ClientEncBytesWritten)) - { - throw new InternalBufferOverflowException("Failed to encrypt the server secret"); - } - //Compute the digest of the raw server key - base64Digest = ManagedHash.ComputeBase64Hash(tokenBuf.SignatureBuffer, HashAlg.SHA384); - /* - * The client will send a hash of the decrypted key and will be used - * as a comparison to the hash string above ^ - */ - base64ClientData = Convert.ToBase64String(tokenBuf.EncryptedOutput, Base64FormattingOptions.None); - } - - /// - /// Determines if the client sent a token header, and it maches against the current session - /// - /// true if the client set the token header, the session is loaded, and the token matches the session, false otherwise - public static bool TokenMatches(this HttpEntity ev) - { - //Get the token from the client header, the client should always sent this - string? clientDigest = ev.Server.Headers[LOGIN_TOKEN_HEADER]; - //Make sure a session is loaded - if (!ev.Session.IsSet || ev.Session.IsNew || string.IsNullOrWhiteSpace(clientDigest)) - { - return false; - } - /* - * Alloc buffer to do conversion and zero initial contents incase the - * payload size has been changed. - * - * The buffer just needs to be large enoguh for the size of the hashes - * that are stored in base64 format. - * - * The values in the buffers will be the raw hash of the client's key - * and the stored key sent during initial authorziation. If the hashes - * are equal it should mean that the client must have the private - * key that generated the public key that was sent - */ - using UnsafeMemoryHandle buffer = Memory.UnsafeAlloc(TokenHashSize * 2, true); - //Slice up buffers - Span headerBuffer = buffer.Span[..TokenHashSize]; - Span sessionBuffer = buffer.Span[TokenHashSize..]; - //Convert the header token and the session token - if (Convert.TryFromBase64String(clientDigest, headerBuffer, out int headerTokenLen) - && Convert.TryFromBase64String(ev.Session.Token, sessionBuffer, out int sessionTokenLen)) - { - //Do a fixed time equal (probably overkill, but should not matter too much) - return CryptographicOperations.FixedTimeEquals(headerBuffer[..headerTokenLen], sessionBuffer[..sessionTokenLen]); - } - return false; - } - - /// - /// Regenerates the user's login token with the public key stored - /// during initial logon - /// - /// The base64 of the newly encrypted secret - public static string? RegenerateClientToken(this HttpEntity ev) - { - if(!ev.Session.IsSet || ev.Session.SessionType != SessionType.Web) - { - return null; - } - //Get the client's stored public key - string clientPublicKey = ev.Session.GetBrowserPubKey(); - //Make sure its set - if (string.IsNullOrWhiteSpace(clientPublicKey)) - { - return null; - } - //Generate a new token using the stored public key - TryGenerateToken(clientPublicKey, out string base64Digest, out string base64ClientData); - //store the token to the user's session - ev.Session.Token = base64Digest; - //return the clients encrypted secret - return base64ClientData; - } - - /// - /// Tries to encrypt the specified data using the stored public key and store the encrypted data into - /// the output buffer. - /// - /// - /// Data to encrypt - /// The buffer to store encrypted data in - /// - /// The number of encrypted bytes written to the output buffer, - /// or false (0) if the operation failed, or if no credential is - /// stored. - /// - /// - public static ERRNO TryEncryptClientData(this in SessionInfo session, ReadOnlySpan data, in Span outputBuffer) - { - if (!session.IsSet) - { - return false; - } - //try to get the public key from the client - string base64PubKey = session.GetBrowserPubKey(); - return TryEncryptClientData(base64PubKey, data, in outputBuffer); - } - /// - /// Tries to encrypt the specified data using the specified public key - /// - /// A base64 encoded public key used to encrypt client data - /// Data to encrypt - /// The buffer to store encrypted data in - /// - /// The number of encrypted bytes written to the output buffer, - /// or false (0) if the operation failed, or if no credential is - /// specified. - /// - /// - public static ERRNO TryEncryptClientData(ReadOnlySpan base64PubKey, ReadOnlySpan data, in Span outputBuffer) - { - if (base64PubKey.IsEmpty) - { - return false; - } - //Alloc a buffer for decoding the public key - using UnsafeMemoryHandle pubKeyBuffer = Memory.UnsafeAlloc(PUBLIC_KEY_BUFFER_SIZE, true); - //Decode the public key - ERRNO pbkBytesWritten = VnEncoding.TryFromBase64Chars(base64PubKey, pubKeyBuffer); - //Try to encrypt the data - return pbkBytesWritten ? TryEncryptClientData(pubKeyBuffer.Span[..(int)pbkBytesWritten], data, in outputBuffer) : false; - } - /// - /// Tries to encrypt the specified data using the specified public key - /// - /// The raw SKI public key - /// Data to encrypt - /// The buffer to store encrypted data in - /// - /// The number of encrypted bytes written to the output buffer, - /// or false (0) if the operation failed, or if no credential is - /// specified. - /// - /// - public static ERRNO TryEncryptClientData(ReadOnlySpan rawPubKey, ReadOnlySpan data, in Span outputBuffer) - { - if (rawPubKey.IsEmpty) - { - return false; - } - //Setup new empty rsa - using RSA rsa = RSA.Create(); - //Import the public key - rsa.ImportSubjectPublicKeyInfo(rawPubKey, out _); - //Encrypt data with OaepSha256 as configured in the browser - return rsa.TryEncrypt(data, outputBuffer, ClientEncryptonPadding, out int bytesWritten) ? bytesWritten : false; - } - - /// - /// Stores the clients public key specified during login - /// - /// - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static void SetBrowserPubKey(in SessionInfo session, string base64PubKey) => session[CLIENT_PUB_KEY_ENTRY] = base64PubKey; - - /// - /// Gets the clients stored public key that was specified during login - /// - /// The base64 encoded public key string specified at login - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static string GetBrowserPubKey(this in SessionInfo session) => session[CLIENT_PUB_KEY_ENTRY]; - - /// - /// Stores the login key as a cookie in the current session as long as the session exists - /// / - /// The event to log-in - /// Does the session belong to a local user account - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static void SetLogin(this HttpEntity ev, bool? localAccount = null) - { - //Make sure the session is loaded - if (!ev.Session.IsSet) - { - return; - } - string loginString = RandomHash.GetRandomBase64(LOGIN_COOKIE_SIZE); - //Set login cookie and session login hash - ev.Server.SetCookie(LOGIN_COOKIE_NAME, loginString, "", "/", LoginCookieLifespan, CookieSameSite.SameSite, true, true); - ev.Session.LoginHash = loginString; - //If not set get from session storage - localAccount ??= ev.Session.HasLocalAccount(); - //Set the client identifier cookie to a value indicating a local account - ev.Server.SetCookie(LOGIN_COOKIE_IDENTIFIER, localAccount.Value ? "1" : "2", "", "/", LoginCookieLifespan, CookieSameSite.SameSite, false, true); - } - - /// - /// Invalidates the login status of the current connection and session (if session is loaded) - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static void InvalidateLogin(this HttpEntity ev) - { - //Expire the login cookie - ev.Server.ExpireCookie(LOGIN_COOKIE_NAME, sameSite: CookieSameSite.SameSite, secure: true); - //Expire the identifier cookie - ev.Server.ExpireCookie(LOGIN_COOKIE_IDENTIFIER, sameSite: CookieSameSite.SameSite, secure: true); - if (ev.Session.IsSet) - { - //Invalidate the session - ev.Session.Invalidate(); - } - } - - /// - /// Determines if the current session login cookie matches the value stored in the current session (if the session is loaded) - /// - /// True if the session is active, the cookie was properly received, and the cookie value matches the session. False otherwise - public static bool LoginCookieMatches(this HttpEntity ev) - { - //Sessions must be loaded - if (!ev.Session.IsSet) - { - return false; - } - //Try to get the login string from the request cookies - if (!ev.Server.RequestCookies.TryGetNonEmptyValue(LOGIN_COOKIE_NAME, out string? liCookie)) - { - return false; - } - /* - * Alloc buffer to do conversion and zero initial contents incase the - * payload size has been changed. - * - * Since the cookie size and the local copy should be the same size - * and equal to the LOGIN_COOKIE_SIZE constant, the buffer size should - * be 2 * LOGIN_COOKIE_SIZE, and it can be split in half and shared - * for both conversions - */ - using UnsafeMemoryHandle buffer = Memory.UnsafeAlloc(2 * LOGIN_COOKIE_SIZE, true); - //Slice up buffers - Span cookieBuffer = buffer.Span[..LOGIN_COOKIE_SIZE]; - Span sessionBuffer = buffer.Span.Slice(LOGIN_COOKIE_SIZE, LOGIN_COOKIE_SIZE); - //Convert cookie and session hash value - if (Convert.TryFromBase64String(liCookie, cookieBuffer, out _) - && Convert.TryFromBase64String(ev.Session.LoginHash, sessionBuffer, out _)) - { - //Do a fixed time equal (probably overkill, but should not matter too much) - if(CryptographicOperations.FixedTimeEquals(cookieBuffer, sessionBuffer)) - { - //If the user is "logged in" and the request is using the POST method, then we can update the cookie - if(ev.Server.Method == HttpMethod.POST && ev.Session.Created.Add(RegenIdPeriod) < DateTimeOffset.UtcNow) - { - //Regen login token - ev.SetLogin(); - ev.Session.RegenID(); - } - - return true; - } - } - return false; - } - - /// - /// Determines if the client's login cookies need to be updated - /// to reflect its state with the current session's state - /// for the client - /// - /// - public static void ReconcileCookies(this HttpEntity ev) - { - //Only handle cookies if session is loaded and is a web based session - if (!ev.Session.IsSet || ev.Session.SessionType != SessionType.Web) - { - return; - } - if (ev.Session.IsNew) - { - //If either login cookies are set on a new session, clear them - if (ev.Server.RequestCookies.ContainsKey(LOGIN_COOKIE_NAME) || ev.Server.RequestCookies.ContainsKey(LOGIN_COOKIE_IDENTIFIER)) - { - //Expire the login cookie - ev.Server.ExpireCookie(LOGIN_COOKIE_NAME, sameSite:CookieSameSite.SameSite, secure:true); - //Expire the identifier cookie - ev.Server.ExpireCookie(LOGIN_COOKIE_IDENTIFIER, sameSite: CookieSameSite.SameSite, secure: true); - } - } - //If the session is not supposed to be logged in, clear the login cookies if they were set - else if (string.IsNullOrEmpty(ev.Session.LoginHash)) - { - //If one of either cookie is not set - if (ev.Server.RequestCookies.ContainsKey(LOGIN_COOKIE_NAME)) - { - //Expire the login cookie - ev.Server.ExpireCookie(LOGIN_COOKIE_NAME, sameSite: CookieSameSite.SameSite, secure: true); - } - if (ev.Server.RequestCookies.ContainsKey(LOGIN_COOKIE_IDENTIFIER)) - { - //Expire the identifier cookie - ev.Server.ExpireCookie(LOGIN_COOKIE_IDENTIFIER, sameSite: CookieSameSite.SameSite, secure: true); - } - } - } - - - /// - /// Stores the browser's id during a login process - /// - /// - /// Browser id value to store - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static void SetBrowserID(in SessionInfo session, string browserId) => session[BROWSER_ID_ENTRY] = browserId; - - /// - /// Gets the current browser's id if it was specified during login process - /// - /// The browser's id if set, otherwise - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static string GetBrowserID(this in SessionInfo session) => session[BROWSER_ID_ENTRY]; - - /// - /// Specifies that the current session belongs to a local user-account - /// - /// - /// True for a local account, false otherwise - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static void HasLocalAccount(this in SessionInfo session, bool value) => session[LOCAL_ACCOUNT_ENTRY] = value ? "1" : null; - /// - /// Gets a value indicating if the session belongs to a local user account - /// - /// - /// True if the current user's account is a local account - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static bool HasLocalAccount(this in SessionInfo session) => int.TryParse(session[LOCAL_ACCOUNT_ENTRY], out int value) && value > 0; - - #endregion - - #region Client Challenge - - /* - * Generates a secret that is used to compute the unique hmac digest of the - * current user's password. The digest is stored in the current session - * and used to compare future requests that require password re-authentication. - * The client will compute the digest of the user's password and send the digest - * instead of the user's password - */ - - /// - /// Generates a new password challenge for the current session and specified password - /// - /// - /// The user's password to compute the hash of - /// The raw derrivation key to send to the client - public static byte[] GenPasswordChallenge(this in SessionInfo session, PrivateString password) - { - ReadOnlySpan rawPass = password; - //Calculate the password buffer size required - int passByteCount = Encoding.UTF8.GetByteCount(rawPass); - //Allocate the buffer - using UnsafeMemoryHandle bufferHandle = Memory.UnsafeAlloc(passByteCount + 64, true); - //Slice buffers - Span utf8PassBytes = bufferHandle.Span[..passByteCount]; - Span hashBuffer = bufferHandle.Span[passByteCount..]; - //Encode the password into the buffer - _ = Encoding.UTF8.GetBytes(rawPass, utf8PassBytes); - try - { - //Get random secret buffer - byte[] secretKey = RandomHash.GetRandomBytes(SESSION_CHALLENGE_SIZE); - //Compute the digest - int count = HMACSHA512.HashData(secretKey, utf8PassBytes, hashBuffer); - //Store the user's password digest - session[CHALLENGE_HMAC_ENTRY] = VnEncoding.ToBase32String(hashBuffer[..count], false); - return secretKey; - } - finally - { - //Wipe buffer - RandomHash.GetRandomBytes(utf8PassBytes); - } - } - /// - /// Verifies the stored unique digest of the user's password against - /// the client derrived password - /// - /// - /// The base64 client derrived digest of the user's password to verify - /// True if formatting was correct and the derrived passwords match, false otherwise - /// - public static bool VerifyChallenge(this in SessionInfo session, ReadOnlySpan base64PasswordDigest) - { - string base32Digest = session[CHALLENGE_HMAC_ENTRY]; - if (string.IsNullOrWhiteSpace(base32Digest)) - { - return false; - } - int bufSize = base32Digest.Length + base64PasswordDigest.Length; - //Alloc buffer - using UnsafeMemoryHandle buffer = Memory.UnsafeAlloc(bufSize); - //Split buffers - Span localBuf = buffer.Span[..base32Digest.Length]; - Span passBuf = buffer.Span[base32Digest.Length..]; - //Recover the stored base32 digest - ERRNO count = VnEncoding.TryFromBase32Chars(base32Digest, localBuf); - if (!count) - { - return false; - } - //Recover base64 bytes - if(!Convert.TryFromBase64Chars(base64PasswordDigest, passBuf, out int passBytesWritten)) - { - return false; - } - //Trim buffers - localBuf = localBuf[..(int)count]; - passBuf = passBuf[..passBytesWritten]; - //Compare and return - return CryptographicOperations.FixedTimeEquals(passBuf, localBuf); - } - - #endregion - - #region Privilage Extensions - /// - /// Compares the users privilage level against the specified level - /// - /// - /// 64bit privilage level to compare - /// true if the current user has at least the specified level or higher - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static bool HasLevel(this in SessionInfo session, byte level) => (session.Privilages & LEVEL_MSK) >= (((ulong)level << LEVEL_MSK_OFFSET) & LEVEL_MSK); - /// - /// Determines if the group ID of the current user matches the specified group - /// - /// - /// Group ID to compare - /// true if the user belongs to the group, false otherwise - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static bool HasGroup(this in SessionInfo session, ushort groupId) => (session.Privilages & GROUP_MSK) == (((ulong)groupId << GROUP_MSK_OFFSET) & GROUP_MSK); - /// - /// Determines if the current user has an equivalent option code - /// - /// - /// Option code check - /// true if the user options field equals the option - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static bool HasOption(this in SessionInfo session, byte option) => (session.Privilages & OPTIONS_MSK) == (((ulong)option << OPTIONS_MSK_OFFSET) & OPTIONS_MSK); - - /// - /// Returns the status of the user's privlage read bit - /// - /// true if the current user has the read permission, false otherwise - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static bool CanRead(this in SessionInfo session) => (session.Privilages & READ_MSK) == READ_MSK; - /// - /// Returns the status of the user's privlage write bit - /// - /// true if the current user has the write permission, false otherwise - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static bool CanWrite(this in SessionInfo session) => (session.Privilages & WRITE_MSK) == WRITE_MSK; - /// - /// Returns the status of the user's privlage delete bit - /// - /// true if the current user has the delete permission, false otherwise - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static bool CanDelete(this in SessionInfo session) => (session.Privilages & DELETE_MSK) == DELETE_MSK; - #endregion - - #region flc - - /// - /// Gets the current number of failed login attempts - /// - /// - /// The current number of failed login attempts - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static TimestampedCounter FailedLoginCount(this IUser user) - { - ulong value = user.GetValueType(FAILED_LOGIN_ENTRY); - return (TimestampedCounter)value; - } - /// - /// Sets the number of failed login attempts for the current session - /// - /// - /// The value to set the failed login attempt count - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static void FailedLoginCount(this IUser user, uint value) - { - TimestampedCounter counter = new(value); - //Cast the counter to a ulong and store as a ulong - user.SetValueType(FAILED_LOGIN_ENTRY, (ulong)counter); - } - /// - /// Sets the number of failed login attempts for the current session - /// - /// - /// The value to set the failed login attempt count - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static void FailedLoginCount(this IUser user, TimestampedCounter value) - { - //Cast the counter to a ulong and store as a ulong - user.SetValueType(FAILED_LOGIN_ENTRY, (ulong)value); - } - /// - /// Increments the failed login attempt count - /// - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static void FailedLoginIncrement(this IUser user) - { - TimestampedCounter current = user.FailedLoginCount(); - user.FailedLoginCount(current.Count + 1); - } - - #endregion - } -} \ No newline at end of file diff --git a/lib/Plugins.Essentials/src/Accounts/AccountUtils.cs b/lib/Plugins.Essentials/src/Accounts/AccountUtils.cs new file mode 100644 index 0000000..610d646 --- /dev/null +++ b/lib/Plugins.Essentials/src/Accounts/AccountUtils.cs @@ -0,0 +1,922 @@ +/* +* Copyright (c) 2022 Vaughn Nugent +* +* Library: VNLib +* Package: VNLib.Plugins.Essentials +* File: AccountManager.cs +* +* AccountManager.cs is part of VNLib.Plugins.Essentials which is part of the larger +* VNLib collection of libraries and utilities. +* +* VNLib.Plugins.Essentials is free software: you can redistribute it and/or modify +* it under the terms of the GNU Affero General Public License as +* published by the Free Software Foundation, either version 3 of the +* License, or (at your option) any later version. +* +* VNLib.Plugins.Essentials is distributed in the hope that it will be useful, +* but WITHOUT ANY WARRANTY; without even the implied warranty of +* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +* GNU Affero General Public License for more details. +* +* You should have received a copy of the GNU Affero General Public License +* along with this program. If not, see https://www.gnu.org/licenses/. +*/ + +using System; +using System.IO; +using System.Text; +using System.Threading.Tasks; +using System.Security.Cryptography; +using System.Text.RegularExpressions; +using System.Runtime.CompilerServices; + +using VNLib.Hashing; +using VNLib.Net.Http; +using VNLib.Utils; +using VNLib.Utils.Memory; +using VNLib.Utils.Extensions; +using VNLib.Plugins.Essentials.Users; +using VNLib.Plugins.Essentials.Sessions; +using VNLib.Plugins.Essentials.Extensions; + +#nullable enable + +namespace VNLib.Plugins.Essentials.Accounts +{ + + /// + /// Provides essential constants, static methods, and session/user extensions + /// to facilitate unified user-controls, athentication, and security + /// application-wide + /// + public static partial class AccountUtil + { + public const int MAX_EMAIL_CHARS = 50; + public const int ID_FIELD_CHARS = 65; + public const int STREET_ADDR_CHARS = 150; + public const int MAX_LOGIN_COUNT = 10; + public const int MAX_FAILED_RESET_ATTEMPS = 5; + + /// + /// The maximum time in seconds for a login message to be considered valid + /// + public const double MAX_TIME_DIFF_SECS = 10.00; + /// + /// The size in bytes of the random passwords generated when invoking the + /// + public const int RANDOM_PASS_SIZE = 128; + /// + /// The name of the header that will identify a client's identiy + /// + public const string LOGIN_TOKEN_HEADER = "X-Web-Token"; + /// + /// The origin string of a local user account. This value will be set if an + /// account is created through the VNLib.Plugins.Essentials.Accounts library + /// + public const string LOCAL_ACCOUNT_ORIGIN = "local"; + /// + /// The size (in bytes) of the challenge secret + /// + public const int CHALLENGE_SIZE = 64; + /// + /// The size (in bytes) of the sesssion long user-password challenge + /// + public const int SESSION_CHALLENGE_SIZE = 128; + + //The buffer size to use when decoding the base64 public key from the user + private const int PUBLIC_KEY_BUFFER_SIZE = 1024; + /// + /// The name of the login cookie set when a user logs in + /// + public const string LOGIN_COOKIE_NAME = "VNLogin"; + /// + /// The name of the login client identifier cookie (cookie that is set fir client to use to determine if the user is logged in) + /// + public const string LOGIN_COOKIE_IDENTIFIER = "li"; + + private const int LOGIN_COOKIE_SIZE = 64; + + //Session entry keys + private const string BROWSER_ID_ENTRY = "acnt.bid"; + private const string CLIENT_PUB_KEY_ENTRY = "acnt.pbk"; + private const string CHALLENGE_HMAC_ENTRY = "acnt.cdig"; + private const string FAILED_LOGIN_ENTRY = "acnt.flc"; + private const string LOCAL_ACCOUNT_ENTRY = "acnt.ila"; + private const string ACC_ORIGIN_ENTRY = "__.org"; + private const string TOKEN_UPDATE_TIME_ENTRY = "acnt.tut"; + //private const string CHALLENGE_HASH_ENTRY = "acnt.chl"; + + //Privlage masks + public const ulong READ_MSK = 0x0000000000000001L; + public const ulong DOWNLOAD_MSK = 0x0000000000000002L; + public const ulong WRITE_MSK = 0x0000000000000004L; + public const ulong DELETE_MSK = 0x0000000000000008L; + public const ulong ALLFILE_MSK = 0x000000000000000FL; + public const ulong OPTIONS_MSK = 0x000000000000FF00L; + public const ulong GROUP_MSK = 0x00000000FFFF0000L; + public const ulong LEVEL_MSK = 0x000000FF00000000L; + + public const byte OPTIONS_MSK_OFFSET = 0x08; + public const byte GROUP_MSK_OFFSET = 0x10; + public const byte LEVEL_MSK_OFFSET = 0x18; + + public const ulong MINIMUM_LEVEL = 0x0000000100000001L; + + //Timeouts + public static readonly TimeSpan LoginCookieLifespan = TimeSpan.FromHours(1); + public static readonly TimeSpan RegenIdPeriod = TimeSpan.FromMinutes(25); + + /// + /// The client data encryption padding. + /// + public static readonly RSAEncryptionPadding ClientEncryptonPadding = RSAEncryptionPadding.OaepSHA256; + + /// + /// The size (in bytes) of the web-token hash size + /// + private static readonly int TokenHashSize = (SHA384.Create().HashSize / 8); + + /// + /// Speical character regual expresion for basic checks + /// + public static readonly Regex SpecialCharacters = new(@"[\r\n\t\a\b\e\f#?!@$%^&*\+\-\~`|<>\{}]", RegexOptions.Compiled); + + #region Password/User helper extensions + + /// + /// Generates and sets a random password for the specified user account + /// + /// The configured to process the password update on + /// The user instance to update the password on + /// The instance to hash the random password with + /// Size (in bytes) of the generated random password + /// A value indicating the results of the event (number of rows affected, should evaluate to true) + /// + /// + /// + public static async Task SetRandomPasswordAsync(this PasswordHashing passHashing, IUserManager manager, IUser user, int size = RANDOM_PASS_SIZE) + { + _ = manager ?? throw new ArgumentNullException(nameof(manager)); + _ = user ?? throw new ArgumentNullException(nameof(user)); + _ = passHashing ?? throw new ArgumentNullException(nameof(passHashing)); + if (user.IsReleased) + { + throw new ObjectDisposedException("The specifed user object has been released"); + } + //Alloc a buffer + using IMemoryHandle buffer = MemoryUtil.SafeAlloc(size); + //Use the CGN to get a random set + RandomHash.GetRandomBytes(buffer.Span); + //Hash the new random password + using PrivateString passHash = passHashing.Hash(buffer.Span); + //Write the password to the user account + return await manager.UpdatePassAsync(user, passHash); + } + + + /// + /// Checks to see if the current user account was created + /// using a local account. + /// + /// + /// True if the account is a local account, false otherwise + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool IsLocalAccount(this IUser user) => LOCAL_ACCOUNT_ORIGIN.Equals(user.GetAccountOrigin(), StringComparison.Ordinal); + + /// + /// If this account was created by any means other than a local account creation. + /// Implementors can use this method to determine the origin of the account. + /// This field is not required + /// + /// The origin of the account + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static string GetAccountOrigin(this IUser ud) => ud[ACC_ORIGIN_ENTRY]; + /// + /// If this account was created by any means other than a local account creation. + /// Implementors can use this method to specify the origin of the account. This field is not required + /// + /// + /// Value of the account origin + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void SetAccountOrigin(this IUser ud, string origin) => ud[ACC_ORIGIN_ENTRY] = origin; + + /// + /// Gets a random user-id generated from crypograhic random number + /// then hashed (SHA1) and returns a hexadecimal string + /// + /// The random string user-id + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static string GetRandomUserId() => RandomHash.GetRandomHash(HashAlg.SHA1, 64, HashEncodingMode.Hexadecimal); + + #endregion + + #region Client Auth Extensions + + /// + /// Runs necessary operations to grant authorization to the specified user of a given session and user with provided variables + /// + /// The connection and session to log-in + /// The message of the client to set the log-in status of + /// The user to log-in + /// The encrypted base64 token secret data to send to the client + /// + /// + public static string GenerateAuthorization(this HttpEntity ev, LoginMessage loginMessage, IUser user) + { + return GenerateAuthorization(ev, loginMessage.ClientPublicKey, loginMessage.ClientID, user); + } + + /// + /// Runs necessary operations to grant authorization to the specified user of a given session and user with provided variables + /// + /// The connection and session to log-in + /// The clients base64 public key + /// The browser/client id + /// The user to log-in + /// The encrypted base64 token secret data to send to the client + /// + /// + /// + public static string GenerateAuthorization(this HttpEntity ev, string base64PubKey, string clientId, IUser user) + { + if (!ev.Session.IsSet || ev.Session.SessionType != SessionType.Web) + { + throw new InvalidOperationException("The session is not set or the session is not a web-based session type"); + } + //Update session-id for "upgrade" + ev.Session.RegenID(); + //derrive token from login data + TryGenerateToken(base64PubKey, out string base64ServerToken, out string base64ClientData); + //Clear flags + user.FailedLoginCount(0); + //Get the "local" account flag from the user object + bool localAccount = user.IsLocalAccount(); + //Set login cookie and session login hash + ev.SetLogin(localAccount); + //Store variables + ev.Session.UserID = user.UserID; + ev.Session.Privilages = user.Privilages; + //Store browserid/client id if specified + SetBrowserID(in ev.Session, clientId); + //Store the clients public key + SetBrowserPubKey(in ev.Session, base64PubKey); + //Set local account flag + ev.Session.HasLocalAccount(localAccount); + //Store the base64 server key to compute the hmac later + ev.Session.Token = base64ServerToken; + //Update the last token upgrade time + ev.Session.LastTokenUpgrade(ev.RequestedTimeUtc); + //Return the client encrypted data + return base64ClientData; + } + + /* + * Notes for RSA client token generator code below + * + * To log-in a client with the following API the calling code + * must have already determined that the client should be + * logged in (verified passwords or auth tokens). + * + * The client will send a LoginMessage object that will + * contain the following Information. + * + * - The clients RSA public key in base64 subject-key info format + * - The client browser's id hex string + * - The clients local-time + * + * The TryGenerateToken method, will generate a random-byte token, + * encrypt it using the clients RSA public key, return the encrypted + * token data to the client, and only the client will be able to + * decrypt the token data. + * + * The token data is also hashed with SHA-256 (for future use) and + * stored in the client's session store. The client must decrypt + * the token data, hash it, and return it as a header for verification. + * + * Ideally the client should sign the data and send the signature or + * hash back, but it wont prevent MITM, and for now I think it just + * adds extra overhead for every connection during the HttpEvent.TokenMatches() + * check extension method + */ + + private ref struct TokenGenBuffers + { + public readonly Span Buffer { private get; init; } + public readonly Span SignatureBuffer => Buffer[..64]; + + + + public int ClientPbkWritten; + public readonly Span ClientPublicKeyBuffer => Buffer.Slice(64, 1024); + public readonly ReadOnlySpan ClientPbkOutput => ClientPublicKeyBuffer[..ClientPbkWritten]; + + + + public int ClientEncBytesWritten; + public readonly Span ClientEncOutputBuffer => Buffer[(64 + 1024)..]; + public readonly ReadOnlySpan EncryptedOutput => ClientEncOutputBuffer[..ClientEncBytesWritten]; + } + + /// + /// Computes a random buffer, encrypts it with the client's public key, + /// computes the digest of that key and returns the base64 encoded strings + /// of those components + /// + /// The user's public key credential + /// The base64 encoded digest of the secret that was encrypted + /// The client's user-agent header value + /// A string representing a unique signed token for a given login context + /// + /// + private static void TryGenerateToken(string base64clientPublicKey, out string base64Digest, out string base64ClientData) + { + //Temporary work buffer + using IMemoryHandle buffer = MemoryUtil.SafeAlloc(4096, true); + /* + * Create a new token buffer for bin buffers. + * This buffer struct is used to break up + * a single block of memory into individual + * non-overlapping (important!) buffer windows + * for named purposes + */ + TokenGenBuffers tokenBuf = new() + { + Buffer = buffer.Span + }; + //Recover the clients public key from its base64 encoding + if (!Convert.TryFromBase64String(base64clientPublicKey, tokenBuf.ClientPublicKeyBuffer, out tokenBuf.ClientPbkWritten)) + { + throw new InternalBufferOverflowException("Failed to recover the clients RSA public key"); + } + /* + * Fill signature buffer with random data + * this signature will be stored and used to verify + * signed client messages. It will also be encryped + * using the clients RSA keys + */ + RandomHash.GetRandomBytes(tokenBuf.SignatureBuffer); + /* + * Setup a new RSA Crypto provider that is initialized with the clients + * supplied public key. RSA will be used to encrypt the server secret + * that only the client will be able to decrypt for the current connection + */ + using RSA rsa = RSA.Create(); + //Setup rsa from the users public key + rsa.ImportSubjectPublicKeyInfo(tokenBuf.ClientPbkOutput, out _); + //try to encypte output data + if (!rsa.TryEncrypt(tokenBuf.SignatureBuffer, tokenBuf.ClientEncOutputBuffer, RSAEncryptionPadding.OaepSHA256, out tokenBuf.ClientEncBytesWritten)) + { + throw new InternalBufferOverflowException("Failed to encrypt the server secret"); + } + //Compute the digest of the raw server key + base64Digest = ManagedHash.ComputeBase64Hash(tokenBuf.SignatureBuffer, HashAlg.SHA384); + /* + * The client will send a hash of the decrypted key and will be used + * as a comparison to the hash string above ^ + */ + base64ClientData = Convert.ToBase64String(tokenBuf.EncryptedOutput, Base64FormattingOptions.None); + } + + /// + /// Determines if the client sent a token header, and it maches against the current session + /// + /// true if the client set the token header, the session is loaded, and the token matches the session, false otherwise + public static bool TokenMatches(this HttpEntity ev) + { + //Get the token from the client header, the client should always sent this + string? clientDigest = ev.Server.Headers[LOGIN_TOKEN_HEADER]; + //Make sure a session is loaded + if (!ev.Session.IsSet || ev.Session.IsNew || string.IsNullOrWhiteSpace(clientDigest)) + { + return false; + } + /* + * Alloc buffer to do conversion and zero initial contents incase the + * payload size has been changed. + * + * The buffer just needs to be large enoguh for the size of the hashes + * that are stored in base64 format. + * + * The values in the buffers will be the raw hash of the client's key + * and the stored key sent during initial authorziation. If the hashes + * are equal it should mean that the client must have the private + * key that generated the public key that was sent + */ + using UnsafeMemoryHandle buffer = MemoryUtil.UnsafeAlloc(TokenHashSize * 2, true); + //Slice up buffers + Span headerBuffer = buffer.Span[..TokenHashSize]; + Span sessionBuffer = buffer.Span[TokenHashSize..]; + //Convert the header token and the session token + if (Convert.TryFromBase64String(clientDigest, headerBuffer, out int headerTokenLen) + && Convert.TryFromBase64String(ev.Session.Token, sessionBuffer, out int sessionTokenLen)) + { + //Do a fixed time equal (probably overkill, but should not matter too much) + if(CryptographicOperations.FixedTimeEquals(headerBuffer[..headerTokenLen], sessionBuffer[..sessionTokenLen])) + { + return true; + } + } + + /* + * If the token does not match, or cannot be found, check if the client + * has login cookies set, if not remove them. + * + * This does not affect the session, but allows for a web client to update + * its login state if its no-longer logged in + */ + + //Expire login cookie if set + if (ev.Server.RequestCookies.ContainsKey(LOGIN_COOKIE_NAME)) + { + ev.Server.ExpireCookie(LOGIN_COOKIE_NAME, sameSite: CookieSameSite.SameSite); + } + //Expire the LI cookie if set + if (ev.Server.RequestCookies.ContainsKey(LOGIN_COOKIE_IDENTIFIER)) + { + ev.Server.ExpireCookie(LOGIN_COOKIE_IDENTIFIER, sameSite: CookieSameSite.SameSite); + } + + return false; + } + + /// + /// Regenerates the user's login token with the public key stored + /// during initial logon + /// + /// The base64 of the newly encrypted secret + public static string? RegenerateClientToken(this HttpEntity ev) + { + if(!ev.Session.IsSet || ev.Session.SessionType != SessionType.Web) + { + return null; + } + //Get the client's stored public key + string clientPublicKey = ev.Session.GetBrowserPubKey(); + //Make sure its set + if (string.IsNullOrWhiteSpace(clientPublicKey)) + { + return null; + } + //Generate a new token using the stored public key + TryGenerateToken(clientPublicKey, out string base64Digest, out string base64ClientData); + //store the token to the user's session + ev.Session.Token = base64Digest; + //Update the last token upgrade time + ev.Session.LastTokenUpgrade(ev.RequestedTimeUtc); + //return the clients encrypted secret + return base64ClientData; + } + + /// + /// Tries to encrypt the specified data using the stored public key and store the encrypted data into + /// the output buffer. + /// + /// + /// Data to encrypt + /// The buffer to store encrypted data in + /// + /// The number of encrypted bytes written to the output buffer, + /// or false (0) if the operation failed, or if no credential is + /// stored. + /// + /// + public static ERRNO TryEncryptClientData(this in SessionInfo session, ReadOnlySpan data, in Span outputBuffer) + { + if (!session.IsSet) + { + return false; + } + //try to get the public key from the client + string base64PubKey = session.GetBrowserPubKey(); + return TryEncryptClientData(base64PubKey, data, in outputBuffer); + } + /// + /// Tries to encrypt the specified data using the specified public key + /// + /// A base64 encoded public key used to encrypt client data + /// Data to encrypt + /// The buffer to store encrypted data in + /// + /// The number of encrypted bytes written to the output buffer, + /// or false (0) if the operation failed, or if no credential is + /// specified. + /// + /// + public static ERRNO TryEncryptClientData(ReadOnlySpan base64PubKey, ReadOnlySpan data, in Span outputBuffer) + { + if (base64PubKey.IsEmpty) + { + return false; + } + //Alloc a buffer for decoding the public key + using UnsafeMemoryHandle pubKeyBuffer = MemoryUtil.UnsafeAlloc(PUBLIC_KEY_BUFFER_SIZE, true); + //Decode the public key + ERRNO pbkBytesWritten = VnEncoding.TryFromBase64Chars(base64PubKey, pubKeyBuffer); + //Try to encrypt the data + return pbkBytesWritten ? TryEncryptClientData(pubKeyBuffer.Span[..(int)pbkBytesWritten], data, in outputBuffer) : false; + } + /// + /// Tries to encrypt the specified data using the specified public key + /// + /// The raw SKI public key + /// Data to encrypt + /// The buffer to store encrypted data in + /// + /// The number of encrypted bytes written to the output buffer, + /// or false (0) if the operation failed, or if no credential is + /// specified. + /// + /// + public static ERRNO TryEncryptClientData(ReadOnlySpan rawPubKey, ReadOnlySpan data, in Span outputBuffer) + { + if (rawPubKey.IsEmpty) + { + return false; + } + //Setup new empty rsa + using RSA rsa = RSA.Create(); + //Import the public key + rsa.ImportSubjectPublicKeyInfo(rawPubKey, out _); + //Encrypt data with OaepSha256 as configured in the browser + return rsa.TryEncrypt(data, outputBuffer, ClientEncryptonPadding, out int bytesWritten) ? bytesWritten : false; + } + + /// + /// Stores the clients public key specified during login + /// + /// + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static void SetBrowserPubKey(in SessionInfo session, string base64PubKey) => session[CLIENT_PUB_KEY_ENTRY] = base64PubKey; + + /// + /// Gets the clients stored public key that was specified during login + /// + /// The base64 encoded public key string specified at login + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static string GetBrowserPubKey(this in SessionInfo session) => session[CLIENT_PUB_KEY_ENTRY]; + + /// + /// Stores the login key as a cookie in the current session as long as the session exists + /// / + /// The event to log-in + /// Does the session belong to a local user account + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static void SetLogin(this HttpEntity ev, bool? localAccount = null) + { + //Make sure the session is loaded + if (!ev.Session.IsSet) + { + return; + } + string loginString = RandomHash.GetRandomBase64(LOGIN_COOKIE_SIZE); + //Set login cookie and session login hash + ev.Server.SetCookie(LOGIN_COOKIE_NAME, loginString, "", "/", LoginCookieLifespan, CookieSameSite.SameSite, true, true); + ev.Session.LoginHash = loginString; + //If not set get from session storage + localAccount ??= ev.Session.HasLocalAccount(); + //Set the client identifier cookie to a value indicating a local account + ev.Server.SetCookie(LOGIN_COOKIE_IDENTIFIER, localAccount.Value ? "1" : "2", "", "/", LoginCookieLifespan, CookieSameSite.SameSite, false, true); + } + + /// + /// Invalidates the login status of the current connection and session (if session is loaded) + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void InvalidateLogin(this HttpEntity ev) + { + //Expire the login cookie + ev.Server.ExpireCookie(LOGIN_COOKIE_NAME, sameSite: CookieSameSite.SameSite, secure: true); + //Expire the identifier cookie + ev.Server.ExpireCookie(LOGIN_COOKIE_IDENTIFIER, sameSite: CookieSameSite.SameSite, secure: true); + if (ev.Session.IsSet) + { + //Invalidate the session + ev.Session.Invalidate(); + } + } + + /// + /// Determines if the current session login cookie matches the value stored in the current session (if the session is loaded) + /// + /// True if the session is active, the cookie was properly received, and the cookie value matches the session. False otherwise + public static bool LoginCookieMatches(this HttpEntity ev) + { + //Sessions must be loaded + if (!ev.Session.IsSet) + { + return false; + } + //Try to get the login string from the request cookies + if (!ev.Server.RequestCookies.TryGetNonEmptyValue(LOGIN_COOKIE_NAME, out string? liCookie)) + { + return false; + } + /* + * Alloc buffer to do conversion and zero initial contents incase the + * payload size has been changed. + * + * Since the cookie size and the local copy should be the same size + * and equal to the LOGIN_COOKIE_SIZE constant, the buffer size should + * be 2 * LOGIN_COOKIE_SIZE, and it can be split in half and shared + * for both conversions + */ + using UnsafeMemoryHandle buffer = MemoryUtil.UnsafeAlloc(2 * LOGIN_COOKIE_SIZE, true); + //Slice up buffers + Span cookieBuffer = buffer.Span[..LOGIN_COOKIE_SIZE]; + Span sessionBuffer = buffer.Span.Slice(LOGIN_COOKIE_SIZE, LOGIN_COOKIE_SIZE); + //Convert cookie and session hash value + if (Convert.TryFromBase64String(liCookie, cookieBuffer, out _) + && Convert.TryFromBase64String(ev.Session.LoginHash, sessionBuffer, out _)) + { + //Do a fixed time equal (probably overkill, but should not matter too much) + if(CryptographicOperations.FixedTimeEquals(cookieBuffer, sessionBuffer)) + { + //If the user is "logged in" and the request is using the POST method, then we can update the cookie + if(ev.Server.Method == HttpMethod.POST && ev.Session.Created.Add(RegenIdPeriod) < ev.RequestedTimeUtc) + { + //Regen login token + ev.SetLogin(); + ev.Session.RegenID(); + } + + return true; + } + } + return false; + } + + /// + /// Determines if the client's login cookies need to be updated + /// to reflect its state with the current session's state + /// for the client + /// + /// + public static void ReconcileCookies(this HttpEntity ev) + { + //Only handle cookies if session is loaded and is a web based session + if (!ev.Session.IsSet || ev.Session.SessionType != SessionType.Web) + { + return; + } + if (ev.Session.IsNew) + { + //If either login cookies are set on a new session, clear them + if (ev.Server.RequestCookies.ContainsKey(LOGIN_COOKIE_NAME) || ev.Server.RequestCookies.ContainsKey(LOGIN_COOKIE_IDENTIFIER)) + { + //Expire the login cookie + ev.Server.ExpireCookie(LOGIN_COOKIE_NAME, sameSite:CookieSameSite.SameSite, secure:true); + //Expire the identifier cookie + ev.Server.ExpireCookie(LOGIN_COOKIE_IDENTIFIER, sameSite: CookieSameSite.SameSite, secure: true); + } + } + //If the session is not supposed to be logged in, clear the login cookies if they were set + else if (string.IsNullOrEmpty(ev.Session.LoginHash)) + { + //If one of either cookie is not set + if (ev.Server.RequestCookies.ContainsKey(LOGIN_COOKIE_NAME)) + { + //Expire the login cookie + ev.Server.ExpireCookie(LOGIN_COOKIE_NAME, sameSite: CookieSameSite.SameSite, secure: true); + } + if (ev.Server.RequestCookies.ContainsKey(LOGIN_COOKIE_IDENTIFIER)) + { + //Expire the identifier cookie + ev.Server.ExpireCookie(LOGIN_COOKIE_IDENTIFIER, sameSite: CookieSameSite.SameSite, secure: true); + } + } + } + + /// + /// Gets the last time the session token was set + /// + /// + /// The last time the token was updated/generated, or if not set + public static DateTimeOffset LastTokenUpgrade(this in SessionInfo session) + { + //Get the serialized time value + string timeString = session[TOKEN_UPDATE_TIME_ENTRY]; + return long.TryParse(timeString, out long time) ? DateTimeOffset.FromUnixTimeSeconds(time) : DateTimeOffset.MinValue; + } + + /// + /// Updates the last time the session token was set + /// + /// + /// The UTC time the last token was set + private static void LastTokenUpgrade(this in SessionInfo session, DateTimeOffset updated) + => session[TOKEN_UPDATE_TIME_ENTRY] = updated.ToUnixTimeSeconds().ToString(); + + /// + /// Stores the browser's id during a login process + /// + /// + /// Browser id value to store + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static void SetBrowserID(in SessionInfo session, string browserId) => session[BROWSER_ID_ENTRY] = browserId; + + /// + /// Gets the current browser's id if it was specified during login process + /// + /// The browser's id if set, otherwise + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static string GetBrowserID(this in SessionInfo session) => session[BROWSER_ID_ENTRY]; + + /// + /// Specifies that the current session belongs to a local user-account + /// + /// + /// True for a local account, false otherwise + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static void HasLocalAccount(this in SessionInfo session, bool value) => session[LOCAL_ACCOUNT_ENTRY] = value ? "1" : null; + /// + /// Gets a value indicating if the session belongs to a local user account + /// + /// + /// True if the current user's account is a local account + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool HasLocalAccount(this in SessionInfo session) => int.TryParse(session[LOCAL_ACCOUNT_ENTRY], out int value) && value > 0; + + #endregion + + #region Client Challenge + + /* + * Generates a secret that is used to compute the unique hmac digest of the + * current user's password. The digest is stored in the current session + * and used to compare future requests that require password re-authentication. + * The client will compute the digest of the user's password and send the digest + * instead of the user's password + */ + + /// + /// Generates a new password challenge for the current session and specified password + /// + /// + /// The user's password to compute the hash of + /// The raw derrivation key to send to the client + public static byte[] GenPasswordChallenge(this in SessionInfo session, PrivateString password) + { + ReadOnlySpan rawPass = password; + //Calculate the password buffer size required + int passByteCount = Encoding.UTF8.GetByteCount(rawPass); + //Allocate the buffer + using UnsafeMemoryHandle bufferHandle = MemoryUtil.UnsafeAlloc(passByteCount + 64, true); + //Slice buffers + Span utf8PassBytes = bufferHandle.Span[..passByteCount]; + Span hashBuffer = bufferHandle.Span[passByteCount..]; + //Encode the password into the buffer + _ = Encoding.UTF8.GetBytes(rawPass, utf8PassBytes); + try + { + //Get random secret buffer + byte[] secretKey = RandomHash.GetRandomBytes(SESSION_CHALLENGE_SIZE); + //Compute the digest + int count = HMACSHA512.HashData(secretKey, utf8PassBytes, hashBuffer); + //Store the user's password digest + session[CHALLENGE_HMAC_ENTRY] = VnEncoding.ToBase32String(hashBuffer[..count], false); + return secretKey; + } + finally + { + //Wipe buffer + RandomHash.GetRandomBytes(utf8PassBytes); + } + } + /// + /// Verifies the stored unique digest of the user's password against + /// the client derrived password + /// + /// + /// The base64 client derrived digest of the user's password to verify + /// True if formatting was correct and the derrived passwords match, false otherwise + /// + public static bool VerifyChallenge(this in SessionInfo session, ReadOnlySpan base64PasswordDigest) + { + string base32Digest = session[CHALLENGE_HMAC_ENTRY]; + if (string.IsNullOrWhiteSpace(base32Digest)) + { + return false; + } + int bufSize = base32Digest.Length + base64PasswordDigest.Length; + //Alloc buffer + using UnsafeMemoryHandle buffer = MemoryUtil.UnsafeAlloc(bufSize); + //Split buffers + Span localBuf = buffer.Span[..base32Digest.Length]; + Span passBuf = buffer.Span[base32Digest.Length..]; + //Recover the stored base32 digest + ERRNO count = VnEncoding.TryFromBase32Chars(base32Digest, localBuf); + if (!count) + { + return false; + } + //Recover base64 bytes + if(!Convert.TryFromBase64Chars(base64PasswordDigest, passBuf, out int passBytesWritten)) + { + return false; + } + //Trim buffers + localBuf = localBuf[..(int)count]; + passBuf = passBuf[..passBytesWritten]; + //Compare and return + return CryptographicOperations.FixedTimeEquals(passBuf, localBuf); + } + + #endregion + + #region Privilage Extensions + /// + /// Compares the users privilage level against the specified level + /// + /// + /// 64bit privilage level to compare + /// true if the current user has at least the specified level or higher + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool HasLevel(this in SessionInfo session, byte level) => (session.Privilages & LEVEL_MSK) >= (((ulong)level << LEVEL_MSK_OFFSET) & LEVEL_MSK); + /// + /// Determines if the group ID of the current user matches the specified group + /// + /// + /// Group ID to compare + /// true if the user belongs to the group, false otherwise + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool HasGroup(this in SessionInfo session, ushort groupId) => (session.Privilages & GROUP_MSK) == (((ulong)groupId << GROUP_MSK_OFFSET) & GROUP_MSK); + /// + /// Determines if the current user has an equivalent option code + /// + /// + /// Option code check + /// true if the user options field equals the option + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool HasOption(this in SessionInfo session, byte option) => (session.Privilages & OPTIONS_MSK) == (((ulong)option << OPTIONS_MSK_OFFSET) & OPTIONS_MSK); + + /// + /// Returns the status of the user's privlage read bit + /// + /// true if the current user has the read permission, false otherwise + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool CanRead(this in SessionInfo session) => (session.Privilages & READ_MSK) == READ_MSK; + /// + /// Returns the status of the user's privlage write bit + /// + /// true if the current user has the write permission, false otherwise + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool CanWrite(this in SessionInfo session) => (session.Privilages & WRITE_MSK) == WRITE_MSK; + /// + /// Returns the status of the user's privlage delete bit + /// + /// true if the current user has the delete permission, false otherwise + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool CanDelete(this in SessionInfo session) => (session.Privilages & DELETE_MSK) == DELETE_MSK; + #endregion + + #region flc + + /// + /// Gets the current number of failed login attempts + /// + /// + /// The current number of failed login attempts + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static TimestampedCounter FailedLoginCount(this IUser user) + { + ulong value = user.GetValueType(FAILED_LOGIN_ENTRY); + return (TimestampedCounter)value; + } + /// + /// Sets the number of failed login attempts for the current session + /// + /// + /// The value to set the failed login attempt count + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void FailedLoginCount(this IUser user, uint value) + { + TimestampedCounter counter = new(value); + //Cast the counter to a ulong and store as a ulong + user.SetValueType(FAILED_LOGIN_ENTRY, (ulong)counter); + } + /// + /// Sets the number of failed login attempts for the current session + /// + /// + /// The value to set the failed login attempt count + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void FailedLoginCount(this IUser user, TimestampedCounter value) + { + //Cast the counter to a ulong and store as a ulong + user.SetValueType(FAILED_LOGIN_ENTRY, (ulong)value); + } + /// + /// Increments the failed login attempt count + /// + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void FailedLoginIncrement(this IUser user) + { + TimestampedCounter current = user.FailedLoginCount(); + user.FailedLoginCount(current.Count + 1); + } + + #endregion + } +} \ No newline at end of file diff --git a/lib/Plugins.Essentials/src/Accounts/INonce.cs b/lib/Plugins.Essentials/src/Accounts/INonce.cs index 7d53183..3a1b779 100644 --- a/lib/Plugins.Essentials/src/Accounts/INonce.cs +++ b/lib/Plugins.Essentials/src/Accounts/INonce.cs @@ -24,9 +24,6 @@ using System; -using VNLib.Utils; -using VNLib.Utils.Memory; - namespace VNLib.Plugins.Essentials.Accounts { /// @@ -48,43 +45,4 @@ namespace VNLib.Plugins.Essentials.Accounts /// True if the nonce values are equal, flase otherwise bool VerifyNonce(ReadOnlySpan nonceBytes); } - - /// - /// Provides INonce extensions for computing/verifying nonce values - /// - public static class NonceExtensions - { - /// - /// Computes a base32 nonce of the specified size and returns a string - /// representation - /// - /// - /// The size (in bytes) of the nonce - /// The base32 string of the computed nonce - public static string ComputeNonce(this T nonce, int size) where T: INonce - { - //Alloc bin buffer - using UnsafeMemoryHandle buffer = Memory.UnsafeAlloc(size); - //Compute nonce - nonce.ComputeNonce(buffer.Span); - //Return base32 string - return VnEncoding.ToBase32String(buffer.Span, false); - } - /// - /// Compares the base32 encoded nonce value against the previously - /// generated nonce - /// - /// - /// The base32 encoded nonce string - /// True if the nonce values are equal, flase otherwise - public static bool VerifyNonce(this T nonce, ReadOnlySpan base32Nonce) where T : INonce - { - //Alloc bin buffer - using UnsafeMemoryHandle buffer = Memory.UnsafeAlloc(base32Nonce.Length); - //Decode base32 nonce - ERRNO count = VnEncoding.TryFromBase32Chars(base32Nonce, buffer.Span); - //Verify nonce - return nonce.VerifyNonce(buffer.Span[..(int)count]); - } - } } diff --git a/lib/Plugins.Essentials/src/Accounts/ISecretProvider.cs b/lib/Plugins.Essentials/src/Accounts/ISecretProvider.cs new file mode 100644 index 0000000..41fb44d --- /dev/null +++ b/lib/Plugins.Essentials/src/Accounts/ISecretProvider.cs @@ -0,0 +1,49 @@ +/* +* Copyright (c) 2022 Vaughn Nugent +* +* Library: VNLib +* Package: VNLib.Plugins.Essentials +* File: ISecretProvider.cs +* +* ISecretProvider.cs is part of VNLib.Plugins.Essentials which is part of the larger +* VNLib collection of libraries and utilities. +* +* VNLib.Plugins.Essentials is free software: you can redistribute it and/or modify +* it under the terms of the GNU Affero General Public License as +* published by the Free Software Foundation, either version 3 of the +* License, or (at your option) any later version. +* +* VNLib.Plugins.Essentials is distributed in the hope that it will be useful, +* but WITHOUT ANY WARRANTY; without even the implied warranty of +* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +* GNU Affero General Public License for more details. +* +* You should have received a copy of the GNU Affero General Public License +* along with this program. If not, see https://www.gnu.org/licenses/. +*/ + +using System; + +using VNLib.Utils; + +namespace VNLib.Plugins.Essentials.Accounts +{ + /// + /// Provides a password hashing secret aka pepper. + /// + public interface ISecretProvider + { + /// + /// The size of the buffer to use when retrieving the secret + /// + int BufferSize { get; } + + /// + /// Writes the secret to the buffer and returns the number of bytes + /// written to the buffer + /// + /// The buffer to write the secret data to + /// The number of secret bytes written to the buffer + ERRNO GetSecret(Span buffer); + } +} \ No newline at end of file diff --git a/lib/Plugins.Essentials/src/Accounts/NonceExtensions.cs b/lib/Plugins.Essentials/src/Accounts/NonceExtensions.cs new file mode 100644 index 0000000..5a40d29 --- /dev/null +++ b/lib/Plugins.Essentials/src/Accounts/NonceExtensions.cs @@ -0,0 +1,75 @@ +/* +* Copyright (c) 2022 Vaughn Nugent +* +* Library: VNLib +* Package: VNLib.Plugins.Essentials +* File: NonceExtensions.cs +* +* NonceExtensions.cs is part of VNLib.Plugins.Essentials which is part of the larger +* VNLib collection of libraries and utilities. +* +* VNLib.Plugins.Essentials is free software: you can redistribute it and/or modify +* it under the terms of the GNU Affero General Public License as +* published by the Free Software Foundation, either version 3 of the +* License, or (at your option) any later version. +* +* VNLib.Plugins.Essentials is distributed in the hope that it will be useful, +* but WITHOUT ANY WARRANTY; without even the implied warranty of +* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +* GNU Affero General Public License for more details. +* +* You should have received a copy of the GNU Affero General Public License +* along with this program. If not, see https://www.gnu.org/licenses/. +*/ + +using System; + +using VNLib.Utils; +using VNLib.Utils.Memory; + +namespace VNLib.Plugins.Essentials.Accounts +{ + /// + /// Provides INonce extensions for computing/verifying nonce values + /// + public static class NonceExtensions + { + /// + /// Computes a base32 nonce of the specified size and returns a string + /// representation + /// + /// + /// The size (in bytes) of the nonce + /// The base32 string of the computed nonce + public static string ComputeNonce(this T nonce, int size) where T: INonce + { + //Alloc bin buffer + using UnsafeMemoryHandle buffer = MemoryUtil.UnsafeAlloc(size); + + //Compute nonce + nonce.ComputeNonce(buffer.Span); + + //Return base32 string + return VnEncoding.ToBase32String(buffer.Span, false); + } + + /// + /// Compares the base32 encoded nonce value against the previously + /// generated nonce + /// + /// + /// The base32 encoded nonce string + /// True if the nonce values are equal, flase otherwise + public static bool VerifyNonce(this T nonce, ReadOnlySpan base32Nonce) where T : INonce + { + //Alloc bin buffer + using UnsafeMemoryHandle buffer = MemoryUtil.UnsafeAlloc(base32Nonce.Length); + + //Decode base32 nonce + ERRNO count = VnEncoding.TryFromBase32Chars(base32Nonce, buffer.Span); + + //Verify nonce + return nonce.VerifyNonce(buffer.Span[..(int)count]); + } + } +} diff --git a/lib/Plugins.Essentials/src/Accounts/PasswordHashing.cs b/lib/Plugins.Essentials/src/Accounts/PasswordHashing.cs index 1c3770b..9dc3ea1 100644 --- a/lib/Plugins.Essentials/src/Accounts/PasswordHashing.cs +++ b/lib/Plugins.Essentials/src/Accounts/PasswordHashing.cs @@ -32,21 +32,12 @@ using VNLib.Utils.Memory; namespace VNLib.Plugins.Essentials.Accounts { /// - /// A delegate method to recover a temporary copy of the secret/pepper - /// for a request - /// - /// The buffer to write the pepper to - /// The number of bytes written to the buffer - public delegate ERRNO SecretAction(Span buffer); - - /// - /// Provides a structrued password hashing system implementing the library + /// Provides a structured password hashing system implementing the library /// with fixed time comparison /// public sealed class PasswordHashing { - private readonly SecretAction _getter; - private readonly int _secretSize; + private readonly ISecretProvider _secret; private readonly uint TimeCost; private readonly uint MemoryCost; @@ -57,23 +48,20 @@ namespace VNLib.Plugins.Essentials.Accounts /// /// Initalizes the class /// - /// - /// The expected size of the secret (the size of the buffer to alloc for a copy) + /// The password secret provider /// A positive integer for the size of the random salt used during the hashing proccess /// The Argon2 time cost parameter /// The Argon2 memory cost parameter /// The size of the hash to produce during hashing operations /// /// The Argon2 parallelism parameter (the number of threads to use for hasing) - /// (default = 0 - the number of processors) + /// (default = 0 - defaults to the number of logical processors) /// /// - /// - public PasswordHashing(SecretAction getter, int secreteSize, int saltLen = 32, uint timeCost = 4, uint memoryCost = UInt16.MaxValue, uint parallism = 0, uint hashLen = 128) + public PasswordHashing(ISecretProvider secret, int saltLen = 32, uint timeCost = 4, uint memoryCost = UInt16.MaxValue, uint parallism = 0, uint hashLen = 128) { //Store getter - _getter = getter ?? throw new ArgumentNullException(nameof(getter)); - _secretSize = secreteSize; + _secret = secret ?? throw new ArgumentNullException(nameof(secret)); //Store parameters HashLen = hashLen; @@ -114,18 +102,18 @@ namespace VNLib.Plugins.Essentials.Accounts return false; } //alloc secret buffer - using UnsafeMemoryHandle secretBuffer = Memory.UnsafeAlloc(_secretSize, true); + using UnsafeMemoryHandle secretBuffer = MemoryUtil.UnsafeAlloc(_secret.BufferSize, true); try { //Get the secret from the callback - ERRNO count = _getter(secretBuffer.Span); + ERRNO count = _secret.GetSecret(secretBuffer.Span); //Verify return VnArgon2.Verify2id(password, passHash, secretBuffer.Span[..(int)count]); } finally { //Erase secret buffer - Memory.InitializeBlock(secretBuffer.Span); + MemoryUtil.InitializeBlock(secretBuffer.Span); } } /// @@ -140,7 +128,7 @@ namespace VNLib.Plugins.Essentials.Accounts public bool Verify(ReadOnlySpan hash, ReadOnlySpan salt, ReadOnlySpan password) { //Alloc a buffer with the same size as the hash - using UnsafeMemoryHandle hashBuf = Memory.UnsafeAlloc(hash.Length, true); + using UnsafeMemoryHandle hashBuf = MemoryUtil.UnsafeAlloc(hash.Length, true); //Hash the password with the current config Hash(password, salt, hashBuf.Span); //Compare the hashed password to the specified hash and return results @@ -164,7 +152,7 @@ namespace VNLib.Plugins.Essentials.Accounts public PrivateString Hash(ReadOnlySpan password) { //Alloc shared buffer for the salt and secret buffer - using UnsafeMemoryHandle buffer = Memory.UnsafeAlloc(SaltLen + _secretSize, true); + using UnsafeMemoryHandle buffer = MemoryUtil.UnsafeAlloc(SaltLen + _secret.BufferSize, true); try { //Split buffers @@ -175,14 +163,14 @@ namespace VNLib.Plugins.Essentials.Accounts RandomHash.GetRandomBytes(saltBuf); //recover the secret - ERRNO count = _getter(secretBuf); + ERRNO count = _secret.GetSecret(secretBuf); //Hashes a password, with the current parameters return (PrivateString)VnArgon2.Hash2id(password, saltBuf, secretBuf[..(int)count], TimeCost, MemoryCost, Parallelism, HashLen); } finally { - Memory.InitializeBlock(buffer.Span); + MemoryUtil.InitializeBlock(buffer.Span); } } @@ -194,7 +182,7 @@ namespace VNLib.Plugins.Essentials.Accounts /// A of the hashed and encoded password public PrivateString Hash(ReadOnlySpan password) { - using UnsafeMemoryHandle buffer = Memory.UnsafeAlloc(SaltLen + _secretSize, true); + using UnsafeMemoryHandle buffer = MemoryUtil.UnsafeAlloc(SaltLen + _secret.BufferSize, true); try { //Split buffers @@ -205,14 +193,14 @@ namespace VNLib.Plugins.Essentials.Accounts RandomHash.GetRandomBytes(saltBuf); //recover the secret - ERRNO count = _getter(secretBuf); + ERRNO count = _secret.GetSecret(secretBuf); //Hashes a password, with the current parameters return (PrivateString)VnArgon2.Hash2id(password, saltBuf, secretBuf[..(int)count], TimeCost, MemoryCost, Parallelism, HashLen); } finally { - Memory.InitializeBlock(buffer.Span); + MemoryUtil.InitializeBlock(buffer.Span); } } /// @@ -226,18 +214,18 @@ namespace VNLib.Plugins.Essentials.Accounts public void Hash(ReadOnlySpan password, ReadOnlySpan salt, Span hashOutput) { //alloc secret buffer - using UnsafeMemoryHandle secretBuffer = Memory.UnsafeAlloc(_secretSize, true); + using UnsafeMemoryHandle secretBuffer = MemoryUtil.UnsafeAlloc(_secret.BufferSize, true); try { //Get the secret from the callback - ERRNO count = _getter(secretBuffer.Span); + ERRNO count = _secret.GetSecret(secretBuffer.Span); //Hashes a password, with the current parameters VnArgon2.Hash2id(password, salt, secretBuffer.Span[..(int)count], hashOutput, TimeCost, MemoryCost, Parallelism); } finally { //Erase secret buffer - Memory.InitializeBlock(secretBuffer.Span); + MemoryUtil.InitializeBlock(secretBuffer.Span); } } } diff --git a/lib/Plugins.Essentials/src/Extensions/JsonResponse.cs b/lib/Plugins.Essentials/src/Extensions/JsonResponse.cs index 22cccd9..d087c06 100644 --- a/lib/Plugins.Essentials/src/Extensions/JsonResponse.cs +++ b/lib/Plugins.Essentials/src/Extensions/JsonResponse.cs @@ -49,12 +49,19 @@ namespace VNLib.Plugins.Essentials.Extensions internal JsonResponse(IObjectRental pool) { + /* + * I am breaking the memoryhandle rules by referrencing the same + * memory handle in two different wrappers. + */ + _pool = pool; //Alloc buffer - _handle = Memory.Shared.Alloc(4096, false); + _handle = MemoryUtil.Shared.Alloc(4096, false); + //Consume handle for stream, but make sure not to dispose the stream _asStream = VnMemoryStream.ConsumeHandle(_handle, 0, false); + //Get memory owner from handle _memoryOwner = _handle.ToMemoryManager(false); } diff --git a/lib/Plugins.Essentials/src/HttpEntity.cs b/lib/Plugins.Essentials/src/HttpEntity.cs index ffad607..416b004 100644 --- a/lib/Plugins.Essentials/src/HttpEntity.cs +++ b/lib/Plugins.Essentials/src/HttpEntity.cs @@ -77,6 +77,9 @@ namespace VNLib.Plugins.Essentials IsLocalConnection = entity.Server.LocalEndpoint.Address.IsLocalSubnet(TrustedRemoteIp); //Cache value IsSecure = entity.Server.IsSecure(IsBehindDownStreamServer); + + //Cache current time + RequestedTimeUtc = DateTimeOffset.UtcNow; } /// @@ -100,6 +103,11 @@ namespace VNLib.Plugins.Essentials /// or behind a trusted downstream server that is using tls. /// public readonly bool IsSecure; + /// + /// Caches a that was created when the connection was created. + /// The approximate current UTC time + /// + public readonly DateTimeOffset RequestedTimeUtc; /// /// The connection info object assocated with the entity diff --git a/lib/Plugins.Essentials/src/Sessions/SessionInfo.cs b/lib/Plugins.Essentials/src/Sessions/SessionInfo.cs index 13e2a84..6a974e0 100644 --- a/lib/Plugins.Essentials/src/Sessions/SessionInfo.cs +++ b/lib/Plugins.Essentials/src/Sessions/SessionInfo.cs @@ -106,7 +106,7 @@ namespace VNLib.Plugins.Essentials.Sessions /// public readonly Uri SpecifiedOrigin; /// - /// Privilages associated with user specified during login + /// The time the session was created /// public readonly DateTimeOffset Created; /// diff --git a/lib/Utils/src/Extensions/IoExtensions.cs b/lib/Utils/src/Extensions/IoExtensions.cs index baba7dc..637cfab 100644 --- a/lib/Utils/src/Extensions/IoExtensions.cs +++ b/lib/Utils/src/Extensions/IoExtensions.cs @@ -33,7 +33,7 @@ using System.Runtime.CompilerServices; using VNLib.Utils.IO; using VNLib.Utils.Memory; -using static VNLib.Utils.Memory.Memory; +using static VNLib.Utils.Memory.MemoryUtil; namespace VNLib.Utils.Extensions { diff --git a/lib/Utils/src/Extensions/MemoryExtensions.cs b/lib/Utils/src/Extensions/MemoryExtensions.cs index c8ee5ef..17ad79d 100644 --- a/lib/Utils/src/Extensions/MemoryExtensions.cs +++ b/lib/Utils/src/Extensions/MemoryExtensions.cs @@ -124,7 +124,7 @@ namespace VNLib.Utils.Extensions } /// - /// Allows direct allocation of a fixed size from a instance + /// Allows direct allocation of a fixed size from a instance /// of the specified number of elements /// /// The unmanaged data type @@ -133,13 +133,39 @@ namespace VNLib.Utils.Extensions /// Optionally zeros conents of the block when allocated /// The wrapper around the block of memory [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static MemoryManager DirectAlloc(this IUnmangedHeap heap, ulong size, bool zero = false) where T : unmanaged + public static MemoryManager DirectAlloc(this IUnmangedHeap heap, nuint size, bool zero = false) where T : unmanaged { return new SysBufferMemoryManager(heap, size, zero); } /// - /// Allows direct allocation of a fixed size from a instance + /// Gets the integer length (number of elements) of the + /// + /// + /// + /// + /// The integer length of the handle, or throws if + /// the platform is 64bit and the handle is larger than + /// + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int GetIntLength(this IMemoryHandle handle) => Convert.ToInt32(handle.Length); + + /// + /// Gets the integer length (number of elements) of the + /// + /// The unmanaged type + /// + /// + /// The integer length of the handle, or throws if + /// the platform is 64bit and the handle is larger than + /// + //Method only exists for consistancy since unsafe handles are always 32bit + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int GetIntLength(this in UnsafeMemoryHandle handle) where T: unmanaged => handle.IntLength; + + /// + /// Allows direct allocation of a fixed size from a instance /// of the specified number of elements /// /// The unmanaged data type @@ -147,11 +173,12 @@ namespace VNLib.Utils.Extensions /// The number of elements to allocate on the heap /// Optionally zeros conents of the block when allocated /// The wrapper around the block of memory + /// /// [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static MemoryManager DirectAlloc(this IUnmangedHeap heap, long size, bool zero = false) where T : unmanaged + public static MemoryManager DirectAlloc(this IUnmangedHeap heap, nint size, bool zero = false) where T : unmanaged { - return size < 0 ? throw new ArgumentOutOfRangeException(nameof(size)) : DirectAlloc(heap, (ulong)size, zero); + return size >= 0 ? DirectAlloc(heap, (nuint)size, zero) : throw new ArgumentOutOfRangeException(nameof(size), "The size paramter must be a positive integer"); } /// /// Gets an offset pointer from the base postion to the number of bytes specified. Performs bounds checks @@ -161,11 +188,10 @@ namespace VNLib.Utils.Extensions /// /// /// pointer to the memory offset specified - /// [MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static unsafe T* GetOffset(this MemoryHandle memory, long elements) where T : unmanaged + public static unsafe T* GetOffset(this MemoryHandle memory, nint elements) where T : unmanaged { - return elements < 0 ? throw new ArgumentOutOfRangeException(nameof(elements)) : memory.GetOffset((ulong)elements); + return elements >= 0 ? memory.GetOffset((nuint)elements) : throw new ArgumentOutOfRangeException(nameof(elements), "The elements paramter must be a positive integer"); } /// /// Resizes the current handle on the heap @@ -177,13 +203,13 @@ namespace VNLib.Utils.Extensions /// /// [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static void Resize(this MemoryHandle memory, long elements) where T : unmanaged + public static void Resize(this MemoryHandle memory, nint elements) where T : unmanaged { if (elements < 0) { throw new ArgumentOutOfRangeException(nameof(elements)); } - memory.Resize((ulong)elements); + memory.Resize((nuint)elements); } /// @@ -197,13 +223,13 @@ namespace VNLib.Utils.Extensions /// /// [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static void ResizeIfSmaller(this MemoryHandle handle, long count) where T : unmanaged + public static void ResizeIfSmaller(this MemoryHandle handle, nint count) where T : unmanaged { if(count < 0) { throw new ArgumentOutOfRangeException(nameof(count)); } - ResizeIfSmaller(handle, (ulong)count); + ResizeIfSmaller(handle, (nuint)count); } /// @@ -217,7 +243,7 @@ namespace VNLib.Utils.Extensions /// /// [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static void ResizeIfSmaller(this MemoryHandle handle, ulong count) where T : unmanaged + public static void ResizeIfSmaller(this MemoryHandle handle, nuint count) where T : unmanaged { //Check handle size if(handle.Length < count) @@ -227,7 +253,7 @@ namespace VNLib.Utils.Extensions } } -#if TARGET_64_BIT + /// /// Gets a 64bit friendly span offset for the current /// @@ -238,9 +264,10 @@ namespace VNLib.Utils.Extensions /// The offset span /// [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static unsafe Span GetOffsetSpan(this MemoryHandle block, ulong offset, int size) where T: unmanaged + public static unsafe Span GetOffsetSpan(this MemoryHandle block, nuint offset, int size) where T: unmanaged { _ = block ?? throw new ArgumentNullException(nameof(block)); + if(size < 0) { throw new ArgumentOutOfRangeException(nameof(size)); @@ -249,14 +276,13 @@ namespace VNLib.Utils.Extensions { return Span.Empty; } - //Make sure the offset size is within the size of the block - if(offset + (ulong)size <= block.Length) - { - //Get long offset from the destination handle - void* ofPtr = block.GetOffset(offset); - return new Span(ofPtr, size); - } - throw new ArgumentOutOfRangeException(nameof(size)); + + //Check bounds + MemoryUtil.CheckBounds(block, offset, (nuint)size); + + //Get long offset from the destination handle + void* ofPtr = block.GetOffset(offset); + return new Span(ofPtr, size); } /// /// Gets a 64bit friendly span offset for the current @@ -268,12 +294,11 @@ namespace VNLib.Utils.Extensions /// The offset span /// [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static unsafe Span GetOffsetSpan(this MemoryHandle block, long offset, int size) where T : unmanaged + public static unsafe Span GetOffsetSpan(this MemoryHandle block, nint offset, int size) where T : unmanaged { - return offset < 0 ? throw new ArgumentOutOfRangeException(nameof(offset)) : block.GetOffsetSpan((ulong)offset, size); + return offset >= 0 ? block.GetOffsetSpan((nuint)offset, size) : throw new ArgumentOutOfRangeException(nameof(offset)); } - /// /// Gets a window within the current block /// @@ -283,12 +308,8 @@ namespace VNLib.Utils.Extensions /// The size of the window /// The new within the block [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static SubSequence GetSubSequence(this MemoryHandle block, ulong offset, int size) where T : unmanaged - { - return new SubSequence(block, offset, size); - } -#else - + public static SubSequence GetSubSequence(this MemoryHandle block, nuint offset, int size) where T : unmanaged => new (block, offset, size); + /// /// Gets a window within the current block /// @@ -298,29 +319,11 @@ namespace VNLib.Utils.Extensions /// The size of the window /// The new within the block [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static SubSequence GetSubSequence(this MemoryHandle block, int offset, int size) where T : unmanaged + public static SubSequence GetSubSequence(this MemoryHandle block, nint offset, int size) where T : unmanaged { - return new SubSequence(block, offset, size); + return offset >= 0 ? new (block, (nuint)offset, size) : throw new ArgumentOutOfRangeException(nameof(offset)); } - /// - /// Gets a 64bit friendly span offset for the current - /// - /// - /// - /// The offset (in elements) from the begining of the block - /// The size of the block (in elements) - /// The offset span - /// - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static unsafe Span GetOffsetSpan(this MemoryHandle block, long offset, int size) where T : unmanaged - { - //TODO fix 32bit/64 bit, this is a safe lazy workaround - return block.Span.Slice(checked((int) offset), size); - } -#endif - /// /// Wraps the current instance with a wrapper /// to allow System.Memory buffer rentals. @@ -346,10 +349,11 @@ namespace VNLib.Utils.Extensions public static unsafe T* StructAlloc(this IUnmangedHeap heap) where T : unmanaged { //Allocate the struct on the heap and zero memory it points to - IntPtr handle = heap.Alloc(1, (uint)sizeof(T), true); + IntPtr handle = heap.Alloc(1, (nuint)sizeof(T), true); //returns the handle return (T*)handle; } + /// /// Frees a structure at the specified address from the this heap. /// This must be the same heap the structure was allocated from @@ -366,6 +370,7 @@ namespace VNLib.Utils.Extensions //Clear ref *structPtr = default; } + /// /// Allocates a block of unmanaged memory of the number of elements to store of an unmanged type /// @@ -378,17 +383,18 @@ namespace VNLib.Utils.Extensions /// /// [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static unsafe MemoryHandle Alloc(this IUnmangedHeap heap, ulong elements, bool zero = false) where T : unmanaged + public static unsafe MemoryHandle Alloc(this IUnmangedHeap heap, nuint elements, bool zero = false) where T : unmanaged { //Minimum of one element elements = Math.Max(elements, 1); //Get element size - uint elementSize = (uint)sizeof(T); + nuint elementSize = (nuint)sizeof(T); //If zero flag is set then specify zeroing memory IntPtr block = heap.Alloc(elements, elementSize, zero); //Return handle wrapper return new MemoryHandle(heap, block, elements, zero); } + /// /// Allocates a block of unmanaged memory of the number of elements to store of an unmanged type /// @@ -401,10 +407,11 @@ namespace VNLib.Utils.Extensions /// /// [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static MemoryHandle Alloc(this IUnmangedHeap heap, long elements, bool zero = false) where T : unmanaged + public static MemoryHandle Alloc(this IUnmangedHeap heap, nint elements, bool zero = false) where T : unmanaged { - return elements < 0 ? throw new ArgumentOutOfRangeException(nameof(elements)) : Alloc(heap, (ulong)elements, zero); + return elements >= 0 ? Alloc(heap, (nuint)elements, zero) : throw new ArgumentOutOfRangeException(nameof(elements)); } + /// /// Allocates a buffer from the current heap and initialzies it by copying the initial data buffer /// @@ -417,8 +424,12 @@ namespace VNLib.Utils.Extensions [MethodImpl(MethodImplOptions.AggressiveInlining)] public static MemoryHandle AllocAndCopy(this IUnmangedHeap heap, ReadOnlySpan initialData) where T:unmanaged { + //Aloc block MemoryHandle handle = heap.Alloc(initialData.Length); - Memory.Copy(initialData, handle, 0); + + //Copy initial data + MemoryUtil.Copy(initialData, handle, 0); + return handle; } @@ -435,12 +446,13 @@ namespace VNLib.Utils.Extensions public static void WriteAndResize(this MemoryHandle handle, ReadOnlySpan input) where T: unmanaged { handle.Resize(input.Length); - Memory.Copy(input, handle, 0); + MemoryUtil.Copy(input, handle, 0); } /// /// Allocates a block of unamanged memory of the number of elements of an unmanaged type, and - /// returns the that must be used cautiously + /// returns the that must be used cautiously. + /// If elements is less than 1 an empty handle is returned /// /// The unamanged value type /// The heap to allocate block from @@ -455,14 +467,16 @@ namespace VNLib.Utils.Extensions { if (elements < 1) { - throw new ArgumentException("Elements must be greater than 0", nameof(elements)); + //Return an empty handle + return new UnsafeMemoryHandle(); } - //Minimum of one element - elements = Math.Max(elements, 1); + //Get element size - uint elementSize = (uint)sizeof(T); - //If zero flag is set then specify zeroing memory - IntPtr block = heap.Alloc((uint)elements, elementSize, zero); + nuint elementSize = (nuint)sizeof(T); + + //If zero flag is set then specify zeroing memory (safe case because of the above check) + IntPtr block = heap.Alloc((nuint)elements, elementSize, zero); + //handle wrapper return new (heap, block, elements); } @@ -560,8 +574,6 @@ namespace VNLib.Utils.Extensions buffer.Advance(charsWritten); } - - /// /// Encodes a set of characters in the input characters span and any characters /// in the internal buffer into a sequence of bytes that are stored in the input @@ -644,11 +656,13 @@ namespace VNLib.Utils.Extensions /// Converts the buffer data to a /// /// A instance that owns the underlying string memory + [MethodImpl(MethodImplOptions.AggressiveInlining)] public static PrivateString ToPrivate(this ref ForwardOnlyWriter buffer) => new(buffer.ToString(), true); /// /// Gets a over the modified section of the internal buffer /// /// A over the modified data + [MethodImpl(MethodImplOptions.AggressiveInlining)] public static Span AsSpan(this ref ForwardOnlyWriter buffer) => buffer.Buffer[..buffer.Written]; diff --git a/lib/Utils/src/Extensions/VnStringExtensions.cs b/lib/Utils/src/Extensions/VnStringExtensions.cs index 285fc4f..329c7a6 100644 --- a/lib/Utils/src/Extensions/VnStringExtensions.cs +++ b/lib/Utils/src/Extensions/VnStringExtensions.cs @@ -25,13 +25,17 @@ using System; using System.Linq; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using VNLib.Utils.Memory; +using System.Runtime.CompilerServices; + +#pragma warning disable CA1062 // Validate arguments of public methods namespace VNLib.Utils.Extensions { - [SuppressMessage("Design", "CA1062:Validate arguments of public methods", Justification = "")] + /// + /// A collection of extensions for + /// public static class VnStringExtensions { /// @@ -41,7 +45,9 @@ namespace VNLib.Utils.Extensions /// The value to find /// True if the character exists within the instance /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] public static bool Contains(this VnString str, char value) => str.AsSpan().Contains(value); + /// /// Derermines if the sequence exists within the instance /// @@ -50,9 +56,10 @@ namespace VNLib.Utils.Extensions /// /// True if the character exists within the instance /// - + [MethodImpl(MethodImplOptions.AggressiveInlining)] public static bool Contains(this VnString str, ReadOnlySpan value, StringComparison stringComparison) => str.AsSpan().Contains(value, stringComparison); + /// /// Searches for the first occurrance of the specified character within the current instance /// @@ -60,7 +67,9 @@ namespace VNLib.Utils.Extensions /// The character to search for within the instance /// The 0 based index of the occurance, -1 if the character was not found /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] public static int IndexOf(this VnString str, char value) => str.IsEmpty ? -1 : str.AsSpan().IndexOf(value); + /// /// Searches for the first occurrance of the specified sequence within the current instance /// @@ -68,12 +77,9 @@ namespace VNLib.Utils.Extensions /// The sequence to search for /// The 0 based index of the occurance, -1 if the sequence was not found /// - public static int IndexOf(this VnString str, ReadOnlySpan search) - { - //Using spans to avoid memory leaks... - ReadOnlySpan self = str.AsSpan(); - return self.IndexOf(search); - } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int IndexOf(this VnString str, ReadOnlySpan search) => str.AsSpan().IndexOf(search); + /// /// Searches for the first occurrance of the specified sequence within the current instance /// @@ -82,12 +88,9 @@ namespace VNLib.Utils.Extensions /// The type to use in searchr /// The 0 based index of the occurance, -1 if the sequence was not found /// - public static int IndexOf(this VnString str, ReadOnlySpan search, StringComparison comparison) - { - //Using spans to avoid memory leaks... - ReadOnlySpan self = str.AsSpan(); - return self.IndexOf(search, comparison); - } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int IndexOf(this VnString str, ReadOnlySpan search, StringComparison comparison) => str.AsSpan().IndexOf(search, comparison); + /// /// Searches for the 0 based index of the first occurance of the search parameter after the start index. /// @@ -136,11 +139,13 @@ namespace VNLib.Utils.Extensions /// The trimmed instance as a child of the original entry /// /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] public static VnString AbsoluteTrim(this VnString data, int start, int end) { AbsoluteTrim(data, ref start, ref end); return data[start..end]; } + /// /// Finds whitespace characters within the sequence defined between start and end parameters /// and adjusts the specified window to "trim" whitespace @@ -175,6 +180,7 @@ namespace VNLib.Utils.Extensions end--; } } + /// /// Allows for trimming whitespace characters in a realtive sequence from /// within a buffer and returning the trimmed entry. @@ -184,13 +190,16 @@ namespace VNLib.Utils.Extensions /// The trimmed instance as a child of the original entry /// /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] public static VnString AbsoluteTrim(this VnString data, int start) => AbsoluteTrim(data, start, data.Length); + /// /// Trims leading or trailing whitespace characters and returns a new child instance /// without leading or trailing whitespace /// /// A child of the current instance without leading or trailing whitespaced /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] public static VnString RelativeTirm(this VnString data) => AbsoluteTrim(data, 0); /// @@ -348,19 +357,6 @@ namespace VNLib.Utils.Extensions return data[start..end]; } - /// - /// Unoptimized character enumerator. You should use to enumerate the unerlying data. - /// - /// The next character in the sequence - /// - public static IEnumerator GetEnumerator(this VnString data) - { - int index = 0; - while (index < data.Length) - { - yield return data[index++]; - } - } /// /// Converts the current handle to a , a zero-alloc immutable wrapper /// for a memory handle @@ -370,14 +366,9 @@ namespace VNLib.Utils.Extensions /// The new wrapper /// /// - public static VnString ToVnString(this MemoryHandle handle, int length) - { - if(handle.Length > int.MaxValue) - { - throw new OverflowException("The handle is larger than 2GB in size"); - } - return VnString.ConsumeHandle(handle, 0, length); - } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static VnString ToVnString(this MemoryHandle handle, int length) => VnString.ConsumeHandle(handle, 0, length); + /// /// Converts the current handle to a , a zero-alloc immutable wrapper /// for a memory handle @@ -386,10 +377,9 @@ namespace VNLib.Utils.Extensions /// The new wrapper /// /// - public static VnString ToVnString(this MemoryHandle handle) - { - return VnString.ConsumeHandle(handle, 0, handle.IntLength); - } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static VnString ToVnString(this MemoryHandle handle) => VnString.ConsumeHandle(handle, 0, handle.GetIntLength()); + /// /// Converts the current handle to a , a zero-alloc immutable wrapper /// for a memory handle @@ -398,21 +388,8 @@ namespace VNLib.Utils.Extensions /// The offset in characters that represents the begining of the string /// The number of characters from the handle to reference (length of the string) /// The new wrapper - /// /// - public static VnString ToVnString(this MemoryHandle handle, -#if TARGET_64_BIT - ulong offset, -#else - int offset, -#endif - int length) - { - if (handle.Length > int.MaxValue) - { - throw new OverflowException("The handle is larger than 2GB in size"); - } - return VnString.ConsumeHandle(handle, offset, length); - } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static VnString ToVnString(this MemoryHandle handle, nuint offset, int length) => VnString.ConsumeHandle(handle, offset, length); } } \ No newline at end of file diff --git a/lib/Utils/src/IO/InMemoryTemplate.cs b/lib/Utils/src/IO/InMemoryTemplate.cs index ae8bf79..5a4a799 100644 --- a/lib/Utils/src/IO/InMemoryTemplate.cs +++ b/lib/Utils/src/IO/InMemoryTemplate.cs @@ -165,7 +165,7 @@ namespace VNLib.Utils.IO try { //Copy async - await fs.CopyToAsync(newBuf, 8192, Memory.Memory.Shared, cancellationToken); + await fs.CopyToAsync(newBuf, 8192, Memory.MemoryUtil.Shared, cancellationToken); } catch { diff --git a/lib/Utils/src/IO/VnMemoryStream.cs b/lib/Utils/src/IO/VnMemoryStream.cs index 4e8a2b3..e984cc1 100644 --- a/lib/Utils/src/IO/VnMemoryStream.cs +++ b/lib/Utils/src/IO/VnMemoryStream.cs @@ -28,24 +28,23 @@ using System.Threading; using System.Threading.Tasks; using System.Runtime.InteropServices; +using VNLib.Utils.Memory; using VNLib.Utils.Extensions; namespace VNLib.Utils.IO { - - using Utils.Memory; - /// /// Provides an unmanaged memory stream. Desigend to help reduce garbage collector load for /// high frequency memory operations. Similar to /// public sealed class VnMemoryStream : Stream, ICloneable { - private long _position; - private long _length; + private nint _position; + private nint _length; + private bool _isReadonly; + //Memory private readonly MemoryHandle _buffer; - private bool IsReadonly; //Default owns handle private readonly bool OwnsHandle = true; @@ -57,7 +56,7 @@ namespace VNLib.Utils.IO /// Should the stream be readonly? /// /// A wrapper to access the handle data - public static VnMemoryStream ConsumeHandle(MemoryHandle handle, Int64 length, bool readOnly) + public static VnMemoryStream ConsumeHandle(MemoryHandle handle, nint length, bool readOnly) { handle.ThrowIfClosed(); return new VnMemoryStream(handle, length, readOnly, true); @@ -71,7 +70,7 @@ namespace VNLib.Utils.IO public static VnMemoryStream CreateReadonly(VnMemoryStream stream) { //Set the readonly flag - stream.IsReadonly = true; + stream._isReadonly = true; //Return the stream return stream; } @@ -79,11 +78,12 @@ namespace VNLib.Utils.IO /// /// Creates a new memory stream /// - public VnMemoryStream() : this(Memory.Shared) { } + public VnMemoryStream() : this(MemoryUtil.Shared) { } + /// /// Create a new memory stream where buffers will be allocated from the specified heap /// - /// to allocate memory from + /// to allocate memory from /// /// public VnMemoryStream(IUnmangedHeap heap) : this(heap, 0, false) { } @@ -92,13 +92,13 @@ namespace VNLib.Utils.IO /// Creates a new memory stream and pre-allocates the internal /// buffer of the specified size on the specified heap to avoid resizing. /// - /// to allocate memory from + /// to allocate memory from /// Number of bytes (length) of the stream if known /// Zero memory allocations during buffer expansions /// /// /// - public VnMemoryStream(IUnmangedHeap heap, long bufferSize, bool zero) + public VnMemoryStream(IUnmangedHeap heap, nuint bufferSize, bool zero) { _ = heap ?? throw new ArgumentNullException(nameof(heap)); _buffer = heap.Alloc(bufferSize, zero); @@ -107,7 +107,7 @@ namespace VNLib.Utils.IO /// /// Creates a new memory stream from the data provided /// - /// to allocate memory from + /// to allocate memory from /// Initial data public VnMemoryStream(IUnmangedHeap heap, ReadOnlySpan data) { @@ -116,8 +116,7 @@ namespace VNLib.Utils.IO _buffer = heap.AllocAndCopy(data); //Set length _length = data.Length; - //Position will default to 0 cuz its dotnet :P - return; + _position = 0; } /// @@ -127,18 +126,18 @@ namespace VNLib.Utils.IO /// The length property of the stream /// Is the stream readonly (should mostly be true!) /// Does the new stream own the memory -> - private VnMemoryStream(MemoryHandle buffer, long length, bool readOnly, bool ownsHandle) + private VnMemoryStream(MemoryHandle buffer, nint length, bool readOnly, bool ownsHandle) { OwnsHandle = ownsHandle; _buffer = buffer; //Consume the handle _length = length; //Store length of the buffer - IsReadonly = readOnly; + _isReadonly = readOnly; } /// /// UNSAFE Number of bytes between position and length. Never negative /// - private long LenToPosDiff => Math.Max(_length - _position, 0); + private nint LenToPosDiff => Math.Max(_length - _position, 0); /// /// If the current stream is a readonly stream, creates an unsafe shallow copy for reading only. @@ -148,7 +147,7 @@ namespace VNLib.Utils.IO public VnMemoryStream GetReadonlyShallowCopy() { //Create a new readonly copy (stream does not own the handle) - return !IsReadonly + return !_isReadonly ? throw new NotSupportedException("This stream is not readonly. Cannot create shallow copy on a mutable stream") : new VnMemoryStream(_buffer, _length, true, false); } @@ -163,6 +162,10 @@ namespace VNLib.Utils.IO public override void CopyTo(Stream destination, int bufferSize) { _ = destination ?? throw new ArgumentNullException(nameof(destination)); + if(bufferSize < 1) + { + throw new ArgumentOutOfRangeException(nameof(bufferSize), "Buffer size must be greater than 0"); + } if (!destination.CanWrite) { @@ -250,7 +253,7 @@ namespace VNLib.Utils.IO /// True unless the stream is (or has been converted to) a readonly /// stream. /// - public override bool CanWrite => !IsReadonly; + public override bool CanWrite => !_isReadonly; /// public override long Length => _length; /// @@ -279,7 +282,7 @@ namespace VNLib.Utils.IO /// public override Task FlushAsync(CancellationToken cancellationToken) => Task.CompletedTask; /// - public override int Read(byte[] buffer, int offset, int count) => Read(new Span(buffer, offset, count)); + public override int Read(byte[] buffer, int offset, int count) => Read(buffer.AsSpan(offset, count)); /// public override int Read(Span buffer) { @@ -288,12 +291,14 @@ namespace VNLib.Utils.IO return 0; } //Number of bytes to read from memory buffer - int bytesToRead = checked((int)Math.Min(LenToPosDiff, buffer.Length)); + int bytesToRead = (int)Math.Min(LenToPosDiff, buffer.Length); + //Copy bytes to buffer - Memory.Copy(_buffer, _position, buffer, 0, bytesToRead); + MemoryUtil.Copy(_buffer, _position, buffer, 0, bytesToRead); + //Increment buffer position _position += bytesToRead; - //Bytestoread should never be larger than int.max because span length is an integer + return bytesToRead; } @@ -322,22 +327,27 @@ namespace VNLib.Utils.IO { throw new ArgumentOutOfRangeException(nameof(offset), "Offset cannot be less than 0"); } + if(offset > nint.MaxValue) + { + throw new ArgumentOutOfRangeException(nameof(offset), "Offset cannot be less than nint.MaxValue"); + } + + //safe cast to nint + nint _offset = (nint)offset; + switch (origin) { case SeekOrigin.Begin: - //Length will never be greater than int.Max so output will never exceed int.max - _position = Math.Min(_length, offset); - return _position; + //Length will never be greater than nint.Max so output will never exceed nint.max + return _position = Math.Min(_length, _offset); case SeekOrigin.Current: - long newPos = _position + offset; - //Length will never be greater than int.Max so output will never exceed length - _position = Math.Min(_length, newPos); - return newPos; + //Calc new seek position from current position + nint newPos = _position + _offset; + return _position = Math.Min(_length, newPos); case SeekOrigin.End: - long real_index = _length - offset; - //If offset moves the position negative, just set the position to 0 and continue - _position = Math.Min(real_index, 0); - return real_index; + //Calc new seek position from end of stream, should be len -1 so 0 can be specified from the end + nint realIndex = _length - (_offset - 1); + return _position = Math.Min(realIndex, 0); default: throw new ArgumentException("Stream operation is not supported on current stream"); } @@ -356,7 +366,7 @@ namespace VNLib.Utils.IO /// public override void SetLength(long value) { - if (IsReadonly) + if (_isReadonly) { throw new NotSupportedException("This stream is readonly"); } @@ -364,25 +374,33 @@ namespace VNLib.Utils.IO { throw new ArgumentOutOfRangeException(nameof(value), "Value cannot be less than 0"); } + if(value > nint.MaxValue) + { + throw new ArgumentOutOfRangeException(nameof(value), "Value cannot be greater than nint.MaxValue"); + } + + nint _value = (nint)value; + //Resize the buffer to the specified length - _buffer.Resize(value); + _buffer.Resize(_value); + //Set length - _length = value; - //Make sure the position is not pointing outside of the buffer + _length = _value; + + //Make sure the position is not pointing outside of the buffer after resize _position = Math.Min(_position, _length); - return; } /// - public override void Write(byte[] buffer, int offset, int count) => Write(new ReadOnlySpan(buffer, offset, count)); + public override void Write(byte[] buffer, int offset, int count) => Write(buffer.AsSpan(offset, count)); /// public override void Write(ReadOnlySpan buffer) { - if (IsReadonly) + if (_isReadonly) { throw new NotSupportedException("Write operation is not allowed on readonly stream!"); } //Calculate the new final position - long newPos = (_position + buffer.Length); + nint newPos = (_position + buffer.Length); //Determine if the buffer needs to be expanded if (buffer.Length > LenToPosDiff) { @@ -392,10 +410,9 @@ namespace VNLib.Utils.IO _length = newPos; } //Copy the input buffer to the internal buffer - Memory.Copy(buffer, _buffer, _position); + MemoryUtil.Copy(buffer, _buffer, (nuint)_position); //Update the position _position = newPos; - return; } /// public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) @@ -423,16 +440,17 @@ namespace VNLib.Utils.IO /// /// Copy of internal buffer /// - /// public byte[] ToArray() { - //Alloc a new array of the size of the internal buffer + //Alloc a new array of the size of the internal buffer, may be 64 bit large block byte[] data = new byte[_length]; - //Copy data from the internal buffer to the output buffer - _buffer.Span.CopyTo(data); + + //Copy the internal buffer to the new array + MemoryUtil.Copy(_buffer, 0, data, 0, (nuint)_length); + return data; - } + /// /// Returns a window over the data within the entire stream /// @@ -440,8 +458,10 @@ namespace VNLib.Utils.IO /// public ReadOnlySpan AsSpan() { - ReadOnlySpan output = _buffer.Span; - return output[..(int)_length]; + //Get 32bit length or throw + int len = Convert.ToInt32(_length); + //Get span with no offset + return _buffer.AsSpan(0, len); } /// diff --git a/lib/Utils/src/Memory/IMemoryHandle.cs b/lib/Utils/src/Memory/IMemoryHandle.cs index 75d1cce..cf19ce9 100644 --- a/lib/Utils/src/Memory/IMemoryHandle.cs +++ b/lib/Utils/src/Memory/IMemoryHandle.cs @@ -33,16 +33,10 @@ namespace VNLib.Utils.Memory /// The type this handle represents public interface IMemoryHandle : IDisposable, IPinnable { - /// - /// The size of the block as an integer - /// - /// - int IntLength { get; } - /// /// The number of elements in the block /// - ulong Length { get; } + nuint Length { get; } /// /// Gets the internal block as a span diff --git a/lib/Utils/src/Memory/IUnmangedHeap.cs b/lib/Utils/src/Memory/IUnmangedHeap.cs index 5d8f4bf..94f34c8 100644 --- a/lib/Utils/src/Memory/IUnmangedHeap.cs +++ b/lib/Utils/src/Memory/IUnmangedHeap.cs @@ -38,7 +38,7 @@ namespace VNLib.Utils.Memory /// The number of elements to allocate /// An optional parameter to zero the block of memory /// - IntPtr Alloc(UInt64 elements, UInt64 size, bool zero); + IntPtr Alloc(nuint elements, nuint size, bool zero); /// /// Resizes the allocated block of memory to the new size @@ -47,7 +47,7 @@ namespace VNLib.Utils.Memory /// The new number of elements /// The size (in bytes) of the type /// An optional parameter to zero the block of memory - void Resize(ref IntPtr block, UInt64 elements, UInt64 size, bool zero); + void Resize(ref IntPtr block, nuint elements, nuint size, bool zero); /// /// Free's a previously allocated block of memory diff --git a/lib/Utils/src/Memory/Memory.cs b/lib/Utils/src/Memory/Memory.cs deleted file mode 100644 index e04c386..0000000 --- a/lib/Utils/src/Memory/Memory.cs +++ /dev/null @@ -1,456 +0,0 @@ -/* -* Copyright (c) 2022 Vaughn Nugent -* -* Library: VNLib -* Package: VNLib.Utils -* File: Memory.cs -* -* Memory.cs is part of VNLib.Utils which is part of the larger -* VNLib collection of libraries and utilities. -* -* VNLib.Utils is free software: you can redistribute it and/or modify -* it under the terms of the GNU General Public License as published -* by the Free Software Foundation, either version 2 of the License, -* or (at your option) any later version. -* -* VNLib.Utils is distributed in the hope that it will be useful, -* but WITHOUT ANY WARRANTY; without even the implied warranty of -* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -* General Public License for more details. -* -* You should have received a copy of the GNU General Public License -* along with VNLib.Utils. If not, see http://www.gnu.org/licenses/. -*/ - -using System; -using System.IO; -using System.Buffers; -using System.Security; -using System.Threading; -using System.Runtime.InteropServices; -using System.Runtime.CompilerServices; - -using VNLib.Utils.Extensions; - -namespace VNLib.Utils.Memory -{ - /// - /// Provides optimized cross-platform maanged/umanaged safe/unsafe memory operations - /// - [SecurityCritical] - [ComVisible(false)] - public unsafe static class Memory - { - public const string SHARED_HEAP_TYPE_ENV= "VNLIB_SHARED_HEAP_TYPE"; - public const string SHARED_HEAP_INTIAL_SIZE_ENV = "VNLIB_SHARED_HEAP_SIZE"; - - /// - /// Initial shared heap size (bytes) - /// - public const ulong SHARED_HEAP_INIT_SIZE = 20971520; - - public const int MAX_BUF_SIZE = 2097152; - public const int MIN_BUF_SIZE = 16000; - - /// - /// The maximum buffer size requested by - /// that will use the array pool before falling back to the . - /// heap. - /// - public const int MAX_UNSAFE_POOL_SIZE = 500 * 1024; - - /// - /// Provides a shared heap instance for the process to allocate memory from. - /// - /// - /// The backing heap - /// is determined by the OS type and process environment varibles. - /// - public static IUnmangedHeap Shared => _sharedHeap.Value; - - private static readonly Lazy _sharedHeap; - - static Memory() - { - _sharedHeap = new Lazy(() => InitHeapInternal(true), LazyThreadSafetyMode.PublicationOnly); - //Cleanup the heap on process exit - AppDomain.CurrentDomain.DomainUnload += DomainUnloaded; - } - - private static void DomainUnloaded(object? sender, EventArgs e) - { - //Dispose the heap if allocated - if (_sharedHeap.IsValueCreated) - { - _sharedHeap.Value.Dispose(); - } - } - - /// - /// Initializes a new determined by compilation/runtime flags - /// and operating system type for the current proccess. - /// - /// An for the current process - /// - /// - public static IUnmangedHeap InitializeNewHeapForProcess() => InitHeapInternal(false); - - private static IUnmangedHeap InitHeapInternal(bool isShared) - { - bool IsWindows = OperatingSystem.IsWindows(); - //Get environment varable - string heapType = Environment.GetEnvironmentVariable(SHARED_HEAP_TYPE_ENV); - //Get inital size - string sharedSize = Environment.GetEnvironmentVariable(SHARED_HEAP_INTIAL_SIZE_ENV); - //Try to parse the shared size from the env - if (!ulong.TryParse(sharedSize, out ulong defaultSize)) - { - defaultSize = SHARED_HEAP_INIT_SIZE; - } - //Gen the private heap from its type or default - switch (heapType) - { - case "win32": - if (!IsWindows) - { - throw new PlatformNotSupportedException("Win32 private heaps are not supported on non-windows platforms"); - } - return PrivateHeap.Create(defaultSize); - case "rpmalloc": - //If the shared heap is being allocated, then return a lock free global heap - return isShared ? RpMallocPrivateHeap.GlobalHeap : new RpMallocPrivateHeap(false); - default: - return IsWindows ? PrivateHeap.Create(defaultSize) : new ProcessHeap(); - } - } - - /// - /// Gets a value that indicates if the Rpmalloc native library is loaded - /// - public static bool IsRpMallocLoaded { get; } = Environment.GetEnvironmentVariable(SHARED_HEAP_TYPE_ENV) == "rpmalloc"; - - #region Zero - /// - /// Zeros a block of memory of umanged type. If Windows is detected at runtime, calls RtlSecureZeroMemory Win32 function - /// - /// Unmanged datatype - /// Block of memory to be cleared - [MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.NoOptimization)] - public static void UnsafeZeroMemory(ReadOnlySpan block) where T : unmanaged - { - if (!block.IsEmpty) - { - checked - { - fixed (void* ptr = &MemoryMarshal.GetReference(block)) - { - //Calls memset - Unsafe.InitBlock(ptr, 0, (uint)(block.Length * sizeof(T))); - } - } - } - } - /// - /// Zeros a block of memory of umanged type. If Windows is detected at runtime, calls RtlSecureZeroMemory Win32 function - /// - /// Unmanged datatype - /// Block of memory to be cleared - [MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.NoOptimization)] - public static void UnsafeZeroMemory(ReadOnlyMemory block) where T : unmanaged - { - if (!block.IsEmpty) - { - checked - { - //Pin memory and get pointer - using MemoryHandle handle = block.Pin(); - //Calls memset - Unsafe.InitBlock(handle.Pointer, 0, (uint)(block.Length * sizeof(T))); - } - } - } - - /// - /// Initializes a block of memory with zeros - /// - /// The unmanaged - /// The block of memory to initialize - public static void InitializeBlock(Span block) where T : unmanaged => UnsafeZeroMemory(block); - /// - /// Initializes a block of memory with zeros - /// - /// The unmanaged - /// The block of memory to initialize - public static void InitializeBlock(Memory block) where T : unmanaged => UnsafeZeroMemory(block); - - /// - /// Zeroes a block of memory pointing to the structure - /// - /// The structure type - /// The pointer to the allocated structure - public static void ZeroStruct(IntPtr block) - { - //get thes size of the structure - int size = Unsafe.SizeOf(); - //Zero block - Unsafe.InitBlock(block.ToPointer(), 0, (uint)size); - } - /// - /// Zeroes a block of memory pointing to the structure - /// - /// The structure type - /// The pointer to the allocated structure - public static void ZeroStruct(void* structPtr) where T: unmanaged - { - //get thes size of the structure - int size = Unsafe.SizeOf(); - //Zero block - Unsafe.InitBlock(structPtr, 0, (uint)size); - } - /// - /// Zeroes a block of memory pointing to the structure - /// - /// The structure type - /// The pointer to the allocated structure - public static void ZeroStruct(T* structPtr) where T : unmanaged - { - //get thes size of the structure - int size = Unsafe.SizeOf(); - //Zero block - Unsafe.InitBlock(structPtr, 0, (uint)size); - } - - #endregion - - #region Copy - /// - /// Copies data from source memory to destination memory of an umanged data type - /// - /// Unmanged type - /// Source data - /// Destination - /// Dest offset - /// - public static void Copy(ReadOnlySpan source, MemoryHandle dest, Int64 destOffset) where T : unmanaged - { - if (source.IsEmpty) - { - return; - } - if (dest.Length < (ulong)(destOffset + source.Length)) - { - throw new ArgumentException("Source data is larger than the dest data block", nameof(source)); - } - //Get long offset from the destination handle - T* offset = dest.GetOffset(destOffset); - fixed(void* src = &MemoryMarshal.GetReference(source)) - { - int byteCount = checked(source.Length * sizeof(T)); - Unsafe.CopyBlock(offset, src, (uint)byteCount); - } - } - /// - /// Copies data from source memory to destination memory of an umanged data type - /// - /// Unmanged type - /// Source data - /// Destination - /// Dest offset - /// - public static void Copy(ReadOnlyMemory source, MemoryHandle dest, Int64 destOffset) where T : unmanaged - { - if (source.IsEmpty) - { - return; - } - if (dest.Length < (ulong)(destOffset + source.Length)) - { - throw new ArgumentException("Dest constraints are larger than the dest data block", nameof(source)); - } - //Get long offset from the destination handle - T* offset = dest.GetOffset(destOffset); - //Pin the source memory - using MemoryHandle srcHandle = source.Pin(); - int byteCount = checked(source.Length * sizeof(T)); - //Copy block using unsafe class - Unsafe.CopyBlock(offset, srcHandle.Pointer, (uint)byteCount); - } - /// - /// Copies data from source memory to destination memory of an umanged data type - /// - /// Unmanged type - /// Source data - /// Number of elements to offset source data - /// Destination - /// Dest offset - /// Number of elements to copy - /// - public static void Copy(MemoryHandle source, Int64 sourceOffset, Span dest, int destOffset, int count) where T : unmanaged - { - if (count <= 0) - { - return; - } - if (source.Length < (ulong)(sourceOffset + count)) - { - throw new ArgumentException("Source constraints are larger than the source data block", nameof(count)); - } - if (dest.Length < destOffset + count) - { - throw new ArgumentOutOfRangeException(nameof(destOffset), "Destination offset range cannot exceed the size of the destination buffer"); - } - //Get offset to allow large blocks of memory - T* src = source.GetOffset(sourceOffset); - fixed(T* dst = &MemoryMarshal.GetReference(dest)) - { - //Cacl offset - T* dstoffset = dst + destOffset; - int byteCount = checked(count * sizeof(T)); - //Aligned copy - Unsafe.CopyBlock(dstoffset, src, (uint)byteCount); - } - } - /// - /// Copies data from source memory to destination memory of an umanged data type - /// - /// Unmanged type - /// Source data - /// Number of elements to offset source data - /// Destination - /// Dest offset - /// Number of elements to copy - /// - public static void Copy(MemoryHandle source, Int64 sourceOffset, Memory dest, int destOffset, int count) where T : unmanaged - { - if (count == 0) - { - return; - } - if (source.Length < (ulong)(sourceOffset + count)) - { - throw new ArgumentException("Source constraints are larger than the source data block", nameof(count)); - } - if(dest.Length < destOffset + count) - { - throw new ArgumentOutOfRangeException(nameof(destOffset), "Destination offset range cannot exceed the size of the destination buffer"); - } - //Get offset to allow large blocks of memory - T* src = source.GetOffset(sourceOffset); - //Pin the memory handle - using MemoryHandle handle = dest.Pin(); - //Byte count - int byteCount = checked(count * sizeof(T)); - //Dest offset - T* dst = ((T*)handle.Pointer) + destOffset; - //Aligned copy - Unsafe.CopyBlock(dst, src, (uint)byteCount); - } - #endregion - - #region Streams - /// - /// Copies data from one stream to another in specified blocks - /// - /// Source memory - /// Source offset - /// Destination memory - /// Destination offset - /// Number of elements to copy - public static void Copy(Stream source, Int64 srcOffset, Stream dest, Int64 destOffst, Int64 count) - { - if (count == 0) - { - return; - } - if (count < 0) - { - throw new ArgumentException("Count must be a positive integer", nameof(count)); - } - //Seek streams - _ = source.Seek(srcOffset, SeekOrigin.Begin); - _ = dest.Seek(destOffst, SeekOrigin.Begin); - //Create new buffer - using IMemoryHandle buffer = Shared.Alloc(count); - Span buf = buffer.Span; - int total = 0; - do - { - //read from source - int read = source.Read(buf); - //guard - if (read == 0) - { - break; - } - //write read slice to dest - dest.Write(buf[..read]); - //update total read - total += read; - } while (total < count); - } - #endregion - - #region alloc - - /// - /// Allocates a block of unmanaged, or pooled manaaged memory depending on - /// compilation flags and runtime unamanged allocators. - /// - /// The unamanged type to allocate - /// The number of elements of the type within the block - /// Flag to zero elements during allocation before the method returns - /// A handle to the block of memory - /// - /// - public static UnsafeMemoryHandle UnsafeAlloc(int elements, bool zero = false) where T : unmanaged - { - if (elements < 0) - { - throw new ArgumentException("Number of elements must be a positive integer", nameof(elements)); - } - if(elements > MAX_UNSAFE_POOL_SIZE || IsRpMallocLoaded) - { - // Alloc from heap - IntPtr block = Shared.Alloc((uint)elements, (uint)sizeof(T), zero); - //Init new handle - return new(Shared, block, elements); - } - else - { - return new(ArrayPool.Shared, elements, zero); - } - } - - /// - /// Allocates a block of unmanaged, or pooled manaaged memory depending on - /// compilation flags and runtime unamanged allocators. - /// - /// The unamanged type to allocate - /// The number of elements of the type within the block - /// Flag to zero elements during allocation before the method returns - /// A handle to the block of memory - /// - /// - public static IMemoryHandle SafeAlloc(int elements, bool zero = false) where T: unmanaged - { - if (elements < 0) - { - throw new ArgumentException("Number of elements must be a positive integer", nameof(elements)); - } - - //If the element count is larger than max pool size, alloc from shared heap - if (elements > MAX_UNSAFE_POOL_SIZE) - { - //Alloc from shared heap - return Shared.Alloc(elements, zero); - } - else - { - //Get temp buffer from shared buffer pool - return new VnTempBuffer(elements, zero); - } - } - - #endregion - } -} \ No newline at end of file diff --git a/lib/Utils/src/Memory/MemoryHandle.cs b/lib/Utils/src/Memory/MemoryHandle.cs index a09edea..df2792b 100644 --- a/lib/Utils/src/Memory/MemoryHandle.cs +++ b/lib/Utils/src/Memory/MemoryHandle.cs @@ -34,7 +34,7 @@ using VNLib.Utils.Extensions; namespace VNLib.Utils.Memory { /// - /// Provides a wrapper for using umanged memory handles from an assigned for types + /// Provides a wrapper for using umanged memory handles from an assigned for types /// /// /// Handles are configured to address blocks larger than 2GB, @@ -72,31 +72,21 @@ namespace VNLib.Utils.Memory get { this.ThrowIfClosed(); - return _length == 0 ? Span.Empty : new Span(Base, IntLength); + int len = Convert.ToInt32(_length); + return _length == 0 ? Span.Empty : new Span(Base, len); } } private readonly bool ZeroMemory; private readonly IUnmangedHeap Heap; - private ulong _length; + private nuint _length; - /// - /// Number of elements allocated to the current instance - /// - public ulong Length + /// + public nuint Length { [MethodImpl(MethodImplOptions.AggressiveInlining)] get => _length; } - /// - /// Number of elements in the memory block casted to an integer - /// - /// - public int IntLength - { - [MethodImpl(MethodImplOptions.AggressiveInlining)] - get => checked((int)_length); - } /// /// Number of bytes allocated to the current instance @@ -106,7 +96,7 @@ namespace VNLib.Utils.Memory { //Check for overflows when converting to bytes (should run out of memory before this is an issue, but just incase) [MethodImpl(MethodImplOptions.AggressiveInlining)] - get => checked(_length * (UInt64)sizeof(T)); + get => MemoryUtil.ByteCount(_length); } /// @@ -116,7 +106,7 @@ namespace VNLib.Utils.Memory /// Number of elements to allocate /// Zero all memory during allocations from heap /// The initial block of allocated memory to wrap - internal MemoryHandle(IUnmangedHeap heap, IntPtr initial, ulong elements, bool zero) : base(true) + internal MemoryHandle(IUnmangedHeap heap, IntPtr initial, nuint elements, bool zero) : base(true) { //Set element size (always allocate at least 1 object) _length = elements; @@ -133,7 +123,7 @@ namespace VNLib.Utils.Memory /// /// /// - public unsafe void Resize(ulong elements) + public unsafe void Resize(nuint elements) { this.ThrowIfClosed(); //Update size (should never be less than inital size) @@ -141,7 +131,7 @@ namespace VNLib.Utils.Memory //Re-alloc (Zero if required) try { - Heap.Resize(ref handle, Length, (ulong)sizeof(T), ZeroMemory); + Heap.Resize(ref handle, Length, (nuint)sizeof(T), ZeroMemory); } //Catch the disposed exception so we can invalidate the current ptr catch (ObjectDisposedException) @@ -161,7 +151,7 @@ namespace VNLib.Utils.Memory /// /// pointer to the memory offset specified [MethodImpl(MethodImplOptions.AggressiveInlining)] - public unsafe T* GetOffset(ulong elements) + public unsafe T* GetOffset(nuint elements) { if (elements >= _length) { @@ -176,10 +166,14 @@ namespace VNLib.Utils.Memory /// /// /// + /// + ///Calling this method increments the handle's referrence count. + ///Disposing the returned handle decrements the handle count. + /// public unsafe MemoryHandle Pin(int elementIndex) { //Get ptr and guard checks before adding the referrence - T* ptr = GetOffset((ulong)elementIndex); + T* ptr = GetOffset((nuint)elementIndex); bool addRef = false; //use the pinned field as success val @@ -198,15 +192,12 @@ namespace VNLib.Utils.Memory DangerousRelease(); } - /// protected override bool ReleaseHandle() { //Return result of free return Heap.Free(ref handle); - } - - + } /// /// Determines if the memory blocks are equal by comparing their base addresses. diff --git a/lib/Utils/src/Memory/MemoryUtil.cs b/lib/Utils/src/Memory/MemoryUtil.cs new file mode 100644 index 0000000..410db6b --- /dev/null +++ b/lib/Utils/src/Memory/MemoryUtil.cs @@ -0,0 +1,603 @@ +/* +* Copyright (c) 2022 Vaughn Nugent +* +* Library: VNLib +* Package: VNLib.Utils +* File: Memory.cs +* +* Memory.cs is part of VNLib.Utils which is part of the larger +* VNLib collection of libraries and utilities. +* +* VNLib.Utils is free software: you can redistribute it and/or modify +* it under the terms of the GNU General Public License as published +* by the Free Software Foundation, either version 2 of the License, +* or (at your option) any later version. +* +* VNLib.Utils is distributed in the hope that it will be useful, +* but WITHOUT ANY WARRANTY; without even the implied warranty of +* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +* General Public License for more details. +* +* You should have received a copy of the GNU General Public License +* along with VNLib.Utils. If not, see http://www.gnu.org/licenses/. +*/ + +using System; +using System.Buffers; +using System.Security; +using System.Threading; +using System.Runtime.InteropServices; +using System.Runtime.CompilerServices; + +using VNLib.Utils.Extensions; + +namespace VNLib.Utils.Memory +{ + /// + /// Provides optimized cross-platform maanged/umanaged safe/unsafe memory operations + /// + [SecurityCritical] + [ComVisible(false)] + public unsafe static class MemoryUtil + { + public const string SHARED_HEAP_TYPE_ENV= "VNLIB_SHARED_HEAP_TYPE"; + public const string SHARED_HEAP_INTIAL_SIZE_ENV = "VNLIB_SHARED_HEAP_SIZE"; + + /// + /// Initial shared heap size (bytes) + /// + public const nuint SHARED_HEAP_INIT_SIZE = 20971520; + + public const int MAX_BUF_SIZE = 2097152; + public const int MIN_BUF_SIZE = 16000; + + /// + /// The maximum buffer size requested by + /// that will use the array pool before falling back to the . + /// heap. + /// + public const int MAX_UNSAFE_POOL_SIZE = 500 * 1024; + + /// + /// Provides a shared heap instance for the process to allocate memory from. + /// + /// + /// The backing heap + /// is determined by the OS type and process environment varibles. + /// + public static IUnmangedHeap Shared => _sharedHeap.Value; + + + private static readonly Lazy _sharedHeap = InitHeapInternal(); + + //Avoiding statit initializer + private static Lazy InitHeapInternal() + { + Lazy heap = new (() => InitHeapInternal(true), LazyThreadSafetyMode.PublicationOnly); + //Cleanup the heap on process exit + AppDomain.CurrentDomain.DomainUnload += DomainUnloaded; + return heap; + } + + + private static void DomainUnloaded(object? sender, EventArgs e) + { + //Dispose the heap if allocated + if (_sharedHeap.IsValueCreated) + { + _sharedHeap.Value.Dispose(); + } + } + + /// + /// Initializes a new determined by compilation/runtime flags + /// and operating system type for the current proccess. + /// + /// An for the current process + /// + /// + public static IUnmangedHeap InitializeNewHeapForProcess() => InitHeapInternal(false); + + private static IUnmangedHeap InitHeapInternal(bool isShared) + { + bool IsWindows = OperatingSystem.IsWindows(); + //Get environment varable + string? heapType = Environment.GetEnvironmentVariable(SHARED_HEAP_TYPE_ENV); + //Get inital size + string? sharedSize = Environment.GetEnvironmentVariable(SHARED_HEAP_INTIAL_SIZE_ENV); + //Try to parse the shared size from the env + if (!nuint.TryParse(sharedSize, out nuint defaultSize)) + { + defaultSize = SHARED_HEAP_INIT_SIZE; + } + //Gen the private heap from its type or default + switch (heapType) + { + case "win32": + if (!IsWindows) + { + throw new PlatformNotSupportedException("Win32 private heaps are not supported on non-windows platforms"); + } + return Win32PrivateHeap.Create(defaultSize); + case "rpmalloc": + //If the shared heap is being allocated, then return a lock free global heap + return isShared ? RpMallocPrivateHeap.GlobalHeap : new RpMallocPrivateHeap(false); + default: + return IsWindows ? Win32PrivateHeap.Create(defaultSize) : new ProcessHeap(); + } + } + + /// + /// Gets a value that indicates if the Rpmalloc native library is loaded + /// + public static bool IsRpMallocLoaded { get; } = Environment.GetEnvironmentVariable(SHARED_HEAP_TYPE_ENV) == "rpmalloc"; + + #region Zero + + /// + /// Zeros a block of memory of umanged type. If Windows is detected at runtime, calls RtlSecureZeroMemory Win32 function + /// + /// Unmanged datatype + /// Block of memory to be cleared + [MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.NoOptimization)] + public static void UnsafeZeroMemory(ReadOnlySpan block) where T : unmanaged + { + if (!block.IsEmpty) + { + checked + { + fixed (void* ptr = &MemoryMarshal.GetReference(block)) + { + //Calls memset + Unsafe.InitBlock(ptr, 0, (uint)(block.Length * sizeof(T))); + } + } + } + } + + /// + /// Zeros a block of memory of umanged type. If Windows is detected at runtime, calls RtlSecureZeroMemory Win32 function + /// + /// Unmanged datatype + /// Block of memory to be cleared + [MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.NoOptimization)] + public static void UnsafeZeroMemory(ReadOnlyMemory block) where T : unmanaged + { + if (!block.IsEmpty) + { + checked + { + //Pin memory and get pointer + using MemoryHandle handle = block.Pin(); + //Calls memset + Unsafe.InitBlock(handle.Pointer, 0, (uint)(block.Length * sizeof(T))); + } + } + } + + /// + /// Initializes a block of memory with zeros + /// + /// The unmanaged + /// The block of memory to initialize + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void InitializeBlock(Span block) where T : unmanaged => UnsafeZeroMemory(block); + + /// + /// Initializes a block of memory with zeros + /// + /// The unmanaged + /// The block of memory to initialize + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void InitializeBlock(Memory block) where T : unmanaged => UnsafeZeroMemory(block); + + /// + /// Zeroes a block of memory pointing to the structure + /// + /// The structure type + /// The pointer to the allocated structure + public static void ZeroStruct(IntPtr block) + { + //get thes size of the structure + int size = Unsafe.SizeOf(); + //Zero block + Unsafe.InitBlock(block.ToPointer(), 0, (uint)size); + } + + /// + /// Zeroes a block of memory pointing to the structure + /// + /// The structure type + /// The pointer to the allocated structure + public static void ZeroStruct(void* structPtr) where T: unmanaged + { + //get thes size of the structure + int size = Unsafe.SizeOf(); + //Zero block + Unsafe.InitBlock(structPtr, 0, (uint)size); + } + + /// + /// Zeroes a block of memory pointing to the structure + /// + /// The structure type + /// The pointer to the allocated structure + public static void ZeroStruct(T* structPtr) where T : unmanaged + { + //get thes size of the structure + int size = Unsafe.SizeOf(); + //Zero block + Unsafe.InitBlock(structPtr, 0, (uint)size); + } + + #endregion + + #region Copy + + /// + /// Copies data from source memory to destination memory of an umanged data type + /// + /// Unmanged type + /// Source data + /// Destination + /// Dest offset + /// + public static void Copy(ReadOnlySpan source, MemoryHandle dest, nuint destOffset) where T : unmanaged + { + if (dest is null) + { + throw new ArgumentNullException(nameof(dest)); + } + + if (source.IsEmpty) + { + return; + } + + //Get long offset from the destination handle (also checks bounds) + Span dst = dest.GetOffsetSpan(destOffset, source.Length); + + //Copy data + source.CopyTo(dst); + } + + /// + /// Copies data from source memory to destination memory of an umanged data type + /// + /// Unmanged type + /// Source data + /// Destination + /// Dest offset + /// + public static void Copy(ReadOnlyMemory source, MemoryHandle dest, nuint destOffset) where T : unmanaged + { + if (dest is null) + { + throw new ArgumentNullException(nameof(dest)); + } + + if (source.IsEmpty) + { + return; + } + + //Get long offset from the destination handle (also checks bounds) + Span dst = dest.GetOffsetSpan(destOffset, source.Length); + + //Copy data + source.Span.CopyTo(dst); + } + + /// + /// Copies data from source memory to destination memory of an umanged data type + /// + /// Unmanged type + /// Source data + /// Number of elements to offset source data + /// Destination + /// Dest offset + /// Number of elements to copy + /// + public static void Copy(MemoryHandle source, nint sourceOffset, Span dest, int destOffset, int count) where T : unmanaged + { + //Validate source/dest/count + ValidateArgs(sourceOffset, destOffset, count); + + //Check count last for debug reasons + if (count == 0) + { + return; + } + + //Get offset span, also checks bounts + Span src = source.GetOffsetSpan(sourceOffset, count); + + //slice the dest span + Span dst = dest.Slice(destOffset, count); + + //Copy data + src.CopyTo(dst); + } + + /// + /// Copies data from source memory to destination memory of an umanged data type + /// + /// Unmanged type + /// Source data + /// Number of elements to offset source data + /// Destination + /// Dest offset + /// Number of elements to copy + /// + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Copy(MemoryHandle source, nint sourceOffset, Memory dest, int destOffset, int count) where T : unmanaged + { + //Call copy method with dest as span + Copy(source, sourceOffset, dest.Span, destOffset, count); + } + + private static void ValidateArgs(nint sourceOffset, nint destOffset, nint count) + { + if(sourceOffset < 0) + { + throw new ArgumentOutOfRangeException(nameof(sourceOffset), "Source offset must be a postive integer"); + } + + if(destOffset < 0) + { + throw new ArgumentOutOfRangeException(nameof(destOffset), "Destination offset must be a positive integer"); + } + + if(count < 0) + { + throw new ArgumentOutOfRangeException(nameof(count), "Count parameter must be a postitive integer"); + } + } + + /// + /// 32/64 bit large block copy + /// + /// + /// The source memory handle to copy data from + /// The element offset to begin reading from + /// The destination array to write data to + /// + /// The number of elements to copy + /// + /// + public static void Copy(IMemoryHandle source, nuint offset, T[] dest, nuint destOffset, nuint count) where T : unmanaged + { + if (source is null) + { + throw new ArgumentNullException(nameof(source)); + } + + if (dest is null) + { + throw new ArgumentNullException(nameof(dest)); + } + + if (count == 0) + { + return; + } + + //Check source bounds + CheckBounds(source, offset, count); + + //Check dest bounts + CheckBounds(dest, destOffset, count); + + +#if TARGET_64BIT + //Get the number of bytes to copy + nuint byteCount = ByteCount(count); + + //Get memory handle from source + using MemoryHandle srcHandle = source.Pin(0); + + //get source offset + T* src = (T*)srcHandle.Pointer + offset; + + //pin array + fixed (T* dst = dest) + { + //Offset dest ptr + T* dstOffset = dst + destOffset; + + //Copy src to set + Buffer.MemoryCopy(src, dstOffset, byteCount, byteCount); + } +#else + //If 32bit its safe to use spans + + Span src = source.Span.Slice((int)offset, (int)count); + Span dst = dest.AsSpan((int)destOffset, (int)count); + //Copy + src.CopyTo(dst); +#endif + } + + #endregion + + #region Validation + + /// + /// Gets the size in bytes of the handle + /// + /// + /// The handle to get the byte size of + /// The number of bytes pointed to by the handle + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static nuint ByteSize(IMemoryHandle handle) + { + _ = handle ?? throw new ArgumentNullException(nameof(handle)); + return checked(handle.Length * (nuint)Unsafe.SizeOf()); + } + + /// + /// Gets the size in bytes of the handle + /// + /// + /// The handle to get the byte size of + /// The number of bytes pointed to by the handle + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static nuint ByteSize(in UnsafeMemoryHandle handle) where T : unmanaged => checked(handle.Length * (nuint)sizeof(T)); + + /// + /// Gets the byte multiple of the length parameter + /// + /// The type to get the byte offset of + /// The number of elements to get the byte count of + /// The byte multiple of the number of elments + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static nuint ByteCount(nuint elementCount) => checked(elementCount * (nuint)Unsafe.SizeOf()); + /// + /// Gets the byte multiple of the length parameter + /// + /// The type to get the byte offset of + /// The number of elements to get the byte count of + /// The byte multiple of the number of elments + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static uint ByteCount(uint elementCount) => checked(elementCount * (uint)Unsafe.SizeOf()); + + /// + /// Checks if the offset/count paramters for the given memory handle + /// point outside the block wrapped in the handle + /// + /// + /// The handle to check bounds of + /// The base offset to add + /// The number of bytes expected to be assigned or dereferrenced + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void CheckBounds(IMemoryHandle handle, nuint offset, nuint count) + { + if (offset + count > handle.Length) + { + throw new ArgumentException("The offset or count is outside of the range of the block of memory"); + } + } + + /// + /// Checks if the offset/count paramters for the given block + /// point outside the block wrapped in the handle + /// + /// + /// The handle to check bounds of + /// The base offset to add + /// The number of bytes expected to be assigned or dereferrenced + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void CheckBounds(ReadOnlySpan block, int offset, int count) + { + //Call slice and discard to raise exception + _ = block.Slice(offset, count); + } + + /// + /// Checks if the offset/count paramters for the given block + /// point outside the block wrapped in the handle + /// + /// + /// The handle to check bounds of + /// The base offset to add + /// The number of bytes expected to be assigned or dereferrenced + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void CheckBounds(Span block, int offset, int count) + { + //Call slice and discard to raise exception + _ = block.Slice(offset, count); + } + + /// + /// Checks if the offset/count paramters for the given block + /// point outside the block bounds + /// + /// + /// The handle to check bounds of + /// The base offset to add + /// The number of bytes expected to be assigned or dereferrenced + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void CheckBounds(T[] block, nuint offset, nuint count) + { + if (((nuint)block.LongLength - offset) <= count) + { + throw new ArgumentException("The offset or count is outside of the range of the block of memory"); + } + } + + #endregion + + + #region alloc + + /// + /// Allocates a block of unmanaged, or pooled manaaged memory depending on + /// compilation flags and runtime unamanged allocators. + /// + /// The unamanged type to allocate + /// The number of elements of the type within the block + /// Flag to zero elements during allocation before the method returns + /// A handle to the block of memory + /// + /// + public static UnsafeMemoryHandle UnsafeAlloc(int elements, bool zero = false) where T : unmanaged + { + if (elements < 0) + { + throw new ArgumentException("Number of elements must be a positive integer", nameof(elements)); + } + + if(elements > MAX_UNSAFE_POOL_SIZE || IsRpMallocLoaded) + { + // Alloc from heap + IntPtr block = Shared.Alloc((uint)elements, (uint)sizeof(T), zero); + //Init new handle + return new(Shared, block, elements); + } + else + { + return new(ArrayPool.Shared, elements, zero); + } + } + + /// + /// Allocates a block of unmanaged, or pooled manaaged memory depending on + /// compilation flags and runtime unamanged allocators. + /// + /// The unamanged type to allocate + /// The number of elements of the type within the block + /// Flag to zero elements during allocation before the method returns + /// A handle to the block of memory + /// + /// + public static IMemoryHandle SafeAlloc(int elements, bool zero = false) where T: unmanaged + { + if (elements < 0) + { + throw new ArgumentException("Number of elements must be a positive integer", nameof(elements)); + } + + //If the element count is larger than max pool size, alloc from shared heap + if (elements > MAX_UNSAFE_POOL_SIZE) + { + //Alloc from shared heap + return Shared.Alloc(elements, zero); + } + else + { + //Get temp buffer from shared buffer pool + return new VnTempBuffer(elements, zero); + } + } + + #endregion + } +} \ No newline at end of file diff --git a/lib/Utils/src/Memory/PrivateBuffersMemoryPool.cs b/lib/Utils/src/Memory/PrivateBuffersMemoryPool.cs index 1e85207..e73a26f 100644 --- a/lib/Utils/src/Memory/PrivateBuffersMemoryPool.cs +++ b/lib/Utils/src/Memory/PrivateBuffersMemoryPool.cs @@ -28,7 +28,7 @@ using System.Buffers; namespace VNLib.Utils.Memory { /// - /// Provides a wrapper for using unmanged s + /// Provides a wrapper for using unmanged s /// /// Unamanged memory type to provide data memory instances from public sealed class PrivateBuffersMemoryPool : MemoryPool where T : unmanaged diff --git a/lib/Utils/src/Memory/PrivateHeap.cs b/lib/Utils/src/Memory/PrivateHeap.cs deleted file mode 100644 index 5d97506..0000000 --- a/lib/Utils/src/Memory/PrivateHeap.cs +++ /dev/null @@ -1,184 +0,0 @@ -/* -* Copyright (c) 2022 Vaughn Nugent -* -* Library: VNLib -* Package: VNLib.Utils -* File: PrivateHeap.cs -* -* PrivateHeap.cs is part of VNLib.Utils which is part of the larger -* VNLib collection of libraries and utilities. -* -* VNLib.Utils is free software: you can redistribute it and/or modify -* it under the terms of the GNU General Public License as published -* by the Free Software Foundation, either version 2 of the License, -* or (at your option) any later version. -* -* VNLib.Utils is distributed in the hope that it will be useful, -* but WITHOUT ANY WARRANTY; without even the implied warranty of -* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -* General Public License for more details. -* -* You should have received a copy of the GNU General Public License -* along with VNLib.Utils. If not, see http://www.gnu.org/licenses/. -*/ - -using System; -using System.Diagnostics; -using System.Runtime.Versioning; -using System.Runtime.InteropServices; - -using DWORD = System.Int64; -using SIZE_T = System.UInt64; -using LPVOID = System.IntPtr; - -namespace VNLib.Utils.Memory -{ - /// - /// - /// Provides a win32 private heap managed wrapper class - /// - /// - /// - /// implements and tracks allocated blocks by its - /// referrence counter. Allocations increment the count, and free's decrement the count, so the heap may - /// be disposed safely - /// - [ComVisible(false)] - [SupportedOSPlatform("Windows")] - public sealed class PrivateHeap : UnmanagedHeapBase - { - private const string KERNEL_DLL = "Kernel32"; - - #region Extern - //Heap flags - public const DWORD HEAP_NO_FLAGS = 0x00; - public const DWORD HEAP_GENERATE_EXCEPTIONS = 0x04; - public const DWORD HEAP_NO_SERIALIZE = 0x01; - public const DWORD HEAP_REALLOC_IN_PLACE_ONLY = 0x10; - public const DWORD HEAP_ZERO_MEMORY = 0x08; - - [DllImport(KERNEL_DLL, SetLastError = true, ExactSpelling = true)] - [DefaultDllImportSearchPaths(DllImportSearchPath.System32)] - private static extern LPVOID HeapAlloc(IntPtr hHeap, DWORD flags, SIZE_T dwBytes); - [DllImport(KERNEL_DLL, SetLastError = true, ExactSpelling = true)] - [DefaultDllImportSearchPaths(DllImportSearchPath.System32)] - private static extern LPVOID HeapReAlloc(IntPtr hHeap, DWORD dwFlags, LPVOID lpMem, SIZE_T dwBytes); - [DllImport(KERNEL_DLL, SetLastError = true, ExactSpelling = true)] - [DefaultDllImportSearchPaths(DllImportSearchPath.System32)] - [return: MarshalAs(UnmanagedType.Bool)] - private static extern bool HeapFree(IntPtr hHeap, DWORD dwFlags, LPVOID lpMem); - - [DllImport(KERNEL_DLL, SetLastError = true, ExactSpelling = true)] - [DefaultDllImportSearchPaths(DllImportSearchPath.System32)] - private static extern LPVOID HeapCreate(DWORD flOptions, SIZE_T dwInitialSize, SIZE_T dwMaximumSize); - [DllImport(KERNEL_DLL, SetLastError = true, ExactSpelling = true)] - [DefaultDllImportSearchPaths(DllImportSearchPath.System32)] - [return: MarshalAs(UnmanagedType.Bool)] - private static extern bool HeapDestroy(IntPtr hHeap); - [DllImport(KERNEL_DLL, SetLastError = true, ExactSpelling = true)] - [return: MarshalAs(UnmanagedType.Bool)] - [DefaultDllImportSearchPaths(DllImportSearchPath.System32)] - private static extern bool HeapValidate(IntPtr hHeap, DWORD dwFlags, LPVOID lpMem); - [DllImport(KERNEL_DLL, SetLastError = true, ExactSpelling = true)] - [return: MarshalAs(UnmanagedType.U8)] - [DefaultDllImportSearchPaths(DllImportSearchPath.System32)] - private static extern SIZE_T HeapSize(IntPtr hHeap, DWORD flags, LPVOID lpMem); - - #endregion - - /// - /// Create a new with the specified sizes and flags - /// - /// Intial size of the heap - /// Maximum size allowed for the heap (disabled = 0, default) - /// Defalt heap flags to set globally for all blocks allocated by the heap (default = 0) - public static PrivateHeap Create(SIZE_T initialSize, SIZE_T maxHeapSize = 0, DWORD flags = HEAP_NO_FLAGS) - { - //Call create, throw exception if the heap falled to allocate - IntPtr heapHandle = HeapCreate(flags, initialSize, maxHeapSize); - if (heapHandle == IntPtr.Zero) - { - throw new NativeMemoryException("Heap could not be created"); - } -#if TRACE - Trace.WriteLine($"Win32 private heap {heapHandle:x} created"); -#endif - //Heap has been created so we can wrap it - return new(heapHandle); - } - /// - /// LIFETIME WARNING. Consumes a valid win32 handle and will manage it's lifetime once constructed. - /// Locking and memory blocks will attempt to be allocated from this heap handle. - /// - /// An open and valid handle to a win32 private heap - /// A wrapper around the specified heap - public static PrivateHeap ConsumeExisting(IntPtr win32HeapHandle) => new (win32HeapHandle); - - private PrivateHeap(IntPtr heapPtr) : base(false, true) => handle = heapPtr; - - /// - /// Retrieves the size of a memory block allocated from the current heap. - /// - /// The pointer to a block of memory to get the size of - /// The size of the block of memory, (SIZE_T)-1 if the operation fails - public SIZE_T HeapSize(ref LPVOID block) => HeapSize(handle, HEAP_NO_FLAGS, block); - - /// - /// Validates the specified block of memory within the current heap instance. This function will block hte - /// - /// Pointer to the block of memory to validate - /// True if the block is valid, false otherwise - public bool Validate(ref LPVOID block) - { - bool result; - //Lock the heap before validating - HeapLock.Wait(); - //validate the block on the current heap - result = HeapValidate(handle, HEAP_NO_FLAGS, block); - //Unlock the heap - HeapLock.Release(); - return result; - - } - /// - /// Validates the current heap instance. The function scans all the memory blocks in the heap and verifies that the heap control structures maintained by - /// the heap manager are in a consistent state. - /// - /// If the specified heap or memory block is valid, the return value is nonzero. - /// This can be a consuming operation which will block all allocations - public bool Validate() - { - bool result; - //Lock the heap before validating - HeapLock.Wait(); - //validate the entire heap - result = HeapValidate(handle, HEAP_NO_FLAGS, IntPtr.Zero); - //Unlock the heap - HeapLock.Release(); - return result; - } - - /// - protected override bool ReleaseHandle() - { -#if TRACE - Trace.WriteLine($"Win32 private heap {handle:x} destroyed"); -#endif - return HeapDestroy(handle) && base.ReleaseHandle(); - } - /// - protected override sealed LPVOID AllocBlock(ulong elements, ulong size, bool zero) - { - ulong bytes = checked(elements * size); - return HeapAlloc(handle, zero ? HEAP_ZERO_MEMORY : HEAP_NO_FLAGS, bytes); - } - /// - protected override sealed bool FreeBlock(LPVOID block) => HeapFree(handle, HEAP_NO_FLAGS, block); - /// - protected override sealed LPVOID ReAllocBlock(LPVOID block, ulong elements, ulong size, bool zero) - { - ulong bytes = checked(elements * size); - return HeapReAlloc(handle, zero ? HEAP_ZERO_MEMORY : HEAP_NO_FLAGS, block, bytes); - } - } -} \ No newline at end of file diff --git a/lib/Utils/src/Memory/PrivateStringManager.cs b/lib/Utils/src/Memory/PrivateStringManager.cs index 9ed8f5f..8f01e98 100644 --- a/lib/Utils/src/Memory/PrivateStringManager.cs +++ b/lib/Utils/src/Memory/PrivateStringManager.cs @@ -63,7 +63,7 @@ namespace VNLib.Utils.Memory //Clear the old value before setting the new one if (!string.IsNullOrEmpty(ProtectedElements[index])) { - Memory.UnsafeZeroMemory(ProtectedElements[index]); + MemoryUtil.UnsafeZeroMemory(ProtectedElements[index]); } //set new value ProtectedElements[index] = value; @@ -87,7 +87,7 @@ namespace VNLib.Utils.Memory if (!string.IsNullOrEmpty(ProtectedElements[i])) { //Zero the string memory - Memory.UnsafeZeroMemory(ProtectedElements[i]); + MemoryUtil.UnsafeZeroMemory(ProtectedElements[i]); //Set to null ProtectedElements[i] = null; } diff --git a/lib/Utils/src/Memory/ProcessHeap.cs b/lib/Utils/src/Memory/ProcessHeap.cs index 4f06d52..7afe4b1 100644 --- a/lib/Utils/src/Memory/ProcessHeap.cs +++ b/lib/Utils/src/Memory/ProcessHeap.cs @@ -48,20 +48,23 @@ namespace VNLib.Utils.Memory /// /// /// - public IntPtr Alloc(ulong elements, ulong size, bool zero) + public IntPtr Alloc(nuint elements, nuint size, bool zero) { return zero - ? (IntPtr)NativeMemory.AllocZeroed((nuint)elements, (nuint)size) - : (IntPtr)NativeMemory.Alloc((nuint)elements, (nuint)size); + ? (IntPtr)NativeMemory.AllocZeroed(elements, size) + : (IntPtr)NativeMemory.Alloc(elements, size); } /// public bool Free(ref IntPtr block) { //Free native mem from ptr NativeMemory.Free(block.ToPointer()); + block = IntPtr.Zero; + return true; } + /// protected override void Free() { @@ -69,14 +72,25 @@ namespace VNLib.Utils.Memory Trace.WriteLine($"Default heap instnace disposed {GetHashCode():x}"); #endif } + /// /// /// - public void Resize(ref IntPtr block, ulong elements, ulong size, bool zero) + public void Resize(ref IntPtr block, nuint elements, nuint size, bool zero) { - nuint bytes = checked((nuint)(elements * size)); - IntPtr old = block; - block = (IntPtr)NativeMemory.Realloc(old.ToPointer(), bytes); + nuint bytes = checked(elements * size); + + //Alloc + void* newBlock = NativeMemory.Realloc(block.ToPointer(), bytes); + + //Check + if (newBlock == null) + { + throw new NativeMemoryOutOfMemoryException("Failed to resize the allocated block"); + } + + //Assign block ptr + block = (IntPtr)newBlock; } } } diff --git a/lib/Utils/src/Memory/RpMallocPrivateHeap.cs b/lib/Utils/src/Memory/RpMallocPrivateHeap.cs index 8ed79b6..f9b7db6 100644 --- a/lib/Utils/src/Memory/RpMallocPrivateHeap.cs +++ b/lib/Utils/src/Memory/RpMallocPrivateHeap.cs @@ -28,7 +28,6 @@ using System.Diagnostics; using System.Runtime.InteropServices; using System.Runtime.CompilerServices; -using size_t = System.UInt64; using LPVOID = System.IntPtr; using LPHEAPHANDLE = System.IntPtr; @@ -59,22 +58,22 @@ namespace VNLib.Utils.Memory static extern void rpmalloc_heap_release(LPHEAPHANDLE heap); [DllImport(DLL_NAME, ExactSpelling = true)] [DefaultDllImportSearchPaths(DllImportSearchPath.SafeDirectories)] - static extern LPVOID rpmalloc_heap_alloc(LPHEAPHANDLE heap, size_t size); + static extern LPVOID rpmalloc_heap_alloc(LPHEAPHANDLE heap, nuint size); [DllImport(DLL_NAME, ExactSpelling = true)] [DefaultDllImportSearchPaths(DllImportSearchPath.SafeDirectories)] - static extern LPVOID rpmalloc_heap_aligned_alloc(LPHEAPHANDLE heap, size_t alignment, size_t size); + static extern LPVOID rpmalloc_heap_aligned_alloc(LPHEAPHANDLE heap, nuint alignment, nuint size); [DllImport(DLL_NAME, ExactSpelling = true)] [DefaultDllImportSearchPaths(DllImportSearchPath.SafeDirectories)] - static extern LPVOID rpmalloc_heap_calloc(LPHEAPHANDLE heap, size_t num, size_t size); + static extern LPVOID rpmalloc_heap_calloc(LPHEAPHANDLE heap, nuint num, nuint size); [DllImport(DLL_NAME, ExactSpelling = true)] [DefaultDllImportSearchPaths(DllImportSearchPath.SafeDirectories)] - static extern LPVOID rpmalloc_heap_aligned_calloc(LPHEAPHANDLE heap, size_t alignment, size_t num, size_t size); + static extern LPVOID rpmalloc_heap_aligned_calloc(LPHEAPHANDLE heap, nuint alignment, nuint num, nuint size); [DllImport(DLL_NAME, ExactSpelling = true)] [DefaultDllImportSearchPaths(DllImportSearchPath.SafeDirectories)] - static extern LPVOID rpmalloc_heap_realloc(LPHEAPHANDLE heap, LPVOID ptr, size_t size, nuint flags); + static extern LPVOID rpmalloc_heap_realloc(LPHEAPHANDLE heap, LPVOID ptr, nuint size, nuint flags); [DllImport(DLL_NAME, ExactSpelling = true)] [DefaultDllImportSearchPaths(DllImportSearchPath.SafeDirectories)] - static extern LPVOID rpmalloc_heap_aligned_realloc(LPHEAPHANDLE heap, LPVOID ptr, size_t alignment, size_t size, nuint flags); + static extern LPVOID rpmalloc_heap_aligned_realloc(LPHEAPHANDLE heap, LPVOID ptr, nuint alignment, nuint size, nuint flags); [DllImport(DLL_NAME, ExactSpelling = true)] [DefaultDllImportSearchPaths(DllImportSearchPath.SafeDirectories)] static extern void rpmalloc_heap_free(LPHEAPHANDLE heap, LPVOID ptr); @@ -96,13 +95,13 @@ namespace VNLib.Utils.Memory static extern void rpmalloc_thread_finalize(int release_caches); [DllImport(DLL_NAME, ExactSpelling = true)] [DefaultDllImportSearchPaths(DllImportSearchPath.SafeDirectories)] - static extern LPVOID rpmalloc(size_t size); + static extern LPVOID rpmalloc(nuint size); [DllImport(DLL_NAME, ExactSpelling = true)] [DefaultDllImportSearchPaths(DllImportSearchPath.SafeDirectories)] - static extern LPVOID rpcalloc(size_t num, size_t size); + static extern LPVOID rpcalloc(nuint num, nuint size); [DllImport(DLL_NAME, ExactSpelling = true)] [DefaultDllImportSearchPaths(DllImportSearchPath.SafeDirectories)] - static extern LPVOID rprealloc(LPVOID ptr, size_t size); + static extern LPVOID rprealloc(LPVOID ptr, nuint size); [DllImport(DLL_NAME, ExactSpelling = true)] [DefaultDllImportSearchPaths(DllImportSearchPath.SafeDirectories)] static extern void rpfree(LPVOID ptr); @@ -111,9 +110,9 @@ namespace VNLib.Utils.Memory private sealed class RpMallocGlobalHeap : IUnmangedHeap { - IntPtr IUnmangedHeap.Alloc(ulong elements, ulong size, bool zero) + IntPtr IUnmangedHeap.Alloc(nuint elements, nuint size, bool zero) { - return RpMalloc(elements, (nuint)size, zero); + return RpMalloc(elements, size, zero); } //Global heap does not need to be disposed @@ -127,17 +126,13 @@ namespace VNLib.Utils.Memory return true; } - void IUnmangedHeap.Resize(ref IntPtr block, ulong elements, ulong size, bool zero) + void IUnmangedHeap.Resize(ref IntPtr block, nuint elements, nuint size, bool zero) { //Try to resize the block - IntPtr resize = RpRealloc(block, elements, (nuint)size); - - if (resize == IntPtr.Zero) - { - throw new NativeMemoryOutOfMemoryException("Failed to resize the block"); - } + IntPtr resize = RpRealloc(block, elements, size); + //assign ptr - block = resize; + block = resize != IntPtr.Zero ? resize : throw new NativeMemoryOutOfMemoryException("Failed to resize the block"); } } @@ -164,7 +159,7 @@ namespace VNLib.Utils.Memory /// The number of bytes per element type (aligment) /// Zero the block of memory before returning /// A pointer to the block, (zero if failed) - public static LPVOID RpMalloc(size_t elements, nuint size, bool zero) + public static LPVOID RpMalloc(nuint elements, nuint size, bool zero) { //See if the current thread has been initialized if (rpmalloc_is_thread_initialized() == 0) @@ -172,8 +167,10 @@ namespace VNLib.Utils.Memory //Initialize the current thread rpmalloc_thread_initialize(); } + //Alloc block LPVOID block; + if (zero) { block = rpcalloc(elements, size); @@ -181,7 +178,8 @@ namespace VNLib.Utils.Memory else { //Calculate the block size - ulong blockSize = checked(elements * size); + nuint blockSize = checked(elements * size); + block = rpmalloc(blockSize); } return block; @@ -209,14 +207,16 @@ namespace VNLib.Utils.Memory /// A pointer to the new block if the reallocation succeeded, null if the resize failed /// /// - public static LPVOID RpRealloc(LPVOID block, size_t elements, nuint size) + public static LPVOID RpRealloc(LPVOID block, nuint elements, nuint size) { if(block == IntPtr.Zero) { throw new ArgumentException("The supplied block is not valid", nameof(block)); } + //Calc new block size - size_t blockSize = checked(elements * size); + nuint blockSize = checked(elements * size); + return rprealloc(block, blockSize); } @@ -239,6 +239,7 @@ namespace VNLib.Utils.Memory Trace.WriteLine($"RPMalloc heap {handle:x} created"); #endif } + /// protected override bool ReleaseHandle() { @@ -252,13 +253,15 @@ namespace VNLib.Utils.Memory //Release base return base.ReleaseHandle(); } + /// [MethodImpl(MethodImplOptions.AggressiveInlining)] - protected sealed override LPVOID AllocBlock(ulong elements, ulong size, bool zero) + protected sealed override LPVOID AllocBlock(nuint elements, nuint size, bool zero) { //Alloc or calloc and initalize return zero ? rpmalloc_heap_calloc(handle, elements, size) : rpmalloc_heap_alloc(handle, checked(size * elements)); } + /// [MethodImpl(MethodImplOptions.AggressiveInlining)] protected sealed override bool FreeBlock(LPVOID block) @@ -267,13 +270,15 @@ namespace VNLib.Utils.Memory rpmalloc_heap_free(handle, block); return true; } + /// [MethodImpl(MethodImplOptions.AggressiveInlining)] - protected sealed override LPVOID ReAllocBlock(LPVOID block, ulong elements, ulong size, bool zero) + protected sealed override LPVOID ReAllocBlock(LPVOID block, nuint elements, nuint size, bool zero) { //Realloc return rpmalloc_heap_realloc(handle, block, checked(elements * size), 0); } + #endregion } } diff --git a/lib/Utils/src/Memory/SubSequence.cs b/lib/Utils/src/Memory/SubSequence.cs index 3800fb5..87e369b 100644 --- a/lib/Utils/src/Memory/SubSequence.cs +++ b/lib/Utils/src/Memory/SubSequence.cs @@ -35,10 +35,7 @@ namespace VNLib.Utils.Memory public readonly struct SubSequence : IEquatable> where T: unmanaged { private readonly MemoryHandle _handle; - /// - /// The number of elements in the current window - /// - public readonly int Size { get; } + private readonly nuint _offset; /// /// Creates a new to the handle to get a window of the block @@ -46,33 +43,23 @@ namespace VNLib.Utils.Memory /// /// /// -#if TARGET_64_BIT - public SubSequence(MemoryHandle block, ulong offset, int size) -#else - public SubSequence(MemoryHandle block, int offset, int size) -#endif + public SubSequence(MemoryHandle block, nuint offset, int size) { _offset = offset; Size = size >= 0 ? size : throw new ArgumentOutOfRangeException(nameof(size)); _handle = block ?? throw new ArgumentNullException(nameof(block)); } + /// + /// The number of elements in the current window + /// + public readonly int Size { get; } -#if TARGET_64_BIT - private readonly ulong _offset; -#else - private readonly int _offset; -#endif /// /// Gets a that is offset from the base of the handle /// /// - -#if TARGET_64_BIT - public readonly Span Span => Size > 0 ? _handle.GetOffsetSpan(_offset, Size) : Span.Empty; -#else - public readonly Span Span => Size > 0 ? _handle.Span.Slice(_offset, Size) : Span.Empty; -#endif + public readonly Span Span => Size > 0 ? _handle.GetOffsetSpan(_offset, Size) : Span.Empty; /// /// Slices the current sequence into a smaller @@ -80,7 +67,7 @@ namespace VNLib.Utils.Memory /// The relative offset from the current window offset /// The size of the block /// A of the current sequence - public readonly SubSequence Slice(uint offset, int size) => new (_handle, _offset + checked((int)offset), size); + public readonly SubSequence Slice(nuint offset, int size) => new (_handle, checked(_offset + offset), size); /// /// Returns the signed 32-bit hashcode diff --git a/lib/Utils/src/Memory/SysBufferMemoryManager.cs b/lib/Utils/src/Memory/SysBufferMemoryManager.cs index 040467f..aca2543 100644 --- a/lib/Utils/src/Memory/SysBufferMemoryManager.cs +++ b/lib/Utils/src/Memory/SysBufferMemoryManager.cs @@ -40,7 +40,7 @@ namespace VNLib.Utils.Memory private readonly bool _ownsHandle; /// - /// Consumes an exisitng to provide wrappers. + /// Consumes an exisitng to provide wrappers. /// The handle should no longer be referrenced directly /// /// The existing handle to consume @@ -52,12 +52,12 @@ namespace VNLib.Utils.Memory } /// - /// Allocates a fized size buffer from the specified unmanaged + /// Allocates a fized size buffer from the specified unmanaged /// /// The heap to perform allocations from /// The number of elements to allocate /// Zero allocations - public SysBufferMemoryManager(IUnmangedHeap heap, ulong elements, bool zero) + public SysBufferMemoryManager(IUnmangedHeap heap, nuint elements, bool zero) { BackingMemory = heap.Alloc(elements, zero); _ownsHandle = true; diff --git a/lib/Utils/src/Memory/UnmanagedHeapBase.cs b/lib/Utils/src/Memory/UnmanagedHeapBase.cs index 5c92aff..1f7dc7f 100644 --- a/lib/Utils/src/Memory/UnmanagedHeapBase.cs +++ b/lib/Utils/src/Memory/UnmanagedHeapBase.cs @@ -28,7 +28,6 @@ using System.Runtime.InteropServices; using Microsoft.Win32.SafeHandles; -using size_t = System.UInt64; using LPVOID = System.IntPtr; namespace VNLib.Utils.Memory @@ -43,6 +42,7 @@ namespace VNLib.Utils.Memory /// The heap synchronization handle /// protected readonly SemaphoreSlim HeapLock; + /// /// The global heap zero flag /// @@ -63,7 +63,7 @@ namespace VNLib.Utils.Memory ///Increments the handle count /// /// - public LPVOID Alloc(size_t elements, size_t size, bool zero) + public LPVOID Alloc(nuint elements, nuint size, bool zero) { //Force zero if global flag is set zero |= GlobalZero; @@ -93,6 +93,7 @@ namespace VNLib.Utils.Memory throw; } } + /// ///Decrements the handle count public bool Free(ref LPVOID block) @@ -116,10 +117,11 @@ namespace VNLib.Utils.Memory block = IntPtr.Zero; return result; } + /// /// /// - public void Resize(ref LPVOID block, size_t elements, size_t size, bool zero) + public void Resize(ref LPVOID block, nuint elements, nuint size, bool zero) { //wait for lock HeapLock.Wait(); @@ -155,12 +157,14 @@ namespace VNLib.Utils.Memory /// The size of the element type (in bytes) /// A flag to zero the allocated block /// A pointer to the allocated block - protected abstract LPVOID AllocBlock(size_t elements, size_t size, bool zero); + protected abstract LPVOID AllocBlock(nuint elements, nuint size, bool zero); + /// /// Frees a previously allocated block of memory /// /// The block to free protected abstract bool FreeBlock(LPVOID block); + /// /// Resizes the previously allocated block of memory on the current heap /// @@ -173,9 +177,11 @@ namespace VNLib.Utils.Memory /// Heap base relies on the block pointer to remain unchanged if the resize fails so the /// block is still valid, and the return value is used to determine if the resize was successful /// - protected abstract LPVOID ReAllocBlock(LPVOID block, size_t elements, size_t size, bool zero); + protected abstract LPVOID ReAllocBlock(LPVOID block, nuint elements, nuint size, bool zero); + /// public override int GetHashCode() => handle.GetHashCode(); + /// public override bool Equals(object? obj) { diff --git a/lib/Utils/src/Memory/UnsafeMemoryHandle.cs b/lib/Utils/src/Memory/UnsafeMemoryHandle.cs index b05ad40..72edb26 100644 --- a/lib/Utils/src/Memory/UnsafeMemoryHandle.cs +++ b/lib/Utils/src/Memory/UnsafeMemoryHandle.cs @@ -40,7 +40,7 @@ namespace VNLib.Utils.Memory [StructLayout(LayoutKind.Sequential)] public readonly struct UnsafeMemoryHandle : IMemoryHandle, IEquatable> where T : unmanaged { - private enum HandleType + private enum HandleType : byte { None, Pool, @@ -60,10 +60,12 @@ namespace VNLib.Utils.Memory [MethodImpl(MethodImplOptions.AggressiveInlining)] get => _handleType == HandleType.Pool ? _poolArr.AsSpan(0, IntLength) : new (_memoryPtr.ToPointer(), IntLength); } - /// + /// + /// Gets the integer number of elements of the block of memory pointed to by this handle + /// public readonly int IntLength => _length; /// - public readonly ulong Length => (ulong)_length; + public readonly nuint Length => (nuint)_length; /// /// Creates an empty @@ -153,11 +155,18 @@ namespace VNLib.Utils.Memory /// public readonly unsafe MemoryHandle Pin(int elementIndex) { - //Guard + //guard empty handle + if (_handleType == HandleType.None) + { + throw new InvalidOperationException("The handle is empty, and cannot be pinned"); + } + + //Guard size if (elementIndex < 0 || elementIndex >= IntLength) { throw new ArgumentOutOfRangeException(nameof(elementIndex)); } + if (_handleType == HandleType.Pool) { diff --git a/lib/Utils/src/Memory/VnString.cs b/lib/Utils/src/Memory/VnString.cs index 7fa0c5a..8bb0bb6 100644 --- a/lib/Utils/src/Memory/VnString.cs +++ b/lib/Utils/src/Memory/VnString.cs @@ -52,6 +52,7 @@ namespace VNLib.Utils.Memory /// The number of unicode characters the current instance can reference /// public int Length => _stringSequence.Size; + /// /// Gets a value indicating if the current instance is empty /// @@ -62,14 +63,7 @@ namespace VNLib.Utils.Memory _stringSequence = sequence; } - private VnString( - MemoryHandle handle, -#if TARGET_64_BIT - ulong start, -#else - int start, -#endif - int length) + private VnString(MemoryHandle handle, nuint start, int length) { Handle = handle ?? throw new ArgumentNullException(nameof(handle)); //get sequence @@ -83,6 +77,7 @@ namespace VNLib.Utils.Memory { //Default string sequence is empty and does not hold any memory } + /// /// Creates a new around a or a of data /// @@ -90,13 +85,13 @@ namespace VNLib.Utils.Memory /// public VnString(ReadOnlySpan data) { - //Create new handle with enough size (heap) - Handle = Memory.Shared.Alloc(data.Length); - //Copy - Memory.Copy(data, Handle, 0); + //Create new handle and copy incoming data to it + Handle = MemoryUtil.Shared.AllocAndCopy(data); + //Get subsequence over the whole copy of data _stringSequence = Handle.GetSubSequence(0, data.Length); } + /// /// Allocates a temporary buffer to read data from the stream until the end of the stream is reached. /// Decodes data from the user-specified encoding @@ -122,7 +117,7 @@ namespace VNLib.Utils.Memory //Get the number of characters int numChars = encoding.GetCharCount(vnms.AsSpan()); //New handle - MemoryHandle charBuffer = Memory.Shared.Alloc(numChars); + MemoryHandle charBuffer = MemoryUtil.Shared.Alloc(numChars); try { //Write characters to character buffer @@ -141,9 +136,9 @@ namespace VNLib.Utils.Memory else { //Create a new char bufer that will expand dyanmically - MemoryHandle charBuffer = Memory.Shared.Alloc(bufferSize); + MemoryHandle charBuffer = MemoryUtil.Shared.Alloc(bufferSize); //Allocate a binary buffer - MemoryHandle binBuffer = Memory.Shared.Alloc(bufferSize); + MemoryHandle binBuffer = MemoryUtil.Shared.Alloc(bufferSize); try { int length = 0; @@ -194,6 +189,7 @@ namespace VNLib.Utils.Memory } } } + /// /// Creates a new Vnstring from the buffer provided. This function "consumes" /// a handle, meaning it now takes ownsership of the the memory it points to. @@ -203,27 +199,24 @@ namespace VNLib.Utils.Memory /// The number of characters this string points to /// The new /// - public static VnString ConsumeHandle( - MemoryHandle handle, - -#if TARGET_64_BIT - ulong start, -#else - int start, -#endif - - int length) + public static VnString ConsumeHandle(MemoryHandle handle, nuint start, int length) { - if(length < 0) + if (handle is null) { - throw new ArgumentOutOfRangeException(nameof(length)); + throw new ArgumentNullException(nameof(handle)); } - if((uint)length > handle.Length) + + if (length < 0) { throw new ArgumentOutOfRangeException(nameof(length)); } + + //Check handle bounts + MemoryUtil.CheckBounds(handle, start, (nuint)length); + return new VnString(handle, start, length); } + /// /// Asynchronously reads data from the specified stream and uses the specified encoding /// to decode the binary data to a new heap character buffer. @@ -354,10 +347,11 @@ namespace VNLib.Utils.Memory throw new ArgumentOutOfRangeException(nameof(count)); } //get sub-sequence slice for the current string - SubSequence sub = _stringSequence.Slice((uint)start, count); + SubSequence sub = _stringSequence.Slice((nuint)start, count); //Create new string with offsets pointing to same internal referrence return new VnString(sub); } + /// /// Creates a that is a window within the current string, /// the referrence points to the same memory as the first instnace. @@ -403,13 +397,12 @@ namespace VNLib.Utils.Memory /// /// representation of internal data /// - public override unsafe string ToString() + public override string ToString() { - //Check - Check(); //Create a new return AsSpan().ToString(); } + /// /// Gets the value of the character at the specified index /// @@ -474,7 +467,21 @@ namespace VNLib.Utils.Memory /// a character span etc /// /// - public override int GetHashCode() => string.GetHashCode(AsSpan()); + public override int GetHashCode() => GetHashCode(StringComparison.Ordinal); + + /// + /// Gets a hashcode for the underyling string by using the .NET + /// method on the character representation of the data + /// + /// The string comperison mode + /// + /// + /// It is safe to compare hashcodes of to the class or + /// a character span etc + /// + /// + public int GetHashCode(StringComparison stringComparison) => string.GetHashCode(AsSpan(), stringComparison); + /// protected override void Free() { diff --git a/lib/Utils/src/Memory/VnTable.cs b/lib/Utils/src/Memory/VnTable.cs index 1d5c0a6..2c6ce74 100644 --- a/lib/Utils/src/Memory/VnTable.cs +++ b/lib/Utils/src/Memory/VnTable.cs @@ -35,33 +35,38 @@ namespace VNLib.Utils.Memory public sealed class VnTable : VnDisposeable, IIndexable where T : unmanaged { private readonly MemoryHandle? BufferHandle; + /// /// A value that indicates if the table does not contain any values /// public bool Empty { get; } + /// /// The number of rows in the table /// - public int Rows { get; } + public uint Rows { get; } + /// /// The nuber of columns in the table /// - public int Cols { get; } + public uint Cols { get; } + /// - /// Creates a new 2 dimensional table in unmanaged heap memory, using the heap. + /// Creates a new 2 dimensional table in unmanaged heap memory, using the heap. /// User should dispose of the table when no longer in use /// /// Number of rows in the table /// Number of columns in the table - public VnTable(int rows, int cols) : this(Memory.Shared, rows, cols) { } + public VnTable(uint rows, uint cols) : this(MemoryUtil.Shared, rows, cols) { } + /// /// Creates a new 2 dimensional table in unmanaged heap memory, using the specified heap. /// User should dispose of the table when no longer in use /// - /// to allocate table memory from + /// to allocate table memory from /// Number of rows in the table /// Number of columns in the table - public VnTable(IUnmangedHeap heap, int rows, int cols) + public VnTable(IUnmangedHeap heap, uint rows, uint cols) { if (rows < 0 || cols < 0) { @@ -71,19 +76,28 @@ namespace VNLib.Utils.Memory if (rows == 0 && cols == 0) { Empty = true; - return; } + else + { + _ = heap ?? throw new ArgumentNullException(nameof(heap)); - _ = heap ?? throw new ArgumentNullException(nameof(heap)); + this.Rows = rows; + this.Cols = cols; - this.Rows = rows; - this.Cols = cols; + ulong tableSize = checked((ulong) rows * (ulong) cols); - long tableSize = Math.BigMul(rows, cols); + if (tableSize > nuint.MaxValue) + { +#pragma warning disable CA2201 // Do not raise reserved exception types + throw new OutOfMemoryException("Table size is too large"); +#pragma warning restore CA2201 // Do not raise reserved exception types + } - //Alloc a buffer with zero memory enabled, with Rows * Cols number of elements - BufferHandle = heap.Alloc(tableSize, true); + //Alloc a buffer with zero memory enabled, with Rows * Cols number of elements + BufferHandle = heap.Alloc((nuint)tableSize, true); + } } + /// /// Gets the value of an item in the table at the given indexes /// @@ -93,7 +107,7 @@ namespace VNLib.Utils.Memory /// /// /// - public T Get(int row, int col) + public T Get(uint row, uint col) { Check(); if (this.Empty) @@ -114,15 +128,18 @@ namespace VNLib.Utils.Memory } //Calculate the address in memory for the item //Calc row offset - long address = Cols * row; + ulong address = checked(row * this.Cols); + //Calc column offset address += col; + unsafe { //Get the value item - return *(BufferHandle!.GetOffset(address)); + return *(BufferHandle!.GetOffset((nuint)address)); } } + /// /// Sets the value of an item in the table at the given address /// @@ -133,7 +150,7 @@ namespace VNLib.Utils.Memory /// /// /// - public void Set(int row, int col, T item) + public void Set(uint row, uint col, T item) { Check(); if (this.Empty) @@ -152,28 +169,34 @@ namespace VNLib.Utils.Memory { throw new ArgumentOutOfRangeException(nameof(col), "Column address out of range of current table"); } + //Calculate the address in memory for the item + //Calc row offset - long address = Cols * row; + ulong address = checked(Cols * row); + //Calc column offset address += col; + //Set the value item unsafe { - *BufferHandle!.GetOffset(address) = item; + *BufferHandle!.GetOffset((nuint)address) = item; } } + /// - /// Equivalent to and + /// Equivalent to and /// /// Row address of item /// Column address of item /// The value of the item - public T this[int row, int col] + public T this[uint row, uint col] { get => Get(row, col); set => Set(row, col, value); } + /// /// Allows for direct addressing in the table. /// @@ -200,10 +223,11 @@ namespace VNLib.Utils.Memory *(BufferHandle!.GetOffset(index)) = value; } } + /// protected override void Free() { - if (!this.Empty) + if (!Empty) { //Dispose the buffer BufferHandle!.Dispose(); diff --git a/lib/Utils/src/Memory/VnTempBuffer.cs b/lib/Utils/src/Memory/VnTempBuffer.cs index 7726fe1..1d8e42f 100644 --- a/lib/Utils/src/Memory/VnTempBuffer.cs +++ b/lib/Utils/src/Memory/VnTempBuffer.cs @@ -28,7 +28,6 @@ using System.Runtime.InteropServices; using System.Runtime.CompilerServices; using VNLib.Utils.Extensions; -using System.Security.Cryptography; namespace VNLib.Utils.Memory { @@ -52,12 +51,7 @@ namespace VNLib.Utils.Memory /// /// Actual length of internal buffer /// - public ulong Length => (ulong)Buffer.LongLength; - - /// - /// Actual length of internal buffer - /// - public int IntLength => Buffer.Length; + public nuint Length => (nuint)Buffer.LongLength; /// /// @@ -77,6 +71,7 @@ namespace VNLib.Utils.Memory /// Set the zero memory flag on close public VnTempBuffer(int minSize, bool zero = false) :this(ArrayPool.Shared, minSize, zero) {} + /// /// Allocates a new with a new buffer from specified array-pool /// @@ -89,6 +84,7 @@ namespace VNLib.Utils.Memory Buffer = pool.Rent(minSize, zero); InitSize = minSize; } + /// /// Gets an offset wrapper around the current buffer /// @@ -101,6 +97,7 @@ namespace VNLib.Utils.Memory //Let arraysegment throw exceptions for checks return new ArraySegment(Buffer, offset, count); } + /// public T this[int index] { @@ -127,6 +124,7 @@ namespace VNLib.Utils.Memory Check(); return new Memory(Buffer, 0, InitSize); } + /// /// Gets a memory structure around the internal buffer /// @@ -140,6 +138,7 @@ namespace VNLib.Utils.Memory Check(); return new Memory(Buffer, start, count); } + /// /// Gets a memory structure around the internal buffer /// @@ -181,17 +180,20 @@ namespace VNLib.Utils.Memory unsafe MemoryHandle IPinnable.Pin(int elementIndex) { //Guard - if (elementIndex < 0 || elementIndex >= IntLength) + if (elementIndex < 0 || elementIndex >= Buffer.Length) { throw new ArgumentOutOfRangeException(nameof(elementIndex)); } //Pin the array GCHandle arrHandle = GCHandle.Alloc(Buffer, GCHandleType.Pinned); + //Get array base address void* basePtr = (void*)arrHandle.AddrOfPinnedObject(); + //Get element offset void* indexOffet = Unsafe.Add(basePtr, elementIndex); + return new(indexOffet, arrHandle, this); } diff --git a/lib/Utils/src/Memory/Win32PrivateHeap.cs b/lib/Utils/src/Memory/Win32PrivateHeap.cs new file mode 100644 index 0000000..fe214f4 --- /dev/null +++ b/lib/Utils/src/Memory/Win32PrivateHeap.cs @@ -0,0 +1,191 @@ +/* +* Copyright (c) 2022 Vaughn Nugent +* +* Library: VNLib +* Package: VNLib.Utils +* File: PrivateHeap.cs +* +* PrivateHeap.cs is part of VNLib.Utils which is part of the larger +* VNLib collection of libraries and utilities. +* +* VNLib.Utils is free software: you can redistribute it and/or modify +* it under the terms of the GNU General Public License as published +* by the Free Software Foundation, either version 2 of the License, +* or (at your option) any later version. +* +* VNLib.Utils is distributed in the hope that it will be useful, +* but WITHOUT ANY WARRANTY; without even the implied warranty of +* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +* General Public License for more details. +* +* You should have received a copy of the GNU General Public License +* along with VNLib.Utils. If not, see http://www.gnu.org/licenses/. +*/ + +using System; +using System.Diagnostics; +using System.Runtime.Versioning; +using System.Runtime.InteropServices; + +using DWORD = System.Int64; +using LPVOID = System.IntPtr; + +namespace VNLib.Utils.Memory +{ + /// + /// + /// Provides a win32 private heap managed wrapper class + /// + /// + /// + /// implements and tracks allocated blocks by its + /// referrence counter. Allocations increment the count, and free's decrement the count, so the heap may + /// be disposed safely + /// + [ComVisible(false)] + [SupportedOSPlatform("Windows")] + public sealed class Win32PrivateHeap : UnmanagedHeapBase + { + private const string KERNEL_DLL = "Kernel32"; + + #region Extern + //Heap flags + public const DWORD HEAP_NO_FLAGS = 0x00; + public const DWORD HEAP_GENERATE_EXCEPTIONS = 0x04; + public const DWORD HEAP_NO_SERIALIZE = 0x01; + public const DWORD HEAP_REALLOC_IN_PLACE_ONLY = 0x10; + public const DWORD HEAP_ZERO_MEMORY = 0x08; + + [DllImport(KERNEL_DLL, SetLastError = true, ExactSpelling = true)] + [DefaultDllImportSearchPaths(DllImportSearchPath.System32)] + private static extern LPVOID HeapAlloc(IntPtr hHeap, DWORD flags, nuint dwBytes); + + [DllImport(KERNEL_DLL, SetLastError = true, ExactSpelling = true)] + [DefaultDllImportSearchPaths(DllImportSearchPath.System32)] + private static extern LPVOID HeapReAlloc(IntPtr hHeap, DWORD dwFlags, LPVOID lpMem, nuint dwBytes); + + [DllImport(KERNEL_DLL, SetLastError = true, ExactSpelling = true)] + [DefaultDllImportSearchPaths(DllImportSearchPath.System32)] + [return: MarshalAs(UnmanagedType.Bool)] + private static extern bool HeapFree(IntPtr hHeap, DWORD dwFlags, LPVOID lpMem); + + [DllImport(KERNEL_DLL, SetLastError = true, ExactSpelling = true)] + [DefaultDllImportSearchPaths(DllImportSearchPath.System32)] + private static extern LPVOID HeapCreate(DWORD flOptions, nuint dwInitialSize, nuint dwMaximumSize); + + [DllImport(KERNEL_DLL, SetLastError = true, ExactSpelling = true)] + [DefaultDllImportSearchPaths(DllImportSearchPath.System32)] + [return: MarshalAs(UnmanagedType.Bool)] + private static extern bool HeapDestroy(IntPtr hHeap); + + [DllImport(KERNEL_DLL, SetLastError = true, ExactSpelling = true)] + [DefaultDllImportSearchPaths(DllImportSearchPath.System32)] + [return: MarshalAs(UnmanagedType.Bool)] + private static extern bool HeapValidate(IntPtr hHeap, DWORD dwFlags, LPVOID lpMem); + + [DllImport(KERNEL_DLL, SetLastError = true, ExactSpelling = true)] + [DefaultDllImportSearchPaths(DllImportSearchPath.System32)] + private static extern nuint HeapSize(IntPtr hHeap, DWORD flags, LPVOID lpMem); + + #endregion + + /// + /// Create a new with the specified sizes and flags + /// + /// Intial size of the heap + /// Maximum size allowed for the heap (disabled = 0, default) + /// Defalt heap flags to set globally for all blocks allocated by the heap (default = 0) + public static Win32PrivateHeap Create(nuint initialSize, nuint maxHeapSize = 0, DWORD flags = HEAP_NO_FLAGS) + { + //Call create, throw exception if the heap falled to allocate + IntPtr heapHandle = HeapCreate(flags, initialSize, maxHeapSize); + + if (heapHandle == IntPtr.Zero) + { + throw new NativeMemoryException("Heap could not be created"); + } +#if TRACE + Trace.WriteLine($"Win32 private heap {heapHandle:x} created"); +#endif + //Heap has been created so we can wrap it + return new(heapHandle); + } + /// + /// LIFETIME WARNING. Consumes a valid win32 handle and will manage it's lifetime once constructed. + /// Locking and memory blocks will attempt to be allocated from this heap handle. + /// + /// An open and valid handle to a win32 private heap + /// A wrapper around the specified heap + public static Win32PrivateHeap ConsumeExisting(IntPtr win32HeapHandle) => new (win32HeapHandle); + + private Win32PrivateHeap(IntPtr heapPtr) : base(false, true) => handle = heapPtr; + + /// + /// Retrieves the size of a memory block allocated from the current heap. + /// + /// The pointer to a block of memory to get the size of + /// The size of the block of memory, (SIZE_T)-1 if the operation fails + public nuint HeapSize(ref LPVOID block) => HeapSize(handle, HEAP_NO_FLAGS, block); + + /// + /// Validates the specified block of memory within the current heap instance. This function will block hte + /// + /// Pointer to the block of memory to validate + /// True if the block is valid, false otherwise + public bool Validate(ref LPVOID block) + { + bool result; + //Lock the heap before validating + HeapLock.Wait(); + //validate the block on the current heap + result = HeapValidate(handle, HEAP_NO_FLAGS, block); + //Unlock the heap + HeapLock.Release(); + return result; + + } + /// + /// Validates the current heap instance. The function scans all the memory blocks in the heap and verifies that the heap control structures maintained by + /// the heap manager are in a consistent state. + /// + /// If the specified heap or memory block is valid, the return value is nonzero. + /// This can be a consuming operation which will block all allocations + public bool Validate() + { + bool result; + //Lock the heap before validating + HeapLock.Wait(); + //validate the entire heap + result = HeapValidate(handle, HEAP_NO_FLAGS, IntPtr.Zero); + //Unlock the heap + HeapLock.Release(); + return result; + } + + /// + protected override bool ReleaseHandle() + { +#if TRACE + Trace.WriteLine($"Win32 private heap {handle:x} destroyed"); +#endif + return HeapDestroy(handle) && base.ReleaseHandle(); + } + /// + protected override sealed LPVOID AllocBlock(nuint elements, nuint size, bool zero) + { + nuint bytes = checked(elements * size); + + return HeapAlloc(handle, zero ? HEAP_ZERO_MEMORY : HEAP_NO_FLAGS, bytes); + } + /// + protected override sealed bool FreeBlock(LPVOID block) => HeapFree(handle, HEAP_NO_FLAGS, block); + + /// + protected override sealed LPVOID ReAllocBlock(LPVOID block, nuint elements, nuint size, bool zero) + { + nuint bytes = checked(elements * size); + + return HeapReAlloc(handle, zero ? HEAP_ZERO_MEMORY : HEAP_NO_FLAGS, block, bytes); + } + } +} \ No newline at end of file diff --git a/lib/Utils/src/VnEncoding.cs b/lib/Utils/src/VnEncoding.cs index 94d8a1a..8359f8f 100644 --- a/lib/Utils/src/VnEncoding.cs +++ b/lib/Utils/src/VnEncoding.cs @@ -61,7 +61,7 @@ namespace VNLib.Utils //get number of bytes int byteCount = encoding.GetByteCount(data); //resize the handle to fit the data - handle = Memory.Memory.Shared.Alloc(byteCount); + handle = Memory.MemoryUtil.Shared.Alloc(byteCount); //encode int size = encoding.GetBytes(data, handle); //Consume the handle into a new vnmemstream and return it @@ -479,7 +479,7 @@ namespace VNLib.Utils //Calculate the base32 entropy to alloc an appropriate buffer (minium buffer of 2 chars) int entropy = Base32CalcMaxBufferSize(binBuffer.Length); //Alloc buffer for enough size (2*long bytes) is not an issue - using (UnsafeMemoryHandle charBuffer = Memory.Memory.UnsafeAlloc(entropy)) + using (UnsafeMemoryHandle charBuffer = Memory.MemoryUtil.UnsafeAlloc(entropy)) { //Encode ERRNO encoded = TryToBase32Chars(binBuffer, charBuffer.Span); @@ -512,7 +512,7 @@ namespace VNLib.Utils //calc size of bin buffer int size = base32.Length; //Rent a bin buffer - using UnsafeMemoryHandle binBuffer = Memory.Memory.UnsafeAlloc(size); + using UnsafeMemoryHandle binBuffer = Memory.MemoryUtil.UnsafeAlloc(size); //Try to decode the data ERRNO decoded = TryFromBase32Chars(base32, binBuffer.Span); //Marshal back to a struct @@ -532,7 +532,7 @@ namespace VNLib.Utils return null; } //Buffer size of the base32 string will always be enough buffer space - using UnsafeMemoryHandle tempBuffer = Memory.Memory.UnsafeAlloc(base32.Length); + using UnsafeMemoryHandle tempBuffer = Memory.MemoryUtil.UnsafeAlloc(base32.Length); //Try to decode the data ERRNO decoded = TryFromBase32Chars(base32, tempBuffer.Span); @@ -903,7 +903,7 @@ namespace VNLib.Utils int decodedSize = encoding.GetByteCount(chars); //alloc buffer - using UnsafeMemoryHandle decodeHandle = Memory.Memory.UnsafeAlloc(decodedSize); + using UnsafeMemoryHandle decodeHandle = Memory.MemoryUtil.UnsafeAlloc(decodedSize); //Get the utf8 binary data int count = encoding.GetBytes(chars, decodeHandle); return Base64UrlDecode(decodeHandle.Span[..count], output); diff --git a/lib/Utils/tests/Memory/MemoryHandleTest.cs b/lib/Utils/tests/Memory/MemoryHandleTest.cs index 02ef1f1..34dbb60 100644 --- a/lib/Utils/tests/Memory/MemoryHandleTest.cs +++ b/lib/Utils/tests/Memory/MemoryHandleTest.cs @@ -25,10 +25,9 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; -using VNLib.Utils; using VNLib.Utils.Extensions; -using static VNLib.Utils.Memory.Memory; +using static VNLib.Utils.Memory.MemoryUtil; namespace VNLib.Utils.Memory.Tests { @@ -43,7 +42,7 @@ namespace VNLib.Utils.Memory.Tests Assert.ThrowsException(() => Shared.Alloc(-1)); //Make sure over-alloc throws - Assert.ThrowsException(() => Shared.Alloc(ulong.MaxValue, false)); + Assert.ThrowsException(() => Shared.Alloc(nuint.MaxValue, false)); } #if TARGET_64_BIT [TestMethod] @@ -54,9 +53,9 @@ namespace VNLib.Utils.Memory.Tests using MemoryHandle handle = Shared.Alloc(bigHandleSize); //verify size - Assert.AreEqual(handle.ByteLength, (ulong)bigHandleSize); + Assert.IsTrue(handle.ByteLength, (ulong)bigHandleSize); //Since handle is byte, should also match - Assert.AreEqual(handle.Length, (ulong)bigHandleSize); + Assert.IsTrue(handle.Length, (ulong)bigHandleSize); //Should throw overflow Assert.ThrowsException(() => _ = handle.Span); @@ -68,8 +67,6 @@ namespace VNLib.Utils.Memory.Tests Assert.ThrowsException(() => _ = handle.GetOffsetSpan((long)int.MaxValue + 1, 1024)); } -#else - #endif [TestMethod] @@ -77,15 +74,15 @@ namespace VNLib.Utils.Memory.Tests { using MemoryHandle handle = Shared.Alloc(128, true); - Assert.AreEqual(handle.IntLength, 128); + Assert.IsTrue(handle.Length == 128); - Assert.AreEqual(handle.Length, (ulong)128); + Assert.IsTrue(handle.Length == 128); //Check span against base pointer deref handle.Span[120] = 10; - Assert.AreEqual(*handle.GetOffset(120), 10); + Assert.IsTrue(*handle.GetOffset(120) == 10); } @@ -153,14 +150,14 @@ namespace VNLib.Utils.Memory.Tests { using MemoryHandle handle = Shared.Alloc(1024); - Assert.AreEqual(handle.IntLength, 1024); + Assert.IsTrue(handle.Length == 1024); Assert.ThrowsException(() => handle.Resize(-1)); //Resize the handle handle.Resize(2048); - Assert.AreEqual(handle.IntLength, 2048); + Assert.IsTrue(handle.Length == 2048); Assert.IsTrue(handle.AsSpan(2048).IsEmpty); @@ -173,11 +170,11 @@ namespace VNLib.Utils.Memory.Tests //test resize handle.ResizeIfSmaller(100); //Handle should be unmodified - Assert.AreEqual(handle.IntLength, 2048); + Assert.IsTrue(handle.Length == 2048); //test working handle.ResizeIfSmaller(4096); - Assert.AreEqual(handle.IntLength, 4096); + Assert.IsTrue(handle.Length == 4096); } } } diff --git a/lib/Utils/tests/Memory/MemoryTests.cs b/lib/Utils/tests/Memory/MemoryTests.cs deleted file mode 100644 index 5b68cf5..0000000 --- a/lib/Utils/tests/Memory/MemoryTests.cs +++ /dev/null @@ -1,244 +0,0 @@ -/* -* Copyright (c) 2022 Vaughn Nugent -* -* Library: VNLib -* Package: VNLib.UtilsTests -* File: MemoryTests.cs -* -* MemoryTests.cs is part of VNLib.UtilsTests which is part of the larger -* VNLib collection of libraries and utilities. -* -* VNLib.UtilsTests is free software: you can redistribute it and/or modify -* it under the terms of the GNU General Public License as published -* by the Free Software Foundation, either version 2 of the License, -* or (at your option) any later version. -* -* VNLib.UtilsTests is distributed in the hope that it will be useful, -* but WITHOUT ANY WARRANTY; without even the implied warranty of -* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -* General Public License for more details. -* -* You should have received a copy of the GNU General Public License -* along with VNLib.UtilsTests. If not, see http://www.gnu.org/licenses/. -*/ - -using Microsoft.VisualStudio.TestTools.UnitTesting; -using System.Runtime.InteropServices; - -using VNLib.Utils.Extensions; - -namespace VNLib.Utils.Memory.Tests -{ - [TestClass()] - public class MemoryTests - { - [TestMethod] - public void MemorySharedHeapLoadedTest() - { - Assert.IsNotNull(Memory.Shared); - } - - [TestMethod()] - public void UnsafeAllocTest() - { - //test against negative number - Assert.ThrowsException(() => Memory.UnsafeAlloc(-1)); - - //Alloc large block test (100mb) - const int largTestSize = 100000 * 1024; - //Alloc super small block - const int smallTestSize = 5; - - using (UnsafeMemoryHandle buffer = Memory.UnsafeAlloc(largTestSize, false)) - { - Assert.AreEqual(largTestSize, buffer.IntLength); - Assert.AreEqual(largTestSize, buffer.Span.Length); - - buffer.Span[0] = 254; - Assert.AreEqual(buffer.Span[0], 254); - } - - using (UnsafeMemoryHandle buffer = Memory.UnsafeAlloc(smallTestSize, false)) - { - Assert.AreEqual(smallTestSize, buffer.IntLength); - Assert.AreEqual(smallTestSize, buffer.Span.Length); - - buffer.Span[0] = 254; - Assert.AreEqual(buffer.Span[0], 254); - } - - //Different data type - - using(UnsafeMemoryHandle buffer = Memory.UnsafeAlloc(largTestSize, false)) - { - Assert.AreEqual(largTestSize, buffer.IntLength); - Assert.AreEqual(largTestSize, buffer.Span.Length); - - buffer.Span[0] = long.MaxValue; - Assert.AreEqual(buffer.Span[0], long.MaxValue); - } - - using (UnsafeMemoryHandle buffer = Memory.UnsafeAlloc(smallTestSize, false)) - { - Assert.AreEqual(smallTestSize, buffer.IntLength); - Assert.AreEqual(smallTestSize, buffer.Span.Length); - - buffer.Span[0] = long.MaxValue; - Assert.AreEqual(buffer.Span[0], long.MaxValue); - } - } - - [TestMethod()] - public void UnsafeZeroMemoryAsSpanTest() - { - //Alloc test buffer - Span test = new byte[1024]; - test.Fill(0); - //test other empty span - Span verify = new byte[1024]; - verify.Fill(0); - - //Fill test buffer with random values - Random.Shared.NextBytes(test); - - //make sure buffers are not equal - Assert.IsFalse(test.SequenceEqual(verify)); - - //Zero buffer - Memory.UnsafeZeroMemory(test); - - //Make sure buffers are equal - Assert.IsTrue(test.SequenceEqual(verify)); - } - - [TestMethod()] - public void UnsafeZeroMemoryAsMemoryTest() - { - //Alloc test buffer - Memory test = new byte[1024]; - test.Span.Fill(0); - //test other empty span - Memory verify = new byte[1024]; - verify.Span.Fill(0); - - //Fill test buffer with random values - Random.Shared.NextBytes(test.Span); - - //make sure buffers are not equal - Assert.IsFalse(test.Span.SequenceEqual(verify.Span)); - - //Zero buffer - Memory.UnsafeZeroMemory(test); - - //Make sure buffers are equal - Assert.IsTrue(test.Span.SequenceEqual(verify.Span)); - } - - [TestMethod()] - public void InitializeBlockAsSpanTest() - { - //Alloc test buffer - Span test = new byte[1024]; - test.Fill(0); - //test other empty span - Span verify = new byte[1024]; - verify.Fill(0); - - //Fill test buffer with random values - Random.Shared.NextBytes(test); - - //make sure buffers are not equal - Assert.IsFalse(test.SequenceEqual(verify)); - - //Zero buffer - Memory.InitializeBlock(test); - - //Make sure buffers are equal - Assert.IsTrue(test.SequenceEqual(verify)); - } - - [TestMethod()] - public void InitializeBlockMemoryTest() - { - //Alloc test buffer - Memory test = new byte[1024]; - test.Span.Fill(0); - //test other empty span - Memory verify = new byte[1024]; - verify.Span.Fill(0); - - //Fill test buffer with random values - Random.Shared.NextBytes(test.Span); - - //make sure buffers are not equal - Assert.IsFalse(test.Span.SequenceEqual(verify.Span)); - - //Zero buffer - Memory.InitializeBlock(test); - - //Make sure buffers are equal - Assert.IsTrue(test.Span.SequenceEqual(verify.Span)); - } - - #region structmemory tests - - [StructLayout(LayoutKind.Sequential)] - struct TestStruct - { - public int X; - public int Y; - } - - [TestMethod()] - public unsafe void ZeroStructAsPointerTest() - { - TestStruct* s = Memory.Shared.StructAlloc(); - s->X = 10; - s->Y = 20; - Assert.AreEqual(10, s->X); - Assert.AreEqual(20, s->Y); - //zero struct - Memory.ZeroStruct(s); - //Verify data was zeroed - Assert.AreEqual(0, s->X); - Assert.AreEqual(0, s->Y); - //Free struct - Memory.Shared.StructFree(s); - } - - [TestMethod()] - public unsafe void ZeroStructAsVoidPointerTest() - { - TestStruct* s = Memory.Shared.StructAlloc(); - s->X = 10; - s->Y = 20; - Assert.AreEqual(10, s->X); - Assert.AreEqual(20, s->Y); - //zero struct - Memory.ZeroStruct((void*)s); - //Verify data was zeroed - Assert.AreEqual(0, s->X); - Assert.AreEqual(0, s->Y); - //Free struct - Memory.Shared.StructFree(s); - } - - [TestMethod()] - public unsafe void ZeroStructAsIntPtrTest() - { - TestStruct* s = Memory.Shared.StructAlloc(); - s->X = 10; - s->Y = 20; - Assert.AreEqual(10, s->X); - Assert.AreEqual(20, s->Y); - //zero struct - Memory.ZeroStruct((IntPtr)s); - //Verify data was zeroed - Assert.AreEqual(0, s->X); - Assert.AreEqual(0, s->Y); - //Free struct - Memory.Shared.StructFree(s); - } - #endregion - } -} \ No newline at end of file diff --git a/lib/Utils/tests/Memory/MemoryUtilTests.cs b/lib/Utils/tests/Memory/MemoryUtilTests.cs new file mode 100644 index 0000000..fb3700e --- /dev/null +++ b/lib/Utils/tests/Memory/MemoryUtilTests.cs @@ -0,0 +1,333 @@ +using System.Buffers; +using System.Runtime.InteropServices; +using System.Security.Cryptography; + +using Microsoft.VisualStudio.TestTools.UnitTesting; + +using VNLib.Utils.Extensions; + +namespace VNLib.Utils.Memory.Tests +{ + [TestClass()] + public class MemoryUtilTests + { + const int ZERO_TEST_LOOP_ITERATIONS = 1000000; + const int ZERO_TEST_MAX_BUFFER_SIZE = 10 * 1024; + + [TestMethod()] + public void InitializeNewHeapForProcessTest() + { + //Check if rpmalloc is loaded + if (MemoryUtil.IsRpMallocLoaded) + { + //Initialize the heap + using IUnmangedHeap heap = MemoryUtil.InitializeNewHeapForProcess(); + + //Confirm that the heap is actually a rpmalloc heap + Assert.IsInstanceOfType(heap, typeof(RpMallocPrivateHeap)); + } + else + { + //Confirm that Rpmalloc will throw DLLNotFound if the lib is not loaded + Assert.ThrowsException(() => _ = RpMallocPrivateHeap.GlobalHeap.Alloc(1, 1, false)); + } + } + + [TestMethod()] + public void UnsafeZeroMemoryTest() + { + //Get random data buffer as a readonly span + ReadOnlyMemory buffer = RandomNumberGenerator.GetBytes(1024); + + //confirm buffer is not all zero + Assert.IsFalse(AllZero(buffer.Span)); + + //Zero readonly memory + MemoryUtil.UnsafeZeroMemory(buffer); + + //Confirm all zero + Assert.IsTrue(AllZero(buffer.Span)); + } + + private static bool AllZero(ReadOnlySpan span) + { + for (int i = 0; i < span.Length; i++) + { + if (span[i] != 0) + { + return false; + } + } + return true; + } + + [TestMethod()] + public void UnsafeZeroMemoryTest1() + { + //Get random data buffer as a readonly span + ReadOnlySpan buffer = RandomNumberGenerator.GetBytes(1024); + + //confirm buffer is not all zero + Assert.IsFalse(AllZero(buffer)); + + //Zero readonly span + MemoryUtil.UnsafeZeroMemory(buffer); + + //Confirm all zero + Assert.IsTrue(AllZero(buffer)); + } + + + [TestMethod()] + public void InitializeBlockAsSpanTest() + { + //Get random data buffer as a readonly span + Span buffer = RandomNumberGenerator.GetBytes(1024); + + //confirm buffer is not all zero + Assert.IsFalse(AllZero(buffer)); + + //Zero readonly span + MemoryUtil.InitializeBlock(buffer); + + //Confirm all zero + Assert.IsTrue(AllZero(buffer)); + } + + [TestMethod()] + public void InitializeBlockMemoryTest() + { + //Get random data buffer as a readonly span + Memory buffer = RandomNumberGenerator.GetBytes(1024); + + //confirm buffer is not all zero + Assert.IsFalse(AllZero(buffer.Span)); + + //Zero readonly span + MemoryUtil.InitializeBlock(buffer); + + //Confirm all zero + Assert.IsTrue(AllZero(buffer.Span)); + } + + + [TestMethod()] + public unsafe void UnsafeAllocTest() + { + //No fail + using (UnsafeMemoryHandle handle = MemoryUtil.UnsafeAlloc(1024)) + { + _ = handle.Span; + _ = handle.Length; + _ = handle.IntLength; + + //Test span pointer against pinned handle + using (MemoryHandle pinned = handle.Pin(0)) + { + fixed (void* ptr = &MemoryMarshal.GetReference(handle.Span)) + { + Assert.IsTrue(ptr == pinned.Pointer); + } + } + + //Test negative pin + Assert.ThrowsException(() => _ = handle.Pin(-1)); + + //Test pinned outsie handle size + Assert.ThrowsException(() => _ = handle.Pin(1024)); + } + + //test against negative number + Assert.ThrowsException(() => MemoryUtil.UnsafeAlloc(-1)); + + //Alloc large block test (100mb) + const int largTestSize = 100000 * 1024; + //Alloc super small block + const int smallTestSize = 5; + + using (UnsafeMemoryHandle buffer = MemoryUtil.UnsafeAlloc(largTestSize, false)) + { + Assert.IsTrue(largTestSize == buffer.IntLength); + Assert.IsTrue(largTestSize == buffer.Span.Length); + + buffer.Span[0] = 254; + Assert.IsTrue(buffer.Span[0] == 254); + } + + using (UnsafeMemoryHandle buffer = MemoryUtil.UnsafeAlloc(smallTestSize, false)) + { + Assert.IsTrue(smallTestSize == buffer.IntLength); + Assert.IsTrue(smallTestSize == buffer.Span.Length); + + buffer.Span[0] = 254; + Assert.IsTrue(buffer.Span[0] == 254); + } + + //Different data type + using (UnsafeMemoryHandle buffer = MemoryUtil.UnsafeAlloc(largTestSize, false)) + { + Assert.IsTrue(largTestSize == buffer.IntLength); + Assert.IsTrue(largTestSize == buffer.Span.Length); + + buffer.Span[0] = long.MaxValue; + Assert.IsTrue(buffer.Span[0] == long.MaxValue); + } + + using (UnsafeMemoryHandle buffer = MemoryUtil.UnsafeAlloc(smallTestSize, false)) + { + Assert.IsTrue(smallTestSize == buffer.IntLength); + Assert.IsTrue(smallTestSize == buffer.Span.Length); + + buffer.Span[0] = long.MaxValue; + Assert.IsTrue(buffer.Span[0] == long.MaxValue); + } + + //Test empty handle + using (UnsafeMemoryHandle empty = new()) + { + Assert.IsTrue(0 == empty.Length); + Assert.IsTrue(0 == empty.IntLength); + + //Test pinning while empty + Assert.ThrowsException(() => _ = empty.Pin(0)); + } + + //Negative value + Assert.ThrowsException(() => _ = MemoryUtil.UnsafeAlloc(-1)); + + + /* + * Alloc random sized blocks in a loop, confirm they are empty + * then fill the block with random data before freeing it back to + * the pool. This confirms that if blocks are allocated from a shared + * pool are properly zeroed when requestd + */ + + for (int i = 0; i < ZERO_TEST_LOOP_ITERATIONS; i++) + { + int randBufferSize = Random.Shared.Next(1024, ZERO_TEST_MAX_BUFFER_SIZE); + + //Alloc block, check if all zero, then free + using UnsafeMemoryHandle handle = MemoryUtil.UnsafeAlloc(randBufferSize, true); + + //Confirm all zero + Assert.IsTrue(AllZero(handle.Span)); + + //Fill with random data + Random.Shared.NextBytes(handle.Span); + } + } + + [TestMethod()] + public unsafe void SafeAllocTest() + { + //No fail + using (IMemoryHandle handle = MemoryUtil.SafeAlloc(1024)) + { + _ = handle.Span; + _ = handle.Length; + _ = handle.GetIntLength(); + + //Test span pointer against pinned handle + using (MemoryHandle pinned = handle.Pin(0)) + { + fixed (void* ptr = &MemoryMarshal.GetReference(handle.Span)) + { + Assert.IsTrue(ptr == pinned.Pointer); + } + } + + //Test negative pin + Assert.ThrowsException(() => _ = handle.Pin(-1)); + + //Test pinned outsie handle size + Assert.ThrowsException(() => _ = handle.Pin(1024)); + } + + + //Negative value + Assert.ThrowsException(() => _ = MemoryUtil.SafeAlloc(-1)); + + + /* + * Alloc random sized blocks in a loop, confirm they are empty + * then fill the block with random data before freeing it back to + * the pool. This confirms that if blocks are allocated from a shared + * pool are properly zeroed when requestd + */ + + for (int i = 0; i < ZERO_TEST_LOOP_ITERATIONS; i++) + { + int randBufferSize = Random.Shared.Next(1024, ZERO_TEST_MAX_BUFFER_SIZE); + + //Alloc block, check if all zero, then free + using IMemoryHandle handle = MemoryUtil.SafeAlloc(randBufferSize, true); + + //Confirm all zero + Assert.IsTrue(AllZero(handle.Span)); + + //Fill with random data + Random.Shared.NextBytes(handle.Span); + } + } + + + [StructLayout(LayoutKind.Sequential)] + struct TestStruct + { + public int X; + public int Y; + } + + [TestMethod()] + public unsafe void ZeroStructAsPointerTest() + { + TestStruct* s = MemoryUtil.Shared.StructAlloc(); + s->X = 10; + s->Y = 20; + Assert.IsTrue(10 == s->X); + Assert.IsTrue(20 == s->Y); + //zero struct + MemoryUtil.ZeroStruct(s); + //Verify data was zeroed + Assert.IsTrue(0 == s->X); + Assert.IsTrue(0 == s->Y); + //Free struct + MemoryUtil.Shared.StructFree(s); + } + + [TestMethod()] + public unsafe void ZeroStructAsVoidPointerTest() + { + TestStruct* s = MemoryUtil.Shared.StructAlloc(); + s->X = 10; + s->Y = 20; + Assert.IsTrue(10 == s->X); + Assert.IsTrue(20 == s->Y); + //zero struct + MemoryUtil.ZeroStruct((void*)s); + //Verify data was zeroed + Assert.IsTrue(0 == s->X); + Assert.IsTrue(0 == s->Y); + //Free struct + MemoryUtil.Shared.StructFree(s); + } + + [TestMethod()] + public unsafe void ZeroStructAsIntPtrTest() + { + TestStruct* s = MemoryUtil.Shared.StructAlloc(); + s->X = 10; + s->Y = 20; + Assert.IsTrue(10 == s->X); + Assert.IsTrue(20 == s->Y); + //zero struct + MemoryUtil.ZeroStruct((IntPtr)s); + //Verify data was zeroed + Assert.IsTrue(0 == s->X); + Assert.IsTrue(0 == s->Y); + //Free struct + MemoryUtil.Shared.StructFree(s); + } + } +} \ No newline at end of file diff --git a/lib/Utils/tests/Memory/VnTableTests.cs b/lib/Utils/tests/Memory/VnTableTests.cs index 11350d4..c9f99ea 100644 --- a/lib/Utils/tests/Memory/VnTableTests.cs +++ b/lib/Utils/tests/Memory/VnTableTests.cs @@ -33,26 +33,13 @@ namespace VNLib.Utils.Memory.Tests [TestMethod()] public void VnTableTest() { - Assert.ThrowsException(() => - { - using VnTable table = new(-1, 0); - }); - Assert.ThrowsException(() => - { - using VnTable table = new(0, -1); - }); - Assert.ThrowsException(() => - { - using VnTable table = new(-1, -1); - }); - //Empty table using (VnTable empty = new(0, 0)) { Assert.IsTrue(empty.Empty); //Test 0 rows/cols - Assert.AreEqual(0, empty.Rows); - Assert.AreEqual(0, empty.Cols); + Assert.IsTrue(0 == empty.Rows); + Assert.IsTrue(0 == empty.Cols); } using (VnTable table = new(40000, 10000)) @@ -60,8 +47,8 @@ namespace VNLib.Utils.Memory.Tests Assert.IsFalse(table.Empty); //Test table size - Assert.AreEqual(40000, table.Rows); - Assert.AreEqual(10000, table.Cols); + Assert.IsTrue(40000 == table.Rows); + Assert.IsTrue(10000 == table.Cols); } @@ -89,41 +76,41 @@ namespace VNLib.Utils.Memory.Tests [TestMethod()] public void GetSetTest() { - static void TestIndexAt(VnTable table, int row, int col, int value) + static void TestIndexAt(VnTable table, uint row, uint col, int value) { table[row, col] = value; - Assert.AreEqual(value, table[row, col]); - Assert.AreEqual(value, table.Get(row, col)); + Assert.IsTrue(value == table[row, col]); + Assert.IsTrue(value == table.Get(row, col)); } - static void TestSetAt(VnTable table, int row, int col, int value) + static void TestSetAt(VnTable table, uint row, uint col, int value) { table.Set(row, col, value); - Assert.AreEqual(value, table[row, col]); - Assert.AreEqual(value, table.Get(row, col)); + Assert.IsTrue(value == table[row, col]); + Assert.IsTrue(value == table.Get(row, col)); } - static void TestSetDirectAccess(VnTable table, int row, int col, int value) + static void TestSetDirectAccess(VnTable table, uint row, uint col, int value) { - int address = row * table.Cols + col; - table[(uint)address] = value; + uint address = row * table.Cols + col; + table[address] = value; //Get value using indexer - Assert.AreEqual(value, table[row, col]); + Assert.IsTrue(value == table[row, col]); } - static void TestGetDirectAccess(VnTable table, int row, int col, int value) + static void TestGetDirectAccess(VnTable table, uint row, uint col, int value) { table[row, col] = value; - int address = row * table.Cols + col; + uint address = row * table.Cols + col; //Test direct access - Assert.AreEqual(value, table[(uint)address]); + Assert.IsTrue(value == table[address]); //Get value using indexer - Assert.AreEqual(value, table[row, col]); - Assert.AreEqual(value, table.Get(row, col)); + Assert.IsTrue(value == table[row, col]); + Assert.IsTrue(value == table.Get(row, col)); } diff --git a/lib/Utils/tests/VnEncodingTests.cs b/lib/Utils/tests/VnEncodingTests.cs index a4e52f0..373b834 100644 --- a/lib/Utils/tests/VnEncodingTests.cs +++ b/lib/Utils/tests/VnEncodingTests.cs @@ -23,16 +23,12 @@ */ using System; +using System.Text; using System.Buffers; using System.Buffers.Text; -using System.Collections.Generic; -using System.Linq; using System.Security.Cryptography; -using System.Text; -using System.Threading.Tasks; using Microsoft.VisualStudio.TestTools.UnitTesting; -using VNLib.Utils; namespace VNLib.Utils.Tests { -- cgit