using System; using System.Net; using System.Text; using System.Security; using System.Text.Json; using System.Security.Cryptography; using System.Runtime.CompilerServices; using RestSharp; using VNLib.Utils.Memory; using VNLib.Utils.Logging; using VNLib.Utils.Extensions; using VNLib.Hashing; using VNLib.Hashing.IdentityUtility; using VNLib.Net.Http; using VNLib.Net.Rest.Client; using VNLib.Net.Messaging.FBM.Client; using VNLib.Net.Messaging.FBM; namespace VNLib.Data.Caching.Extensions { /// /// Provides extension methods for FBM data caching using /// cache servers and brokers /// public static class FBMDataCacheExtensions { /// /// The websocket sub-protocol to use when connecting to cache servers /// public const string CACHE_WS_SUB_PROCOL = "object-cache"; /// /// The default cache message header size /// public const int MAX_FBM_MESSAGE_HEADER_SIZE = 1024; private static readonly IReadOnlyDictionary BrokerJwtHeader = new Dictionary() { { "alg", "ES384" }, //Must match alg name { "typ", "JWT"} }; private static readonly RestClientPool ClientPool = new(2,new RestClientOptions() { MaxTimeout = 10 * 1000, FollowRedirects = false, Encoding = Encoding.UTF8, AutomaticDecompression = DecompressionMethods.All, ThrowOnAnyError = true, }); /// /// The default hashing algorithm used to sign an verify connection /// tokens /// public static readonly HashAlgorithmName CacheJwtAlgorithm = HashAlgorithmName.SHA384; //using the es384 algorithm for signing (friendlyname is secp384r1) /// /// The default ECCurve used by the connection library /// public static readonly ECCurve CacheCurve = ECCurve.CreateFromFriendlyName("secp384r1"); /// /// Gets a preconfigured object caching /// protocl /// /// The client buffer heap /// The maxium message size (in bytes) /// An optional debug log /// A preconfigured for object caching public static FBMClientConfig GetDefaultConfig(IUnmangedHeap heap, int maxMessageSize, ILogProvider? debugLog = null) { return new() { BufferHeap = heap, MaxMessageSize = maxMessageSize * 2, RecvBufferSize = maxMessageSize, MessageBufferSize = maxMessageSize, MaxHeaderBufferSize = MAX_FBM_MESSAGE_HEADER_SIZE, SubProtocol = CACHE_WS_SUB_PROCOL, HeaderEncoding = Helpers.DefaultEncoding, KeepAliveInterval = TimeSpan.FromSeconds(30), DebugLog = debugLog }; } private class CacheConnectionConfig { public ECDsa ClientAlg { get; init; } public ECDsa BrokerAlg { get; init; } public string ServerChallenge { get; init; } public string? NodeId { get; set; } public Uri? BrokerAddress { get; set; } public bool useTls { get; set; } public ActiveServer[]? BrokerServers { get; set; } public CacheConnectionConfig() { //Init the algorithms ClientAlg = ECDsa.Create(CacheCurve); BrokerAlg = ECDsa.Create(CacheCurve); ServerChallenge = RandomHash.GetRandomBase32(24); } ~CacheConnectionConfig() { ClientAlg.Clear(); BrokerAlg.Clear(); } } /// /// Contacts the cache broker to get a list of active servers to connect to /// /// The broker server to connec to /// The private key used to sign messages sent to the broker /// The broker public key used to verify broker messages /// A token to cancel the operationS /// The list of active servers /// /// public static async Task ListServersAsync(Uri brokerAddress, ReadOnlyMemory clientPrivKey, ReadOnlyMemory brokerPubKey, CancellationToken cancellationToken = default) { using ECDsa client = ECDsa.Create(CacheCurve); using ECDsa broker = ECDsa.Create(CacheCurve); //Import client private key client.ImportPkcs8PrivateKey(clientPrivKey.Span, out _); //Broker public key to verify broker messages broker.ImportSubjectPublicKeyInfo(brokerPubKey.Span, out _); return await ListServersAsync(brokerAddress, client, broker, cancellationToken); } /// /// Contacts the cache broker to get a list of active servers to connect to /// /// The broker server to connec to /// The signature algorithm used to sign messages to the broker /// The signature used to verify broker messages /// A token to cancel the operationS /// The list of active servers /// /// public static async Task ListServersAsync(Uri brokerAddress, ECDsa clientAlg, ECDsa brokerAlg, CancellationToken cancellationToken = default) { _ = brokerAddress ?? throw new ArgumentNullException(nameof(brokerAddress)); _ = clientAlg ?? throw new ArgumentNullException(nameof(clientAlg)); _ = brokerAlg ?? throw new ArgumentNullException(nameof(brokerAlg)); string jwtBody; //Build request jwt using (JsonWebToken requestJwt = new()) { requestJwt.WriteHeader(BrokerJwtHeader); requestJwt.InitPayloadClaim() .AddClaim("iat", DateTimeOffset.UtcNow.ToUnixTimeMilliseconds()) .CommitClaims(); //sign the jwt requestJwt.Sign(clientAlg, in CacheJwtAlgorithm, 512); //Compile the jwt jwtBody = requestJwt.Compile(); } //New list request RestRequest listRequest = new(brokerAddress, Method.Post); //Add the jwt as a string to the request body listRequest.AddStringBody(jwtBody, DataFormat.None); listRequest.AddHeader("Content-Type", HttpHelpers.GetContentTypeString(ContentType.Text)); //Rent client using ClientContract client = ClientPool.Lease(); //Exec list request RestResponse response = await client.Resource.ExecuteAsync(listRequest, cancellationToken); if (!response.IsSuccessful) { throw response.ErrorException!; } //Response is jwt using JsonWebToken responseJwt = JsonWebToken.ParseRaw(response.RawBytes); //Verify the jwt if (!responseJwt.Verify(brokerAlg, in CacheJwtAlgorithm)) { throw new SecurityException("Failed to verify the broker's challenge, cannot continue"); } using JsonDocument doc = responseJwt.GetPayload(); return doc.RootElement.GetProperty("servers").Deserialize(); } /// /// Configures a connection to the remote cache server at the specified location /// with proper authentication. /// /// /// The server's address /// The pks8 format EC private key uesd to sign the message /// A challenge to send to the server /// A token used to identify the current server's event queue on the remote server /// A token to cancel the connection operation /// Enables the secure websocket protocol /// A Task that completes when the connection has been established /// public static Task ConnectAsync(this FBMClient client, string serverUri, ReadOnlyMemory signingKey, string challenge, string? nodeId, bool useTls, CancellationToken token = default) { //Sign the jwt using ECDsa sigAlg = ECDsa.Create(CacheCurve); //Import the signing key sigAlg.ImportPkcs8PrivateKey(signingKey.Span, out _); //Return without await because the alg is used to sign before this method returns and can be discarded return ConnectAsync(client, serverUri, sigAlg, challenge, nodeId, useTls, token); } private static Task ConnectAsync(FBMClient client, string serverUri, ECDsa sigAlg, string challenge, string? nodeId, bool useTls, CancellationToken token = default) { _ = serverUri ?? throw new ArgumentNullException(nameof(serverUri)); _ = challenge ?? throw new ArgumentNullException(nameof(challenge)); //build ws uri UriBuilder uriBuilder = new(serverUri) { Scheme = useTls ? "wss://" : "ws://" }; string jwtMessage; //Init jwt for connecting to server using (JsonWebToken jwt = new()) { jwt.WriteHeader(BrokerJwtHeader); //Init claim JwtPayload claim = jwt.InitPayloadClaim(); claim.AddClaim("challenge", challenge); if (!string.IsNullOrWhiteSpace(nodeId)) { /* * The unique node id so the other nodes know to load the * proper event queue for the current server */ claim.AddClaim("server_id", nodeId); } claim.CommitClaims(); //Sign jwt jwt.Sign(sigAlg, in CacheJwtAlgorithm, 512); //Compile to string jwtMessage = jwt.Compile(); } //Set jwt as authorization header client.ClientSocket.Headers[HttpRequestHeader.Authorization] = jwtMessage; //Connect async return client.ConnectAsync(uriBuilder.Uri, token); } /// /// Registers the current server as active with the specified broker /// /// The address of the broker to register with /// The private key used to sign the message /// The local address of the current server used for discovery /// The unique id to identify this server (for event queues) /// A unique security token used by the broker to authenticate itself /// A task that resolves when a successful registration is completed, raises exceptions otherwise public static async Task ResgisterWithBrokerAsync(Uri brokerAddress, ReadOnlyMemory signingKey, string serverAddress, string nodeId, string keepAliveToken) { _ = brokerAddress ?? throw new ArgumentNullException(nameof(brokerAddress)); _ = serverAddress ?? throw new ArgumentNullException(nameof(serverAddress)); _ = keepAliveToken ?? throw new ArgumentNullException(nameof(keepAliveToken)); _ = nodeId ?? throw new ArgumentNullException(nameof(nodeId)); string requestData; //Create the jwt for signed registration message using (JsonWebToken jwt = new()) { //Shared jwt header jwt.WriteHeader(BrokerJwtHeader); //build jwt claim jwt.InitPayloadClaim() .AddClaim("address", serverAddress) .AddClaim("server_id", nodeId) .AddClaim("token", keepAliveToken) .CommitClaims(); //Sign the jwt using (ECDsa sigAlg = ECDsa.Create(CacheCurve)) { //Import the signing key sigAlg.ImportPkcs8PrivateKey(signingKey.Span, out _); jwt.Sign(sigAlg, in CacheJwtAlgorithm, 512); } //Compile and save requestData = jwt.Compile(); } //Create reg request message RestRequest regRequest = new(brokerAddress); regRequest.AddStringBody(requestData, DataFormat.None); regRequest.AddHeader("Content-Type", "text/plain"); //Rent client using ClientContract client = ClientPool.Lease(); //Exec the regitration request RestResponse response = await client.Resource.ExecutePutAsync(regRequest); if(!response.IsSuccessful) { throw response.ErrorException!; } } private static readonly ConditionalWeakTable ClientCacheConfig = new(); /// /// Imports the client signature algorithim's private key from its pkcs8 binary representation /// /// /// Pkcs8 format private key /// Chainable fluent object /// /// public static FBMClient ImportClientPrivateKey(this FBMClient client, ReadOnlySpan pkcs8PrivateKey) { CacheConnectionConfig conf = ClientCacheConfig.GetOrCreateValue(client); conf.ClientAlg.ImportPkcs8PrivateKey(pkcs8PrivateKey, out _); return client; } /// /// Imports the public key used to verify broker server messages /// /// /// The subject-public-key-info formatted broker public key /// Chainable fluent object /// /// public static FBMClient ImportBrokerPublicKey(this FBMClient client, ReadOnlySpan spkiPublicKey) { CacheConnectionConfig conf = ClientCacheConfig.GetOrCreateValue(client); conf.BrokerAlg.ImportSubjectPublicKeyInfo(spkiPublicKey, out _); return client; } /// /// Specifies if all connections should be using TLS /// /// /// A value that indicates if connections should use TLS /// Chainable fluent object public static FBMClient UseTls(this FBMClient client, bool useTls) { CacheConnectionConfig conf = ClientCacheConfig.GetOrCreateValue(client); conf.useTls = useTls; return client; } /// /// Specifies the broker address to discover cache nodes from /// /// /// The address of the server broker /// Chainable fluent object /// public static FBMClient UseBroker(this FBMClient client, Uri brokerAddress) { CacheConnectionConfig conf = ClientCacheConfig.GetOrCreateValue(client); conf.BrokerAddress = brokerAddress ?? throw new ArgumentNullException(nameof(brokerAddress)); return client; } /// /// Specifies the current server's cluster node id. If this /// is a server connection attempting to listen for changes on the /// remote server, this id must be set and unique /// /// /// The cluster node id of the current server /// Chainable fluent object /// public static FBMClient SetNodeId(this FBMClient client, string nodeId) { CacheConnectionConfig conf = ClientCacheConfig.GetOrCreateValue(client); conf.NodeId = nodeId ?? throw new ArgumentNullException(nameof(nodeId)); return client; } /// /// Discovers cache nodes in the broker configured for the current client. /// /// /// A token to cancel the discovery /// A task the resolves the list of active servers on the broker server public static Task DiscoverNodesAsync(this FBMClientWorkerBase client, CancellationToken token = default) { return client.Client.DiscoverNodesAsync(token); } /// /// Discovers cache nodes in the broker configured for the current client. /// /// /// A token to cancel the discovery /// A task the resolves the list of active servers on the broker server public static async Task DiscoverNodesAsync(this FBMClient client, CancellationToken token = default) { CacheConnectionConfig conf = ClientCacheConfig.GetOrCreateValue(client); //List servers async ActiveServer[]? servers = await ListServersAsync(conf.BrokerAddress!, conf.ClientAlg, conf.BrokerAlg, token); conf.BrokerServers = servers; return servers; } /// /// Connects the client to a remote cache server /// /// /// The server to connect to /// A token to cancel the connection and/or wait operation /// A task that resolves when cancelled or when the connection is lost to the server /// public static Task ConnectAndWaitForExitAsync(this FBMClientWorkerBase client, ActiveServer server, CancellationToken token = default) { return client.Client.ConnectAndWaitForExitAsync(server, token); } /// /// Connects the client to a remote cache server /// /// /// The server to connect to /// A token to cancel the connection and/or wait operation /// A task that resolves when cancelled or when the connection is lost to the server /// public static async Task ConnectAndWaitForExitAsync(this FBMClient client, ActiveServer server, CancellationToken token = default) { CacheConnectionConfig conf = ClientCacheConfig.GetOrCreateValue(client); //Connect to server (no server id because client not replication server) await ConnectAsync(client, server.HostName!, conf.ClientAlg, conf.ServerChallenge, conf.NodeId, conf.useTls, token); //Get task for cancellation Task cancellation = token.WaitHandle.WaitAsync(); //Task for status handle Task run = client.ConnectionStatusHandle.WaitAsync(); //Wait for cancellation or _ = await Task.WhenAny(cancellation, run); //Normal try to disconnect the socket await client.DisconnectAsync(CancellationToken.None); //Notify if cancelled token.ThrowIfCancellationRequested(); } /// /// Selects a random server from a collection of active servers /// /// /// A server selected at random public static ActiveServer SelectRandom(this ICollection servers) { //select random server int randServer = RandomNumberGenerator.GetInt32(0, servers.Count); return servers.ElementAt(randServer); } } }