diff --git a/DepotDownloader/CDNClientPool.cs b/DepotDownloader/CDNClientPool.cs index 654d3f54..380919fc 100644 --- a/DepotDownloader/CDNClientPool.cs +++ b/DepotDownloader/CDNClientPool.cs @@ -20,7 +20,7 @@ namespace DepotDownloader public CDNClient CDNClient { get; } - private readonly ConcurrentBag activeConnectionPool; + private readonly ConcurrentStack activeConnectionPool; private readonly BlockingCollection availableServerEndpoints; private readonly AutoResetEvent populatePoolEvent; @@ -33,7 +33,7 @@ namespace DepotDownloader this.steamSession = steamSession; CDNClient = new CDNClient(steamSession.steamClient); - activeConnectionPool = new ConcurrentBag(); + activeConnectionPool = new ConcurrentStack(); availableServerEndpoints = new BlockingCollection(); populatePoolEvent = new AutoResetEvent(true); @@ -122,24 +122,6 @@ namespace DepotDownloader } } - private async Task AuthenticateConnection(uint appId, uint depotId, CDNClient.Server server) - { - var host = steamSession.ResolveCDNTopLevelHost(server.Host); - var cdnKey = $"{depotId:D}:{host}"; - - steamSession.RequestCDNAuthToken(appId, depotId, host, cdnKey); - - if (steamSession.CDNAuthTokens.TryGetValue(cdnKey, out var authTokenCallbackPromise)) - { - var result = await authTokenCallbackPromise.Task; - return result.Token; - } - else - { - throw new Exception($"Failed to retrieve CDN token for server {server.Host} depot {depotId}"); - } - } - private CDNClient.Server BuildConnection(CancellationToken token) { if (availableServerEndpoints.Count < ServerEndpointMinimumSize) @@ -150,29 +132,42 @@ namespace DepotDownloader return availableServerEndpoints.Take(token); } - public async Task> GetConnectionForDepot(uint appId, uint depotId, CancellationToken token) + public CDNClient.Server GetConnection(CancellationToken token) { - // Take a free connection from the connection pool - // If there were no free connections, create a new one from the server list - if (!activeConnectionPool.TryTake(out var server)) + if (!activeConnectionPool.TryPop(out var connection)) { - server = BuildConnection(token); + connection = BuildConnection(token); } - // If we don't have a CDN token yet for this server and depot, fetch one now - var cdnToken = await AuthenticateConnection(appId, depotId, server); + return connection; + } + + public async Task AuthenticateConnection(uint appId, uint depotId, CDNClient.Server server) + { + var host = steamSession.ResolveCDNTopLevelHost(server.Host); + var cdnKey = $"{depotId:D}:{host}"; + + steamSession.RequestCDNAuthToken(appId, depotId, host, cdnKey); - return Tuple.Create(server, cdnToken); + if (steamSession.CDNAuthTokens.TryGetValue(cdnKey, out var authTokenCallbackPromise)) + { + var result = await authTokenCallbackPromise.Task; + return result.Token; + } + else + { + throw new Exception($"Failed to retrieve CDN token for server {server.Host} depot {depotId}"); + } } - public void ReturnConnection(Tuple server) + public void ReturnConnection(CDNClient.Server server) { if (server == null) return; - activeConnectionPool.Add(server.Item1); + activeConnectionPool.Push(server); } - public void ReturnBrokenConnection(Tuple server) + public void ReturnBrokenConnection(CDNClient.Server server) { if (server == null) return; diff --git a/DepotDownloader/ContentDownloader.cs b/DepotDownloader/ContentDownloader.cs index aa40d00e..734d1779 100644 --- a/DepotDownloader/ContentDownloader.cs +++ b/DepotDownloader/ContentDownloader.cs @@ -1,5 +1,6 @@ using SteamKit2; using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.IO; using System.Linq; @@ -592,464 +593,518 @@ namespace DepotDownloader public ProtoManifest.ChunkData NewChunk { get; private set; } } - private static async Task DownloadSteam3Async( uint appId, List depots ) + private class FileStreamData { - ulong TotalBytesCompressed = 0; - ulong TotalBytesUncompressed = 0; + public FileStream fileStream; + public SemaphoreSlim fileLock; + public int chunksToDownload; + } + + private class GlobalDownloadCounter + { + public long TotalBytesCompressed; + public long TotalBytesUncompressed; + } - foreach ( var depot in depots ) + private class DepotDownloadCounter + { + public long CompleteDownloadSize; + public long SizeDownloaded; + public long DepotBytesCompressed; + public long DepotBytesUncompressed; + + } + + private static async Task DownloadSteam3Async(uint appId, List depots) + { + CancellationTokenSource cts = new CancellationTokenSource(); + cdnPool.ExhaustedToken = cts; + + GlobalDownloadCounter downloadCounter = new GlobalDownloadCounter(); + + foreach (var depot in depots) { - ulong DepotBytesCompressed = 0; - ulong DepotBytesUncompressed = 0; + await DownloadSteam3AsyncDepot(cts, downloadCounter, appId, depot); - Console.WriteLine( "Downloading depot {0} - {1}", depot.id, depot.contentName ); + cts.Token.ThrowIfCancellationRequested(); + } - CancellationTokenSource cts = new CancellationTokenSource(); - cdnPool.ExhaustedToken = cts; + Console.WriteLine("Total downloaded: {0} bytes ({1} bytes uncompressed) from {2} depots", + downloadCounter.TotalBytesCompressed, downloadCounter.TotalBytesUncompressed, depots.Count); + } - ProtoManifest oldProtoManifest = null; - ProtoManifest newProtoManifest = null; - string configDir = Path.Combine( depot.installDir, CONFIG_DIR ); + private static async Task DownloadSteam3AsyncDepot(CancellationTokenSource cts, + GlobalDownloadCounter downloadCounter, + uint appId, DepotDownloadInfo depot) + { + DepotDownloadCounter depotCounter = new DepotDownloadCounter(); + + Console.WriteLine("Downloading depot {0} - {1}", depot.id, depot.contentName); + + ProtoManifest oldProtoManifest = null; + ProtoManifest newProtoManifest = null; + string configDir = Path.Combine(depot.installDir, CONFIG_DIR); - ulong lastManifestId = INVALID_MANIFEST_ID; - DepotConfigStore.Instance.InstalledManifestIDs.TryGetValue( depot.id, out lastManifestId ); + ulong lastManifestId = INVALID_MANIFEST_ID; + DepotConfigStore.Instance.InstalledManifestIDs.TryGetValue(depot.id, out lastManifestId); - // In case we have an early exit, this will force equiv of verifyall next run. - DepotConfigStore.Instance.InstalledManifestIDs[ depot.id ] = INVALID_MANIFEST_ID; - DepotConfigStore.Save(); + // In case we have an early exit, this will force equiv of verifyall next run. + DepotConfigStore.Instance.InstalledManifestIDs[depot.id] = INVALID_MANIFEST_ID; + DepotConfigStore.Save(); - if ( lastManifestId != INVALID_MANIFEST_ID ) + if (lastManifestId != INVALID_MANIFEST_ID) + { + var oldManifestFileName = Path.Combine(configDir, string.Format("{0}_{1}.bin", depot.id, lastManifestId)); + + if (File.Exists(oldManifestFileName)) { - var oldManifestFileName = Path.Combine( configDir, string.Format( "{0}.bin", lastManifestId ) ); + byte[] expectedChecksum, currentChecksum; - if (File.Exists(oldManifestFileName)) + try + { + expectedChecksum = File.ReadAllBytes(oldManifestFileName + ".sha"); + } + catch (IOException) { - byte[] expectedChecksum, currentChecksum; + expectedChecksum = null; + } - try - { - expectedChecksum = File.ReadAllBytes(oldManifestFileName + ".sha"); - } - catch (IOException) - { - expectedChecksum = null; - } + oldProtoManifest = ProtoManifest.LoadFromFile(oldManifestFileName, out currentChecksum); - oldProtoManifest = ProtoManifest.LoadFromFile(oldManifestFileName, out currentChecksum); + if (expectedChecksum == null || !expectedChecksum.SequenceEqual(currentChecksum)) + { + // We only have to show this warning if the old manifest ID was different + if (lastManifestId != depot.manifestId) + Console.WriteLine("Manifest {0} on disk did not match the expected checksum.", lastManifestId); + oldProtoManifest = null; + } + } + } - if (expectedChecksum == null || !expectedChecksum.SequenceEqual(currentChecksum)) - { - // We only have to show this warning if the old manifest ID was different - if (lastManifestId != depot.manifestId) - Console.WriteLine("Manifest {0} on disk did not match the expected checksum.", lastManifestId); - oldProtoManifest = null; - } + if (lastManifestId == depot.manifestId && oldProtoManifest != null) + { + newProtoManifest = oldProtoManifest; + Console.WriteLine("Already have manifest {0} for depot {1}.", depot.manifestId, depot.id); + } + else + { + var newManifestFileName = Path.Combine(configDir, string.Format("{0}_{1}.bin", depot.id, depot.manifestId)); + if (newManifestFileName != null) + { + byte[] expectedChecksum, currentChecksum; + + try + { + expectedChecksum = File.ReadAllBytes(newManifestFileName + ".sha"); + } + catch (IOException) + { + expectedChecksum = null; + } + + newProtoManifest = ProtoManifest.LoadFromFile(newManifestFileName, out currentChecksum); + + if (newProtoManifest != null && (expectedChecksum == null || !expectedChecksum.SequenceEqual(currentChecksum))) + { + Console.WriteLine("Manifest {0} on disk did not match the expected checksum.", depot.manifestId); + newProtoManifest = null; } } - if ( lastManifestId == depot.manifestId && oldProtoManifest != null ) + if (newProtoManifest != null) { - newProtoManifest = oldProtoManifest; - Console.WriteLine( "Already have manifest {0} for depot {1}.", depot.manifestId, depot.id ); + Console.WriteLine("Already have manifest {0} for depot {1}.", depot.manifestId, depot.id); } else { - var newManifestFileName = Path.Combine( configDir, string.Format( "{0}_{1}.bin", depot.id, depot.manifestId ) ); - if ( newManifestFileName != null ) + Console.Write("Downloading depot manifest..."); + + DepotManifest depotManifest = null; + + do { - byte[] expectedChecksum, currentChecksum; + cts.Token.ThrowIfCancellationRequested(); + + CDNClient.Server connection = null; try { - expectedChecksum = File.ReadAllBytes(newManifestFileName + ".sha"); + connection = cdnPool.GetConnection(cts.Token); + var cdnToken = await cdnPool.AuthenticateConnection(appId, depot.id, connection); + + depotManifest = await cdnPool.CDNClient.DownloadManifestAsync(depot.id, depot.manifestId, + connection, cdnToken, depot.depotKey).ConfigureAwait(false); + + cdnPool.ReturnConnection(connection); } - catch (IOException) + catch (TaskCanceledException) { - expectedChecksum = null; + Console.WriteLine("Connection timeout downloading depot manifest {0} {1}", depot.id, depot.manifestId); } + catch (SteamKitWebRequestException e) + { + cdnPool.ReturnBrokenConnection(connection); - newProtoManifest = ProtoManifest.LoadFromFile(newManifestFileName, out currentChecksum); - - if (newProtoManifest != null && (expectedChecksum == null || !expectedChecksum.SequenceEqual(currentChecksum))) + if (e.StatusCode == HttpStatusCode.Unauthorized || e.StatusCode == HttpStatusCode.Forbidden) + { + Console.WriteLine("Encountered 401 for depot manifest {0} {1}. Aborting.", depot.id, depot.manifestId); + break; + } + else + { + Console.WriteLine("Encountered error downloading depot manifest {0} {1}: {2}", depot.id, depot.manifestId, e.StatusCode); + } + } + catch (OperationCanceledException) + { + break; + } + catch (Exception e) { - Console.WriteLine("Manifest {0} on disk did not match the expected checksum.", depot.manifestId); - newProtoManifest = null; + cdnPool.ReturnBrokenConnection(connection); + Console.WriteLine("Encountered error downloading manifest for depot {0} {1}: {2}", depot.id, depot.manifestId, e.Message); } } + while (depotManifest == null); - if ( newProtoManifest != null ) + if (depotManifest == null) { - Console.WriteLine( "Already have manifest {0} for depot {1}.", depot.manifestId, depot.id ); + Console.WriteLine("\nUnable to download manifest {0} for depot {1}", depot.manifestId, depot.id); + cts.Cancel(); } - else - { - Console.Write( "Downloading depot manifest..." ); - DepotManifest depotManifest = null; + // Throw the cancellation exception if requested so that this task is marked failed + cts.Token.ThrowIfCancellationRequested(); - while ( depotManifest == null ) - { - Tuple connection = null; - try - { - connection = await cdnPool.GetConnectionForDepot( appId, depot.id, CancellationToken.None ); + byte[] checksum; - depotManifest = await cdnPool.CDNClient.DownloadManifestAsync( depot.id, depot.manifestId, - connection.Item1, connection.Item2, depot.depotKey ).ConfigureAwait(false); + newProtoManifest = new ProtoManifest(depotManifest, depot.manifestId); + newProtoManifest.SaveToFile(newManifestFileName, out checksum); + File.WriteAllBytes(newManifestFileName + ".sha", checksum); - cdnPool.ReturnConnection( connection ); - } - catch ( SteamKitWebRequestException e ) - { - cdnPool.ReturnBrokenConnection( connection ); + Console.WriteLine(" Done!"); + } + } - if ( e.StatusCode == HttpStatusCode.Unauthorized || e.StatusCode == HttpStatusCode.Forbidden ) - { - Console.WriteLine( "Encountered 401 for depot manifest {0} {1}. Aborting.", depot.id, depot.manifestId ); - break; - } - else - { - Console.WriteLine( "Encountered error downloading depot manifest {0} {1}: {2}", depot.id, depot.manifestId, e.StatusCode ); - } - } - catch ( Exception e ) - { - cdnPool.ReturnBrokenConnection( connection ); - Console.WriteLine( "Encountered error downloading manifest for depot {0} {1}: {2}", depot.id, depot.manifestId, e.Message ); - } - } + newProtoManifest.Files.Sort((x, y) => string.Compare(x.FileName, y.FileName, StringComparison.Ordinal)); - if ( depotManifest == null ) - { - Console.WriteLine( "\nUnable to download manifest {0} for depot {1}", depot.manifestId, depot.id ); - return; - } + Console.WriteLine("Manifest {0} ({1})", depot.manifestId, newProtoManifest.CreationTime); - byte[] checksum; + if (Config.DownloadManifestOnly) + { + StringBuilder manifestBuilder = new StringBuilder(); + string txtManifest = Path.Combine(depot.installDir, string.Format("manifest_{0}_{1}.txt", depot.id, depot.manifestId)); + manifestBuilder.Append(string.Format("{0}\n\n", newProtoManifest.CreationTime)); - newProtoManifest = new ProtoManifest( depotManifest, depot.manifestId ); - newProtoManifest.SaveToFile( newManifestFileName, out checksum ); - File.WriteAllBytes( newManifestFileName + ".sha", checksum ); + foreach (var file in newProtoManifest.Files) + { + if (file.Flags.HasFlag(EDepotFileFlag.Directory)) + continue; - Console.WriteLine( " Done!" ); - } + manifestBuilder.Append(string.Format("{0}\n", file.FileName)); + manifestBuilder.Append(string.Format("\t{0}\n", file.TotalSize)); + manifestBuilder.Append(string.Format("\t{0}\n", BitConverter.ToString(file.FileHash).Replace("-", ""))); } - newProtoManifest.Files.Sort( ( x, y ) => string.Compare( x.FileName, y.FileName, StringComparison.Ordinal ) ); + File.WriteAllText(txtManifest, manifestBuilder.ToString()); + return; + } - Console.WriteLine( "Manifest {0} ({1})", depot.manifestId, newProtoManifest.CreationTime ); + string stagingDir = Path.Combine(depot.installDir, STAGING_DIR); - if ( Config.DownloadManifestOnly ) - { - StringBuilder manifestBuilder = new StringBuilder(); - string txtManifest = Path.Combine( depot.installDir, string.Format( "manifest_{0}_{1}.txt", depot.id, depot.manifestId ) ); - manifestBuilder.Append( string.Format( "{0}\n\n", newProtoManifest.CreationTime ) ); + var filesAfterExclusions = newProtoManifest.Files.AsParallel().Where(f => TestIsFileIncluded(f.FileName)).ToList(); - foreach ( var file in newProtoManifest.Files ) - { - if ( file.Flags.HasFlag( EDepotFileFlag.Directory ) ) - continue; + // Pre-process + filesAfterExclusions.ForEach(file => + { + var fileFinalPath = Path.Combine(depot.installDir, file.FileName); + var fileStagingPath = Path.Combine(stagingDir, file.FileName); - manifestBuilder.Append( string.Format( "{0}\n", file.FileName ) ); - manifestBuilder.Append( string.Format( "\t{0}\n", file.TotalSize ) ); - manifestBuilder.Append( string.Format( "\t{0}\n", BitConverter.ToString( file.FileHash ).Replace( "-", "" ) ) ); - } + if (file.Flags.HasFlag(EDepotFileFlag.Directory)) + { + Directory.CreateDirectory(fileFinalPath); + Directory.CreateDirectory(fileStagingPath); + } + else + { + // Some manifests don't explicitly include all necessary directories + Directory.CreateDirectory(Path.GetDirectoryName(fileFinalPath)); + Directory.CreateDirectory(Path.GetDirectoryName(fileStagingPath)); - File.WriteAllText( txtManifest, manifestBuilder.ToString() ); - continue; + Interlocked.Add(ref depotCounter.CompleteDownloadSize, (long)file.TotalSize); } + }); + + var files = filesAfterExclusions.Where(f => !f.Flags.HasFlag(EDepotFileFlag.Directory)).ToArray(); + var networkChunkQueue = new ConcurrentQueue>(); + + await Util.InvokeAsync( + files.Select(f => new Func(async () => + await Task.Run(() => DownloadSteam3AsyncDepotFile(cts, downloadCounter, depotCounter, + stagingDir, oldProtoManifest, newProtoManifest, + networkChunkQueue, + appId, depot, f)))), + maxDegreeOfParallelism: Config.MaxDownloads + ); - ulong complete_download_size = 0; - ulong size_downloaded = 0; - string stagingDir = Path.Combine( depot.installDir, STAGING_DIR ); + await Util.InvokeAsync( + networkChunkQueue.Select((x) => new Func(async () => + await Task.Run(() => DownloadSteam3AsyncDepotFileChunk(cts, downloadCounter, depotCounter, + stagingDir, oldProtoManifest, newProtoManifest, + appId, depot, x.Item2, x.Item1, x.Item3)))), + maxDegreeOfParallelism: Config.MaxDownloads + ); - var filesAfterExclusions = newProtoManifest.Files.AsParallel().Where( f => TestIsFileIncluded( f.FileName ) ).ToList(); + DepotConfigStore.Instance.InstalledManifestIDs[depot.id] = depot.manifestId; + DepotConfigStore.Save(); - // Pre-process - filesAfterExclusions.ForEach( file => - { - var fileFinalPath = Path.Combine( depot.installDir, file.FileName ); - var fileStagingPath = Path.Combine( stagingDir, file.FileName ); + Console.WriteLine("Depot {0} - Downloaded {1} bytes ({2} bytes uncompressed)", depot.id, depotCounter.DepotBytesCompressed, depotCounter.DepotBytesUncompressed); + } - if ( file.Flags.HasFlag( EDepotFileFlag.Directory ) ) - { - Directory.CreateDirectory( fileFinalPath ); - Directory.CreateDirectory( fileStagingPath ); - } - else - { - // Some manifests don't explicitly include all necessary directories - Directory.CreateDirectory( Path.GetDirectoryName( fileFinalPath ) ); - Directory.CreateDirectory( Path.GetDirectoryName( fileStagingPath ) ); + private static void DownloadSteam3AsyncDepotFile( + CancellationTokenSource cts, + GlobalDownloadCounter downloadCounter, DepotDownloadCounter depotDownloadCounter, + string stagingDir, ProtoManifest oldProtoManifest, ProtoManifest newProtoManifest, + ConcurrentQueue> networkChunkQueue, + uint appId, DepotDownloadInfo depot, ProtoManifest.FileData file) + { + cts.Token.ThrowIfCancellationRequested(); - complete_download_size += file.TotalSize; - } - } ); + string fileFinalPath = Path.Combine(depot.installDir, file.FileName); + string fileStagingPath = Path.Combine(stagingDir, file.FileName); + + // This may still exist if the previous run exited before cleanup + if (File.Exists(fileStagingPath)) + { + File.Delete(fileStagingPath); + } - var ioSemaphore = new SemaphoreSlim( Config.MaxDownloads ); - var networkSemaphore = new SemaphoreSlim( Config.MaxDownloads ); + FileStream fs = null; + List neededChunks; + FileInfo fi = new FileInfo(fileFinalPath); + if (!fi.Exists) + { + // create new file. need all chunks + fs = File.Create(fileFinalPath); + fs.SetLength((long)file.TotalSize); + neededChunks = new List(file.Chunks); + } + else + { + // open existing + ProtoManifest.FileData oldManifestFile = null; + if (oldProtoManifest != null) + { + oldManifestFile = oldProtoManifest.Files.SingleOrDefault(f => f.FileName == file.FileName); + } - var files = filesAfterExclusions.Where( f => !f.Flags.HasFlag( EDepotFileFlag.Directory ) ).ToArray(); - var ioTasks = new Task[ files.Length ]; - for ( var i = 0; i < files.Length; i++ ) + if (oldManifestFile != null) { - var file = files[ i ]; - var task = Task.Run( async () => + neededChunks = new List(); + + if (Config.VerifyAll || !oldManifestFile.FileHash.SequenceEqual(file.FileHash)) { - cts.Token.ThrowIfCancellationRequested(); - - try + // we have a version of this file, but it doesn't fully match what we want + if (Config.VerifyAll) { - await ioSemaphore.WaitAsync().ConfigureAwait( false ); - cts.Token.ThrowIfCancellationRequested(); - - string fileFinalPath = Path.Combine( depot.installDir, file.FileName ); - string fileStagingPath = Path.Combine( stagingDir, file.FileName ); + Console.WriteLine("Validating {0}", fileFinalPath); + } - // This may still exist if the previous run exited before cleanup - if ( File.Exists( fileStagingPath ) ) - { - File.Delete( fileStagingPath ); - } + var matchingChunks = new List(); - FileStream fs = null; - List neededChunks; - FileInfo fi = new FileInfo( fileFinalPath ); - if ( !fi.Exists ) + foreach (var chunk in file.Chunks) + { + var oldChunk = oldManifestFile.Chunks.FirstOrDefault(c => c.ChunkID.SequenceEqual(chunk.ChunkID)); + if (oldChunk != null) { - // create new file. need all chunks - fs = File.Create( fileFinalPath ); - fs.SetLength( ( long )file.TotalSize ); - neededChunks = new List( file.Chunks ); + matchingChunks.Add(new ChunkMatch(oldChunk, chunk)); } else { - // open existing - ProtoManifest.FileData oldManifestFile = null; - if ( oldProtoManifest != null ) - { - oldManifestFile = oldProtoManifest.Files.SingleOrDefault( f => f.FileName == file.FileName ); - } + neededChunks.Add(chunk); + } + } - if ( oldManifestFile != null ) - { - neededChunks = new List(); - - if ( Config.VerifyAll || !oldManifestFile.FileHash.SequenceEqual( file.FileHash ) ) - { - // we have a version of this file, but it doesn't fully match what we want - - var matchingChunks = new List(); - - foreach ( var chunk in file.Chunks ) - { - var oldChunk = oldManifestFile.Chunks.FirstOrDefault( c => c.ChunkID.SequenceEqual( chunk.ChunkID ) ); - if ( oldChunk != null ) - { - matchingChunks.Add( new ChunkMatch( oldChunk, chunk ) ); - } - else - { - neededChunks.Add( chunk ); - } - } - - File.Move( fileFinalPath, fileStagingPath ); - - fs = File.Open( fileFinalPath, FileMode.Create ); - fs.SetLength( ( long )file.TotalSize ); - - using ( var fsOld = File.Open( fileStagingPath, FileMode.Open ) ) - { - foreach ( var match in matchingChunks ) - { - fsOld.Seek( ( long )match.OldChunk.Offset, SeekOrigin.Begin ); - - byte[] tmp = new byte[ match.OldChunk.UncompressedLength ]; - fsOld.Read( tmp, 0, tmp.Length ); - - byte[] adler = Util.AdlerHash( tmp ); - if ( !adler.SequenceEqual( match.OldChunk.Checksum ) ) - { - neededChunks.Add( match.NewChunk ); - } - else - { - fs.Seek( ( long )match.NewChunk.Offset, SeekOrigin.Begin ); - fs.Write( tmp, 0, tmp.Length ); - } - } - } - - File.Delete( fileStagingPath ); - } - } - else - { - // No old manifest or file not in old manifest. We must validate. + var orderedChunks = matchingChunks.OrderBy(x => x.OldChunk.Offset); - fs = File.Open( fileFinalPath, FileMode.Open ); - if ( ( ulong )fi.Length != file.TotalSize ) - { - fs.SetLength( ( long )file.TotalSize ); - } + File.Move(fileFinalPath, fileStagingPath); - neededChunks = Util.ValidateSteam3FileChecksums( fs, file.Chunks.OrderBy( x => x.Offset ).ToArray() ); - } + fs = File.Open(fileFinalPath, FileMode.Create); + fs.SetLength((long)file.TotalSize); - if ( neededChunks.Count() == 0 ) + using (var fsOld = File.Open(fileStagingPath, FileMode.Open)) + { + foreach (var match in orderedChunks) + { + fsOld.Seek((long)match.OldChunk.Offset, SeekOrigin.Begin); + + byte[] tmp = new byte[match.OldChunk.UncompressedLength]; + fsOld.Read(tmp, 0, tmp.Length); + + byte[] adler = Util.AdlerHash(tmp); + if (!adler.SequenceEqual(match.OldChunk.Checksum)) { - size_downloaded += file.TotalSize; - Console.WriteLine( "{0,6:#00.00}% {1}", ( ( float )size_downloaded / ( float )complete_download_size ) * 100.0f, fileFinalPath ); - if ( fs != null ) - fs.Dispose(); - return; + neededChunks.Add(match.NewChunk); } else { - size_downloaded += ( file.TotalSize - ( ulong )neededChunks.Select( x => ( long )x.UncompressedLength ).Sum() ); + fs.Seek((long)match.NewChunk.Offset, SeekOrigin.Begin); + fs.Write(tmp, 0, tmp.Length); } } + } - var fileSemaphore = new SemaphoreSlim(1); + File.Delete(fileStagingPath); + } + } + else + { + // No old manifest or file not in old manifest. We must validate. - var downloadTasks = new Task[neededChunks.Count]; - for (var x = 0; x < neededChunks.Count; x++) - { - var chunk = neededChunks[x]; + fs = File.Open(fileFinalPath, FileMode.Open); + if ((ulong)fi.Length != file.TotalSize) + { + fs.SetLength((long)file.TotalSize); + } - var downloadTask = Task.Run(async () => - { - cts.Token.ThrowIfCancellationRequested(); - - try - { - await networkSemaphore.WaitAsync().ConfigureAwait(false); - cts.Token.ThrowIfCancellationRequested(); - - string chunkID = Util.EncodeHexString(chunk.ChunkID); - CDNClient.DepotChunk chunkData = null; - - while (!cts.IsCancellationRequested) - { - Tuple connection; - try - { - connection = await cdnPool.GetConnectionForDepot(appId, depot.id, cts.Token); - } - catch (OperationCanceledException) - { - break; - } - - DepotManifest.ChunkData data = new DepotManifest.ChunkData(); - data.ChunkID = chunk.ChunkID; - data.Checksum = chunk.Checksum; - data.Offset = chunk.Offset; - data.CompressedLength = chunk.CompressedLength; - data.UncompressedLength = chunk.UncompressedLength; - - try - { - chunkData = await cdnPool.CDNClient.DownloadDepotChunkAsync(depot.id, data, - connection.Item1, connection.Item2, depot.depotKey).ConfigureAwait(false); - cdnPool.ReturnConnection(connection); - - break; - } - catch (SteamKitWebRequestException e) - { - cdnPool.ReturnBrokenConnection(connection); - - if (e.StatusCode == HttpStatusCode.Unauthorized || e.StatusCode == HttpStatusCode.Forbidden) - { - Console.WriteLine("Encountered 401 for chunk {0}. Aborting.", chunkID); - cts.Cancel(); - break; - } - else - { - Console.WriteLine("Encountered error downloading chunk {0}: {1}", chunkID, e.StatusCode); - } - } - catch (TaskCanceledException) - { - Console.WriteLine("Connection timeout downloading chunk {0}", chunkID); - } - catch (Exception e) - { - cdnPool.ReturnBrokenConnection(connection); - Console.WriteLine("Encountered unexpected error downloading chunk {0}: {1}", chunkID, e.Message); - } - } - - if (chunkData == null) - { - Console.WriteLine("Failed to find any server with chunk {0} for depot {1}. Aborting.", chunkID, depot.id); - cts.Cancel(); - } - - // Throw the cancellation exception if requested so that this task is marked failed - cts.Token.ThrowIfCancellationRequested(); - - try - { - await fileSemaphore.WaitAsync().ConfigureAwait(false); - - fs.Seek((long)chunkData.ChunkInfo.Offset, SeekOrigin.Begin); - fs.Write(chunkData.Data, 0, chunkData.Data.Length); - - return chunkData.ChunkInfo; - } - finally - { - fileSemaphore.Release(); - } - } - finally - { - networkSemaphore.Release(); - } - }); - - downloadTasks[x] = downloadTask; - } + Console.WriteLine("Validating {0}", fileFinalPath); + neededChunks = Util.ValidateSteam3FileChecksums(fs, file.Chunks.OrderBy(x => x.Offset).ToArray()); + } + + if (neededChunks.Count() == 0) + { + var sizeDownloaded = Interlocked.Add(ref depotDownloadCounter.SizeDownloaded, (long)file.TotalSize); + Console.WriteLine("{0,6:#00.00}% {1}", ((float)sizeDownloaded / (float)depotDownloadCounter.CompleteDownloadSize) * 100.0f, fileFinalPath); + + if (fs != null) + fs.Dispose(); + return; + } + else + { + var sizeDownloaded = ((long)file.TotalSize - (long)neededChunks.Select(x => (long)x.UncompressedLength).Sum()); + Interlocked.Add(ref depotDownloadCounter.SizeDownloaded, sizeDownloaded); + } + } - var completedDownloads = await Task.WhenAll(downloadTasks).ConfigureAwait(false); + FileStreamData fileStreamData = new FileStreamData + { + fileStream = fs, + fileLock = new SemaphoreSlim(1), + chunksToDownload = neededChunks.Count + }; - fs.Dispose(); + foreach (var chunk in neededChunks) + { + networkChunkQueue.Enqueue(Tuple.Create(fileStreamData, file, chunk)); + } + } - foreach (var chunkInfo in completedDownloads) - { - TotalBytesCompressed += chunkInfo.CompressedLength; - DepotBytesCompressed += chunkInfo.CompressedLength; - TotalBytesUncompressed += chunkInfo.UncompressedLength; - DepotBytesUncompressed += chunkInfo.UncompressedLength; + private static async Task DownloadSteam3AsyncDepotFileChunk( + CancellationTokenSource cts, + GlobalDownloadCounter downloadCounter, DepotDownloadCounter depotDownloadCounter, + string stagingDir, ProtoManifest oldProtoManifest, ProtoManifest newProtoManifest, + uint appId, DepotDownloadInfo depot, ProtoManifest.FileData file, + FileStreamData fileStreamData, ProtoManifest.ChunkData chunk) + { + cts.Token.ThrowIfCancellationRequested(); - size_downloaded += chunkInfo.UncompressedLength; - } + string chunkID = Util.EncodeHexString(chunk.ChunkID); - Console.WriteLine("{0,6:#00.00}% {1}", ((float)size_downloaded / (float)complete_download_size) * 100.0f, fileFinalPath); - } - finally - { - ioSemaphore.Release(); - } - } ); + DepotManifest.ChunkData data = new DepotManifest.ChunkData(); + data.ChunkID = chunk.ChunkID; + data.Checksum = chunk.Checksum; + data.Offset = chunk.Offset; + data.CompressedLength = chunk.CompressedLength; + data.UncompressedLength = chunk.UncompressedLength; + + CDNClient.DepotChunk chunkData = null; + + do + { + cts.Token.ThrowIfCancellationRequested(); + + CDNClient.Server connection = null; + + try + { + connection = cdnPool.GetConnection(cts.Token); + var cdnToken = await cdnPool.AuthenticateConnection(appId, depot.id, connection); + + chunkData = await cdnPool.CDNClient.DownloadDepotChunkAsync(depot.id, data, + connection, cdnToken, depot.depotKey).ConfigureAwait(false); + + cdnPool.ReturnConnection(connection); + } + catch (TaskCanceledException) + { + Console.WriteLine("Connection timeout downloading chunk {0}", chunkID); + } + catch (SteamKitWebRequestException e) + { + cdnPool.ReturnBrokenConnection(connection); - ioTasks[ i ] = task; + if (e.StatusCode == HttpStatusCode.Unauthorized || e.StatusCode == HttpStatusCode.Forbidden) + { + Console.WriteLine("Encountered 401 for chunk {0}. Aborting.", chunkID); + break; + } + else + { + Console.WriteLine("Encountered error downloading chunk {0}: {1}", chunkID, e.StatusCode); + } } + catch (OperationCanceledException) + { + break; + } + catch (Exception e) + { + cdnPool.ReturnBrokenConnection(connection); + Console.WriteLine("Encountered unexpected error downloading chunk {0}: {1}", chunkID, e.Message); + } + } + while (chunkData == null); - await Task.WhenAll(ioTasks).ConfigureAwait(false); + if (chunkData == null) + { + Console.WriteLine("Failed to find any server with chunk {0} for depot {1}. Aborting.", chunkID, depot.id); + cts.Cancel(); + } - DepotConfigStore.Instance.InstalledManifestIDs[ depot.id ] = depot.manifestId; - DepotConfigStore.Save(); + // Throw the cancellation exception if requested so that this task is marked failed + cts.Token.ThrowIfCancellationRequested(); + + try + { + await fileStreamData.fileLock.WaitAsync().ConfigureAwait(false); + + fileStreamData.fileStream.Seek((long)chunkData.ChunkInfo.Offset, SeekOrigin.Begin); + fileStreamData.fileStream.Write(chunkData.Data, 0, chunkData.Data.Length); + + + } + finally + { + fileStreamData.fileLock.Release(); + } + + Interlocked.Add(ref downloadCounter.TotalBytesCompressed, chunk.CompressedLength); + Interlocked.Add(ref depotDownloadCounter.DepotBytesCompressed, chunk.CompressedLength); + Interlocked.Add(ref downloadCounter.TotalBytesUncompressed, chunk.UncompressedLength); + Interlocked.Add(ref depotDownloadCounter.DepotBytesUncompressed, chunk.UncompressedLength); + + var sizeDownloaded = Interlocked.Add(ref depotDownloadCounter.SizeDownloaded, chunkData.Data.Length); + + int remainingChunks = Interlocked.Decrement(ref fileStreamData.chunksToDownload); + if (remainingChunks == 0) + { + fileStreamData.fileStream.Dispose(); - Console.WriteLine( "Depot {0} - Downloaded {1} bytes ({2} bytes uncompressed)", depot.id, DepotBytesCompressed, DepotBytesUncompressed ); + var fileFinalPath = Path.Combine(depot.installDir, file.FileName); + Console.WriteLine("{0,6:#00.00}% {1}", ((float)sizeDownloaded / (float)depotDownloadCounter.CompleteDownloadSize) * 100.0f, fileFinalPath); } - Console.WriteLine( "Total downloaded: {0} bytes ({1} bytes uncompressed) from {2} depots", TotalBytesCompressed, TotalBytesUncompressed, depots.Count ); } } } diff --git a/DepotDownloader/Util.cs b/DepotDownloader/Util.cs index 1c08138b..176cc1d8 100644 --- a/DepotDownloader/Util.cs +++ b/DepotDownloader/Util.cs @@ -5,6 +5,7 @@ using System.Linq; using System.Runtime.InteropServices; using System.Security.Cryptography; using System.Text; +using System.Threading.Tasks; namespace DepotDownloader { @@ -132,5 +133,38 @@ namespace DepotDownloader ( sb, v ) => sb.Append( v.ToString( "x2" ) ) ).ToString(); } + + public static async Task InvokeAsync(IEnumerable> taskFactories, int maxDegreeOfParallelism) + { + if (taskFactories == null) throw new ArgumentNullException(nameof(taskFactories)); + if (maxDegreeOfParallelism <= 0) throw new ArgumentException(nameof(maxDegreeOfParallelism)); + + Func[] queue = taskFactories.ToArray(); + + if (queue.Length == 0) + { + return; + } + + List tasksInFlight = new List(maxDegreeOfParallelism); + int index = 0; + + do + { + while (tasksInFlight.Count < maxDegreeOfParallelism && index < queue.Length) + { + Func taskFactory = queue[index++]; + + tasksInFlight.Add(taskFactory()); + } + + Task completedTask = await Task.WhenAny(tasksInFlight).ConfigureAwait(false); + + await completedTask.ConfigureAwait(false); + + tasksInFlight.Remove(completedTask); + } + while (index < queue.Length || tasksInFlight.Count != 0); + } } }