Implementing CountDownLatch Functionality in Go Inspired by Java

Posted on Apr 9, 2021

When working with concurrent applications, synchronization primitives are essential tools. While Go provides sync.WaitGroup, sometimes we need more sophisticated control like timeouts. Java’s CountDownLatch offers this functionality, so let’s implement it in Go.

The Implementation

Our CountDownLatch combines Go’s sync.WaitGroup with atomic operations for thread-safe counting:

type CountDownLatch struct {
   sync.WaitGroup
   counter uint64
}

The struct embeds sync.WaitGroup and adds a counter field using uint64 to store two 32-bit counters in one atomic value.

Key Components

  1. Counter Management
func splitUint64For(c uint64) (uint32, uint32) {
   return uint32(c >> 32), uint32(c)
}

func combineUInt64For(x, y uint32) uint64 {
   return uint64(x)<<32 | uint64(y)
}

These functions split/combine two 32-bit values into/from a single 64-bit value:

  • High 32 bits store the active count
  • Low 32 bits maintain a parallel count for consistency checks
  1. Adding Tasks
func (cdl *CountDownLatch) Add(delta int) {
   cdl.WaitGroup.Add(delta)
   d := uint32(delta)
   atomic.AddUint64(&cdl.counter, combineUInt64For(d, d))
}

When adding tasks:

  • Updates WaitGroup count
  • Atomically updates both halves of counter
  1. Completing Tasks
func (cdl *CountDownLatch) Done() {
   c := atomic.LoadUint64(&cdl.counter)
   hc, lc := splitUint64For(c)
   for hc > 0 {
      if atomic.CompareAndSwapUint64(&cdl.counter, c, combineUInt64For(hc-1, lc-1)) {
         cdl.WaitGroup.Done()
         return
      }
      c = atomic.LoadUint64(&cdl.counter)
      hc, lc = splitUint64For(c)
   }
}

The Done method:

  • Atomically loads current counter
  • Uses CAS operations for thread-safe decrements
  • Updates WaitGroup when successful
  1. Timeout Support
func (cdl *CountDownLatch) Await(t time.Duration) bool {
   time.AfterFunc(t, func() {
      c := atomic.LoadUint64(&cdl.counter)
      hc, lc := splitUint64For(c)
      for hc > 0 {
         if atomic.CompareAndSwapUint64(&cdl.counter, c, combineUInt64For(0, lc)) {
            cdl.WaitGroup.Add(-int(hc))
            break
         }
         c = atomic.LoadUint64(&cdl.counter)
         hc, lc = splitUint64For(c)
      }
   })
   cdl.Wait()
   return atomic.LoadUint64(&cdl.counter) == 0
}

The Await method:

  • Schedules timeout handler
  • Uses atomic operations for safe counter reset
  • Returns whether completion was successful

Performance Characteristics

  1. Lock-Free Operations
  • Uses atomic operations instead of mutexes
  • Minimizes contention in high-concurrency scenarios
  • Efficient for large numbers of goroutines
  1. Memory Efficiency
  • Packs two counters into single 64-bit word
  • Reduces memory footprint
  • Improves cache locality
  1. Timeout Handling
  • Non-blocking timeout implementation
  • Graceful cancellation of waiting goroutines
  • Consistent state management

Usage Example

import (
    "sync/atomic"
    "testing"
    "time"
)

func TestCountDownLatch(t *testing.T) {
    t.Run("Basic functionality", func(t *testing.T) {
        cdl := &CountDownLatch{}
        cdl.Add(1)
        
        success := cdl.Await(time.Millisecond * 100)
        if success {
            t.Error("Should timeout when count not zero")
        }
        
        cdl.Done()
        success = cdl.Await(time.Millisecond * 100)
        if !success {
            t.Error("Should complete when count is zero")
        }
    })

    t.Run("High concurrency", func(t *testing.T) {
        cdl := &CountDownLatch{}
        const numGoroutines = 10000
        cdl.Add(numGoroutines)
        
        for i := 0; i < numGoroutines; i++ {
            go func() {
                cdl.Done()
            }()
        }
        
        if !cdl.Await(time.Second * 2) {
            t.Error("Failed to complete in time")
        }
    })

    t.Run("Multiple adds", func(t *testing.T) {
        cdl := &CountDownLatch{}
        cdl.Add(2)
        cdl.Add(3)
        
        initial := atomic.LoadUint64(&cdl.counter)
        hc, lc := splitUint64For(initial)
        if hc != 5 || lc != 5 {
            t.Errorf("Expected count 5,5 got %d,%d", hc, lc)
        }
    })

    t.Run("Timeout behavior", func(t *testing.T) {
        cdl := &CountDownLatch{}
        cdl.Add(5)
        
        start := time.Now()
        success := cdl.Await(time.Millisecond * 100)
        elapsed := time.Since(start)
        
        if success {
            t.Error("Should timeout")
        }
        if elapsed < time.Millisecond*100 {
            t.Error("Returned before timeout")
        }
    })

    t.Run("Zero initial state", func(t *testing.T) {
        cdl := &CountDownLatch{}
        if !cdl.Await(time.Millisecond * 100) {
            t.Error("Should complete immediately when count is 0")
        }
    })

    t.Run("Counter consistency", func(t *testing.T) {
        cdl := &CountDownLatch{}
        cdl.Add(3)
        
        cdl.Done()
        c := atomic.LoadUint64(&cdl.counter)
        hc, lc := splitUint64For(c)
        if hc != 2 || lc != 2 {
            t.Errorf("Expected count 2,2 got %d,%d", hc, lc)
        }
    })

    t.Run("Stress test", func(t *testing.T) {
        cdl := &CountDownLatch{}
        const iterations = 1000
        const goroutines = 10
        
        for i := 0; i < iterations; i++ {
            cdl.Add(goroutines)
            for j := 0; j < goroutines; j++ {
                go func() {
                    time.Sleep(time.Microsecond)
                    cdl.Done()
                }()
            }
            if !cdl.Await(time.Second) {
                t.Fatalf("Failed at iteration %d", i)
            }
        }
    })
}

func BenchmarkCountDownLatch(b *testing.B) {
    b.Run("Sequential", func(b *testing.B) {
        cdl := &CountDownLatch{}
        b.ResetTimer()
        for i := 0; i < b.N; i++ {
            cdl.Add(1)
            cdl.Done()
        }
    })

    b.Run("Concurrent", func(b *testing.B) {
        cdl := &CountDownLatch{}
        b.ResetTimer()
        for i := 0; i < b.N; i++ {
            const n = 100
            cdl.Add(n)
            for j := 0; j < n; j++ {
                go func() {
                    cdl.Done()
                }()
            }
            cdl.Await(time.Second)
        }
    })
}

Best Practices

  1. Always check Await’s return value to handle timeouts
  2. Use defer for Done() calls to prevent leaks
  3. Initialize counter with Add() before spawning goroutines
  4. Consider timeout duration carefully based on workload

Conclusion

This implementation provides a robust, high-performance CountDownLatch for Go applications. It combines the reliability of WaitGroup with atomic operations for thread-safety and adds crucial timeout support. The use of bit manipulation for counter management makes it both memory-efficient and fast under high concurrency.

While more complex than a basic WaitGroup, the added timeout capability makes it invaluable for real-world applications where operations must complete within specific time constraints.