using System;
using System.Net;
using System.Text;
using System.Security;
using System.Text.Json;
using System.Security.Cryptography;
using System.Runtime.CompilerServices;
using RestSharp;
using VNLib.Utils.Memory;
using VNLib.Utils.Logging;
using VNLib.Utils.Extensions;
using VNLib.Hashing;
using VNLib.Hashing.IdentityUtility;
using VNLib.Net.Http;
using VNLib.Net.Rest.Client;
using VNLib.Net.Messaging.FBM.Client;
using VNLib.Net.Messaging.FBM;
namespace VNLib.Data.Caching.Extensions
{
///
/// Provides extension methods for FBM data caching using
/// cache servers and brokers
///
public static class FBMDataCacheExtensions
{
///
/// The websocket sub-protocol to use when connecting to cache servers
///
public const string CACHE_WS_SUB_PROCOL = "object-cache";
///
/// The default cache message header size
///
public const int MAX_FBM_MESSAGE_HEADER_SIZE = 1024;
private static readonly IReadOnlyDictionary BrokerJwtHeader = new Dictionary()
{
{ "alg", "ES384" }, //Must match alg name
{ "typ", "JWT"}
};
private static readonly RestClientPool ClientPool = new(2,new RestClientOptions()
{
MaxTimeout = 10 * 1000,
FollowRedirects = false,
Encoding = Encoding.UTF8,
AutomaticDecompression = DecompressionMethods.All,
ThrowOnAnyError = true,
});
///
/// The default hashing algorithm used to sign an verify connection
/// tokens
///
public static readonly HashAlgorithmName CacheJwtAlgorithm = HashAlgorithmName.SHA384;
//using the es384 algorithm for signing (friendlyname is secp384r1)
///
/// The default ECCurve used by the connection library
///
public static readonly ECCurve CacheCurve = ECCurve.CreateFromFriendlyName("secp384r1");
///
/// Gets a preconfigured object caching
/// protocl
///
/// The client buffer heap
/// The maxium message size (in bytes)
/// An optional debug log
/// A preconfigured for object caching
public static FBMClientConfig GetDefaultConfig(IUnmangedHeap heap, int maxMessageSize, ILogProvider? debugLog = null)
{
return new()
{
BufferHeap = heap,
MaxMessageSize = maxMessageSize * 2,
RecvBufferSize = maxMessageSize,
MessageBufferSize = maxMessageSize,
MaxHeaderBufferSize = MAX_FBM_MESSAGE_HEADER_SIZE,
SubProtocol = CACHE_WS_SUB_PROCOL,
HeaderEncoding = Helpers.DefaultEncoding,
KeepAliveInterval = TimeSpan.FromSeconds(30),
DebugLog = debugLog
};
}
private class CacheConnectionConfig
{
public ECDsa ClientAlg { get; init; }
public ECDsa BrokerAlg { get; init; }
public string ServerChallenge { get; init; }
public string? NodeId { get; set; }
public Uri? BrokerAddress { get; set; }
public bool useTls { get; set; }
public ActiveServer[]? BrokerServers { get; set; }
public CacheConnectionConfig()
{
//Init the algorithms
ClientAlg = ECDsa.Create(CacheCurve);
BrokerAlg = ECDsa.Create(CacheCurve);
ServerChallenge = RandomHash.GetRandomBase32(24);
}
~CacheConnectionConfig()
{
ClientAlg.Clear();
BrokerAlg.Clear();
}
}
///
/// Contacts the cache broker to get a list of active servers to connect to
///
/// The broker server to connec to
/// The private key used to sign messages sent to the broker
/// The broker public key used to verify broker messages
/// A token to cancel the operationS
/// The list of active servers
///
///
public static async Task ListServersAsync(Uri brokerAddress, ReadOnlyMemory clientPrivKey, ReadOnlyMemory brokerPubKey, CancellationToken cancellationToken = default)
{
using ECDsa client = ECDsa.Create(CacheCurve);
using ECDsa broker = ECDsa.Create(CacheCurve);
//Import client private key
client.ImportPkcs8PrivateKey(clientPrivKey.Span, out _);
//Broker public key to verify broker messages
broker.ImportSubjectPublicKeyInfo(brokerPubKey.Span, out _);
return await ListServersAsync(brokerAddress, client, broker, cancellationToken);
}
///
/// Contacts the cache broker to get a list of active servers to connect to
///
/// The broker server to connec to
/// The signature algorithm used to sign messages to the broker
/// The signature used to verify broker messages
/// A token to cancel the operationS
/// The list of active servers
///
///
public static async Task ListServersAsync(Uri brokerAddress, ECDsa clientAlg, ECDsa brokerAlg, CancellationToken cancellationToken = default)
{
_ = brokerAddress ?? throw new ArgumentNullException(nameof(brokerAddress));
_ = clientAlg ?? throw new ArgumentNullException(nameof(clientAlg));
_ = brokerAlg ?? throw new ArgumentNullException(nameof(brokerAlg));
string jwtBody;
//Build request jwt
using (JsonWebToken requestJwt = new())
{
requestJwt.WriteHeader(BrokerJwtHeader);
requestJwt.InitPayloadClaim()
.AddClaim("iat", DateTimeOffset.UtcNow.ToUnixTimeMilliseconds())
.CommitClaims();
//sign the jwt
requestJwt.Sign(clientAlg, in CacheJwtAlgorithm, 512);
//Compile the jwt
jwtBody = requestJwt.Compile();
}
//New list request
RestRequest listRequest = new(brokerAddress, Method.Post);
//Add the jwt as a string to the request body
listRequest.AddStringBody(jwtBody, DataFormat.None);
listRequest.AddHeader("Content-Type", HttpHelpers.GetContentTypeString(ContentType.Text));
//Rent client
using ClientContract client = ClientPool.Lease();
//Exec list request
RestResponse response = await client.Resource.ExecuteAsync(listRequest, cancellationToken);
if (!response.IsSuccessful)
{
throw response.ErrorException!;
}
//Response is jwt
using JsonWebToken responseJwt = JsonWebToken.ParseRaw(response.RawBytes);
//Verify the jwt
if (!responseJwt.Verify(brokerAlg, in CacheJwtAlgorithm))
{
throw new SecurityException("Failed to verify the broker's challenge, cannot continue");
}
using JsonDocument doc = responseJwt.GetPayload();
return doc.RootElement.GetProperty("servers").Deserialize();
}
///
/// Configures a connection to the remote cache server at the specified location
/// with proper authentication.
///
///
/// The server's address
/// The pks8 format EC private key uesd to sign the message
/// A challenge to send to the server
/// A token used to identify the current server's event queue on the remote server
/// A token to cancel the connection operation
/// Enables the secure websocket protocol
/// A Task that completes when the connection has been established
///
public static Task ConnectAsync(this FBMClient client, string serverUri, ReadOnlyMemory signingKey, string challenge, string? nodeId, bool useTls, CancellationToken token = default)
{
//Sign the jwt
using ECDsa sigAlg = ECDsa.Create(CacheCurve);
//Import the signing key
sigAlg.ImportPkcs8PrivateKey(signingKey.Span, out _);
//Return without await because the alg is used to sign before this method returns and can be discarded
return ConnectAsync(client, serverUri, sigAlg, challenge, nodeId, useTls, token);
}
private static Task ConnectAsync(FBMClient client, string serverUri, ECDsa sigAlg, string challenge, string? nodeId, bool useTls, CancellationToken token = default)
{
_ = serverUri ?? throw new ArgumentNullException(nameof(serverUri));
_ = challenge ?? throw new ArgumentNullException(nameof(challenge));
//build ws uri
UriBuilder uriBuilder = new(serverUri)
{
Scheme = useTls ? "wss://" : "ws://"
};
string jwtMessage;
//Init jwt for connecting to server
using (JsonWebToken jwt = new())
{
jwt.WriteHeader(BrokerJwtHeader);
//Init claim
JwtPayload claim = jwt.InitPayloadClaim();
claim.AddClaim("challenge", challenge);
if (!string.IsNullOrWhiteSpace(nodeId))
{
/*
* The unique node id so the other nodes know to load the
* proper event queue for the current server
*/
claim.AddClaim("server_id", nodeId);
}
claim.CommitClaims();
//Sign jwt
jwt.Sign(sigAlg, in CacheJwtAlgorithm, 512);
//Compile to string
jwtMessage = jwt.Compile();
}
//Set jwt as authorization header
client.ClientSocket.Headers[HttpRequestHeader.Authorization] = jwtMessage;
//Connect async
return client.ConnectAsync(uriBuilder.Uri, token);
}
///
/// Registers the current server as active with the specified broker
///
/// The address of the broker to register with
/// The private key used to sign the message
/// The local address of the current server used for discovery
/// The unique id to identify this server (for event queues)
/// A unique security token used by the broker to authenticate itself
/// A task that resolves when a successful registration is completed, raises exceptions otherwise
public static async Task ResgisterWithBrokerAsync(Uri brokerAddress, ReadOnlyMemory signingKey, string serverAddress, string nodeId, string keepAliveToken)
{
_ = brokerAddress ?? throw new ArgumentNullException(nameof(brokerAddress));
_ = serverAddress ?? throw new ArgumentNullException(nameof(serverAddress));
_ = keepAliveToken ?? throw new ArgumentNullException(nameof(keepAliveToken));
_ = nodeId ?? throw new ArgumentNullException(nameof(nodeId));
string requestData;
//Create the jwt for signed registration message
using (JsonWebToken jwt = new())
{
//Shared jwt header
jwt.WriteHeader(BrokerJwtHeader);
//build jwt claim
jwt.InitPayloadClaim()
.AddClaim("address", serverAddress)
.AddClaim("server_id", nodeId)
.AddClaim("token", keepAliveToken)
.CommitClaims();
//Sign the jwt
using (ECDsa sigAlg = ECDsa.Create(CacheCurve))
{
//Import the signing key
sigAlg.ImportPkcs8PrivateKey(signingKey.Span, out _);
jwt.Sign(sigAlg, in CacheJwtAlgorithm, 512);
}
//Compile and save
requestData = jwt.Compile();
}
//Create reg request message
RestRequest regRequest = new(brokerAddress);
regRequest.AddStringBody(requestData, DataFormat.None);
regRequest.AddHeader("Content-Type", "text/plain");
//Rent client
using ClientContract client = ClientPool.Lease();
//Exec the regitration request
RestResponse response = await client.Resource.ExecutePutAsync(regRequest);
if(!response.IsSuccessful)
{
throw response.ErrorException!;
}
}
private static readonly ConditionalWeakTable ClientCacheConfig = new();
///
/// Imports the client signature algorithim's private key from its pkcs8 binary representation
///
///
/// Pkcs8 format private key
/// Chainable fluent object
///
///
public static FBMClient ImportClientPrivateKey(this FBMClient client, ReadOnlySpan pkcs8PrivateKey)
{
CacheConnectionConfig conf = ClientCacheConfig.GetOrCreateValue(client);
conf.ClientAlg.ImportPkcs8PrivateKey(pkcs8PrivateKey, out _);
return client;
}
///
/// Imports the public key used to verify broker server messages
///
///
/// The subject-public-key-info formatted broker public key
/// Chainable fluent object
///
///
public static FBMClient ImportBrokerPublicKey(this FBMClient client, ReadOnlySpan spkiPublicKey)
{
CacheConnectionConfig conf = ClientCacheConfig.GetOrCreateValue(client);
conf.BrokerAlg.ImportSubjectPublicKeyInfo(spkiPublicKey, out _);
return client;
}
///
/// Specifies if all connections should be using TLS
///
///
/// A value that indicates if connections should use TLS
/// Chainable fluent object
public static FBMClient UseTls(this FBMClient client, bool useTls)
{
CacheConnectionConfig conf = ClientCacheConfig.GetOrCreateValue(client);
conf.useTls = useTls;
return client;
}
///
/// Specifies the broker address to discover cache nodes from
///
///
/// The address of the server broker
/// Chainable fluent object
///
public static FBMClient UseBroker(this FBMClient client, Uri brokerAddress)
{
CacheConnectionConfig conf = ClientCacheConfig.GetOrCreateValue(client);
conf.BrokerAddress = brokerAddress ?? throw new ArgumentNullException(nameof(brokerAddress));
return client;
}
///
/// Specifies the current server's cluster node id. If this
/// is a server connection attempting to listen for changes on the
/// remote server, this id must be set and unique
///
///
/// The cluster node id of the current server
/// Chainable fluent object
///
public static FBMClient SetNodeId(this FBMClient client, string nodeId)
{
CacheConnectionConfig conf = ClientCacheConfig.GetOrCreateValue(client);
conf.NodeId = nodeId ?? throw new ArgumentNullException(nameof(nodeId));
return client;
}
///
/// Discovers cache nodes in the broker configured for the current client.
///
///
/// A token to cancel the discovery
/// A task the resolves the list of active servers on the broker server
public static Task DiscoverNodesAsync(this FBMClientWorkerBase client, CancellationToken token = default)
{
return client.Client.DiscoverNodesAsync(token);
}
///
/// Discovers cache nodes in the broker configured for the current client.
///
///
/// A token to cancel the discovery
/// A task the resolves the list of active servers on the broker server
public static async Task DiscoverNodesAsync(this FBMClient client, CancellationToken token = default)
{
CacheConnectionConfig conf = ClientCacheConfig.GetOrCreateValue(client);
//List servers async
ActiveServer[]? servers = await ListServersAsync(conf.BrokerAddress!, conf.ClientAlg, conf.BrokerAlg, token);
conf.BrokerServers = servers;
return servers;
}
///
/// Connects the client to a remote cache server
///
///
/// The server to connect to
/// A token to cancel the connection and/or wait operation
/// A task that resolves when cancelled or when the connection is lost to the server
///
public static Task ConnectAndWaitForExitAsync(this FBMClientWorkerBase client, ActiveServer server, CancellationToken token = default)
{
return client.Client.ConnectAndWaitForExitAsync(server, token);
}
///
/// Connects the client to a remote cache server
///
///
/// The server to connect to
/// A token to cancel the connection and/or wait operation
/// A task that resolves when cancelled or when the connection is lost to the server
///
public static async Task ConnectAndWaitForExitAsync(this FBMClient client, ActiveServer server, CancellationToken token = default)
{
CacheConnectionConfig conf = ClientCacheConfig.GetOrCreateValue(client);
//Connect to server (no server id because client not replication server)
await ConnectAsync(client, server.HostName!, conf.ClientAlg, conf.ServerChallenge, conf.NodeId, conf.useTls, token);
//Get task for cancellation
Task cancellation = token.WaitHandle.WaitAsync();
//Task for status handle
Task run = client.ConnectionStatusHandle.WaitAsync();
//Wait for cancellation or
_ = await Task.WhenAny(cancellation, run);
//Normal try to disconnect the socket
await client.DisconnectAsync(CancellationToken.None);
//Notify if cancelled
token.ThrowIfCancellationRequested();
}
///
/// Selects a random server from a collection of active servers
///
///
/// A server selected at random
public static ActiveServer SelectRandom(this ICollection servers)
{
//select random server
int randServer = RandomNumberGenerator.GetInt32(0, servers.Count);
return servers.ElementAt(randServer);
}
}
}