最近看到 Go 語言 一段程式碼,認為有很大的優化空間,也將過程跟想法分享給大家。也許每個人優化的方向不同,各位讀者可以把程式碼 整個看完後,先停住,不要繼續往下看,想看看是否有優化的空間。此程式碼本身沒有任何問題,執行過程不會出現任何錯誤。
先說明底下範例在做什麼,相信大家都有聽過在 Go 語言內要實現 Worker Pools 機制相當簡單,看到 ExecuteAll
函式就是讓開發者可以自訂同時間開多少個 Goroutine 來平行執行工作,第二個參數可以自訂義工作內容是什麼。
影片教學 VIDEO
其他線上課程請參考如下
程式碼 線上測試看看
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
package main
import (
"context"
"errors"
"fmt"
"runtime"
"sync"
)
type TaskFunc func (ctx context.Context) error
func ExecuteAll (numCPU int , tasks ...TaskFunc) error {
var err error
ctx, cancel := context.WithCancel (context.Background ())
defer cancel ()
wg := sync.WaitGroup{}
wg.Add (len (tasks))
if numCPU == 0 {
numCPU = runtime.NumCPU ()
}
fmt.Println ("numCPU:" , numCPU)
queue := make (chan TaskFunc, numCPU)
// Spawn the executer
for i := 0 ; i < numCPU; i++ {
go func () {
for task := range queue {
fmt.Println ("get task" )
if err == nil {
taskErr := task (ctx)
if taskErr != nil {
err = taskErr
cancel ()
}
}
wg.Done ()
}
}()
}
// Add tasks to queue
for _, task := range tasks {
queue <- task
}
close (queue)
// wait for all task done
wg.Wait ()
return err
}
func main () {
tasks := make ([]TaskFunc, 0 , 100 )
for i := 0 ; i < 100 ; i++ {
func (val int ) {
tasks = append (tasks, func (ctx context.Context) error {
fmt.Println (val)
if val == 51 {
return errors.New ("missing" )
}
return nil
})
}(i)
}
err := ExecuteAll (0 , tasks...)
if err == nil {
fmt.Println ("missing error" )
}
}
三大優化方向 問題一 大家看完上述程式碼,是否心裡已經有想法該怎麼優化,或者是有看出什麼問題?首先我看到第一個 疑問
1
2
wg := sync.WaitGroup{}
wg.Add (len (tasks))
為什麼是從 Task 數量來放進去 WatiGroup,理論上我們是要控制開多少個 Goroutine,而不是將 Task 數量全部執行完畢,才結束程式。
問題二 第二個 問題就是這段代碼會 blocking 在最下面的讀取 Task 塞入 Queue 變數上,大家看到底下代碼,宣告的是根據想要開多少 Goroutine 的 buffer 大小 Channel。舉例假設使用 4 core,然後 100 個 Task,每個 Task 執行需要 10 秒,此時塞 4 個 Task 進去 Queue 後,會被順利讀取出來 4 個 task,接著 Queue 又被塞滿 4 個 task 後,就無法再繼續將新的 Task 放入,故程式就會被 blocking。
1
2
3
4
5
6
7
8
9
10
11
12
13
queue := make (chan TaskFunc, numCPU)
//
// 中間省略一堆代碼
//
//
// Add tasks to queue
for _, task := range tasks {
queue <- task
}
close (queue)
// wait for all task done
wg.Wait ()
問題三 先看看讀取 Task 的 goroutine for 迴圈,由於只要有一個 Task 執行錯誤,就會將錯誤設定給全域變數 err
,但是可以看到如果有 1 萬的 Task,此迴圈後續還是將每個 Task 都讀取出來,完全沒有使用到 Context 重要的 Channel 功能。更多 Context 用法可以參考這篇『用 10 分鐘了解 Go 語言 context package 使用場景及介紹 』
1
2
3
4
5
6
7
8
9
10
11
12
13
go func () {
for task := range queue {
fmt.Println ("get task" )
if err == nil {
taskErr := task (ctx)
if taskErr != nil {
err = taskErr
cancel ()
}
}
wg.Done ()
}
}()
重構流程 改寫 sync.WaitGroup 使用方式 根據上面提到的三個問題,底下來一一解決,首先這段程式碼目的是開多個平行化處理的 Goroutine,故結束前必須要等待全部 Goroutine 執行完成才讓主程式繼續往下走,所以使用 sync.WaitGroup
可以改成根據目前設定多少平行處理來決定
1
2
3
4
5
6
if numCPU == 0 {
numCPU = runtime.NumCPU ()
}
wg := sync.WaitGroup{}
wg.Add (numCPU)
改寫 buffer channel 大小 上面有提到 Channel 大小原本使用要同步處理多少工作當作 Buffer 大小,但是只要 Task 數量大於 Buffer 大小,就會出現 blocking,故這邊可以改成底下
1
2
3
4
5
6
7
queue := make (chan TaskFunc, len (tasks))
// Add tasks to queue
for _, task := range tasks {
queue <- task
}
close (queue)
將 Buffer 大小改成跟 Task 數量一致,藉此透過 for 迴圈先將 Task 塞到 Channel 內,並關閉 Channel 即可。
讀取 Task 流程 此函式目的就是平行跑多個 Task,遇到任何錯誤,就中斷流程,並返回錯誤訊息,故需要透過 Context Cancel 特性來改寫原本流程
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
for i := 0 ; i < numCPU; i++ {
go func () {
defer wg.Done ()
for {
select {
case task, ok := <-queue:
if ctx.Err () != nil || !ok {
return
}
fmt.Println ("get task" )
if e := task (ctx); e != nil {
err = e
cancel ()
}
case <-ctx.Done ():
return
}
}
}()
}
當 Task 出現錯誤時,會將錯誤訊息放到全域變數 err 內,並且執行 cancel()
,此時 for 在讀取下一個 Job 時,就可以透過 <-ctx.Done()
或 ctx.Err()
方式來終止程式執行,這樣才不會多跑了很多次迴圈
心得 Worker Pool 網路上寫法千奇百種,優化的方式每個人想的也是不一樣,透過這樣的練習可以加深自己對於 Go Channel 特性。原本的程式碼都可以正常執行沒問題,只是看到覺得有幾個地方可以優化,故寫在這邊紀錄重構想法,可以讓剛入門 Go 語言的朋友們參考。
See also