Umit Unal

Blog

Countdownlatch Implementation in Golang

Posted at — Apr 9, 2021

If you need the await method that supports timeout in Waitgroups, you should implement it yourself. This method is supported by Java.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
// CountDownLatch
type CountDownLatch struct {
   sync.WaitGroup
   counter uint64
}

func (cdl *CountDownLatch) Add(delta int) {
   cdl.WaitGroup.Add(delta)
   d := uint32(delta)
   atomic.AddUint64(&cdl.counter, combineUInt64For(d, d))
}

func (cdl *CountDownLatch) Done() {
   c := atomic.LoadUint64(&cdl.counter)
   hc, lc := splitUint64For(c)
   for hc > 0 {
      if atomic.CompareAndSwapUint64(&cdl.counter, c, combineUInt64For(hc-1, lc-1)) {
         cdl.WaitGroup.Done()
         return
      }
      c = atomic.LoadUint64(&cdl.counter)
      hc, lc = splitUint64For(c)
   }
}

func (cdl *CountDownLatch) Await(t time.Duration) bool {
   time.AfterFunc(t, func() {
      c := atomic.LoadUint64(&cdl.counter)
      hc, lc := splitUint64For(c)
      for hc > 0 {
         if atomic.CompareAndSwapUint64(&cdl.counter, c, combineUInt64For(0, lc)) {
            cdl.WaitGroup.Add(-int(hc))
            break
         }
         c = atomic.LoadUint64(&cdl.counter)
         hc, lc = splitUint64For(c)
      }
   })
   cdl.Wait()
   return atomic.LoadUint64(&cdl.counter) == 0
}

func splitUint64For(c uint64) (uint32, uint32) {
   return uint32(c >> 32), uint32(c)
}

func combineUInt64For(x, y uint32) uint64 {
   return uint64(x)<<32 | uint64(y)
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import (
	"fmt"
	"time"
)

func main() {
	const numGreeters = 10

	hello := func(wg *CountDownLatch, id int) {
		defer wg.Done()
		time.Sleep(10 * time.Second)
		fmt.Printf("Hello from %v!\n", id)
	}

	waitGroup := CountDownLatch{}
	waitGroup.Add(numGreeters)

	for i := 0; i < numGreeters; i++ {
		go hello(&waitGroup, i)
	}

	await := waitGroup.Await(5 * time.Second)
	fmt.Printf("Await status: %t \n", await)
}