由於專案需求,需要開發一套 CLI 工具,讓 User 可以透過 CLI 上傳大檔案來進行 Model Training,請參考上面的流程圖。首先第一步驟會先跟 API Server 驗證使用者,驗證完畢就開始上傳資料到 AWS S3 或其他 Storage 空間,除了上傳過程需要在 CLI 顯示目前進度,另外也需要將目前上傳的進度 (速度, 進度及剩餘時間) 都上傳到 API Server,最後在 Web UI 介面透過 GraphQL Subscription 讓使用者可以即時看到上傳進度數據。
而 CLI 上傳進度部分,我們選用了一套開源套件 cheggaaa/pb,相信有在寫 Go 語言都並不會陌生。而此套件雖然可以幫助在 Terminal 顯示進度條,但是有些接口是沒有提供的,像是即時速度,上傳進度及剩餘時間。本篇教大家如何實作這些數據,及分享過程會遇到相關問題。
讀取上傳進度顯示
透過 cheggaaa/pb 提供的範例如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
| package main
import (
"crypto/rand"
"io"
"io/ioutil"
"log"
"github.com/cheggaaa/pb/v3"
)
func main() {
var limit int64 = 1024 * 1024 * 10000
// we will copy 10 Gb from /dev/rand to /dev/null
reader := io.LimitReader(rand.Reader, limit)
writer := ioutil.Discard
// start new bar
bar := pb.Full.Start64(limit)
// create proxy reader
barReader := bar.NewProxyReader(reader)
// copy from proxy reader
if _, err := io.Copy(writer, barReader); err != nil {
log.Fatal(err)
}
// finish bar
bar.Finish()
}
|
很清楚可以看到透過 io.Copy
方式開始上傳模擬進度。接著需要透過 goroutine
方式讀取目前進度並上傳到 API Server。
計算上傳進度及剩餘時間
使用 pb v3 版本只有開放幾個 public 資訊,像是起始進度時間,及目前上傳了多少 bits 資料,透過這兩個資料,可以即時算出剩餘時間,目前速度及進度。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
| package main
import (
"crypto/rand"
"fmt"
"io"
"io/ioutil"
"log"
"time"
"github.com/cheggaaa/pb/v3"
)
func main() {
var limit int64 = 1024 * 1024 * 10000
// we will copy 10 Gb from /dev/rand to /dev/null
reader := io.LimitReader(rand.Reader, limit)
writer := ioutil.Discard
// start new bar
bar := pb.Full.Start64(limit)
go func(bar *pb.ProgressBar) {
d := time.NewTicker(2 * time.Second)
startTime := bar.StartTime()
// Using for loop
for {
// Select statement
select {
// Case to print current time
case <-d.C:
if !bar.IsStarted() {
continue
}
currentTime := time.Now()
dur := currentTime.Sub(startTime)
lastSpeed := float64(bar.Current()) / dur.Seconds()
remain := float64(bar.Total() - bar.Current())
remainDur := time.Duration(remain/lastSpeed) * time.Second
fmt.Println("Progress:", float32(bar.Current())/float32(bar.Total())*100)
fmt.Println("last speed:", lastSpeed/1024/1024)
fmt.Println("remain duration:", remainDur)
// TODO: upload progress and remain duration to api server
}
}
}(bar)
// create proxy reader
barReader := bar.NewProxyReader(reader)
// copy from proxy reader
if _, err := io.Copy(writer, barReader); err != nil {
log.Fatal(err)
}
// finish bar
bar.Finish()
}
|
使用 time.NewTicker
固定每兩秒計算目前進度資料,並且上傳到 API Server,從上傳資料及使用的時間,可以算出目前 Speed 大概多少,當然這不是很準,原因是從上傳開始到現在時間計算 (總已上傳資料/目前花費時間)。
使用 Channel 結束上傳
做完上述這些功能,不難的發現有個問題,這個 goroutine 不會停止,還是會每兩秒去計算進度,這時候需要透過一個 Channel 通知 goroutine 結束。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
| package main
import (
"crypto/rand"
"fmt"
"io"
"io/ioutil"
"log"
"time"
"github.com/cheggaaa/pb/v3"
)
func main() {
var limit int64 = 1024 * 1024 * 10000
// we will copy 10 Gb from /dev/rand to /dev/null
reader := io.LimitReader(rand.Reader, limit)
writer := ioutil.Discard
// start new bar
bar := pb.Full.Start64(limit)
finishCh := make(chan struct{})
go func(bar *pb.ProgressBar) {
d := time.NewTicker(2 * time.Second)
startTime := bar.StartTime()
// Using for loop
for {
// Select statement
select {
case <-finishCh:
d.Stop()
log.Println("finished")
return
// Case to print current time
case <-d.C:
if !bar.IsStarted() {
continue
}
currentTime := time.Now()
dur := currentTime.Sub(startTime)
lastSpeed := float64(bar.Current()) / dur.Seconds()
remain := float64(bar.Total() - bar.Current())
remainDur := time.Duration(remain/lastSpeed) * time.Second
fmt.Println("Progress:", float32(bar.Current())/float32(bar.Total())*100)
fmt.Println("last speed:", lastSpeed/1024/1024)
fmt.Println("remain suration:", remainDur)
}
}
}(bar)
// create proxy reader
barReader := bar.NewProxyReader(reader)
// copy from proxy reader
if _, err := io.Copy(writer, barReader); err != nil {
log.Fatal(err)
}
// finish bar
bar.Finish()
close(finishCh)
}
|
先宣告一個 finishCh := make(chan struct{})
,用來通知 goroutine 跳出迴圈,大家注意看一下,最後是用的是關閉 Channel,如果是用底下方法:
1
| finishCh <- strunct{}{}
|
這時候看看 switch case 有機率是同時到達,造成無法跳脫迴圈,而直接關閉 channel,可以確保 case <-finishCh
一直拿到空的資料,進而達成跳出迴圈的需求。
整合 Graceful Shutdown
最後來看看如何整合 Graceful Shutdown。當使用者按下 ctrl + c
需要停止上傳,並將狀態改成 stopped
。底下來看看加上 Graceful Shutdown 的方式:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
| package main
import (
"context"
"crypto/rand"
"fmt"
"io"
"io/ioutil"
"log"
"os"
"os/signal"
"syscall"
"time"
"github.com/cheggaaa/pb/v3"
)
func withContextFunc(ctx context.Context, f func()) context.Context {
ctx, cancel := context.WithCancel(ctx)
go func() {
c := make(chan os.Signal, 1)
signal.Notify(c, syscall.SIGINT, syscall.SIGTERM)
defer signal.Stop(c)
select {
case <-ctx.Done():
case <-c:
f()
cancel()
}
}()
return ctx
}
func main() {
ctx := withContextFunc(
context.Background(),
func() {
// clear machine field
log.Println("interrupt received, terminating process")
},
)
var limit int64 = 1024 * 1024 * 10000
// we will copy 10 Gb from /dev/rand to /dev/null
reader := io.LimitReader(rand.Reader, limit)
writer := ioutil.Discard
// start new bar
bar := pb.Full.Start64(limit)
finishCh := make(chan struct{})
go func(ctx context.Context, bar *pb.ProgressBar) {
d := time.NewTicker(2 * time.Second)
startTime := bar.StartTime()
// Using for loop
for {
// Select statement
select {
case <-ctx.Done():
d.Stop()
log.Println("interrupt received")
return
case <-finishCh:
d.Stop()
log.Println("finished")
return
// Case to print current time
case <-d.C:
if ctx.Err() != nil {
return
}
if !bar.IsStarted() {
continue
}
currentTime := time.Now()
dur := currentTime.Sub(startTime)
lastSpeed := float64(bar.Current()) / dur.Seconds()
remain := float64(bar.Total() - bar.Current())
remainDur := time.Duration(remain/lastSpeed) * time.Second
fmt.Println("Progress:", float32(bar.Current())/float32(bar.Total())*100)
fmt.Println("last speed:", lastSpeed/1024/1024)
fmt.Println("remain suration:", remainDur)
}
}
}(ctx, bar)
// create proxy reader
barReader := bar.NewProxyReader(reader)
// copy from proxy reader
if _, err := io.Copy(writer, barReader); err != nil {
log.Fatal(err)
}
// finish bar
bar.Finish()
close(finishCh)
}
|
透過 Go 語言的 context 跟 signal.Notify 可以偵測是否有系統訊號關閉 CLI 程式,這時候就可以做後續相對應的事情,在程式碼就需要多接受 ctx.Done()
Channel,由於在 Select 多個 Channel 通道,故也是有可能同時發生,所以需要在另外的 switch case 內判斷 conetxt 的 Err 錯誤訊息,如果不等於 nil 那就是收到訊號,進而 return,必免 goroutine 在背景持續進行。大家執行上述程式後,按下 ctrl + c 可以正常看到底下訊息:
1
2
3
4
5
| ^C
2021/05/21 12:29:25 interrupt received, terminating process
2021/05/21 12:29:25 interrupt received
^C
signal: interrupt
|
可以看到要在按下一次 ctrl + c 才能結束程式,這邊的原因就是 io.Reader 還是正在上傳,並沒有停止,而系統第一次中斷訊號已經被程式用掉了,這時候解決方式就是要修改底下程式
1
2
3
4
5
| barReader := bar.NewProxyReader(reader)
// copy from proxy reader
if _, err := io.Copy(writer, barReader); err != nil {
log.Fatal(err)
}
|
io.Copy 支援 context 中斷
io.Copy
需要支援 context 中斷程式,但是我們只能從 reader 下手,,先看看原本 Reader 的 interface:
1
2
3
| type Reader interface {
Read(p []byte) (n int, err error)
}
|
現在來自己寫一份 func 來支援 context 功能:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
| type readerFunc func(p []byte) (n int, err error)
func (r readerFunc) Read(p []byte) (n int, err error) { return rf(p) }
func copy(ctx context.Context, dst io.Writer, src io.Reader) error {
_, err := io.Copy(dst, readerFunc(func(p []byte) (int, error) {
select {
case <-ctx.Done():
return 0, ctx.Err()
default:
return src.Read(p)
}
}))
return err
}
|
由於 io.Reader 會把整個檔案分成多個 chunk 分別上傳,避免 Memory 直接讀取太大的檔案而爆掉,那在每個 chunk 上傳前確保沒有收到 context 中斷的訊息,這樣就可以解決無法停止上傳的行為。整體程式碼如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
| package main
import (
"context"
"crypto/rand"
"fmt"
"io"
"io/ioutil"
"log"
"os"
"os/signal"
"syscall"
"time"
"github.com/cheggaaa/pb/v3"
)
type readerFunc func(p []byte) (n int, err error)
func (rf readerFunc) Read(p []byte) (n int, err error) { return rf(p) }
func copy(ctx context.Context, dst io.Writer, src io.Reader) error {
_, err := io.Copy(dst, readerFunc(func(p []byte) (int, error) {
select {
case <-ctx.Done():
return 0, ctx.Err()
default:
return src.Read(p)
}
}))
return err
}
func withContextFunc(ctx context.Context, f func()) context.Context {
ctx, cancel := context.WithCancel(ctx)
go func() {
c := make(chan os.Signal, 1)
signal.Notify(c, syscall.SIGINT, syscall.SIGTERM)
defer signal.Stop(c)
select {
case <-ctx.Done():
case <-c:
f()
cancel()
}
}()
return ctx
}
func main() {
ctx := withContextFunc(
context.Background(),
func() {
// clear machine field
log.Println("interrupt received, terminating process")
},
)
var limit int64 = 1024 * 1024 * 10000
// we will copy 10 Gb from /dev/rand to /dev/null
reader := io.LimitReader(rand.Reader, limit)
writer := ioutil.Discard
// start new bar
bar := pb.Full.Start64(limit)
finishCh := make(chan struct{})
go func(bar *pb.ProgressBar) {
d := time.NewTicker(2 * time.Second)
startTime := bar.StartTime()
// Using for loop
for {
// Select statement
select {
case <-ctx.Done():
log.Println("stop to get current process")
return
case <-finishCh:
d.Stop()
log.Println("finished")
return
// Case to print current time
case <-d.C:
if !bar.IsStarted() {
continue
}
currentTime := time.Now()
dur := currentTime.Sub(startTime)
lastSpeed := float64(bar.Current()) / dur.Seconds()
remain := float64(bar.Total() - bar.Current())
remainDur := time.Duration(remain/lastSpeed) * time.Second
fmt.Println("Progress:", float32(bar.Current())/float32(bar.Total())*100)
fmt.Println("last speed:", lastSpeed/1024/1024)
fmt.Println("remain suration:", remainDur)
}
}
}(bar)
// create proxy reader
barReader := bar.NewProxyReader(reader)
// copy from proxy reader
if err := copy(ctx, writer, barReader); err != nil {
log.Println("cancel upload data:", err.Error())
}
// finish bar
bar.Finish()
close(finishCh)
time.Sleep(1 * time.Second)
}
|
See also