Blob Blame History Raw
From d1ee735a4aacb80d9c3c4c34fc4317c6eef6718a Mon Sep 17 00:00:00 2001
From: "brian m. carlson" <bk2204@github.com>
Date: Wed, 28 Aug 2019 21:02:26 +0000
Subject: [PATCH] Avoid deadlock when transfer queue fails

In 1412d6e4 ("Don't fail if we lack objects the server has",
2019-04-30), we changed the code to abort later if a missing object
occurs.  In doing so, we had to consider the case where the transfer
queue aborts early for some reason and ensure that the sync.WaitGroup
does not unnecessarily block due to outstanding objects never getting
processed.

However, the approach we used, which was to explicitly add the number of
items we skipped processing, was error prone and didn't cover all cases.
Notably, a DNS failure could randomly cause a hang during a push.  Solve
this by creating a class for a wait group which is abortable and simply
abort it if we encounter an error, preventing any deadlocks caused by
miscounting the number of items.
---
 tq/transfer_queue.go | 55 ++++++++++++++++++++++++++++++++++----------
 1 file changed, 43 insertions(+), 12 deletions(-)

diff --git a/tq/transfer_queue.go b/tq/transfer_queue.go
index 89296a646..7d39fe581 100644
--- a/tq/transfer_queue.go
+++ b/tq/transfer_queue.go
@@ -123,6 +123,43 @@ func (b batch) Len() int           { return len(b) }
 func (b batch) Less(i, j int) bool { return b[i].Size < b[j].Size }
 func (b batch) Swap(i, j int)      { b[i], b[j] = b[j], b[i] }
 
+type abortableWaitGroup struct {
+	wq      sync.WaitGroup
+	counter int
+	mu      sync.Mutex
+}
+
+func newAbortableWaitQueue() *abortableWaitGroup {
+	return &abortableWaitGroup{}
+}
+
+func (q *abortableWaitGroup) Add(delta int) {
+	q.mu.Lock()
+	defer q.mu.Unlock()
+
+	q.counter += delta
+	q.wq.Add(delta)
+}
+
+func (q *abortableWaitGroup) Done() {
+	q.mu.Lock()
+	defer q.mu.Unlock()
+
+	q.counter -= 1
+	q.wq.Done()
+}
+
+func (q *abortableWaitGroup) Abort() {
+	q.mu.Lock()
+	defer q.mu.Unlock()
+
+	q.wq.Add(-q.counter)
+}
+
+func (q *abortableWaitGroup) Wait() {
+	q.wq.Wait()
+}
+
 // TransferQueue organises the wider process of uploading and downloading,
 // including calling the API, passing the actual transfer request to transfer
 // adapters, and dealing with progress, errors and retries.
@@ -150,7 +187,7 @@ type TransferQueue struct {
 	// wait is used to keep track of pending transfers. It is incremented
 	// once per unique OID on Add(), and is decremented when that transfer
 	// is marked as completed or failed, but not retried.
-	wait     sync.WaitGroup
+	wait     *abortableWaitGroup
 	manifest *Manifest
 	rc       *retryCounter
 
@@ -250,6 +287,7 @@ func NewTransferQueue(dir Direction, manifest *Manifest, remote string, options
 		trMutex:   &sync.Mutex{},
 		manifest:  manifest,
 		rc:        newRetryCounter(),
+		wait:      newAbortableWaitQueue(),
 	}
 
 	for _, opt := range options {
@@ -401,8 +439,11 @@ func (q *TransferQueue) collectBatches() {
 		collected, closing = q.collectPendingUntil(done)
 
 		// If we've encountered a serious error here, abort immediately;
-		// don't process further batches.
+		// don't process further batches.  Abort the wait queue so that
+		// we don't deadlock waiting for objects to complete when they
+		// never will.
 		if err != nil {
+			q.wait.Abort()
 			break
 		}
 
@@ -497,11 +538,6 @@ func (q *TransferQueue) enqueueAndCollectRetriesFor(batch batch) (batch, error)
 				}
 			}
 
-			if err != nil && bRes != nil {
-				// Avoid a hang if we return early.
-				q.wait.Add(-len(bRes.Objects))
-			}
-
 			return next, err
 		}
 	}
@@ -521,11 +557,6 @@ func (q *TransferQueue) enqueueAndCollectRetriesFor(batch batch) (batch, error)
 			// missing in that case, since we don't need to upload
 			// it.
 			if o.Missing && len(o.Actions) != 0 {
-				// Indicate that we've handled these objects, in
-				// this case by ignoring them and aborting
-				// early. Failing to do this means we deadlock
-				// on this WaitGroup.
-				q.wait.Add(-len(bRes.Objects))
 				return nil, errors.Errorf("Unable to find source for object %v (try running git lfs fetch --all)", o.Oid)
 			}
 		}