Vastly simplify CDNClientPool

pull/610/head
Pavel Djundik 11 months ago
parent 14c6a6dafa
commit 0150b7eff4

@ -2,10 +2,8 @@
// in file 'LICENSE', which is part of this source code package. // in file 'LICENSE', which is part of this source code package.
using System; using System;
using System.Collections.Concurrent;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using SteamKit2.CDN; using SteamKit2.CDN;
@ -16,72 +14,24 @@ namespace DepotDownloader
/// </summary> /// </summary>
class CDNClientPool class CDNClientPool
{ {
private const int ServerEndpointMinimumSize = 8;
private readonly Steam3Session steamSession; private readonly Steam3Session steamSession;
private readonly uint appId; private readonly uint appId;
public Client CDNClient { get; } public Client CDNClient { get; }
public Server ProxyServer { get; private set; } public Server ProxyServer { get; private set; }
private readonly ConcurrentStack<Server> activeConnectionPool = []; private readonly List<Server> servers = [];
private readonly BlockingCollection<Server> availableServerEndpoints = []; private int nextServer;
private readonly AutoResetEvent populatePoolEvent = new(true);
private readonly Task monitorTask;
private readonly CancellationTokenSource shutdownToken = new();
public CancellationTokenSource ExhaustedToken { get; set; }
public CDNClientPool(Steam3Session steamSession, uint appId) public CDNClientPool(Steam3Session steamSession, uint appId)
{ {
this.steamSession = steamSession; this.steamSession = steamSession;
this.appId = appId; this.appId = appId;
CDNClient = new Client(steamSession.steamClient); CDNClient = new Client(steamSession.steamClient);
monitorTask = Task.Factory.StartNew(ConnectionPoolMonitorAsync).Unwrap();
}
public void Shutdown()
{
shutdownToken.Cancel();
monitorTask.Wait();
}
private async Task<IReadOnlyCollection<Server>> FetchBootstrapServerListAsync()
{
try
{
var cdnServers = await this.steamSession.steamContent.GetServersForSteamPipe();
if (cdnServers != null)
{
return cdnServers;
}
}
catch (Exception ex)
{
Console.WriteLine("Failed to retrieve content server list: {0}", ex.Message);
}
return null;
} }
private async Task ConnectionPoolMonitorAsync() public async Task UpdateServerList()
{
var didPopulate = false;
while (!shutdownToken.IsCancellationRequested)
{
populatePoolEvent.WaitOne(TimeSpan.FromSeconds(1));
// We want the Steam session so we can take the CellID from the session and pass it through to the ContentServer Directory Service
if (availableServerEndpoints.Count < ServerEndpointMinimumSize && steamSession.steamClient.IsConnected)
{ {
var servers = await FetchBootstrapServerListAsync().ConfigureAwait(false); var servers = await this.steamSession.steamContent.GetServersForSteamPipe();
if (servers == null || servers.Count == 0)
{
ExhaustedToken?.Cancel();
return;
}
ProxyServer = servers.Where(x => x.UseAsProxy).FirstOrDefault(); ProxyServer = servers.Where(x => x.UseAsProxy).FirstOrDefault();
@ -103,52 +53,41 @@ namespace DepotDownloader
{ {
for (var i = 0; i < server.NumEntries; i++) for (var i = 0; i < server.NumEntries; i++)
{ {
availableServerEndpoints.Add(server); this.servers.Add(server);
} }
} }
didPopulate = true; if (this.servers.Count == 0)
}
else if (availableServerEndpoints.Count == 0 && !steamSession.steamClient.IsConnected && didPopulate)
{ {
ExhaustedToken?.Cancel(); throw new Exception("Failed to retrieve any download servers.");
return;
}
} }
} }
private Server BuildConnection(CancellationToken token) public Server GetConnection()
{ {
if (availableServerEndpoints.Count < ServerEndpointMinimumSize) return servers[nextServer % servers.Count];
{
populatePoolEvent.Set();
}
return availableServerEndpoints.Take(token);
}
public Server GetConnection(CancellationToken token)
{
if (!activeConnectionPool.TryPop(out var connection))
{
connection = BuildConnection(token);
}
return connection;
} }
public void ReturnConnection(Server server) public void ReturnConnection(Server server)
{ {
if (server == null) return; if (server == null) return;
activeConnectionPool.Push(server); // nothing to do, maybe remove from ContentServerPenalty?
} }
public void ReturnBrokenConnection(Server server) public void ReturnBrokenConnection(Server server)
{ {
if (server == null) return; if (server == null) return;
// Broken connections are not returned to the pool lock (servers)
{
if (servers[nextServer % servers.Count] == server)
{
nextServer++;
// TODO: Add server to ContentServerPenalty
}
}
} }
} }
} }

@ -333,12 +333,6 @@ namespace DepotDownloader
public static void ShutdownSteam3() public static void ShutdownSteam3()
{ {
if (cdnPool != null)
{
cdnPool.Shutdown();
cdnPool = null;
}
if (steam3 == null) if (steam3 == null)
return; return;
@ -660,9 +654,9 @@ namespace DepotDownloader
{ {
Ansi.Progress(Ansi.ProgressState.Indeterminate); Ansi.Progress(Ansi.ProgressState.Indeterminate);
var cts = new CancellationTokenSource(); await cdnPool.UpdateServerList();
cdnPool.ExhaustedToken = cts;
var cts = new CancellationTokenSource();
var downloadCounter = new GlobalDownloadCounter(); var downloadCounter = new GlobalDownloadCounter();
var depotsToDownload = new List<DepotFilesData>(depots.Count); var depotsToDownload = new List<DepotFilesData>(depots.Count);
var allFileNamesAllDepots = new HashSet<string>(); var allFileNamesAllDepots = new HashSet<string>();
@ -759,7 +753,7 @@ namespace DepotDownloader
try try
{ {
connection = cdnPool.GetConnection(cts.Token); connection = cdnPool.GetConnection();
string cdnToken = null; string cdnToken = null;
if (steam3.CDNAuthTokens.TryGetValue((depot.DepotId, connection.Host), out var authTokenCallbackPromise)) if (steam3.CDNAuthTokens.TryGetValue((depot.DepotId, connection.Host), out var authTokenCallbackPromise))
@ -1202,7 +1196,7 @@ namespace DepotDownloader
try try
{ {
connection = cdnPool.GetConnection(cts.Token); connection = cdnPool.GetConnection();
string cdnToken = null; string cdnToken = null;
if (steam3.CDNAuthTokens.TryGetValue((depot.DepotId, connection.Host), out var authTokenCallbackPromise)) if (steam3.CDNAuthTokens.TryGetValue((depot.DepotId, connection.Host), out var authTokenCallbackPromise))

Loading…
Cancel
Save