diff --git a/DepotDownloader/CDNClientPool.cs b/DepotDownloader/CDNClientPool.cs index 596beea6..f5a57818 100644 --- a/DepotDownloader/CDNClientPool.cs +++ b/DepotDownloader/CDNClientPool.cs @@ -2,10 +2,8 @@ // in file 'LICENSE', which is part of this source code package. using System; -using System.Collections.Concurrent; using System.Collections.Generic; using System.Linq; -using System.Threading; using System.Threading.Tasks; using SteamKit2.CDN; @@ -16,139 +14,80 @@ namespace DepotDownloader /// class CDNClientPool { - private const int ServerEndpointMinimumSize = 8; - private readonly Steam3Session steamSession; private readonly uint appId; public Client CDNClient { get; } public Server ProxyServer { get; private set; } - private readonly ConcurrentStack activeConnectionPool = []; - private readonly BlockingCollection availableServerEndpoints = []; - - private readonly AutoResetEvent populatePoolEvent = new(true); - private readonly Task monitorTask; - private readonly CancellationTokenSource shutdownToken = new(); - public CancellationTokenSource ExhaustedToken { get; set; } + private readonly List servers = []; + private int nextServer; public CDNClientPool(Steam3Session steamSession, uint appId) { this.steamSession = steamSession; this.appId = appId; CDNClient = new Client(steamSession.steamClient); - - monitorTask = Task.Factory.StartNew(ConnectionPoolMonitorAsync).Unwrap(); } - public void Shutdown() + public async Task UpdateServerList() { - shutdownToken.Cancel(); - monitorTask.Wait(); - } + var servers = await this.steamSession.steamContent.GetServersForSteamPipe(); - private async Task> FetchBootstrapServerListAsync() - { - try - { - var cdnServers = await this.steamSession.steamContent.GetServersForSteamPipe(); - if (cdnServers != null) - { - return cdnServers; - } - } - catch (Exception ex) - { - Console.WriteLine("Failed to retrieve content server list: {0}", ex.Message); - } + ProxyServer = servers.Where(x => x.UseAsProxy).FirstOrDefault(); - return null; - } + var weightedCdnServers = servers + .Where(server => + { + var isEligibleForApp = server.AllowedAppIds.Length == 0 || server.AllowedAppIds.Contains(appId); + return isEligibleForApp && (server.Type == "SteamCache" || server.Type == "CDN"); + }) + .Select(server => + { + AccountSettingsStore.Instance.ContentServerPenalty.TryGetValue(server.Host, out var penalty); - private async Task ConnectionPoolMonitorAsync() - { - var didPopulate = false; + return (server, penalty); + }) + .OrderBy(pair => pair.penalty).ThenBy(pair => pair.server.WeightedLoad); - while (!shutdownToken.IsCancellationRequested) + foreach (var (server, weight) in weightedCdnServers) { - populatePoolEvent.WaitOne(TimeSpan.FromSeconds(1)); - - // We want the Steam session so we can take the CellID from the session and pass it through to the ContentServer Directory Service - if (availableServerEndpoints.Count < ServerEndpointMinimumSize && steamSession.steamClient.IsConnected) - { - var servers = await FetchBootstrapServerListAsync().ConfigureAwait(false); - - if (servers == null || servers.Count == 0) - { - ExhaustedToken?.Cancel(); - return; - } - - ProxyServer = servers.Where(x => x.UseAsProxy).FirstOrDefault(); - - var weightedCdnServers = servers - .Where(server => - { - var isEligibleForApp = server.AllowedAppIds.Length == 0 || server.AllowedAppIds.Contains(appId); - return isEligibleForApp && (server.Type == "SteamCache" || server.Type == "CDN"); - }) - .Select(server => - { - AccountSettingsStore.Instance.ContentServerPenalty.TryGetValue(server.Host, out var penalty); - - return (server, penalty); - }) - .OrderBy(pair => pair.penalty).ThenBy(pair => pair.server.WeightedLoad); - - foreach (var (server, weight) in weightedCdnServers) - { - for (var i = 0; i < server.NumEntries; i++) - { - availableServerEndpoints.Add(server); - } - } - - didPopulate = true; - } - else if (availableServerEndpoints.Count == 0 && !steamSession.steamClient.IsConnected && didPopulate) + for (var i = 0; i < server.NumEntries; i++) { - ExhaustedToken?.Cancel(); - return; + this.servers.Add(server); } } - } - private Server BuildConnection(CancellationToken token) - { - if (availableServerEndpoints.Count < ServerEndpointMinimumSize) + if (this.servers.Count == 0) { - populatePoolEvent.Set(); + throw new Exception("Failed to retrieve any download servers."); } - - return availableServerEndpoints.Take(token); } - public Server GetConnection(CancellationToken token) + public Server GetConnection() { - if (!activeConnectionPool.TryPop(out var connection)) - { - connection = BuildConnection(token); - } - - return connection; + return servers[nextServer % servers.Count]; } public void ReturnConnection(Server server) { if (server == null) return; - activeConnectionPool.Push(server); + // nothing to do, maybe remove from ContentServerPenalty? } public void ReturnBrokenConnection(Server server) { if (server == null) return; - // Broken connections are not returned to the pool + lock (servers) + { + if (servers[nextServer % servers.Count] == server) + { + nextServer++; + + // TODO: Add server to ContentServerPenalty + } + } } } } diff --git a/DepotDownloader/ContentDownloader.cs b/DepotDownloader/ContentDownloader.cs index a089f122..8ab56d49 100644 --- a/DepotDownloader/ContentDownloader.cs +++ b/DepotDownloader/ContentDownloader.cs @@ -333,12 +333,6 @@ namespace DepotDownloader public static void ShutdownSteam3() { - if (cdnPool != null) - { - cdnPool.Shutdown(); - cdnPool = null; - } - if (steam3 == null) return; @@ -660,9 +654,9 @@ namespace DepotDownloader { Ansi.Progress(Ansi.ProgressState.Indeterminate); - var cts = new CancellationTokenSource(); - cdnPool.ExhaustedToken = cts; + await cdnPool.UpdateServerList(); + var cts = new CancellationTokenSource(); var downloadCounter = new GlobalDownloadCounter(); var depotsToDownload = new List(depots.Count); var allFileNamesAllDepots = new HashSet(); @@ -759,7 +753,7 @@ namespace DepotDownloader try { - connection = cdnPool.GetConnection(cts.Token); + connection = cdnPool.GetConnection(); string cdnToken = null; if (steam3.CDNAuthTokens.TryGetValue((depot.DepotId, connection.Host), out var authTokenCallbackPromise)) @@ -1202,7 +1196,7 @@ namespace DepotDownloader try { - connection = cdnPool.GetConnection(cts.Token); + connection = cdnPool.GetConnection(); string cdnToken = null; if (steam3.CDNAuthTokens.TryGetValue((depot.DepotId, connection.Host), out var authTokenCallbackPromise))