Refactored download Tasks and use InvokeAsync to throttle Tasks

pull/132/head
Ryan Kistner 6 years ago
parent 609a66f280
commit d02aec256c

@ -20,7 +20,7 @@ namespace DepotDownloader
public CDNClient CDNClient { get; } public CDNClient CDNClient { get; }
private readonly ConcurrentBag<CDNClient.Server> activeConnectionPool; private readonly ConcurrentStack<CDNClient.Server> activeConnectionPool;
private readonly BlockingCollection<CDNClient.Server> availableServerEndpoints; private readonly BlockingCollection<CDNClient.Server> availableServerEndpoints;
private readonly AutoResetEvent populatePoolEvent; private readonly AutoResetEvent populatePoolEvent;
@ -33,7 +33,7 @@ namespace DepotDownloader
this.steamSession = steamSession; this.steamSession = steamSession;
CDNClient = new CDNClient(steamSession.steamClient); CDNClient = new CDNClient(steamSession.steamClient);
activeConnectionPool = new ConcurrentBag<CDNClient.Server>(); activeConnectionPool = new ConcurrentStack<CDNClient.Server>();
availableServerEndpoints = new BlockingCollection<CDNClient.Server>(); availableServerEndpoints = new BlockingCollection<CDNClient.Server>();
populatePoolEvent = new AutoResetEvent(true); populatePoolEvent = new AutoResetEvent(true);
@ -122,24 +122,6 @@ namespace DepotDownloader
} }
} }
private async Task<string> AuthenticateConnection(uint appId, uint depotId, CDNClient.Server server)
{
var host = steamSession.ResolveCDNTopLevelHost(server.Host);
var cdnKey = $"{depotId:D}:{host}";
steamSession.RequestCDNAuthToken(appId, depotId, host, cdnKey);
if (steamSession.CDNAuthTokens.TryGetValue(cdnKey, out var authTokenCallbackPromise))
{
var result = await authTokenCallbackPromise.Task;
return result.Token;
}
else
{
throw new Exception($"Failed to retrieve CDN token for server {server.Host} depot {depotId}");
}
}
private CDNClient.Server BuildConnection(CancellationToken token) private CDNClient.Server BuildConnection(CancellationToken token)
{ {
if (availableServerEndpoints.Count < ServerEndpointMinimumSize) if (availableServerEndpoints.Count < ServerEndpointMinimumSize)
@ -150,29 +132,42 @@ namespace DepotDownloader
return availableServerEndpoints.Take(token); return availableServerEndpoints.Take(token);
} }
public async Task<Tuple<CDNClient.Server, string>> GetConnectionForDepot(uint appId, uint depotId, CancellationToken token) public CDNClient.Server GetConnection(CancellationToken token)
{ {
// Take a free connection from the connection pool if (!activeConnectionPool.TryPop(out var connection))
// If there were no free connections, create a new one from the server list
if (!activeConnectionPool.TryTake(out var server))
{ {
server = BuildConnection(token); connection = BuildConnection(token);
} }
// If we don't have a CDN token yet for this server and depot, fetch one now return connection;
var cdnToken = await AuthenticateConnection(appId, depotId, server); }
public async Task<string> AuthenticateConnection(uint appId, uint depotId, CDNClient.Server server)
{
var host = steamSession.ResolveCDNTopLevelHost(server.Host);
var cdnKey = $"{depotId:D}:{host}";
steamSession.RequestCDNAuthToken(appId, depotId, host, cdnKey);
return Tuple.Create(server, cdnToken); if (steamSession.CDNAuthTokens.TryGetValue(cdnKey, out var authTokenCallbackPromise))
{
var result = await authTokenCallbackPromise.Task;
return result.Token;
}
else
{
throw new Exception($"Failed to retrieve CDN token for server {server.Host} depot {depotId}");
}
} }
public void ReturnConnection(Tuple<CDNClient.Server, string> server) public void ReturnConnection(CDNClient.Server server)
{ {
if (server == null) return; if (server == null) return;
activeConnectionPool.Add(server.Item1); activeConnectionPool.Push(server);
} }
public void ReturnBrokenConnection(Tuple<CDNClient.Server, string> server) public void ReturnBrokenConnection(CDNClient.Server server)
{ {
if (server == null) return; if (server == null) return;

@ -1,5 +1,6 @@
using SteamKit2; using SteamKit2;
using System; using System;
using System.Collections.Concurrent;
using System.Collections.Generic; using System.Collections.Generic;
using System.IO; using System.IO;
using System.Linq; using System.Linq;
@ -592,35 +593,68 @@ namespace DepotDownloader
public ProtoManifest.ChunkData NewChunk { get; private set; } public ProtoManifest.ChunkData NewChunk { get; private set; }
} }
private static async Task DownloadSteam3Async( uint appId, List<DepotDownloadInfo> depots ) private class FileStreamData
{ {
ulong TotalBytesCompressed = 0; public FileStream fileStream;
ulong TotalBytesUncompressed = 0; public SemaphoreSlim fileLock;
public int chunksToDownload;
}
foreach ( var depot in depots ) private class GlobalDownloadCounter
{ {
ulong DepotBytesCompressed = 0; public long TotalBytesCompressed;
ulong DepotBytesUncompressed = 0; public long TotalBytesUncompressed;
}
Console.WriteLine( "Downloading depot {0} - {1}", depot.id, depot.contentName ); private class DepotDownloadCounter
{
public long CompleteDownloadSize;
public long SizeDownloaded;
public long DepotBytesCompressed;
public long DepotBytesUncompressed;
}
private static async Task DownloadSteam3Async(uint appId, List<DepotDownloadInfo> depots)
{
CancellationTokenSource cts = new CancellationTokenSource(); CancellationTokenSource cts = new CancellationTokenSource();
cdnPool.ExhaustedToken = cts; cdnPool.ExhaustedToken = cts;
GlobalDownloadCounter downloadCounter = new GlobalDownloadCounter();
foreach (var depot in depots)
{
await DownloadSteam3AsyncDepot(cts, downloadCounter, appId, depot);
cts.Token.ThrowIfCancellationRequested();
}
Console.WriteLine("Total downloaded: {0} bytes ({1} bytes uncompressed) from {2} depots",
downloadCounter.TotalBytesCompressed, downloadCounter.TotalBytesUncompressed, depots.Count);
}
private static async Task DownloadSteam3AsyncDepot(CancellationTokenSource cts,
GlobalDownloadCounter downloadCounter,
uint appId, DepotDownloadInfo depot)
{
DepotDownloadCounter depotCounter = new DepotDownloadCounter();
Console.WriteLine("Downloading depot {0} - {1}", depot.id, depot.contentName);
ProtoManifest oldProtoManifest = null; ProtoManifest oldProtoManifest = null;
ProtoManifest newProtoManifest = null; ProtoManifest newProtoManifest = null;
string configDir = Path.Combine( depot.installDir, CONFIG_DIR ); string configDir = Path.Combine(depot.installDir, CONFIG_DIR);
ulong lastManifestId = INVALID_MANIFEST_ID; ulong lastManifestId = INVALID_MANIFEST_ID;
DepotConfigStore.Instance.InstalledManifestIDs.TryGetValue( depot.id, out lastManifestId ); DepotConfigStore.Instance.InstalledManifestIDs.TryGetValue(depot.id, out lastManifestId);
// In case we have an early exit, this will force equiv of verifyall next run. // In case we have an early exit, this will force equiv of verifyall next run.
DepotConfigStore.Instance.InstalledManifestIDs[ depot.id ] = INVALID_MANIFEST_ID; DepotConfigStore.Instance.InstalledManifestIDs[depot.id] = INVALID_MANIFEST_ID;
DepotConfigStore.Save(); DepotConfigStore.Save();
if ( lastManifestId != INVALID_MANIFEST_ID ) if (lastManifestId != INVALID_MANIFEST_ID)
{ {
var oldManifestFileName = Path.Combine( configDir, string.Format( "{0}.bin", lastManifestId ) ); var oldManifestFileName = Path.Combine(configDir, string.Format("{0}_{1}.bin", depot.id, lastManifestId));
if (File.Exists(oldManifestFileName)) if (File.Exists(oldManifestFileName))
{ {
@ -647,15 +681,15 @@ namespace DepotDownloader
} }
} }
if ( lastManifestId == depot.manifestId && oldProtoManifest != null ) if (lastManifestId == depot.manifestId && oldProtoManifest != null)
{ {
newProtoManifest = oldProtoManifest; newProtoManifest = oldProtoManifest;
Console.WriteLine( "Already have manifest {0} for depot {1}.", depot.manifestId, depot.id ); Console.WriteLine("Already have manifest {0} for depot {1}.", depot.manifestId, depot.id);
} }
else else
{ {
var newManifestFileName = Path.Combine( configDir, string.Format( "{0}_{1}.bin", depot.id, depot.manifestId ) ); var newManifestFileName = Path.Combine(configDir, string.Format("{0}_{1}.bin", depot.id, depot.manifestId));
if ( newManifestFileName != null ) if (newManifestFileName != null)
{ {
byte[] expectedChecksum, currentChecksum; byte[] expectedChecksum, currentChecksum;
@ -677,271 +711,305 @@ namespace DepotDownloader
} }
} }
if ( newProtoManifest != null ) if (newProtoManifest != null)
{ {
Console.WriteLine( "Already have manifest {0} for depot {1}.", depot.manifestId, depot.id ); Console.WriteLine("Already have manifest {0} for depot {1}.", depot.manifestId, depot.id);
} }
else else
{ {
Console.Write( "Downloading depot manifest..." ); Console.Write("Downloading depot manifest...");
DepotManifest depotManifest = null; DepotManifest depotManifest = null;
while ( depotManifest == null ) do
{ {
Tuple<CDNClient.Server, string> connection = null; cts.Token.ThrowIfCancellationRequested();
CDNClient.Server connection = null;
try try
{ {
connection = await cdnPool.GetConnectionForDepot( appId, depot.id, CancellationToken.None ); connection = cdnPool.GetConnection(cts.Token);
var cdnToken = await cdnPool.AuthenticateConnection(appId, depot.id, connection);
depotManifest = await cdnPool.CDNClient.DownloadManifestAsync( depot.id, depot.manifestId, depotManifest = await cdnPool.CDNClient.DownloadManifestAsync(depot.id, depot.manifestId,
connection.Item1, connection.Item2, depot.depotKey ).ConfigureAwait(false); connection, cdnToken, depot.depotKey).ConfigureAwait(false);
cdnPool.ReturnConnection( connection ); cdnPool.ReturnConnection(connection);
} }
catch ( SteamKitWebRequestException e ) catch (TaskCanceledException)
{
Console.WriteLine("Connection timeout downloading depot manifest {0} {1}", depot.id, depot.manifestId);
}
catch (SteamKitWebRequestException e)
{ {
cdnPool.ReturnBrokenConnection( connection ); cdnPool.ReturnBrokenConnection(connection);
if ( e.StatusCode == HttpStatusCode.Unauthorized || e.StatusCode == HttpStatusCode.Forbidden ) if (e.StatusCode == HttpStatusCode.Unauthorized || e.StatusCode == HttpStatusCode.Forbidden)
{ {
Console.WriteLine( "Encountered 401 for depot manifest {0} {1}. Aborting.", depot.id, depot.manifestId ); Console.WriteLine("Encountered 401 for depot manifest {0} {1}. Aborting.", depot.id, depot.manifestId);
break; break;
} }
else else
{ {
Console.WriteLine( "Encountered error downloading depot manifest {0} {1}: {2}", depot.id, depot.manifestId, e.StatusCode ); Console.WriteLine("Encountered error downloading depot manifest {0} {1}: {2}", depot.id, depot.manifestId, e.StatusCode);
} }
} }
catch ( Exception e ) catch (OperationCanceledException)
{ {
cdnPool.ReturnBrokenConnection( connection ); break;
Console.WriteLine( "Encountered error downloading manifest for depot {0} {1}: {2}", depot.id, depot.manifestId, e.Message ); }
catch (Exception e)
{
cdnPool.ReturnBrokenConnection(connection);
Console.WriteLine("Encountered error downloading manifest for depot {0} {1}: {2}", depot.id, depot.manifestId, e.Message);
} }
} }
while (depotManifest == null);
if ( depotManifest == null ) if (depotManifest == null)
{ {
Console.WriteLine( "\nUnable to download manifest {0} for depot {1}", depot.manifestId, depot.id ); Console.WriteLine("\nUnable to download manifest {0} for depot {1}", depot.manifestId, depot.id);
return; cts.Cancel();
} }
// Throw the cancellation exception if requested so that this task is marked failed
cts.Token.ThrowIfCancellationRequested();
byte[] checksum; byte[] checksum;
newProtoManifest = new ProtoManifest( depotManifest, depot.manifestId ); newProtoManifest = new ProtoManifest(depotManifest, depot.manifestId);
newProtoManifest.SaveToFile( newManifestFileName, out checksum ); newProtoManifest.SaveToFile(newManifestFileName, out checksum);
File.WriteAllBytes( newManifestFileName + ".sha", checksum ); File.WriteAllBytes(newManifestFileName + ".sha", checksum);
Console.WriteLine( " Done!" ); Console.WriteLine(" Done!");
} }
} }
newProtoManifest.Files.Sort( ( x, y ) => string.Compare( x.FileName, y.FileName, StringComparison.Ordinal ) ); newProtoManifest.Files.Sort((x, y) => string.Compare(x.FileName, y.FileName, StringComparison.Ordinal));
Console.WriteLine( "Manifest {0} ({1})", depot.manifestId, newProtoManifest.CreationTime ); Console.WriteLine("Manifest {0} ({1})", depot.manifestId, newProtoManifest.CreationTime);
if ( Config.DownloadManifestOnly ) if (Config.DownloadManifestOnly)
{ {
StringBuilder manifestBuilder = new StringBuilder(); StringBuilder manifestBuilder = new StringBuilder();
string txtManifest = Path.Combine( depot.installDir, string.Format( "manifest_{0}_{1}.txt", depot.id, depot.manifestId ) ); string txtManifest = Path.Combine(depot.installDir, string.Format("manifest_{0}_{1}.txt", depot.id, depot.manifestId));
manifestBuilder.Append( string.Format( "{0}\n\n", newProtoManifest.CreationTime ) ); manifestBuilder.Append(string.Format("{0}\n\n", newProtoManifest.CreationTime));
foreach ( var file in newProtoManifest.Files ) foreach (var file in newProtoManifest.Files)
{ {
if ( file.Flags.HasFlag( EDepotFileFlag.Directory ) ) if (file.Flags.HasFlag(EDepotFileFlag.Directory))
continue; continue;
manifestBuilder.Append( string.Format( "{0}\n", file.FileName ) ); manifestBuilder.Append(string.Format("{0}\n", file.FileName));
manifestBuilder.Append( string.Format( "\t{0}\n", file.TotalSize ) ); manifestBuilder.Append(string.Format("\t{0}\n", file.TotalSize));
manifestBuilder.Append( string.Format( "\t{0}\n", BitConverter.ToString( file.FileHash ).Replace( "-", "" ) ) ); manifestBuilder.Append(string.Format("\t{0}\n", BitConverter.ToString(file.FileHash).Replace("-", "")));
} }
File.WriteAllText( txtManifest, manifestBuilder.ToString() ); File.WriteAllText(txtManifest, manifestBuilder.ToString());
continue; return;
} }
ulong complete_download_size = 0; string stagingDir = Path.Combine(depot.installDir, STAGING_DIR);
ulong size_downloaded = 0;
string stagingDir = Path.Combine( depot.installDir, STAGING_DIR );
var filesAfterExclusions = newProtoManifest.Files.AsParallel().Where( f => TestIsFileIncluded( f.FileName ) ).ToList(); var filesAfterExclusions = newProtoManifest.Files.AsParallel().Where(f => TestIsFileIncluded(f.FileName)).ToList();
// Pre-process // Pre-process
filesAfterExclusions.ForEach( file => filesAfterExclusions.ForEach(file =>
{ {
var fileFinalPath = Path.Combine( depot.installDir, file.FileName ); var fileFinalPath = Path.Combine(depot.installDir, file.FileName);
var fileStagingPath = Path.Combine( stagingDir, file.FileName ); var fileStagingPath = Path.Combine(stagingDir, file.FileName);
if ( file.Flags.HasFlag( EDepotFileFlag.Directory ) ) if (file.Flags.HasFlag(EDepotFileFlag.Directory))
{ {
Directory.CreateDirectory( fileFinalPath ); Directory.CreateDirectory(fileFinalPath);
Directory.CreateDirectory( fileStagingPath ); Directory.CreateDirectory(fileStagingPath);
} }
else else
{ {
// Some manifests don't explicitly include all necessary directories // Some manifests don't explicitly include all necessary directories
Directory.CreateDirectory( Path.GetDirectoryName( fileFinalPath ) ); Directory.CreateDirectory(Path.GetDirectoryName(fileFinalPath));
Directory.CreateDirectory( Path.GetDirectoryName( fileStagingPath ) ); Directory.CreateDirectory(Path.GetDirectoryName(fileStagingPath));
complete_download_size += file.TotalSize; Interlocked.Add(ref depotCounter.CompleteDownloadSize, (long)file.TotalSize);
} }
} ); });
var ioSemaphore = new SemaphoreSlim( Config.MaxDownloads ); var files = filesAfterExclusions.Where(f => !f.Flags.HasFlag(EDepotFileFlag.Directory)).ToArray();
var networkSemaphore = new SemaphoreSlim( Config.MaxDownloads ); var networkChunkQueue = new ConcurrentQueue<Tuple<FileStreamData, ProtoManifest.FileData, ProtoManifest.ChunkData>>();
var files = filesAfterExclusions.Where( f => !f.Flags.HasFlag( EDepotFileFlag.Directory ) ).ToArray(); await Util.InvokeAsync(
var ioTasks = new Task[ files.Length ]; files.Select(f => new Func<Task>(async () =>
for ( var i = 0; i < files.Length; i++ ) await Task.Run(() => DownloadSteam3AsyncDepotFile(cts, downloadCounter, depotCounter,
{ stagingDir, oldProtoManifest, newProtoManifest,
var file = files[ i ]; networkChunkQueue,
var task = Task.Run( async () => appId, depot, f)))),
{ maxDegreeOfParallelism: Config.MaxDownloads
cts.Token.ThrowIfCancellationRequested(); );
try await Util.InvokeAsync(
networkChunkQueue.Select((x) => new Func<Task>(async () =>
await Task.Run(() => DownloadSteam3AsyncDepotFileChunk(cts, downloadCounter, depotCounter,
stagingDir, oldProtoManifest, newProtoManifest,
appId, depot, x.Item2, x.Item1, x.Item3)))),
maxDegreeOfParallelism: Config.MaxDownloads
);
DepotConfigStore.Instance.InstalledManifestIDs[depot.id] = depot.manifestId;
DepotConfigStore.Save();
Console.WriteLine("Depot {0} - Downloaded {1} bytes ({2} bytes uncompressed)", depot.id, depotCounter.DepotBytesCompressed, depotCounter.DepotBytesUncompressed);
}
private static void DownloadSteam3AsyncDepotFile(
CancellationTokenSource cts,
GlobalDownloadCounter downloadCounter, DepotDownloadCounter depotDownloadCounter,
string stagingDir, ProtoManifest oldProtoManifest, ProtoManifest newProtoManifest,
ConcurrentQueue<Tuple<FileStreamData, ProtoManifest.FileData, ProtoManifest.ChunkData>> networkChunkQueue,
uint appId, DepotDownloadInfo depot, ProtoManifest.FileData file)
{ {
await ioSemaphore.WaitAsync().ConfigureAwait( false );
cts.Token.ThrowIfCancellationRequested(); cts.Token.ThrowIfCancellationRequested();
string fileFinalPath = Path.Combine( depot.installDir, file.FileName ); string fileFinalPath = Path.Combine(depot.installDir, file.FileName);
string fileStagingPath = Path.Combine( stagingDir, file.FileName ); string fileStagingPath = Path.Combine(stagingDir, file.FileName);
// This may still exist if the previous run exited before cleanup // This may still exist if the previous run exited before cleanup
if ( File.Exists( fileStagingPath ) ) if (File.Exists(fileStagingPath))
{ {
File.Delete( fileStagingPath ); File.Delete(fileStagingPath);
} }
FileStream fs = null; FileStream fs = null;
List<ProtoManifest.ChunkData> neededChunks; List<ProtoManifest.ChunkData> neededChunks;
FileInfo fi = new FileInfo( fileFinalPath ); FileInfo fi = new FileInfo(fileFinalPath);
if ( !fi.Exists ) if (!fi.Exists)
{ {
// create new file. need all chunks // create new file. need all chunks
fs = File.Create( fileFinalPath ); fs = File.Create(fileFinalPath);
fs.SetLength( ( long )file.TotalSize ); fs.SetLength((long)file.TotalSize);
neededChunks = new List<ProtoManifest.ChunkData>( file.Chunks ); neededChunks = new List<ProtoManifest.ChunkData>(file.Chunks);
} }
else else
{ {
// open existing // open existing
ProtoManifest.FileData oldManifestFile = null; ProtoManifest.FileData oldManifestFile = null;
if ( oldProtoManifest != null ) if (oldProtoManifest != null)
{ {
oldManifestFile = oldProtoManifest.Files.SingleOrDefault( f => f.FileName == file.FileName ); oldManifestFile = oldProtoManifest.Files.SingleOrDefault(f => f.FileName == file.FileName);
} }
if ( oldManifestFile != null ) if (oldManifestFile != null)
{ {
neededChunks = new List<ProtoManifest.ChunkData>(); neededChunks = new List<ProtoManifest.ChunkData>();
if ( Config.VerifyAll || !oldManifestFile.FileHash.SequenceEqual( file.FileHash ) ) if (Config.VerifyAll || !oldManifestFile.FileHash.SequenceEqual(file.FileHash))
{ {
// we have a version of this file, but it doesn't fully match what we want // we have a version of this file, but it doesn't fully match what we want
if (Config.VerifyAll)
{
Console.WriteLine("Validating {0}", fileFinalPath);
}
var matchingChunks = new List<ChunkMatch>(); var matchingChunks = new List<ChunkMatch>();
foreach ( var chunk in file.Chunks ) foreach (var chunk in file.Chunks)
{ {
var oldChunk = oldManifestFile.Chunks.FirstOrDefault( c => c.ChunkID.SequenceEqual( chunk.ChunkID ) ); var oldChunk = oldManifestFile.Chunks.FirstOrDefault(c => c.ChunkID.SequenceEqual(chunk.ChunkID));
if ( oldChunk != null ) if (oldChunk != null)
{ {
matchingChunks.Add( new ChunkMatch( oldChunk, chunk ) ); matchingChunks.Add(new ChunkMatch(oldChunk, chunk));
} }
else else
{ {
neededChunks.Add( chunk ); neededChunks.Add(chunk);
} }
} }
File.Move( fileFinalPath, fileStagingPath ); var orderedChunks = matchingChunks.OrderBy(x => x.OldChunk.Offset);
File.Move(fileFinalPath, fileStagingPath);
fs = File.Open( fileFinalPath, FileMode.Create ); fs = File.Open(fileFinalPath, FileMode.Create);
fs.SetLength( ( long )file.TotalSize ); fs.SetLength((long)file.TotalSize);
using ( var fsOld = File.Open( fileStagingPath, FileMode.Open ) ) using (var fsOld = File.Open(fileStagingPath, FileMode.Open))
{ {
foreach ( var match in matchingChunks ) foreach (var match in orderedChunks)
{ {
fsOld.Seek( ( long )match.OldChunk.Offset, SeekOrigin.Begin ); fsOld.Seek((long)match.OldChunk.Offset, SeekOrigin.Begin);
byte[] tmp = new byte[ match.OldChunk.UncompressedLength ]; byte[] tmp = new byte[match.OldChunk.UncompressedLength];
fsOld.Read( tmp, 0, tmp.Length ); fsOld.Read(tmp, 0, tmp.Length);
byte[] adler = Util.AdlerHash( tmp ); byte[] adler = Util.AdlerHash(tmp);
if ( !adler.SequenceEqual( match.OldChunk.Checksum ) ) if (!adler.SequenceEqual(match.OldChunk.Checksum))
{ {
neededChunks.Add( match.NewChunk ); neededChunks.Add(match.NewChunk);
} }
else else
{ {
fs.Seek( ( long )match.NewChunk.Offset, SeekOrigin.Begin ); fs.Seek((long)match.NewChunk.Offset, SeekOrigin.Begin);
fs.Write( tmp, 0, tmp.Length ); fs.Write(tmp, 0, tmp.Length);
} }
} }
} }
File.Delete( fileStagingPath ); File.Delete(fileStagingPath);
} }
} }
else else
{ {
// No old manifest or file not in old manifest. We must validate. // No old manifest or file not in old manifest. We must validate.
fs = File.Open( fileFinalPath, FileMode.Open ); fs = File.Open(fileFinalPath, FileMode.Open);
if ( ( ulong )fi.Length != file.TotalSize ) if ((ulong)fi.Length != file.TotalSize)
{ {
fs.SetLength( ( long )file.TotalSize ); fs.SetLength((long)file.TotalSize);
} }
neededChunks = Util.ValidateSteam3FileChecksums( fs, file.Chunks.OrderBy( x => x.Offset ).ToArray() ); Console.WriteLine("Validating {0}", fileFinalPath);
neededChunks = Util.ValidateSteam3FileChecksums(fs, file.Chunks.OrderBy(x => x.Offset).ToArray());
} }
if ( neededChunks.Count() == 0 ) if (neededChunks.Count() == 0)
{ {
size_downloaded += file.TotalSize; var sizeDownloaded = Interlocked.Add(ref depotDownloadCounter.SizeDownloaded, (long)file.TotalSize);
Console.WriteLine( "{0,6:#00.00}% {1}", ( ( float )size_downloaded / ( float )complete_download_size ) * 100.0f, fileFinalPath ); Console.WriteLine("{0,6:#00.00}% {1}", ((float)sizeDownloaded / (float)depotDownloadCounter.CompleteDownloadSize) * 100.0f, fileFinalPath);
if ( fs != null )
if (fs != null)
fs.Dispose(); fs.Dispose();
return; return;
} }
else else
{ {
size_downloaded += ( file.TotalSize - ( ulong )neededChunks.Select( x => ( long )x.UncompressedLength ).Sum() ); var sizeDownloaded = ((long)file.TotalSize - (long)neededChunks.Select(x => (long)x.UncompressedLength).Sum());
Interlocked.Add(ref depotDownloadCounter.SizeDownloaded, sizeDownloaded);
} }
} }
var fileSemaphore = new SemaphoreSlim(1); FileStreamData fileStreamData = new FileStreamData
var downloadTasks = new Task<DepotManifest.ChunkData>[neededChunks.Count];
for (var x = 0; x < neededChunks.Count; x++)
{ {
var chunk = neededChunks[x]; fileStream = fs,
fileLock = new SemaphoreSlim(1),
chunksToDownload = neededChunks.Count
};
var downloadTask = Task.Run(async () => foreach (var chunk in neededChunks)
{ {
cts.Token.ThrowIfCancellationRequested(); networkChunkQueue.Enqueue(Tuple.Create(fileStreamData, file, chunk));
}
}
try private static async Task DownloadSteam3AsyncDepotFileChunk(
CancellationTokenSource cts,
GlobalDownloadCounter downloadCounter, DepotDownloadCounter depotDownloadCounter,
string stagingDir, ProtoManifest oldProtoManifest, ProtoManifest newProtoManifest,
uint appId, DepotDownloadInfo depot, ProtoManifest.FileData file,
FileStreamData fileStreamData, ProtoManifest.ChunkData chunk)
{ {
await networkSemaphore.WaitAsync().ConfigureAwait(false);
cts.Token.ThrowIfCancellationRequested(); cts.Token.ThrowIfCancellationRequested();
string chunkID = Util.EncodeHexString(chunk.ChunkID); string chunkID = Util.EncodeHexString(chunk.ChunkID);
CDNClient.DepotChunk chunkData = null;
while (!cts.IsCancellationRequested)
{
Tuple<CDNClient.Server, string> connection;
try
{
connection = await cdnPool.GetConnectionForDepot(appId, depot.id, cts.Token);
}
catch (OperationCanceledException)
{
break;
}
DepotManifest.ChunkData data = new DepotManifest.ChunkData(); DepotManifest.ChunkData data = new DepotManifest.ChunkData();
data.ChunkID = chunk.ChunkID; data.ChunkID = chunk.ChunkID;
@ -950,13 +1018,27 @@ namespace DepotDownloader
data.CompressedLength = chunk.CompressedLength; data.CompressedLength = chunk.CompressedLength;
data.UncompressedLength = chunk.UncompressedLength; data.UncompressedLength = chunk.UncompressedLength;
CDNClient.DepotChunk chunkData = null;
do
{
cts.Token.ThrowIfCancellationRequested();
CDNClient.Server connection = null;
try try
{ {
connection = cdnPool.GetConnection(cts.Token);
var cdnToken = await cdnPool.AuthenticateConnection(appId, depot.id, connection);
chunkData = await cdnPool.CDNClient.DownloadDepotChunkAsync(depot.id, data, chunkData = await cdnPool.CDNClient.DownloadDepotChunkAsync(depot.id, data,
connection.Item1, connection.Item2, depot.depotKey).ConfigureAwait(false); connection, cdnToken, depot.depotKey).ConfigureAwait(false);
cdnPool.ReturnConnection(connection);
break; cdnPool.ReturnConnection(connection);
}
catch (TaskCanceledException)
{
Console.WriteLine("Connection timeout downloading chunk {0}", chunkID);
} }
catch (SteamKitWebRequestException e) catch (SteamKitWebRequestException e)
{ {
@ -965,7 +1047,6 @@ namespace DepotDownloader
if (e.StatusCode == HttpStatusCode.Unauthorized || e.StatusCode == HttpStatusCode.Forbidden) if (e.StatusCode == HttpStatusCode.Unauthorized || e.StatusCode == HttpStatusCode.Forbidden)
{ {
Console.WriteLine("Encountered 401 for chunk {0}. Aborting.", chunkID); Console.WriteLine("Encountered 401 for chunk {0}. Aborting.", chunkID);
cts.Cancel();
break; break;
} }
else else
@ -973,9 +1054,9 @@ namespace DepotDownloader
Console.WriteLine("Encountered error downloading chunk {0}: {1}", chunkID, e.StatusCode); Console.WriteLine("Encountered error downloading chunk {0}: {1}", chunkID, e.StatusCode);
} }
} }
catch (TaskCanceledException) catch (OperationCanceledException)
{ {
Console.WriteLine("Connection timeout downloading chunk {0}", chunkID); break;
} }
catch (Exception e) catch (Exception e)
{ {
@ -983,6 +1064,7 @@ namespace DepotDownloader
Console.WriteLine("Encountered unexpected error downloading chunk {0}: {1}", chunkID, e.Message); Console.WriteLine("Encountered unexpected error downloading chunk {0}: {1}", chunkID, e.Message);
} }
} }
while (chunkData == null);
if (chunkData == null) if (chunkData == null)
{ {
@ -995,61 +1077,34 @@ namespace DepotDownloader
try try
{ {
await fileSemaphore.WaitAsync().ConfigureAwait(false); await fileStreamData.fileLock.WaitAsync().ConfigureAwait(false);
fs.Seek((long)chunkData.ChunkInfo.Offset, SeekOrigin.Begin);
fs.Write(chunkData.Data, 0, chunkData.Data.Length);
return chunkData.ChunkInfo;
}
finally
{
fileSemaphore.Release();
}
}
finally
{
networkSemaphore.Release();
}
});
downloadTasks[x] = downloadTask;
}
var completedDownloads = await Task.WhenAll(downloadTasks).ConfigureAwait(false);
fs.Dispose(); fileStreamData.fileStream.Seek((long)chunkData.ChunkInfo.Offset, SeekOrigin.Begin);
fileStreamData.fileStream.Write(chunkData.Data, 0, chunkData.Data.Length);
foreach (var chunkInfo in completedDownloads)
{
TotalBytesCompressed += chunkInfo.CompressedLength;
DepotBytesCompressed += chunkInfo.CompressedLength;
TotalBytesUncompressed += chunkInfo.UncompressedLength;
DepotBytesUncompressed += chunkInfo.UncompressedLength;
size_downloaded += chunkInfo.UncompressedLength;
}
Console.WriteLine("{0,6:#00.00}% {1}", ((float)size_downloaded / (float)complete_download_size) * 100.0f, fileFinalPath);
} }
finally finally
{ {
ioSemaphore.Release(); fileStreamData.fileLock.Release();
} }
} );
ioTasks[ i ] = task; Interlocked.Add(ref downloadCounter.TotalBytesCompressed, chunk.CompressedLength);
} Interlocked.Add(ref depotDownloadCounter.DepotBytesCompressed, chunk.CompressedLength);
Interlocked.Add(ref downloadCounter.TotalBytesUncompressed, chunk.UncompressedLength);
Interlocked.Add(ref depotDownloadCounter.DepotBytesUncompressed, chunk.UncompressedLength);
await Task.WhenAll(ioTasks).ConfigureAwait(false); var sizeDownloaded = Interlocked.Add(ref depotDownloadCounter.SizeDownloaded, chunkData.Data.Length);
DepotConfigStore.Instance.InstalledManifestIDs[ depot.id ] = depot.manifestId; int remainingChunks = Interlocked.Decrement(ref fileStreamData.chunksToDownload);
DepotConfigStore.Save(); if (remainingChunks == 0)
{
fileStreamData.fileStream.Dispose();
Console.WriteLine( "Depot {0} - Downloaded {1} bytes ({2} bytes uncompressed)", depot.id, DepotBytesCompressed, DepotBytesUncompressed ); var fileFinalPath = Path.Combine(depot.installDir, file.FileName);
Console.WriteLine("{0,6:#00.00}% {1}", ((float)sizeDownloaded / (float)depotDownloadCounter.CompleteDownloadSize) * 100.0f, fileFinalPath);
} }
Console.WriteLine( "Total downloaded: {0} bytes ({1} bytes uncompressed) from {2} depots", TotalBytesCompressed, TotalBytesUncompressed, depots.Count );
} }
} }
} }

@ -5,6 +5,7 @@ using System.Linq;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
using System.Security.Cryptography; using System.Security.Cryptography;
using System.Text; using System.Text;
using System.Threading.Tasks;
namespace DepotDownloader namespace DepotDownloader
{ {
@ -132,5 +133,38 @@ namespace DepotDownloader
( sb, v ) => sb.Append( v.ToString( "x2" ) ) ( sb, v ) => sb.Append( v.ToString( "x2" ) )
).ToString(); ).ToString();
} }
public static async Task InvokeAsync(IEnumerable<Func<Task>> taskFactories, int maxDegreeOfParallelism)
{
if (taskFactories == null) throw new ArgumentNullException(nameof(taskFactories));
if (maxDegreeOfParallelism <= 0) throw new ArgumentException(nameof(maxDegreeOfParallelism));
Func<Task>[] queue = taskFactories.ToArray();
if (queue.Length == 0)
{
return;
}
List<Task> tasksInFlight = new List<Task>(maxDegreeOfParallelism);
int index = 0;
do
{
while (tasksInFlight.Count < maxDegreeOfParallelism && index < queue.Length)
{
Func<Task> taskFactory = queue[index++];
tasksInFlight.Add(taskFactory());
}
Task completedTask = await Task.WhenAny(tasksInFlight).ConfigureAwait(false);
await completedTask.ConfigureAwait(false);
tasksInFlight.Remove(completedTask);
}
while (index < queue.Length || tasksInFlight.Count != 0);
}
} }
} }

Loading…
Cancel
Save