[Go 教學] graceful shutdown with multiple workers

golang logo

在閱讀本文章之前請先預習『用 Go 語言 buffered channel 實作 Job Queue』,本篇會針對投影片 p.26 到 p.56 做詳細的介紹,教大家如何從無到有寫一個簡單的 multiple worker,以及如何處理 graceful shutdown with workers,為什麼要處理 graceful shutdown? 原因是中途手動執行 ctrl + c 或者是部署新版程式都會遇到該如何確保 job 執行完成後才結束 main 函式。

教學影片

教學影片會之後放上,如果對於課程內容有興趣,可以參考底下課程。

關閉 Channel

通常會開一個 Channel 搭配多個 worker 才能達到平行處理,那該如何正確關閉 Channel? 底下看個例子:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
func main() {
    ch := make(chan int, 2)
    go func() {
        ch <- 1
        ch <- 2
    }()

    for n := range ch {
        fmt.Println(n)
    }
}

執行上述程式你會發現出現了

fatal error: all goroutines are asleep - deadlock!

原因在於沒有關閉 channel,造成 main 函式一直讀取 channel,但是 channle 裡面已經不會再有值了,就造成主程式 deadlock,避免此問題很簡單

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
func main() {
    ch := make(chan int, 2)
    go func() {
        ch <- 1
        ch <- 2
        close(ch)
    }()

    for n := range ch {
        fmt.Println(n)
    }
}

除了 close(ch) 之外,另一個方式就將讀取 channel 也丟到 goroutine 內

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
func main() {
    ch := make(chan int, 2)
    go func() {
        ch <- 1
        ch <- 2
    }()

    go func() {
        for n := range ch {
            fmt.Println(n)
        }
    }()

    time.Sleep(1 * time.Second)
}

了解上述 channel 觀念後,可以來實作底下 consumer 流程

實作 consumer

底下會創建兩個 channel 來實作 consumer,其中 jobsChan 後面會有多個 worker 串接。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
// Consumer struct
type Consumer struct {
    inputChan chan int
    jobsChan  chan int
}

func main() {
    // create the consumer
    consumer := Consumer{
        inputChan: make(chan int, 10),
        jobsChan:  make(chan int, poolSize),
    }
}

接著實現 worker 模組

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
func (c *Consumer) queue(input int) {
    select {
    case c.inputChan <- input:
        log.Println("already send input value:", input)
        return true
    default:
        return false
    }
}

func (c *Consumer) process(num, job int) {
    n := getRandomTime()
    log.Printf("Sleeping %d seconds...\n", n)
    time.Sleep(time.Duration(n) * time.Second)
    log.Println("worker:", num, " job value:", job)
}

func (c *Consumer) worker(num int) {
    log.Println("start the worker", num)
    for {
        select {
        case job := <-c.jobsChan:
            c.process(num, job)
        }
    }
}

func (c Consumer) startConsumer(ctx context.Context) {
    for {
        select {
        case job := <-c.inputChan:
            c.jobsChan <- job
        }
    }
}

const poolSize = 2

func main() {
    // create the consumer
    consumer := Consumer{
        inputChan: make(chan int, 10),
        jobsChan:  make(chan int, poolSize),
    }

    for i := 0; i < poolSize; i++ {
        go consumer.worker(i)
    }

    go consumer.startConsumer(ctx)

    consumer.queue(1)
    consumer.queue(2)
    consumer.queue(3)
    consumer.queue(4)
    consumer.queue(5)
}

由上述程式碼可以看到,都會透過 for select 方式來對 channel 進行讀寫動作。其中 queue 用來將資料丟入 input channel。

Shutdown with Sigterm Handling

接著處理當使用者按下 ctrl + c 或者是容器被移除時 (restart) 該如何接到此訊號?

這時候就需要用到 context

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
func withContextFunc(ctx context.Context, f func()) context.Context {
    ctx, cancel := context.WithCancel(ctx)
    go func() {
        c := make(chan os.Signal)
        signal.Notify(c, syscall.SIGINT, syscall.SIGTERM)
        defer signal.Stop(c)

        select {
        case <-ctx.Done():
        case <-c:
            cancel()
            f()
        }
    }()

    return ctx
}

其中 syscall.SIGINT, syscall.SIGTERM 用來偵測使用者是否按下 ctrl+c 或者是容器被移除時就會執行。所以當開發者按下 ctrl+c 就會直接觸發 cancel(),所以在最前面會使用 context.WithCancel,之後有機會再詳細介紹 context 的使用方式。

由於使用了 context,這樣就可以在每個 func 帶入客製化的 context。需要變動的有 startConsumerworker

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
func (c Consumer) startConsumer(ctx context.Context) {
    for {
        select {
        case job := <-c.inputChan:
            if ctx.Err() != nil {
                close(c.jobsChan)
                return
            }
            c.jobsChan <- job
        case <-ctx.Done():
            close(c.jobsChan)
            return
        }
    }
}

func (c *Consumer) worker(ctx context.Context, num int) {
    log.Println("start the worker", num)
    for {
        select {
        case job := <-c.jobsChan:
            if ctx.Err() != nil {
                log.Println("get next job", job, "and close the worker", num)
                return
            }
            c.process(num, job)
        case <-ctx.Done():
            log.Println("close the worker", num)
            return
        }
    }
}

這邊要注意的是,當我們按下 ctrl+c 終止 worker 時,理論上會直接到 case <-ctx.Done() 但是實際狀況是有時候會直接在繼續讀取 channel 下一個值。這時候就需要在讀取 channel 後判斷 context 是否已經取消。在 main 最後通常會放一個 channel 來判斷是否需要中斷 main 函式。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
func main() {
    finished := make(chan bool)

    ctx := withContextFunc(context.Background(), func() {
        log.Println("cancel from ctrl+c event")
        close(finished)
    })

    <-finished
}

上述完成後,按下 ctrl + c 後,就可以直接執行 close channel,整個主程式都停止,但是這不是我們預期得結果,預期的是需要等到全部的 worker 把正在處理的 Job 完成後,才進行停止才是。

Graceful shutdown with worker

要用什麼方式才可以等到 worker 處理完畢後才結束 main 函式呢?這時候需要用到 sync.WaitGroup

1
2
3
4
5
6
7
const poolSize = 2

func main() {
    finished := make(chan bool)
    wg := &sync.WaitGroup{}
    wg.Add(poolSize)
}

其中 poolSize 代表的是 worker 數量,接著調整 worker 函式

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
func (c *Consumer) worker(ctx context.Context, num int, wg *sync.WaitGroup) {
    defer wg.Done()
    log.Println("start the worker", num)
    for {
        select {
        case job := <-c.jobsChan:
            if ctx.Err() != nil {
                log.Println("get next job", job, "and close the worker", num)
                return
            }
            c.process(num, job)
        case <-ctx.Done():
            log.Println("close the worker", num)
            return
        }
    }
}

只有在最前面加上 defer wg.Done(),接著修正 context 的 callback 函式,增加 wg.Wait() 讓 main 函式等到所有的 worker 處理完畢後才關閉 finished channel。

1
2
3
4
5
    ctx := withContextFunc(context.Background(), func() {
        log.Println("cancel from ctrl+c event")
        wg.Wait()
        close(finished)
    })

最後在主程式後面加上 <-finished 即可。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
const poolSize = 2

func main() {
    finished := make(chan bool)
    wg := &sync.WaitGroup{}
    wg.Add(poolSize)
    // create the consumer
    consumer := Consumer{
        inputChan: make(chan int, 10),
        jobsChan:  make(chan int, poolSize),
    }

    ctx := withContextFunc(context.Background(), func() {
        log.Println("cancel from ctrl+c event")
        wg.Wait()
        close(finished)
    })

    for i := 0; i < poolSize; i++ {
        go consumer.worker(ctx, i, wg)
    }

    <-finished
    log.Println("Game over")
}

最後附上完整的程式碼讓大家測試:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
package main

import (
    "context"
    "log"
    "math/rand"
    "os"
    "os/signal"
    "sync"
    "syscall"
    "time"
)

// Consumer struct
type Consumer struct {
    inputChan chan int
    jobsChan  chan int
}

func getRandomTime() int {
    rand.Seed(time.Now().UnixNano())
    return rand.Intn(10)
}

func withContextFunc(ctx context.Context, f func()) context.Context {
    ctx, cancel := context.WithCancel(ctx)
    go func() {
        c := make(chan os.Signal)
        signal.Notify(c, syscall.SIGINT, syscall.SIGTERM)
        defer signal.Stop(c)

        select {
        case <-ctx.Done():
        case <-c:
            cancel()
            f()
        }
    }()

    return ctx
}

func (c *Consumer) queue(input int) bool {
    select {
    case c.inputChan <- input:
        log.Println("already send input value:", input)
        return true
    default:
        return false
    }
}

func (c Consumer) startConsumer(ctx context.Context) {
    for {
        select {
        case job := <-c.inputChan:
            if ctx.Err() != nil {
                close(c.jobsChan)
                return
            }
            c.jobsChan <- job
        case <-ctx.Done():
            close(c.jobsChan)
            return
        }
    }
}

func (c *Consumer) process(num, job int) {
    n := getRandomTime()
    log.Printf("Sleeping %d seconds...\n", n)
    time.Sleep(time.Duration(n) * time.Second)
    log.Println("worker:", num, " job value:", job)
}

func (c *Consumer) worker(ctx context.Context, num int, wg *sync.WaitGroup) {
    defer wg.Done()
    log.Println("start the worker", num)
    for {
        select {
        case job := <-c.jobsChan:
            if ctx.Err() != nil {
                log.Println("get next job", job, "and close the worker", num)
                return
            }
            c.process(num, job)
        case <-ctx.Done():
            log.Println("close the worker", num)
            return
        }
    }
}

const poolSize = 2

func main() {
    finished := make(chan bool)
    wg := &sync.WaitGroup{}
    wg.Add(poolSize)
    // create the consumer
    consumer := Consumer{
        inputChan: make(chan int, 10),
        jobsChan:  make(chan int, poolSize),
    }

    ctx := withContextFunc(context.Background(), func() {
        log.Println("cancel from ctrl+c event")
        wg.Wait()
        close(finished)
    })

    for i := 0; i < poolSize; i++ {
        go consumer.worker(ctx, i, wg)
    }

    go consumer.startConsumer(ctx)

    go func() {
        consumer.queue(1)
        consumer.queue(2)
        consumer.queue(3)
        consumer.queue(4)
        consumer.queue(5)
    }()

    <-finished
    log.Println("Game over")
}

See also