From ac97c011077f8476beee540f04c8a3e220744233 Mon Sep 17 00:00:00 2001 From: Ryan Kistner Date: Sat, 31 Oct 2020 19:41:49 -0600 Subject: [PATCH] Process download chunks as tasks to increase concurrency --- DepotDownloader/CDNClientPool.cs | 82 +-- DepotDownloader/ContentDownloader.cs | 847 ++++++++++++++++----------- DepotDownloader/Program.cs | 6 +- DepotDownloader/Steam3Session.cs | 16 +- DepotDownloader/Util.cs | 34 ++ README.md | 4 +- 6 files changed, 597 insertions(+), 392 deletions(-) diff --git a/DepotDownloader/CDNClientPool.cs b/DepotDownloader/CDNClientPool.cs index 654d3f54..120bb73f 100644 --- a/DepotDownloader/CDNClientPool.cs +++ b/DepotDownloader/CDNClientPool.cs @@ -17,10 +17,11 @@ namespace DepotDownloader private const int ServerEndpointMinimumSize = 8; private readonly Steam3Session steamSession; + private readonly uint appId; public CDNClient CDNClient { get; } - private readonly ConcurrentBag activeConnectionPool; + private readonly ConcurrentStack activeConnectionPool; private readonly BlockingCollection availableServerEndpoints; private readonly AutoResetEvent populatePoolEvent; @@ -28,12 +29,13 @@ namespace DepotDownloader private readonly CancellationTokenSource shutdownToken; public CancellationTokenSource ExhaustedToken { get; set; } - public CDNClientPool(Steam3Session steamSession) + public CDNClientPool(Steam3Session steamSession, uint appId) { this.steamSession = steamSession; + this.appId = appId; CDNClient = new CDNClient(steamSession.steamClient); - activeConnectionPool = new ConcurrentBag(); + activeConnectionPool = new ConcurrentStack(); availableServerEndpoints = new BlockingCollection(); populatePoolEvent = new AutoResetEvent(true); @@ -97,12 +99,23 @@ namespace DepotDownloader return; } - var weightedCdnServers = servers.Where(x => x.Type == "SteamCache" || x.Type == "CDN").Select(x => - { - AccountSettingsStore.Instance.ContentServerPenalty.TryGetValue(x.Host, out var penalty); + var weightedCdnServers = servers + .Where(x => + { +#if STEAMKIT_UNRELEASED + var isEligibleForApp = x.AllowedAppIds == null || x.AllowedAppIds.Contains(appId); + return isEligibleForApp && (x.Type == "SteamCache" || x.Type == "CDN"); +#else + return x.Type == "SteamCache" || x.Type == "CDN"; +#endif + }) + .Select(x => + { + AccountSettingsStore.Instance.ContentServerPenalty.TryGetValue(x.Host, out var penalty); - return Tuple.Create(x, penalty); - }).OrderBy(x => x.Item2).ThenBy(x => x.Item1.WeightedLoad); + return Tuple.Create(x, penalty); + }) + .OrderBy(x => x.Item2).ThenBy(x => x.Item1.WeightedLoad); foreach (var (server, weight) in weightedCdnServers) { @@ -122,24 +135,6 @@ namespace DepotDownloader } } - private async Task AuthenticateConnection(uint appId, uint depotId, CDNClient.Server server) - { - var host = steamSession.ResolveCDNTopLevelHost(server.Host); - var cdnKey = $"{depotId:D}:{host}"; - - steamSession.RequestCDNAuthToken(appId, depotId, host, cdnKey); - - if (steamSession.CDNAuthTokens.TryGetValue(cdnKey, out var authTokenCallbackPromise)) - { - var result = await authTokenCallbackPromise.Task; - return result.Token; - } - else - { - throw new Exception($"Failed to retrieve CDN token for server {server.Host} depot {depotId}"); - } - } - private CDNClient.Server BuildConnection(CancellationToken token) { if (availableServerEndpoints.Count < ServerEndpointMinimumSize) @@ -150,29 +145,42 @@ namespace DepotDownloader return availableServerEndpoints.Take(token); } - public async Task> GetConnectionForDepot(uint appId, uint depotId, CancellationToken token) + public CDNClient.Server GetConnection(CancellationToken token) { - // Take a free connection from the connection pool - // If there were no free connections, create a new one from the server list - if (!activeConnectionPool.TryTake(out var server)) + if (!activeConnectionPool.TryPop(out var connection)) { - server = BuildConnection(token); + connection = BuildConnection(token); } - // If we don't have a CDN token yet for this server and depot, fetch one now - var cdnToken = await AuthenticateConnection(appId, depotId, server); + return connection; + } + + public async Task AuthenticateConnection(uint appId, uint depotId, CDNClient.Server server) + { + var host = steamSession.ResolveCDNTopLevelHost(server.Host); + var cdnKey = $"{depotId:D}:{host}"; + + steamSession.RequestCDNAuthToken(appId, depotId, host, cdnKey); - return Tuple.Create(server, cdnToken); + if (steamSession.CDNAuthTokens.TryGetValue(cdnKey, out var authTokenCallbackPromise)) + { + var result = await authTokenCallbackPromise.Task; + return result.Token; + } + else + { + throw new Exception($"Failed to retrieve CDN token for server {server.Host} depot {depotId}"); + } } - public void ReturnConnection(Tuple server) + public void ReturnConnection(CDNClient.Server server) { if (server == null) return; - activeConnectionPool.Add(server.Item1); + activeConnectionPool.Push(server); } - public void ReturnBrokenConnection(Tuple server) + public void ReturnBrokenConnection(CDNClient.Server server) { if (server == null) return; diff --git a/DepotDownloader/ContentDownloader.cs b/DepotDownloader/ContentDownloader.cs index 49345ff8..fd277937 100644 --- a/DepotDownloader/ContentDownloader.cs +++ b/DepotDownloader/ContentDownloader.cs @@ -1,5 +1,6 @@ using SteamKit2; using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.IO; using System.Linq; @@ -361,7 +362,6 @@ namespace DepotDownloader return false; } - cdnPool = new CDNClientPool( steam3 ); return true; } @@ -396,6 +396,8 @@ namespace DepotDownloader public static async Task DownloadAppAsync( uint appId, uint depotId, ulong manifestId, string branch, string os, string arch, string language, bool lv, bool isUgc ) { + cdnPool = new CDNClientPool(steam3, appId); + // Load our configuration data containing the depots currently installed string configPath = ContentDownloader.Config.InstallDirectory; if (string.IsNullOrWhiteSpace(configPath)) @@ -595,454 +597,607 @@ namespace DepotDownloader public ProtoManifest.ChunkData NewChunk { get; private set; } } - private static async Task DownloadSteam3Async( uint appId, List depots ) + private class DepotFilesData + { + public DepotDownloadInfo depotDownloadInfo; + public DepotDownloadCounter depotCounter; + public string stagingDir; + public ProtoManifest manifest; + public ProtoManifest previousManifest; + public List filteredFiles; + public HashSet allFileNames; + } + + private class FileStreamData { - ulong TotalBytesCompressed = 0; - ulong TotalBytesUncompressed = 0; - var previousFiles = new List(); + public FileStream fileStream; + public SemaphoreSlim fileLock; + public int chunksToDownload; + } + + private class GlobalDownloadCounter + { + public ulong TotalBytesCompressed; + public ulong TotalBytesUncompressed; + } + + private class DepotDownloadCounter + { + public ulong CompleteDownloadSize; + public ulong SizeDownloaded; + public ulong DepotBytesCompressed; + public ulong DepotBytesUncompressed; + + } + + private static async Task DownloadSteam3Async(uint appId, List depots) + { + CancellationTokenSource cts = new CancellationTokenSource(); + cdnPool.ExhaustedToken = cts; + + GlobalDownloadCounter downloadCounter = new GlobalDownloadCounter(); + var depotsToDownload = new List(depots.Count); + var allFileNames = new HashSet(); - foreach ( var depot in depots ) + // First, fetch all the manifests for each depot (including previous manifests) and perform the initial setup + foreach (var depot in depots) { - ulong DepotBytesCompressed = 0; - ulong DepotBytesUncompressed = 0; + var depotFileData = await ProcessDepotManifestAndFiles(cts, appId, depot); - Console.WriteLine( "Downloading depot {0} - {1}", depot.id, depot.contentName ); + if (depotFileData != null) + { + depotsToDownload.Add(depotFileData); + allFileNames.UnionWith(depotFileData.allFileNames); + } + + cts.Token.ThrowIfCancellationRequested(); + } - CancellationTokenSource cts = new CancellationTokenSource(); - cdnPool.ExhaustedToken = cts; + foreach (var depotFileData in depotsToDownload) + { + await DownloadSteam3AsyncDepotFiles(cts, appId, downloadCounter, depotFileData, allFileNames); + } - ProtoManifest oldProtoManifest = null; - ProtoManifest newProtoManifest = null; - string configDir = Path.Combine( depot.installDir, CONFIG_DIR ); + Console.WriteLine("Total downloaded: {0} bytes ({1} bytes uncompressed) from {2} depots", + downloadCounter.TotalBytesCompressed, downloadCounter.TotalBytesUncompressed, depots.Count); + } - ulong lastManifestId = INVALID_MANIFEST_ID; - DepotConfigStore.Instance.InstalledManifestIDs.TryGetValue( depot.id, out lastManifestId ); + private static async Task ProcessDepotManifestAndFiles(CancellationTokenSource cts, + uint appId, DepotDownloadInfo depot) + { + DepotDownloadCounter depotCounter = new DepotDownloadCounter(); - // In case we have an early exit, this will force equiv of verifyall next run. - DepotConfigStore.Instance.InstalledManifestIDs[ depot.id ] = INVALID_MANIFEST_ID; - DepotConfigStore.Save(); + Console.WriteLine("Processing depot {0} - {1}", depot.id, depot.contentName); - if ( lastManifestId != INVALID_MANIFEST_ID ) + ProtoManifest oldProtoManifest = null; + ProtoManifest newProtoManifest = null; + string configDir = Path.Combine(depot.installDir, CONFIG_DIR); + + ulong lastManifestId = INVALID_MANIFEST_ID; + DepotConfigStore.Instance.InstalledManifestIDs.TryGetValue(depot.id, out lastManifestId); + + // In case we have an early exit, this will force equiv of verifyall next run. + DepotConfigStore.Instance.InstalledManifestIDs[depot.id] = INVALID_MANIFEST_ID; + DepotConfigStore.Save(); + + if (lastManifestId != INVALID_MANIFEST_ID) + { + var oldManifestFileName = Path.Combine(configDir, string.Format("{0}_{1}.bin", depot.id, lastManifestId)); + + if (File.Exists(oldManifestFileName)) { - var oldManifestFileName = Path.Combine( configDir, string.Format( "{0}_{1}.bin", depot.id, lastManifestId ) ); + byte[] expectedChecksum, currentChecksum; - if (File.Exists(oldManifestFileName)) + try + { + expectedChecksum = File.ReadAllBytes(oldManifestFileName + ".sha"); + } + catch (IOException) { - byte[] expectedChecksum, currentChecksum; + expectedChecksum = null; + } - try - { - expectedChecksum = File.ReadAllBytes(oldManifestFileName + ".sha"); - } - catch (IOException) - { - expectedChecksum = null; - } + oldProtoManifest = ProtoManifest.LoadFromFile(oldManifestFileName, out currentChecksum); - oldProtoManifest = ProtoManifest.LoadFromFile(oldManifestFileName, out currentChecksum); + if (expectedChecksum == null || !expectedChecksum.SequenceEqual(currentChecksum)) + { + // We only have to show this warning if the old manifest ID was different + if (lastManifestId != depot.manifestId) + Console.WriteLine("Manifest {0} on disk did not match the expected checksum.", lastManifestId); + oldProtoManifest = null; + } + } + } - if (expectedChecksum == null || !expectedChecksum.SequenceEqual(currentChecksum)) - { - // We only have to show this warning if the old manifest ID was different - if (lastManifestId != depot.manifestId) - Console.WriteLine("Manifest {0} on disk did not match the expected checksum.", lastManifestId); - oldProtoManifest = null; - } + if (lastManifestId == depot.manifestId && oldProtoManifest != null) + { + newProtoManifest = oldProtoManifest; + Console.WriteLine("Already have manifest {0} for depot {1}.", depot.manifestId, depot.id); + } + else + { + var newManifestFileName = Path.Combine(configDir, string.Format("{0}_{1}.bin", depot.id, depot.manifestId)); + if (newManifestFileName != null) + { + byte[] expectedChecksum, currentChecksum; + + try + { + expectedChecksum = File.ReadAllBytes(newManifestFileName + ".sha"); + } + catch (IOException) + { + expectedChecksum = null; + } + + newProtoManifest = ProtoManifest.LoadFromFile(newManifestFileName, out currentChecksum); + + if (newProtoManifest != null && (expectedChecksum == null || !expectedChecksum.SequenceEqual(currentChecksum))) + { + Console.WriteLine("Manifest {0} on disk did not match the expected checksum.", depot.manifestId); + newProtoManifest = null; } } - if ( lastManifestId == depot.manifestId && oldProtoManifest != null ) + if (newProtoManifest != null) { - newProtoManifest = oldProtoManifest; - Console.WriteLine( "Already have manifest {0} for depot {1}.", depot.manifestId, depot.id ); + Console.WriteLine("Already have manifest {0} for depot {1}.", depot.manifestId, depot.id); } else { - var newManifestFileName = Path.Combine( configDir, string.Format( "{0}_{1}.bin", depot.id, depot.manifestId ) ); - if ( newManifestFileName != null ) + Console.Write("Downloading depot manifest..."); + + DepotManifest depotManifest = null; + + do { - byte[] expectedChecksum, currentChecksum; + cts.Token.ThrowIfCancellationRequested(); + + CDNClient.Server connection = null; try { - expectedChecksum = File.ReadAllBytes(newManifestFileName + ".sha"); + connection = cdnPool.GetConnection(cts.Token); + var cdnToken = await cdnPool.AuthenticateConnection(appId, depot.id, connection); + + depotManifest = await cdnPool.CDNClient.DownloadManifestAsync(depot.id, depot.manifestId, + connection, cdnToken, depot.depotKey).ConfigureAwait(false); + + cdnPool.ReturnConnection(connection); } - catch (IOException) + catch (TaskCanceledException) { - expectedChecksum = null; + Console.WriteLine("Connection timeout downloading depot manifest {0} {1}", depot.id, depot.manifestId); } + catch (SteamKitWebRequestException e) + { + cdnPool.ReturnBrokenConnection(connection); - newProtoManifest = ProtoManifest.LoadFromFile(newManifestFileName, out currentChecksum); - - if (newProtoManifest != null && (expectedChecksum == null || !expectedChecksum.SequenceEqual(currentChecksum))) + if (e.StatusCode == HttpStatusCode.Unauthorized || e.StatusCode == HttpStatusCode.Forbidden) + { + Console.WriteLine("Encountered 401 for depot manifest {0} {1}. Aborting.", depot.id, depot.manifestId); + break; + } + else + { + Console.WriteLine("Encountered error downloading depot manifest {0} {1}: {2}", depot.id, depot.manifestId, e.StatusCode); + } + } + catch (OperationCanceledException) + { + break; + } + catch (Exception e) { - Console.WriteLine("Manifest {0} on disk did not match the expected checksum.", depot.manifestId); - newProtoManifest = null; + cdnPool.ReturnBrokenConnection(connection); + Console.WriteLine("Encountered error downloading manifest for depot {0} {1}: {2}", depot.id, depot.manifestId, e.Message); } } + while (depotManifest == null); - if ( newProtoManifest != null ) + if (depotManifest == null) { - Console.WriteLine( "Already have manifest {0} for depot {1}.", depot.manifestId, depot.id ); + Console.WriteLine("\nUnable to download manifest {0} for depot {1}", depot.manifestId, depot.id); + cts.Cancel(); } - else - { - Console.Write( "Downloading depot manifest..." ); - DepotManifest depotManifest = null; + // Throw the cancellation exception if requested so that this task is marked failed + cts.Token.ThrowIfCancellationRequested(); - while ( depotManifest == null ) - { - Tuple connection = null; - try - { - connection = await cdnPool.GetConnectionForDepot( appId, depot.id, CancellationToken.None ); + byte[] checksum; - depotManifest = await cdnPool.CDNClient.DownloadManifestAsync( depot.id, depot.manifestId, - connection.Item1, connection.Item2, depot.depotKey ).ConfigureAwait(false); + newProtoManifest = new ProtoManifest(depotManifest, depot.manifestId); + newProtoManifest.SaveToFile(newManifestFileName, out checksum); + File.WriteAllBytes(newManifestFileName + ".sha", checksum); - cdnPool.ReturnConnection( connection ); - } - catch ( SteamKitWebRequestException e ) - { - cdnPool.ReturnBrokenConnection( connection ); + Console.WriteLine(" Done!"); + } + } - if ( e.StatusCode == HttpStatusCode.Unauthorized || e.StatusCode == HttpStatusCode.Forbidden ) - { - Console.WriteLine( "Encountered 401 for depot manifest {0} {1}. Aborting.", depot.id, depot.manifestId ); - break; - } - else - { - Console.WriteLine( "Encountered error downloading depot manifest {0} {1}: {2}", depot.id, depot.manifestId, e.StatusCode ); - } - } - catch ( Exception e ) - { - cdnPool.ReturnBrokenConnection( connection ); - Console.WriteLine( "Encountered error downloading manifest for depot {0} {1}: {2}", depot.id, depot.manifestId, e.Message ); - } - } + newProtoManifest.Files.Sort((x, y) => string.Compare(x.FileName, y.FileName, StringComparison.Ordinal)); - if ( depotManifest == null ) - { - Console.WriteLine( "\nUnable to download manifest {0} for depot {1}", depot.manifestId, depot.id ); - return; - } + Console.WriteLine("Manifest {0} ({1})", depot.manifestId, newProtoManifest.CreationTime); - byte[] checksum; + if (Config.DownloadManifestOnly) + { + StringBuilder manifestBuilder = new StringBuilder(); + string txtManifest = Path.Combine(depot.installDir, string.Format("manifest_{0}_{1}.txt", depot.id, depot.manifestId)); + manifestBuilder.Append(string.Format("{0}\n\n", newProtoManifest.CreationTime)); - newProtoManifest = new ProtoManifest( depotManifest, depot.manifestId ); - newProtoManifest.SaveToFile( newManifestFileName, out checksum ); - File.WriteAllBytes( newManifestFileName + ".sha", checksum ); + foreach (var file in newProtoManifest.Files) + { + if (file.Flags.HasFlag(EDepotFileFlag.Directory)) + continue; - Console.WriteLine( " Done!" ); - } + manifestBuilder.Append(string.Format("{0}\n", file.FileName)); + manifestBuilder.Append(string.Format("\t{0}\n", file.TotalSize)); + manifestBuilder.Append(string.Format("\t{0}\n", BitConverter.ToString(file.FileHash).Replace("-", ""))); } - newProtoManifest.Files.Sort( ( x, y ) => string.Compare( x.FileName, y.FileName, StringComparison.Ordinal ) ); + File.WriteAllText(txtManifest, manifestBuilder.ToString()); + return null; + } - Console.WriteLine( "Manifest {0} ({1})", depot.manifestId, newProtoManifest.CreationTime ); + string stagingDir = Path.Combine(depot.installDir, STAGING_DIR); - if ( Config.DownloadManifestOnly ) - { - StringBuilder manifestBuilder = new StringBuilder(); - string txtManifest = Path.Combine( depot.installDir, string.Format( "manifest_{0}_{1}.txt", depot.id, depot.manifestId ) ); - manifestBuilder.Append( string.Format( "{0}\n\n", newProtoManifest.CreationTime ) ); + var filesAfterExclusions = newProtoManifest.Files.AsParallel().Where(f => TestIsFileIncluded(f.FileName)).ToList(); + var allFileNames = new HashSet(filesAfterExclusions.Count); - foreach ( var file in newProtoManifest.Files ) - { - if ( file.Flags.HasFlag( EDepotFileFlag.Directory ) ) - continue; + // Pre-process + filesAfterExclusions.ForEach(file => + { + allFileNames.Add(file.FileName); - manifestBuilder.Append( string.Format( "{0}\n", file.FileName ) ); - manifestBuilder.Append( string.Format( "\t{0}\n", file.TotalSize ) ); - manifestBuilder.Append( string.Format( "\t{0}\n", BitConverter.ToString( file.FileHash ).Replace( "-", "" ) ) ); - } + var fileFinalPath = Path.Combine(depot.installDir, file.FileName); + var fileStagingPath = Path.Combine(stagingDir, file.FileName); - File.WriteAllText( txtManifest, manifestBuilder.ToString() ); - continue; + if (file.Flags.HasFlag(EDepotFileFlag.Directory)) + { + Directory.CreateDirectory(fileFinalPath); + Directory.CreateDirectory(fileStagingPath); } + else + { + // Some manifests don't explicitly include all necessary directories + Directory.CreateDirectory(Path.GetDirectoryName(fileFinalPath)); + Directory.CreateDirectory(Path.GetDirectoryName(fileStagingPath)); - ulong complete_download_size = 0; - ulong size_downloaded = 0; - string stagingDir = Path.Combine( depot.installDir, STAGING_DIR ); + depotCounter.CompleteDownloadSize += file.TotalSize; + } + }); - var filesAfterExclusions = newProtoManifest.Files.AsParallel().Where( f => TestIsFileIncluded( f.FileName ) ).ToList(); + return new DepotFilesData + { + depotDownloadInfo = depot, + depotCounter = depotCounter, + stagingDir = stagingDir, + manifest = newProtoManifest, + previousManifest = oldProtoManifest, + filteredFiles = filesAfterExclusions, + allFileNames = allFileNames + }; + } + + private static async Task DownloadSteam3AsyncDepotFiles(CancellationTokenSource cts, uint appId, + GlobalDownloadCounter downloadCounter, DepotFilesData depotFilesData, HashSet allFileNames) + { + var depot = depotFilesData.depotDownloadInfo; + var depotCounter = depotFilesData.depotCounter; + + Console.WriteLine("Downloading depot {0} - {1}", depot.id, depot.contentName); + + var files = depotFilesData.filteredFiles.Where(f => !f.Flags.HasFlag(EDepotFileFlag.Directory)).ToArray(); + var networkChunkQueue = new ConcurrentQueue>(); + + await Util.InvokeAsync( + files.Select(file => new Func(async () => + await Task.Run(() => DownloadSteam3AsyncDepotFile(cts, depotFilesData, file, networkChunkQueue)))), + maxDegreeOfParallelism: Config.MaxDownloads + ); + + await Util.InvokeAsync( + networkChunkQueue.Select((x) => new Func(async () => + await Task.Run(() => DownloadSteam3AsyncDepotFileChunk(cts, appId, downloadCounter, depotFilesData, + x.Item2, x.Item1, x.Item3)))), + maxDegreeOfParallelism: Config.MaxDownloads + ); - // Pre-process - filesAfterExclusions.ForEach( file => + // Check for deleted files if updating the depot. + if (depotFilesData.previousManifest != null) + { + var previousFilteredFiles = depotFilesData.previousManifest.Files.AsParallel().Where(f => TestIsFileIncluded(f.FileName)).Select(f => f.FileName).ToHashSet(); + + // Of the list of files in the previous manifest, remove any file names that exist in the current set of all file names across all depots being downloaded + previousFilteredFiles.ExceptWith(allFileNames); + + foreach(var existingFileName in previousFilteredFiles) { - var fileFinalPath = Path.Combine( depot.installDir, file.FileName ); - var fileStagingPath = Path.Combine( stagingDir, file.FileName ); + string fileFinalPath = Path.Combine(depot.installDir, existingFileName); - if ( file.Flags.HasFlag( EDepotFileFlag.Directory ) ) - { - Directory.CreateDirectory( fileFinalPath ); - Directory.CreateDirectory( fileStagingPath ); - } - else - { - // Some manifests don't explicitly include all necessary directories - Directory.CreateDirectory( Path.GetDirectoryName( fileFinalPath ) ); - Directory.CreateDirectory( Path.GetDirectoryName( fileStagingPath ) ); + if (!File.Exists(fileFinalPath)) + continue; - complete_download_size += file.TotalSize; - } - } ); + File.Delete(fileFinalPath); + Console.WriteLine("Deleted {0}", fileFinalPath); + } + } + + DepotConfigStore.Instance.InstalledManifestIDs[depot.id] = depot.manifestId; + DepotConfigStore.Save(); + + Console.WriteLine("Depot {0} - Downloaded {1} bytes ({2} bytes uncompressed)", depot.id, depotCounter.DepotBytesCompressed, depotCounter.DepotBytesUncompressed); + } + + private static void DownloadSteam3AsyncDepotFile( + CancellationTokenSource cts, + DepotFilesData depotFilesData, + ProtoManifest.FileData file, + ConcurrentQueue> networkChunkQueue) + { + cts.Token.ThrowIfCancellationRequested(); + + var depot = depotFilesData.depotDownloadInfo; + var stagingDir = depotFilesData.stagingDir; + var depotDownloadCounter = depotFilesData.depotCounter; + var oldProtoManifest = depotFilesData.previousManifest; + + string fileFinalPath = Path.Combine(depot.installDir, file.FileName); + string fileStagingPath = Path.Combine(stagingDir, file.FileName); + + // This may still exist if the previous run exited before cleanup + if (File.Exists(fileStagingPath)) + { + File.Delete(fileStagingPath); + } + + FileStream fs = null; + List neededChunks; + FileInfo fi = new FileInfo(fileFinalPath); + if (!fi.Exists) + { + Console.WriteLine("Pre-allocating {0}", fileFinalPath); - var semaphore = new SemaphoreSlim( Config.MaxDownloads ); - var files = filesAfterExclusions.Where( f => !f.Flags.HasFlag( EDepotFileFlag.Directory ) ).ToArray(); - var tasks = new Task[ files.Length ]; - for ( var i = 0; i < files.Length; i++ ) + // create new file. need all chunks + fs = File.Create(fileFinalPath); + fs.SetLength((long)file.TotalSize); + neededChunks = new List(file.Chunks); + } + else + { + // open existing + ProtoManifest.FileData oldManifestFile = null; + if (oldProtoManifest != null) { - var file = files[ i ]; - var task = Task.Run( async () => + oldManifestFile = oldProtoManifest.Files.SingleOrDefault(f => f.FileName == file.FileName); + } + + if (oldManifestFile != null) + { + neededChunks = new List(); + + if (Config.VerifyAll || !oldManifestFile.FileHash.SequenceEqual(file.FileHash)) { - cts.Token.ThrowIfCancellationRequested(); - - try + // we have a version of this file, but it doesn't fully match what we want + if (Config.VerifyAll) { - await semaphore.WaitAsync().ConfigureAwait( false ); - cts.Token.ThrowIfCancellationRequested(); + Console.WriteLine("Validating {0}", fileFinalPath); + } - string fileFinalPath = Path.Combine( depot.installDir, file.FileName ); - string fileStagingPath = Path.Combine( stagingDir, file.FileName ); + var matchingChunks = new List(); - // This may still exist if the previous run exited before cleanup - if ( File.Exists( fileStagingPath ) ) - { - File.Delete( fileStagingPath ); - } - - FileStream fs = null; - List neededChunks; - FileInfo fi = new FileInfo( fileFinalPath ); - if ( !fi.Exists ) + foreach (var chunk in file.Chunks) + { + var oldChunk = oldManifestFile.Chunks.FirstOrDefault(c => c.ChunkID.SequenceEqual(chunk.ChunkID)); + if (oldChunk != null) { - // create new file. need all chunks - fs = File.Create( fileFinalPath ); - fs.SetLength( ( long )file.TotalSize ); - neededChunks = new List( file.Chunks ); + matchingChunks.Add(new ChunkMatch(oldChunk, chunk)); } else { - // open existing - ProtoManifest.FileData oldManifestFile = null; - if ( oldProtoManifest != null ) - { - oldManifestFile = oldProtoManifest.Files.SingleOrDefault( f => f.FileName == file.FileName ); - } + neededChunks.Add(chunk); + } + } - if ( oldManifestFile != null ) - { - neededChunks = new List(); - - if ( Config.VerifyAll || !oldManifestFile.FileHash.SequenceEqual( file.FileHash ) ) - { - // we have a version of this file, but it doesn't fully match what we want - - var matchingChunks = new List(); - - foreach ( var chunk in file.Chunks ) - { - var oldChunk = oldManifestFile.Chunks.FirstOrDefault( c => c.ChunkID.SequenceEqual( chunk.ChunkID ) ); - if ( oldChunk != null ) - { - matchingChunks.Add( new ChunkMatch( oldChunk, chunk ) ); - } - else - { - neededChunks.Add( chunk ); - } - } - - File.Move( fileFinalPath, fileStagingPath ); - - fs = File.Open( fileFinalPath, FileMode.Create ); - fs.SetLength( ( long )file.TotalSize ); - - using ( var fsOld = File.Open( fileStagingPath, FileMode.Open ) ) - { - foreach ( var match in matchingChunks ) - { - fsOld.Seek( ( long )match.OldChunk.Offset, SeekOrigin.Begin ); - - byte[] tmp = new byte[ match.OldChunk.UncompressedLength ]; - fsOld.Read( tmp, 0, tmp.Length ); - - byte[] adler = Util.AdlerHash( tmp ); - if ( !adler.SequenceEqual( match.OldChunk.Checksum ) ) - { - neededChunks.Add( match.NewChunk ); - } - else - { - fs.Seek( ( long )match.NewChunk.Offset, SeekOrigin.Begin ); - fs.Write( tmp, 0, tmp.Length ); - } - } - } - - File.Delete( fileStagingPath ); - } - } - else - { - // No old manifest or file not in old manifest. We must validate. + var orderedChunks = matchingChunks.OrderBy(x => x.OldChunk.Offset); - fs = File.Open( fileFinalPath, FileMode.Open ); - if ( ( ulong )fi.Length != file.TotalSize ) - { - fs.SetLength( ( long )file.TotalSize ); - } + File.Move(fileFinalPath, fileStagingPath); - neededChunks = Util.ValidateSteam3FileChecksums( fs, file.Chunks.OrderBy( x => x.Offset ).ToArray() ); - } + fs = File.Open(fileFinalPath, FileMode.Create); + fs.SetLength((long)file.TotalSize); + + using (var fsOld = File.Open(fileStagingPath, FileMode.Open)) + { + foreach (var match in orderedChunks) + { + fsOld.Seek((long)match.OldChunk.Offset, SeekOrigin.Begin); + + byte[] tmp = new byte[match.OldChunk.UncompressedLength]; + fsOld.Read(tmp, 0, tmp.Length); - if ( neededChunks.Count() == 0 ) + byte[] adler = Util.AdlerHash(tmp); + if (!adler.SequenceEqual(match.OldChunk.Checksum)) { - size_downloaded += file.TotalSize; - Console.WriteLine( "{0,6:#00.00}% {1}", ( ( float )size_downloaded / ( float )complete_download_size ) * 100.0f, fileFinalPath ); - if ( fs != null ) - fs.Dispose(); - return; + neededChunks.Add(match.NewChunk); } else { - size_downloaded += ( file.TotalSize - ( ulong )neededChunks.Select( x => ( long )x.UncompressedLength ).Sum() ); + fs.Seek((long)match.NewChunk.Offset, SeekOrigin.Begin); + fs.Write(tmp, 0, tmp.Length); } } + } - foreach ( var chunk in neededChunks ) - { - if ( cts.IsCancellationRequested ) break; + File.Delete(fileStagingPath); + } + } + else + { + // No old manifest or file not in old manifest. We must validate. - string chunkID = Util.EncodeHexString( chunk.ChunkID ); - CDNClient.DepotChunk chunkData = null; + fs = File.Open(fileFinalPath, FileMode.Open); + if ((ulong)fi.Length != file.TotalSize) + { + fs.SetLength((long)file.TotalSize); + } - while ( !cts.IsCancellationRequested ) - { - Tuple connection; - try - { - connection = await cdnPool.GetConnectionForDepot( appId, depot.id, cts.Token ); - } - catch ( OperationCanceledException ) - { - break; - } - - DepotManifest.ChunkData data = new DepotManifest.ChunkData(); - data.ChunkID = chunk.ChunkID; - data.Checksum = chunk.Checksum; - data.Offset = chunk.Offset; - data.CompressedLength = chunk.CompressedLength; - data.UncompressedLength = chunk.UncompressedLength; - - try - { - chunkData = await cdnPool.CDNClient.DownloadDepotChunkAsync( depot.id, data, - connection.Item1, connection.Item2, depot.depotKey ).ConfigureAwait( false ); - cdnPool.ReturnConnection( connection ); - break; - } - catch ( SteamKitWebRequestException e ) - { - cdnPool.ReturnBrokenConnection( connection ); - - if ( e.StatusCode == HttpStatusCode.Unauthorized || e.StatusCode == HttpStatusCode.Forbidden ) - { - Console.WriteLine( "Encountered 401 for chunk {0}. Aborting.", chunkID ); - cts.Cancel(); - break; - } - else - { - Console.WriteLine( "Encountered error downloading chunk {0}: {1}", chunkID, e.StatusCode ); - } - } - catch ( TaskCanceledException ) - { - Console.WriteLine( "Connection timeout downloading chunk {0}", chunkID ); - } - catch ( Exception e ) - { - cdnPool.ReturnBrokenConnection( connection ); - Console.WriteLine( "Encountered unexpected error downloading chunk {0}: {1}", chunkID, e.Message ); - } - } + Console.WriteLine("Validating {0}", fileFinalPath); + neededChunks = Util.ValidateSteam3FileChecksums(fs, file.Chunks.OrderBy(x => x.Offset).ToArray()); + } - if ( chunkData == null ) - { - Console.WriteLine( "Failed to find any server with chunk {0} for depot {1}. Aborting.", chunkID, depot.id ); - cts.Cancel(); - } + if (neededChunks.Count() == 0) + { + lock (depotDownloadCounter) + { + depotDownloadCounter.SizeDownloaded += (ulong)file.TotalSize; + Console.WriteLine("{0,6:#00.00}% {1}", ((float)depotDownloadCounter.SizeDownloaded / (float)depotDownloadCounter.CompleteDownloadSize) * 100.0f, fileFinalPath); + } - // Throw the cancellation exception if requested so that this task is marked failed - cts.Token.ThrowIfCancellationRequested(); + if (fs != null) + fs.Dispose(); + return; + } + else + { + var sizeOnDisk = (file.TotalSize - (ulong)neededChunks.Select(x => (long)x.UncompressedLength).Sum()); + lock (depotDownloadCounter) + { + depotDownloadCounter.SizeDownloaded += sizeOnDisk; + } + } + } - TotalBytesCompressed += chunk.CompressedLength; - DepotBytesCompressed += chunk.CompressedLength; - TotalBytesUncompressed += chunk.UncompressedLength; - DepotBytesUncompressed += chunk.UncompressedLength; + FileStreamData fileStreamData = new FileStreamData + { + fileStream = fs, + fileLock = new SemaphoreSlim(1), + chunksToDownload = neededChunks.Count + }; - fs.Seek( ( long )chunk.Offset, SeekOrigin.Begin ); - fs.Write( chunkData.Data, 0, chunkData.Data.Length ); + foreach (var chunk in neededChunks) + { + networkChunkQueue.Enqueue(Tuple.Create(fileStreamData, file, chunk)); + } + } - size_downloaded += chunk.UncompressedLength; - } + private static async Task DownloadSteam3AsyncDepotFileChunk( + CancellationTokenSource cts, uint appId, + GlobalDownloadCounter downloadCounter, + DepotFilesData depotFilesData, + ProtoManifest.FileData file, + FileStreamData fileStreamData, + ProtoManifest.ChunkData chunk) + { + cts.Token.ThrowIfCancellationRequested(); - fs.Dispose(); + var depot = depotFilesData.depotDownloadInfo; + var depotDownloadCounter = depotFilesData.depotCounter; - Console.WriteLine( "{0,6:#00.00}% {1}", ( ( float )size_downloaded / ( float )complete_download_size ) * 100.0f, fileFinalPath ); - } - finally - { - semaphore.Release(); - } - } ); + string chunkID = Util.EncodeHexString(chunk.ChunkID); - tasks[ i ] = task; - } + DepotManifest.ChunkData data = new DepotManifest.ChunkData(); + data.ChunkID = chunk.ChunkID; + data.Checksum = chunk.Checksum; + data.Offset = chunk.Offset; + data.CompressedLength = chunk.CompressedLength; + data.UncompressedLength = chunk.UncompressedLength; - await Task.WhenAll( tasks ).ConfigureAwait( false ); + CDNClient.DepotChunk chunkData = null; - // Check for deleted files if updating the depot. - if ( oldProtoManifest != null ) - { - var oldfilesAfterExclusions = oldProtoManifest.Files.AsParallel().Where( f => TestIsFileIncluded( f.FileName ) ).ToList(); + do + { + cts.Token.ThrowIfCancellationRequested(); - foreach ( var file in oldfilesAfterExclusions ) - { - // Delete it if it's in the old manifest AND not in the new manifest AND not in any of the previous depots. - var newManifestFile = filesAfterExclusions.SingleOrDefault( f => f.FileName == file.FileName ); - if ( newManifestFile == null ) - continue; + CDNClient.Server connection = null; - var previousFile = previousFiles.SingleOrDefault( f => f.FileName == file.FileName ); - if ( previousFile == null ) - continue; + try + { + connection = cdnPool.GetConnection(cts.Token); + var cdnToken = await cdnPool.AuthenticateConnection(appId, depot.id, connection); - string fileFinalPath = Path.Combine( depot.installDir, file.FileName ); - if ( !File.Exists( fileFinalPath ) ) - continue; + chunkData = await cdnPool.CDNClient.DownloadDepotChunkAsync(depot.id, data, + connection, cdnToken, depot.depotKey).ConfigureAwait(false); - File.Delete( fileFinalPath ); - Console.WriteLine( "Deleted {0}", fileFinalPath ); + cdnPool.ReturnConnection(connection); + } + catch (TaskCanceledException) + { + Console.WriteLine("Connection timeout downloading chunk {0}", chunkID); + } + catch (SteamKitWebRequestException e) + { + cdnPool.ReturnBrokenConnection(connection); + + if (e.StatusCode == HttpStatusCode.Unauthorized || e.StatusCode == HttpStatusCode.Forbidden) + { + Console.WriteLine("Encountered 401 for chunk {0}. Aborting.", chunkID); + break; + } + else + { + Console.WriteLine("Encountered error downloading chunk {0}: {1}", chunkID, e.StatusCode); } } + catch (OperationCanceledException) + { + break; + } + catch (Exception e) + { + cdnPool.ReturnBrokenConnection(connection); + Console.WriteLine("Encountered unexpected error downloading chunk {0}: {1}", chunkID, e.Message); + } + } + while (chunkData == null); - // Remember files we processed for later. - previousFiles.AddRange( filesAfterExclusions ); + if (chunkData == null) + { + Console.WriteLine("Failed to find any server with chunk {0} for depot {1}. Aborting.", chunkID, depot.id); + cts.Cancel(); + } - DepotConfigStore.Instance.InstalledManifestIDs[ depot.id ] = depot.manifestId; - DepotConfigStore.Save(); + // Throw the cancellation exception if requested so that this task is marked failed + cts.Token.ThrowIfCancellationRequested(); - Console.WriteLine( "Depot {0} - Downloaded {1} bytes ({2} bytes uncompressed)", depot.id, DepotBytesCompressed, DepotBytesUncompressed ); + try + { + await fileStreamData.fileLock.WaitAsync().ConfigureAwait(false); + + fileStreamData.fileStream.Seek((long)chunkData.ChunkInfo.Offset, SeekOrigin.Begin); + await fileStreamData.fileStream.WriteAsync(chunkData.Data, 0, chunkData.Data.Length); + } + finally + { + fileStreamData.fileLock.Release(); + } + + int remainingChunks = Interlocked.Decrement(ref fileStreamData.chunksToDownload); + if (remainingChunks == 0) + { + fileStreamData.fileStream.Dispose(); + fileStreamData.fileLock.Dispose(); + } + + ulong sizeDownloaded = 0; + lock (depotDownloadCounter) + { + sizeDownloaded = depotDownloadCounter.SizeDownloaded + (ulong)chunkData.Data.Length; + depotDownloadCounter.SizeDownloaded = sizeDownloaded; + depotDownloadCounter.DepotBytesCompressed += chunk.CompressedLength; + depotDownloadCounter.DepotBytesUncompressed += chunk.UncompressedLength; + } + + lock (downloadCounter) + { + downloadCounter.TotalBytesCompressed += chunk.CompressedLength; + downloadCounter.TotalBytesUncompressed += chunk.UncompressedLength; + } + + if (remainingChunks == 0) + { + var fileFinalPath = Path.Combine(depot.installDir, file.FileName); + Console.WriteLine("{0,6:#00.00}% {1}", ((float)sizeDownloaded / (float)depotDownloadCounter.CompleteDownloadSize) * 100.0f, fileFinalPath); } - Console.WriteLine( "Total downloaded: {0} bytes ({1} bytes uncompressed) from {2} depots", TotalBytesCompressed, TotalBytesUncompressed, depots.Count ); } } } diff --git a/DepotDownloader/Program.cs b/DepotDownloader/Program.cs index b3b9b737..d2a77266 100644 --- a/DepotDownloader/Program.cs +++ b/DepotDownloader/Program.cs @@ -108,7 +108,7 @@ namespace DepotDownloader ContentDownloader.Config.VerifyAll = HasParameter( args, "-verify-all" ) || HasParameter( args, "-verify_all" ) || HasParameter( args, "-validate" ); ContentDownloader.Config.MaxServers = GetParameter( args, "-max-servers", 20 ); - ContentDownloader.Config.MaxDownloads = GetParameter( args, "-max-downloads", 4 ); + ContentDownloader.Config.MaxDownloads = GetParameter( args, "-max-downloads", 8 ); ContentDownloader.Config.MaxServers = Math.Max( ContentDownloader.Config.MaxServers, ContentDownloader.Config.MaxDownloads ); ContentDownloader.Config.LoginID = HasParameter( args, "-loginid" ) ? (uint?)GetParameter( args, "-loginid" ) : null; @@ -341,8 +341,8 @@ namespace DepotDownloader Console.WriteLine(); Console.WriteLine( "\t-manifest-only\t\t\t- downloads a human readable manifest for any depots that would be downloaded." ); Console.WriteLine( "\t-cellid <#>\t\t\t\t- the overridden CellID of the content server to download from." ); - Console.WriteLine( "\t-max-servers <#>\t\t- maximum number of content servers to use. (default: 8)." ); - Console.WriteLine( "\t-max-downloads <#>\t\t- maximum number of chunks to download concurrently. (default: 4)." ); + Console.WriteLine( "\t-max-servers <#>\t\t- maximum number of content servers to use. (default: 20)." ); + Console.WriteLine( "\t-max-downloads <#>\t\t- maximum number of chunks to download concurrently. (default: 8)." ); Console.WriteLine( "\t-loginid <#>\t\t- a unique 32-bit integer Steam LogonID in decimal, required if running multiple instances of DepotDownloader concurrently." ); } } diff --git a/DepotDownloader/Steam3Session.cs b/DepotDownloader/Steam3Session.cs index f622bca1..7f33b65e 100644 --- a/DepotDownloader/Steam3Session.cs +++ b/DepotDownloader/Steam3Session.cs @@ -129,16 +129,24 @@ namespace DepotDownloader } public delegate bool WaitCondition(); + private object steamLock = new object(); + public bool WaitUntilCallback( Action submitter, WaitCondition waiter ) { while ( !bAborted && !waiter() ) { - submitter(); + lock (steamLock) + { + submitter(); + } int seq = this.seq; do { - WaitForCallbacks(); + lock (steamLock) + { + WaitForCallbacks(); + } } while ( !bAborted && this.seq == seq && !waiter() ); } @@ -473,7 +481,7 @@ namespace DepotDownloader public void TryWaitForLoginKey() { - if ( logonDetails.Username == null || !ContentDownloader.Config.RememberPassword ) return; + if ( logonDetails.Username == null || !credentials.LoggedOn || !ContentDownloader.Config.RememberPassword ) return; var totalWaitPeriod = DateTime.Now.AddSeconds( 3 ); @@ -584,7 +592,7 @@ namespace DepotDownloader } else { - Console.WriteLine( "Login key was expired. Please enter your password: " ); + Console.Write( "Login key was expired. Please enter your password: " ); logonDetails.Password = Util.ReadPassword(); } } diff --git a/DepotDownloader/Util.cs b/DepotDownloader/Util.cs index 1c08138b..7d866487 100644 --- a/DepotDownloader/Util.cs +++ b/DepotDownloader/Util.cs @@ -5,6 +5,7 @@ using System.Linq; using System.Runtime.InteropServices; using System.Security.Cryptography; using System.Text; +using System.Threading.Tasks; namespace DepotDownloader { @@ -132,5 +133,38 @@ namespace DepotDownloader ( sb, v ) => sb.Append( v.ToString( "x2" ) ) ).ToString(); } + + public static async Task InvokeAsync(IEnumerable> taskFactories, int maxDegreeOfParallelism) + { + if (taskFactories == null) throw new ArgumentNullException(nameof(taskFactories)); + if (maxDegreeOfParallelism <= 0) throw new ArgumentException(nameof(maxDegreeOfParallelism)); + + Func[] queue = taskFactories.ToArray(); + + if (queue.Length == 0) + { + return; + } + + List tasksInFlight = new List(maxDegreeOfParallelism); + int index = 0; + + do + { + while (tasksInFlight.Count < maxDegreeOfParallelism && index < queue.Length) + { + Func taskFactory = queue[index++]; + + tasksInFlight.Add(taskFactory()); + } + + Task completedTask = await Task.WhenAny(tasksInFlight).ConfigureAwait(false); + + await completedTask.ConfigureAwait(false); + + tasksInFlight.Remove(completedTask); + } + while (index < queue.Length || tasksInFlight.Count != 0); + } } } diff --git a/README.md b/README.md index 118f79ff..e71b5519 100644 --- a/README.md +++ b/README.md @@ -50,6 +50,6 @@ Parameter | Description -validate | Include checksum verification of files already downloaded -manifest-only | downloads a human readable manifest for any depots that would be downloaded. -cellid \<#> | the overridden CellID of the content server to download from. --max-servers \<#> | maximum number of content servers to use. (default: 8). --max-downloads \<#> | maximum number of chunks to download concurrently. (default: 4). +-max-servers \<#> | maximum number of content servers to use. (default: 20). +-max-downloads \<#> | maximum number of chunks to download concurrently. (default: 8). -loginid \<#> | a unique 32-bit integer Steam LogonID in decimal, required if running multiple instances of DepotDownloader concurrently.