From 6048cfbebc37b2c892590800c92bf27294adfc04 Mon Sep 17 00:00:00 2001 From: simon987 Date: Sat, 9 Mar 2019 09:20:51 -0500 Subject: [PATCH] minimum viable (excluding auth) --- .gitignore | 2 + api/api.go | 136 +++++++++++++++++++++++++++++++++ api/auth.go | 62 +++++++++++++++ api/models.go | 71 +++++++++++++++++ api/slot.go | 180 ++++++++++++++++++++++++++++++++++++++++++++ main.go | 17 +++++ test/auth_test.go | 24 ++++++ test/common.go | 63 ++++++++++++++++ test/main_test.go | 25 ++++++ test/slot_test.go | 53 +++++++++++++ test/upload_test.go | 179 +++++++++++++++++++++++++++++++++++++++++++ 11 files changed, 812 insertions(+) create mode 100644 api/api.go create mode 100644 api/auth.go create mode 100644 api/models.go create mode 100644 api/slot.go create mode 100644 main.go create mode 100644 test/auth_test.go create mode 100644 test/common.go create mode 100644 test/main_test.go create mode 100644 test/slot_test.go create mode 100644 test/upload_test.go diff --git a/.gitignore b/.gitignore index f1c181e..ea5c68e 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,5 @@ # Output of the go coverage tool, specifically when used with LiteIDE *.out + +.idea/ diff --git a/api/api.go b/api/api.go new file mode 100644 index 0000000..12a12af --- /dev/null +++ b/api/api.go @@ -0,0 +1,136 @@ +package api + +import ( + "encoding/json" + "github.com/buaazp/fasthttprouter" + "github.com/fasthttp/websocket" + "github.com/jinzhu/gorm" + _ "github.com/jinzhu/gorm/dialects/postgres" + "github.com/sirupsen/logrus" + "github.com/valyala/fasthttp" + "os" + "path/filepath" +) + +var WorkDir, _ = filepath.Abs("./data/") + +type Info struct { + Name string `json:"name"` + Version string `json:"version"` +} + +var info = Info{ + Name: "ws_bucket", + Version: "1.0", +} + +var motd = WebsocketMotd{ + Info: info, + Motd: "Hello, world", +} + +type WebApi struct { + server fasthttp.Server + db *gorm.DB + MotdMessage *websocket.PreparedMessage +} + +func Index(ctx *fasthttp.RequestCtx) { + Json(info, ctx) +} + +func Json(object interface{}, ctx *fasthttp.RequestCtx) { + + resp, err := json.Marshal(object) + if err != nil { + panic(err) + } + + ctx.Response.Header.Set("Content-Type", "application/json") + _, err = ctx.Write(resp) + if err != nil { + panic(err) + } +} + +func LogRequestMiddleware(h fasthttp.RequestHandler) fasthttp.RequestHandler { + return fasthttp.RequestHandler(func(ctx *fasthttp.RequestCtx) { + + logrus.WithFields(logrus.Fields{ + "path": string(ctx.Path()), + "header": ctx.Request.Header.String(), + }).Trace(string(ctx.Method())) + + h(ctx) + }) +} + +func New(db *gorm.DB) *WebApi { + + api := &WebApi{} + + logrus.SetLevel(getLogLevel()) + + router := fasthttprouter.New() + router.GET("/", LogRequestMiddleware(Index)) + + router.POST("/client", LogRequestMiddleware(api.CreateClient)) + + router.POST("/slot", LogRequestMiddleware(api.AllocateUploadSlot)) + router.GET("/slot", LogRequestMiddleware(api.ReadUploadSlot)) + router.GET("/upload", LogRequestMiddleware(api.Upload)) + + api.server = fasthttp.Server{ + Handler: router.Handler, + Name: "ws_bucket", + } + + api.db = db + db.AutoMigrate(&Client{}) + db.AutoMigrate(&UploadSlot{}) + + api.setupMotd() + + return api +} + +func (api *WebApi) setupMotd() { + var data []byte + data, _ = json.Marshal(motd) + motdMsg, _ := websocket.NewPreparedMessage(websocket.TextMessage, data) + api.MotdMessage = motdMsg +} + +func (api *WebApi) Run() { + address := GetServerAddress() + + logrus.WithFields(logrus.Fields{ + "addr": address, + }).Info("Starting web server") + + err := api.server.ListenAndServe(address) + if err != nil { + logrus.Fatalf("Error in ListenAndServe: %s", err) + } +} + +func GetServerAddress() string { + serverAddress := os.Getenv("WS_BUCKET_ADDR") + if serverAddress == "" { + serverAddress = "0.0.0.0:3020" + } + return serverAddress +} + +func getLogLevel() logrus.Level { + levelStr := os.Getenv("WS_BUCKET_LOGLEVEL") + if levelStr == "" { + return logrus.TraceLevel + } else { + level, err := logrus.ParseLevel(levelStr) + if err != nil { + panic(err) + } + return level + } +} diff --git a/api/auth.go b/api/auth.go new file mode 100644 index 0000000..36d013a --- /dev/null +++ b/api/auth.go @@ -0,0 +1,62 @@ +package api + +import ( + "encoding/json" + "github.com/sirupsen/logrus" + "github.com/valyala/fasthttp" + "math/rand" +) + +func (api *WebApi) CreateClient(ctx *fasthttp.RequestCtx) { + + //TODO: auth + + req := &CreateClientRequest{} + err := json.Unmarshal(ctx.Request.Body(), req) + if err != nil { + ctx.Response.SetStatusCode(400) + Json(CreateClientResponse{ + Ok: false, + }, ctx) + return + } + + if !req.IsValid() { + ctx.Response.SetStatusCode(400) + Json(CreateClientResponse{ + Ok: false, + }, ctx) + return + } + + client := api.createClient(req) + + Json(CreateClientResponse{ + Ok: true, + Secret: client.Secret, + }, ctx) +} + +func (api *WebApi) createClient(req *CreateClientRequest) *Client { + + client := &Client{ + Alias: req.Alias, + Secret: genSecret(), + } + + api.db.Create(client) + + logrus.WithFields(logrus.Fields{ + "client": client, + }).Info("Created client") + + return client +} + +func genSecret() string { + bytes := make([]byte, 32) + for i := 0; i < 32; i++ { + bytes[i] = byte(48 + rand.Intn(122-48)) + } + return string(bytes) +} diff --git a/api/models.go b/api/models.go new file mode 100644 index 0000000..8704286 --- /dev/null +++ b/api/models.go @@ -0,0 +1,71 @@ +package api + +import ( + "path/filepath" + "strings" +) + +type GenericResponse struct { + Ok bool `json:"ok"` +} + +type CreateClientRequest struct { + Alias string `json:"alias"` +} + +func (req *CreateClientRequest) IsValid() bool { + return len(req.Alias) > 3 +} + +type CreateClientResponse struct { + Ok bool `json:"ok"` + Secret string `json:"secret,omitempty"` +} + +type Client struct { + ID int64 + Alias string `json:"alias"` + Secret string `json:"secret"` +} + +type AllocateUploadSlotRequest struct { + Token string `json:"token"` + MaxSize int64 `json:"max_size"` + FileName string `json:"file_name"` +} + +func (req *AllocateUploadSlotRequest) IsValid() bool { + if len(req.Token) < 3 { + return false + } + + if len(req.FileName) <= 0 { + return false + } + + path := filepath.Join(WorkDir, req.FileName) + pathAbs, err := filepath.Abs(path) + if err != nil { + return false + } + if !strings.HasPrefix(pathAbs, WorkDir) { + return false + } + + if req.MaxSize < 0 { + return false + } + + return true +} + +type UploadSlot struct { + MaxSize int64 `json:"max_size"` + Token string `gorm:"primary_key",json:"token"` + FileName string `json:"file_name"` +} + +type WebsocketMotd struct { + Info Info `json:"info"` + Motd string `json:"motd"` +} diff --git a/api/slot.go b/api/slot.go new file mode 100644 index 0000000..2e8153c --- /dev/null +++ b/api/slot.go @@ -0,0 +1,180 @@ +package api + +import ( + "encoding/json" + "github.com/fasthttp/websocket" + "github.com/sirupsen/logrus" + "github.com/valyala/fasthttp" + "io" + "os" + "path/filepath" + "sync" +) + +const WsBufferSize = 4096 + +var Mutexes sync.Map +var upgrader = websocket.FastHTTPUpgrader{ + ReadBufferSize: WsBufferSize, + WriteBufferSize: WsBufferSize, + EnableCompression: true, +} + +func (api *WebApi) AllocateUploadSlot(ctx *fasthttp.RequestCtx) { + + req := &AllocateUploadSlotRequest{} + err := json.Unmarshal(ctx.Request.Body(), req) + if err != nil { + ctx.Response.SetStatusCode(400) + Json(GenericResponse{ + Ok: false, + }, ctx) + return + } + + if !req.IsValid() { + ctx.Response.SetStatusCode(400) + Json(CreateClientResponse{ + Ok: false, + }, ctx) + return + } + + api.allocateUploadSlot(req) + + Json(CreateClientResponse{ + Ok: true, + }, ctx) +} + +func (api *WebApi) Upload(ctx *fasthttp.RequestCtx) { + + token := string(ctx.Request.Header.Peek("X-Upload-Token")) + slot := UploadSlot{} + err := api.db.Where("token=?", token).First(&slot).Error + if err != nil { + ctx.Response.Header.SetStatusCode(400) + logrus.WithFields(logrus.Fields{ + "token": token, + }).Warning("Upload slot not found") + return + } + + logrus.WithFields(logrus.Fields{ + "slot": slot, + }).Info("Upgrading connection") + + err = upgrader.Upgrade(ctx, func(ws *websocket.Conn) { + defer ws.Close() + + err := ws.WritePreparedMessage(api.MotdMessage) + if err != nil { + panic(err) + } + + mt, reader, err := ws.NextReader() + if err != nil { + panic(err) + } + if mt != websocket.BinaryMessage { + return + } + + mu, _ := Mutexes.LoadOrStore(slot.Token, &sync.RWMutex{}) + mu.(*sync.RWMutex).Lock() + path := filepath.Join(WorkDir, slot.FileName) + fp, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600) + if err != nil { + panic(err) + } + + buf := make([]byte, WsBufferSize) + totalRead := int64(0) + for totalRead < slot.MaxSize { + read, err := reader.Read(buf) + + var toWrite int + if totalRead+int64(read) > slot.MaxSize { + toWrite = int(slot.MaxSize - totalRead) + } else { + toWrite = read + } + + _, _ = fp.Write(buf[:toWrite]) + if err == io.EOF { + break + } + totalRead += int64(read) + } + + logrus.WithFields(logrus.Fields{ + "totalRead": totalRead, + }).Info("Finished reading") + err = fp.Close() + if err != nil { + panic(err) + } + mu.(*sync.RWMutex).Unlock() + }) + if err != nil { + panic(err) + } +} + +func (api *WebApi) ReadUploadSlot(ctx *fasthttp.RequestCtx) { + + tokenStr := string(ctx.Request.Header.Peek("X-Upload-Token")) + + slot := UploadSlot{} + err := api.db.Where("token=?", tokenStr).First(&slot).Error + + if err != nil { + ctx.Response.Header.SetStatusCode(404) + logrus.WithFields(logrus.Fields{ + "token": tokenStr, + }).Warning("Upload slot not found") + return + } + + logrus.WithFields(logrus.Fields{ + "slot": slot, + }).Info("Reading") + + path := filepath.Join(WorkDir, slot.FileName) + + mu, _ := Mutexes.LoadOrStore(slot.Token, &sync.RWMutex{}) + mu.(*sync.RWMutex).RLock() + fp, err := os.OpenFile(path, os.O_RDONLY, 0600) + if err != nil { + panic(err) + } + + buf := make([]byte, WsBufferSize) + response := ctx.Response.BodyWriter() + for { + read, err := fp.Read(buf) + _, _ = response.Write(buf[:read]) + if err == io.EOF { + break + } + if err != nil { + panic(err) + } + } + mu.(*sync.RWMutex).RUnlock() +} + +func (api *WebApi) allocateUploadSlot(req *AllocateUploadSlotRequest) { + + slot := &UploadSlot{ + MaxSize: req.MaxSize, + FileName: req.FileName, + Token: req.Token, + } + + logrus.WithFields(logrus.Fields{ + "slot": slot, + }).Info("Allocated new upload slot") + + api.db.Create(slot) +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..7d5d74b --- /dev/null +++ b/main.go @@ -0,0 +1,17 @@ +package main + +import ( + "github.com/jinzhu/gorm" + "github.com/simon987/ws_bucket/api" +) + +func main() { + + db, err := gorm.Open("postgres", "host=localhost user=ws_bucket dbname=ws_bucket password=ws_bucket sslmode=disable") + if err != nil { + panic(err) + } + + a := api.New(db) + a.Run() +} diff --git a/test/auth_test.go b/test/auth_test.go new file mode 100644 index 0000000..ec3259e --- /dev/null +++ b/test/auth_test.go @@ -0,0 +1,24 @@ +package test + +import ( + "github.com/simon987/ws_bucket/api" + "testing" +) + +func TestCreateClient(t *testing.T) { + + r := createClient(api.CreateClientRequest{ + Alias: "testcreateclient", + }) + + if r.Ok != true { + t.Error() + } +} + +func createClient(request api.CreateClientRequest) (ar *api.CreateClientResponse) { + + resp := Post("/client", request) + UnmarshalResponse(resp, &ar) + return +} diff --git a/test/common.go b/test/common.go new file mode 100644 index 0000000..50fe86b --- /dev/null +++ b/test/common.go @@ -0,0 +1,63 @@ +package test + +import ( + "bytes" + "encoding/json" + "fmt" + "github.com/simon987/ws_bucket/api" + "io/ioutil" + "net/http" +) + +func Post(path string, x interface{}) *http.Response { + + s := http.Client{} + + body, err := json.Marshal(x) + buf := bytes.NewBuffer(body) + + req, err := http.NewRequest("POST", "http://"+api.GetServerAddress()+path, buf) + handleErr(err) + + //ts := time.Now().Format(time.RFC1123) + // + //mac := hmac.New(crypto.SHA256.New, worker.Secret) + //mac.Write(body) + //mac.Write([]byte(ts)) + //sig := hex.EncodeToString(mac.Sum(nil)) + // + //req.Header.Add("X-Worker-Id", strconv.FormatInt(worker.Id, 10)) + //req.Header.Add("X-Signature", sig) + //req.Header.Add("Timestamp", ts) + + r, err := s.Do(req) + handleErr(err) + + return r +} + +func Get(path string, token string) *http.Response { + + s := http.Client{} + + req, err := http.NewRequest("GET", "http://"+api.GetServerAddress()+path, nil) + handleErr(err) + + req.Header.Set("X-Upload-Token", token) + + r, err := s.Do(req) + return r +} + +func UnmarshalResponse(r *http.Response, result interface{}) { + data, err := ioutil.ReadAll(r.Body) + fmt.Println(string(data)) + err = json.Unmarshal(data, result) + handleErr(err) +} + +func handleErr(err error) { + if err != nil { + panic(err) + } +} diff --git a/test/main_test.go b/test/main_test.go new file mode 100644 index 0000000..1dd0687 --- /dev/null +++ b/test/main_test.go @@ -0,0 +1,25 @@ +package test + +import ( + "github.com/jinzhu/gorm" + _ "github.com/jinzhu/gorm/dialects/sqlite" + "github.com/simon987/ws_bucket/api" + "testing" + "time" +) + +func TestMain(m *testing.M) { + + //db, err := gorm.Open("postgres", "host=localhost user=ws_bucket dbname=ws_bucket password=ws_bucket sslmode=disable") + db, err := gorm.Open("sqlite3", ":memory:") + if err != nil { + panic(err) + } + + a := api.New(db) + go a.Run() + + time.Sleep(time.Millisecond * 100) + + m.Run() +} diff --git a/test/slot_test.go b/test/slot_test.go new file mode 100644 index 0000000..f7583e0 --- /dev/null +++ b/test/slot_test.go @@ -0,0 +1,53 @@ +package test + +import ( + "github.com/simon987/ws_bucket/api" + "testing" +) + +func TestAllocateUploadInvalidMaxSize(t *testing.T) { + + if allocateUploadSlot(api.AllocateUploadSlotRequest{ + FileName: "valid", + Token: "valid", + MaxSize: -1, + }).Ok != false { + t.Error() + } +} + +func TestAllocateUploadSlotInvalidToken(t *testing.T) { + + if allocateUploadSlot(api.AllocateUploadSlotRequest{ + FileName: "valid", + Token: "", + MaxSize: 100, + }).Ok != false { + t.Error() + } +} + +func TestAllocateUploadSlotUnsafePath(t *testing.T) { + + if allocateUploadSlot(api.AllocateUploadSlotRequest{ + FileName: "../test.png", + Token: "valid", + MaxSize: 100, + }).Ok != false { + t.Error() + } + + if allocateUploadSlot(api.AllocateUploadSlotRequest{ + FileName: "test/../../test.png", + Token: "valid", + MaxSize: 100, + }).Ok != false { + t.Error() + } +} + +func allocateUploadSlot(request api.AllocateUploadSlotRequest) (ar *api.GenericResponse) { + resp := Post("/slot", request) + UnmarshalResponse(resp, &ar) + return +} diff --git a/test/upload_test.go b/test/upload_test.go new file mode 100644 index 0000000..342b396 --- /dev/null +++ b/test/upload_test.go @@ -0,0 +1,179 @@ +package test + +import ( + "bytes" + "fmt" + "github.com/fasthttp/websocket" + "github.com/google/uuid" + "github.com/simon987/ws_bucket/api" + "io/ioutil" + "math" + "net/http" + "net/url" + "testing" +) + +func TestWebsocketReturnsMotd(t *testing.T) { + + id := uuid.New() + allocateUploadSlot(api.AllocateUploadSlotRequest{ + FileName: "testmotd", + MaxSize: 0, + Token: id.String(), + }) + + c := ws(id.String()) + motd := &api.WebsocketMotd{} + err := c.ReadJSON(&motd) + handleErr(err) + + if len(motd.Motd) <= 0 { + t.Error() + } + if len(motd.Info.Version) <= 0 { + t.Error() + } +} + +func TestWebSocketUploadSmallFile(t *testing.T) { + + id := uuid.New() + + allocateUploadSlot(api.AllocateUploadSlotRequest{ + FileName: "testfile", + Token: id.String(), + MaxSize: math.MaxInt64, + }) + + c := ws(id.String()) + _, _, err := c.ReadMessage() + handleErr(err) + + err = c.WriteMessage(websocket.BinaryMessage, []byte("testuploadsmallfile")) + handleErr(err) + + err = c.Close() + handleErr(err) + + resp := readUploadSlot(id.String()) + + if bytes.Compare(resp, []byte("testuploadsmallfile")) != 0 { + t.Error() + } +} + +func TestWebSocketUploadOverwritesFile(t *testing.T) { + + id := uuid.New() + + allocateUploadSlot(api.AllocateUploadSlotRequest{ + FileName: "testuploadoverwrites", + Token: id.String(), + MaxSize: math.MaxInt64, + }) + + c := ws(id.String()) + _, _, err := c.ReadMessage() + handleErr(err) + + err = c.WriteMessage(websocket.BinaryMessage, []byte("testuploadsmallfile")) + handleErr(err) + + err = c.Close() + handleErr(err) + + c1 := ws(id.String()) + _, _, err = c1.ReadMessage() + handleErr(err) + + err = c1.WriteMessage(websocket.BinaryMessage, []byte("newvalue")) + handleErr(err) + + err = c1.Close() + handleErr(err) + + resp := readUploadSlot(id.String()) + + if bytes.Compare(resp, []byte("newvalue")) != 0 { + t.Error() + } +} + +func TestWebSocketUploadLargeFile(t *testing.T) { + + id := uuid.New() + + allocateUploadSlot(api.AllocateUploadSlotRequest{ + FileName: "testlargefile", + Token: id.String(), + MaxSize: math.MaxInt64, + }) + + c := ws(id.String()) + _, _, err := c.ReadMessage() + handleErr(err) + + chunk := make([]byte, 100000) + _ = copy(chunk, "test") + _ = c.WriteMessage(websocket.BinaryMessage, chunk) + + err = c.Close() + handleErr(err) + + resp := readUploadSlot(id.String()) + + if bytes.Compare(resp, chunk) != 0 { + t.Error() + } +} + +func TestWebSocketUploadMaxSize(t *testing.T) { + + id := uuid.New() + + allocateUploadSlot(api.AllocateUploadSlotRequest{ + FileName: "testmaxsize", + Token: id.String(), + MaxSize: 10, + }) + + c := ws(id.String()) + _, _, err := c.ReadMessage() + handleErr(err) + + chunk := make([]byte, 100000) + _ = copy(chunk, "test") + _ = c.WriteMessage(websocket.BinaryMessage, chunk) + + err = c.Close() + handleErr(err) + + resp := readUploadSlot(id.String()) + + if len(resp) != 10 { + t.Error() + } +} + +func readUploadSlot(token string) []byte { + + r := Get("/slot", token) + + data, err := ioutil.ReadAll(r.Body) + handleErr(err) + + return data +} + +func ws(slot string) *websocket.Conn { + + u := url.URL{Scheme: "ws", Host: "localhost:3021", Path: "/upload"} + fmt.Printf("Connecting to %s", u.String()) + + header := http.Header{} + header.Add("X-Upload-Token", slot) + c, _, err := websocket.DefaultDialer.Dial(u.String(), header) + handleErr(err) + + return c +}