如何取得上傳進度條 progress bar 相關數據及實作 Graceful Shutdown

由於專案需求,需要開發一套 CLI 工具,讓 User 可以透過 CLI 上傳大檔案來進行 Model Training,請參考上面的流程圖。首先第一步驟會先跟 API Server 驗證使用者,驗證完畢就開始上傳資料到 AWS S3 或其他 Storage 空間,除了上傳過程需要在 CLI 顯示目前進度,另外也需要將目前上傳的進度 (速度, 進度及剩餘時間) 都上傳到 API Server,最後在 Web UI 介面透過 GraphQL Subscription 讓使用者可以即時看到上傳進度數據。

而 CLI 上傳進度部分,我們選用了一套開源套件 cheggaaa/pb,相信有在寫 Go 語言都並不會陌生。而此套件雖然可以幫助在 Terminal 顯示進度條,但是有些接口是沒有提供的,像是即時速度,上傳進度及剩餘時間。本篇教大家如何實作這些數據,及分享過程會遇到相關問題。

讀取上傳進度顯示

透過 cheggaaa/pb 提供的範例如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
package main

import (
    "crypto/rand"
    "io"
    "io/ioutil"
    "log"

    "github.com/cheggaaa/pb/v3"
)

func main() {

    var limit int64 = 1024 * 1024 * 10000
    // we will copy 10 Gb from /dev/rand to /dev/null
    reader := io.LimitReader(rand.Reader, limit)
    writer := ioutil.Discard

    // start new bar
    bar := pb.Full.Start64(limit)
    // create proxy reader
    barReader := bar.NewProxyReader(reader)
    // copy from proxy reader
    if _, err := io.Copy(writer, barReader); err != nil {
        log.Fatal(err)
    }
    // finish bar
    bar.Finish()
}

很清楚可以看到透過 io.Copy 方式開始上傳模擬進度。接著需要透過 goroutine 方式讀取目前進度並上傳到 API Server。

計算上傳進度及剩餘時間

使用 pb v3 版本只有開放幾個 public 資訊,像是起始進度時間,及目前上傳了多少 bits 資料,透過這兩個資料,可以即時算出剩餘時間,目前速度及進度。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
package main

import (
    "crypto/rand"
    "fmt"
    "io"
    "io/ioutil"
    "log"
    "time"

    "github.com/cheggaaa/pb/v3"
)

func main() {
    var limit int64 = 1024 * 1024 * 10000
    // we will copy 10 Gb from /dev/rand to /dev/null
    reader := io.LimitReader(rand.Reader, limit)
    writer := ioutil.Discard

    // start new bar
    bar := pb.Full.Start64(limit)
    go func(bar *pb.ProgressBar) {
        d := time.NewTicker(2 * time.Second)
        startTime := bar.StartTime()
        // Using for loop
        for {
            // Select statement
            select {
            // Case to print current time
            case <-d.C:
                if !bar.IsStarted() {
                    continue
                }
                currentTime := time.Now()
                dur := currentTime.Sub(startTime)
                lastSpeed := float64(bar.Current()) / dur.Seconds()
                remain := float64(bar.Total() - bar.Current())
                remainDur := time.Duration(remain/lastSpeed) * time.Second
                fmt.Println("Progress:", float32(bar.Current())/float32(bar.Total())*100)
                fmt.Println("last speed:", lastSpeed/1024/1024)
                fmt.Println("remain duration:", remainDur)

                // TODO: upload progress and remain duration to api server
            }
        }
    }(bar)
    // create proxy reader
    barReader := bar.NewProxyReader(reader)
    // copy from proxy reader
    if _, err := io.Copy(writer, barReader); err != nil {
        log.Fatal(err)
    }
    // finish bar
    bar.Finish()
}

使用 time.NewTicker 固定每兩秒計算目前進度資料,並且上傳到 API Server,從上傳資料及使用的時間,可以算出目前 Speed 大概多少,當然這不是很準,原因是從上傳開始到現在時間計算 (總已上傳資料/目前花費時間)。

使用 Channel 結束上傳

做完上述這些功能,不難的發現有個問題,這個 goroutine 不會停止,還是會每兩秒去計算進度,這時候需要透過一個 Channel 通知 goroutine 結束。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
package main

import (
    "crypto/rand"
    "fmt"
    "io"
    "io/ioutil"
    "log"
    "time"

    "github.com/cheggaaa/pb/v3"
)

func main() {
    var limit int64 = 1024 * 1024 * 10000
    // we will copy 10 Gb from /dev/rand to /dev/null
    reader := io.LimitReader(rand.Reader, limit)
    writer := ioutil.Discard

    // start new bar
    bar := pb.Full.Start64(limit)
    finishCh := make(chan struct{})
    go func(bar *pb.ProgressBar) {
        d := time.NewTicker(2 * time.Second)
        startTime := bar.StartTime()
        // Using for loop
        for {
            // Select statement
            select {
            case <-finishCh:
                d.Stop()
                log.Println("finished")
                return
            // Case to print current time
            case <-d.C:
                if !bar.IsStarted() {
                    continue
                }
                currentTime := time.Now()
                dur := currentTime.Sub(startTime)
                lastSpeed := float64(bar.Current()) / dur.Seconds()
                remain := float64(bar.Total() - bar.Current())
                remainDur := time.Duration(remain/lastSpeed) * time.Second
                fmt.Println("Progress:", float32(bar.Current())/float32(bar.Total())*100)
                fmt.Println("last speed:", lastSpeed/1024/1024)
                fmt.Println("remain suration:", remainDur)
            }
        }
    }(bar)
    // create proxy reader
    barReader := bar.NewProxyReader(reader)
    // copy from proxy reader
    if _, err := io.Copy(writer, barReader); err != nil {
        log.Fatal(err)
    }
    // finish bar
    bar.Finish()
    close(finishCh)
}

先宣告一個 finishCh := make(chan struct{}),用來通知 goroutine 跳出迴圈,大家注意看一下,最後是用的是關閉 Channel,如果是用底下方法:

1
finishCh <- strunct{}{}

這時候看看 switch case 有機率是同時到達,造成無法跳脫迴圈,而直接關閉 channel,可以確保 case <-finishCh 一直拿到空的資料,進而達成跳出迴圈的需求。

整合 Graceful Shutdown

最後來看看如何整合 Graceful Shutdown。當使用者按下 ctrl + c 需要停止上傳,並將狀態改成 stopped。底下來看看加上 Graceful Shutdown 的方式:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
package main

import (
    "context"
    "crypto/rand"
    "fmt"
    "io"
    "io/ioutil"
    "log"
    "os"
    "os/signal"
    "syscall"
    "time"

    "github.com/cheggaaa/pb/v3"
)

func withContextFunc(ctx context.Context, f func()) context.Context {
    ctx, cancel := context.WithCancel(ctx)
    go func() {
        c := make(chan os.Signal, 1)
        signal.Notify(c, syscall.SIGINT, syscall.SIGTERM)
        defer signal.Stop(c)

        select {
        case <-ctx.Done():
        case <-c:
            f()
            cancel()
        }
    }()

    return ctx
}

func main() {

    ctx := withContextFunc(
        context.Background(),
        func() {
            // clear machine field
            log.Println("interrupt received, terminating process")
        },
    )

    var limit int64 = 1024 * 1024 * 10000
    // we will copy 10 Gb from /dev/rand to /dev/null
    reader := io.LimitReader(rand.Reader, limit)
    writer := ioutil.Discard

    // start new bar
    bar := pb.Full.Start64(limit)
    finishCh := make(chan struct{})
    go func(ctx context.Context, bar *pb.ProgressBar) {
        d := time.NewTicker(2 * time.Second)
        startTime := bar.StartTime()
        // Using for loop
        for {
            // Select statement
            select {
            case <-ctx.Done():
                d.Stop()
                log.Println("interrupt received")
                return
            case <-finishCh:
                d.Stop()
                log.Println("finished")
                return
            // Case to print current time
            case <-d.C:
                if ctx.Err() != nil {
                    return
                }
                if !bar.IsStarted() {
                    continue
                }
                currentTime := time.Now()
                dur := currentTime.Sub(startTime)
                lastSpeed := float64(bar.Current()) / dur.Seconds()
                remain := float64(bar.Total() - bar.Current())
                remainDur := time.Duration(remain/lastSpeed) * time.Second
                fmt.Println("Progress:", float32(bar.Current())/float32(bar.Total())*100)
                fmt.Println("last speed:", lastSpeed/1024/1024)
                fmt.Println("remain suration:", remainDur)
            }
        }
    }(ctx, bar)
    // create proxy reader
    barReader := bar.NewProxyReader(reader)
    // copy from proxy reader
    if _, err := io.Copy(writer, barReader); err != nil {
        log.Fatal(err)
    }
    // finish bar
    bar.Finish()
    close(finishCh)
}

透過 Go 語言的 context 跟 signal.Notify 可以偵測是否有系統訊號關閉 CLI 程式,這時候就可以做後續相對應的事情,在程式碼就需要多接受 ctx.Done() Channel,由於在 Select 多個 Channel 通道,故也是有可能同時發生,所以需要在另外的 switch case 內判斷 conetxt 的 Err 錯誤訊息,如果不等於 nil 那就是收到訊號,進而 return,必免 goroutine 在背景持續進行。大家執行上述程式後,按下 ctrl + c 可以正常看到底下訊息:

1
2
3
4
5
^C
2021/05/21 12:29:25 interrupt received, terminating process
2021/05/21 12:29:25 interrupt received
^C
signal: interrupt

可以看到要在按下一次 ctrl + c 才能結束程式,這邊的原因就是 io.Reader 還是正在上傳,並沒有停止,而系統第一次中斷訊號已經被程式用掉了,這時候解決方式就是要修改底下程式

1
2
3
4
5
    barReader := bar.NewProxyReader(reader)
    // copy from proxy reader
    if _, err := io.Copy(writer, barReader); err != nil {
        log.Fatal(err)
    }

io.Copy 支援 context 中斷

io.Copy 需要支援 context 中斷程式,但是我們只能從 reader 下手,,先看看原本 Reader 的 interface:

1
2
3
type Reader interface {
    Read(p []byte) (n int, err error)
}

現在來自己寫一份 func 來支援 context 功能:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
type readerFunc func(p []byte) (n int, err error)

func (r readerFunc) Read(p []byte) (n int, err error) { return rf(p) }
func copy(ctx context.Context, dst io.Writer, src io.Reader) error {
    _, err := io.Copy(dst, readerFunc(func(p []byte) (int, error) {
        select {
        case <-ctx.Done():
            return 0, ctx.Err()
        default:
            return src.Read(p)
        }
    }))
    return err
}

由於 io.Reader 會把整個檔案分成多個 chunk 分別上傳,避免 Memory 直接讀取太大的檔案而爆掉,那在每個 chunk 上傳前確保沒有收到 context 中斷的訊息,這樣就可以解決無法停止上傳的行為。整體程式碼如下:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
package main

import (
    "context"
    "crypto/rand"
    "fmt"
    "io"
    "io/ioutil"
    "log"
    "os"
    "os/signal"
    "syscall"
    "time"

    "github.com/cheggaaa/pb/v3"
)

type readerFunc func(p []byte) (n int, err error)

func (rf readerFunc) Read(p []byte) (n int, err error) { return rf(p) }

func copy(ctx context.Context, dst io.Writer, src io.Reader) error {
    _, err := io.Copy(dst, readerFunc(func(p []byte) (int, error) {
        select {
        case <-ctx.Done():
            return 0, ctx.Err()
        default:
            return src.Read(p)
        }
    }))
    return err
}

func withContextFunc(ctx context.Context, f func()) context.Context {
    ctx, cancel := context.WithCancel(ctx)
    go func() {
        c := make(chan os.Signal, 1)
        signal.Notify(c, syscall.SIGINT, syscall.SIGTERM)
        defer signal.Stop(c)

        select {
        case <-ctx.Done():
        case <-c:
            f()
            cancel()
        }
    }()

    return ctx
}

func main() {

    ctx := withContextFunc(
        context.Background(),
        func() {
            // clear machine field
            log.Println("interrupt received, terminating process")
        },
    )

    var limit int64 = 1024 * 1024 * 10000
    // we will copy 10 Gb from /dev/rand to /dev/null
    reader := io.LimitReader(rand.Reader, limit)
    writer := ioutil.Discard

    // start new bar
    bar := pb.Full.Start64(limit)
    finishCh := make(chan struct{})
    go func(bar *pb.ProgressBar) {
        d := time.NewTicker(2 * time.Second)
        startTime := bar.StartTime()
        // Using for loop
        for {
            // Select statement
            select {
            case <-ctx.Done():
                log.Println("stop to get current process")
                return
            case <-finishCh:
                d.Stop()
                log.Println("finished")
                return
            // Case to print current time
            case <-d.C:
                if !bar.IsStarted() {
                    continue
                }
                currentTime := time.Now()
                dur := currentTime.Sub(startTime)
                lastSpeed := float64(bar.Current()) / dur.Seconds()
                remain := float64(bar.Total() - bar.Current())
                remainDur := time.Duration(remain/lastSpeed) * time.Second
                fmt.Println("Progress:", float32(bar.Current())/float32(bar.Total())*100)
                fmt.Println("last speed:", lastSpeed/1024/1024)
                fmt.Println("remain suration:", remainDur)
            }
        }
    }(bar)
    // create proxy reader
    barReader := bar.NewProxyReader(reader)
    // copy from proxy reader
    if err := copy(ctx, writer, barReader); err != nil {
        log.Println("cancel upload data:", err.Error())
    }
    // finish bar
    bar.Finish()
    close(finishCh)
    time.Sleep(1 * time.Second)
}

See also