diff --git a/api/git.go b/api/git.go index d74d6ad..a7e85d3 100644 --- a/api/git.go +++ b/api/git.go @@ -54,12 +54,16 @@ func (api *WebAPI) ReceiveGitWebHook(r *Request) { } payload := &GitPayload{} - if r.GetJson(payload) { - logrus.WithFields(logrus.Fields{ - "payload": payload, - }).Info("Received git WebHook") + err := json.Unmarshal(r.Ctx.Request.Body(), payload) + if err != nil { + r.Ctx.SetStatusCode(400) + return } + logrus.WithFields(logrus.Fields{ + "payload": payload, + }).Info("Received git WebHook") + if !isProductionBranch(payload) { return } @@ -72,7 +76,7 @@ func (api *WebAPI) ReceiveGitWebHook(r *Request) { version := getVersion(payload) project.Version = version - err := api.Database.UpdateProject(project) + err = api.Database.UpdateProject(project) handleErr(err, r) } diff --git a/api/helper.go b/api/helper.go index 60010f1..8771eaa 100644 --- a/api/helper.go +++ b/api/helper.go @@ -37,11 +37,3 @@ func (r *Request) Json(object interface{}, code int) { } } - -func (r *Request) GetJson(x interface{}) bool { - - err := json.Unmarshal(r.Ctx.Request.Body(), x) - handleErr(err, r) - - return err == nil -} diff --git a/api/log.go b/api/log.go index 5a41e40..025de6b 100644 --- a/api/log.go +++ b/api/log.go @@ -1,6 +1,7 @@ package api import ( + "encoding/json" "errors" "github.com/Sirupsen/logrus" "github.com/valyala/fasthttp" @@ -20,6 +21,7 @@ type LogRequest struct { Scope string `json:"scope"` Message string `json:"Message"` TimeStamp int64 `json:"timestamp"` + worker *storage.Worker } type GetLogResponse struct { @@ -52,86 +54,134 @@ func (api *WebAPI) SetupLogger() { api.Database.SetupLoggerHook() } -func parseLogEntry(r *Request) *LogRequest { +func (api *WebAPI) parseLogEntry(r *Request) (*LogRequest, error) { + + worker, err := api.validateSignature(r) + if err != nil { + return nil, err + } entry := LogRequest{} - if r.GetJson(&entry) { - if len(entry.Message) == 0 { - handleErr(errors.New("invalid message"), r) - } else if len(entry.Scope) == 0 { - handleErr(errors.New("invalid scope"), r) - } else if entry.TimeStamp <= 0 { - handleErr(errors.New("invalid timestamp"), r) - } + err = json.Unmarshal(r.Ctx.Request.Body(), &entry) + if err != nil { + return nil, err } - return &entry + if len(entry.Message) == 0 { + return nil, errors.New("invalid message") + } else if len(entry.Scope) == 0 { + return nil, errors.New("invalid scope") + } else if entry.TimeStamp <= 0 { + return nil, errors.New("invalid timestamp") + } + + entry.worker = worker + + return &entry, nil } -func LogTrace(r *Request) { +func (api *WebAPI) LogTrace(r *Request) { - entry := parseLogEntry(r) + entry, err := api.parseLogEntry(r) + if err != nil { + r.Json(GetLogResponse{ + Ok: false, + Message: "Could not parse request", + }, 400) + return + } logrus.WithFields(logrus.Fields{ - "scope": entry.Scope, + "scope": entry.Scope, + "worker": entry.worker.Id, }).WithTime(entry.Time()).Trace(entry.Message) } -func LogInfo(r *Request) { +func (api *WebAPI) LogInfo(r *Request) { - entry := parseLogEntry(r) + entry, err := api.parseLogEntry(r) + if err != nil { + r.Json(GetLogResponse{ + Ok: false, + Message: "Could not parse request", + }, 400) + return + } logrus.WithFields(logrus.Fields{ - "scope": entry.Scope, + "scope": entry.Scope, + "worker": entry.worker.Id, }).WithTime(entry.Time()).Info(entry.Message) } -func LogWarn(r *Request) { +func (api *WebAPI) LogWarn(r *Request) { - entry := parseLogEntry(r) + entry, err := api.parseLogEntry(r) + if err != nil { + r.Json(GetLogResponse{ + Ok: false, + Message: "Could not parse request", + }, 400) + return + } logrus.WithFields(logrus.Fields{ - "scope": entry.Scope, + "scope": entry.Scope, + "worker": entry.worker.Id, }).WithTime(entry.Time()).Warn(entry.Message) } -func LogError(r *Request) { +func (api *WebAPI) LogError(r *Request) { - entry := parseLogEntry(r) + entry, err := api.parseLogEntry(r) + if err != nil { + r.Json(GetLogResponse{ + Ok: false, + Message: "Could not parse request", + }, 400) + return + } logrus.WithFields(logrus.Fields{ - "scope": entry.Scope, + "scope": entry.Scope, + "worker": entry.worker.Id, }).WithTime(entry.Time()).Error(entry.Message) } func (api *WebAPI) GetLog(r *Request) { req := &GetLogRequest{} - if r.GetJson(req) { - if req.isValid() { + err := json.Unmarshal(r.Ctx.Request.Body(), req) + if err != nil { + r.Json(GetLogResponse{ + Ok: false, + Message: "Could not parse request", + }, 400) + return + } + if req.isValid() { - logs := api.Database.GetLogs(req.Since, req.Level) + logs := api.Database.GetLogs(req.Since, req.Level) - logrus.WithFields(logrus.Fields{ - "getLogRequest": req, - "logCount": len(*logs), - }).Trace("Get log request") + logrus.WithFields(logrus.Fields{ + "getLogRequest": req, + "logCount": len(*logs), + }).Trace("Get log request") - r.OkJson(GetLogResponse{ - Ok: true, - Logs: logs, - }) - } else { - logrus.WithFields(logrus.Fields{ - "getLogRequest": req, - }).Warn("Invalid log request") + r.OkJson(GetLogResponse{ + Ok: true, + Logs: logs, + }) + } else { + logrus.WithFields(logrus.Fields{ + "getLogRequest": req, + }).Warn("Invalid log request") - r.Json(GetLogResponse{ - Ok: false, - Message: "Invalid log request", - }, 400) - } + r.Json(GetLogResponse{ + Ok: false, + Message: "Invalid log request", + }, 400) } } diff --git a/api/main.go b/api/main.go index 06cbfcd..64f0787 100644 --- a/api/main.go +++ b/api/main.go @@ -43,10 +43,10 @@ func New() *WebAPI { api.router.GET("/", LogRequestMiddleware(Index)) - api.router.POST("/log/trace", LogRequestMiddleware(LogTrace)) - api.router.POST("/log/info", LogRequestMiddleware(LogInfo)) - api.router.POST("/log/warn", LogRequestMiddleware(LogWarn)) - api.router.POST("/log/error", LogRequestMiddleware(LogError)) + api.router.POST("/log/trace", LogRequestMiddleware(api.LogTrace)) + api.router.POST("/log/info", LogRequestMiddleware(api.LogInfo)) + api.router.POST("/log/warn", LogRequestMiddleware(api.LogWarn)) + api.router.POST("/log/error", LogRequestMiddleware(api.LogError)) api.router.POST("/worker/create", LogRequestMiddleware(api.WorkerCreate)) api.router.POST("/worker/update", LogRequestMiddleware(api.WorkerUpdate)) diff --git a/api/project.go b/api/project.go index 8be3d58..6a7c4a6 100644 --- a/api/project.go +++ b/api/project.go @@ -1,6 +1,7 @@ package api import ( + "encoding/json" "github.com/Sirupsen/logrus" "src/task_tracker/storage" "strconv" @@ -57,99 +58,114 @@ type GetAllProjectsStatsResponse struct { func (api *WebAPI) ProjectCreate(r *Request) { createReq := &CreateProjectRequest{} - if r.GetJson(createReq) { + err := json.Unmarshal(r.Ctx.Request.Body(), createReq) + if err != nil { + r.Json(CreateProjectResponse{ + Ok: false, + Message: "Could not parse request", + }, 400) + return + } + project := &storage.Project{ + Name: createReq.Name, + Version: createReq.Version, + CloneUrl: createReq.CloneUrl, + GitRepo: createReq.GitRepo, + Priority: createReq.Priority, + Motd: createReq.Motd, + Public: createReq.Public, + } - project := &storage.Project{ - Name: createReq.Name, - Version: createReq.Version, - CloneUrl: createReq.CloneUrl, - GitRepo: createReq.GitRepo, - Priority: createReq.Priority, - Motd: createReq.Motd, - Public: createReq.Public, - } - - if isValidProject(project) { - id, err := api.Database.SaveProject(project) - - if err != nil { - r.Json(CreateProjectResponse{ - Ok: false, - Message: err.Error(), - }, 500) - } else { - r.OkJson(CreateProjectResponse{ - Ok: true, - Id: id, - }) - logrus.WithFields(logrus.Fields{ - "project": project, - }).Debug("Created project") - } - } else { - logrus.WithFields(logrus.Fields{ - "project": project, - }).Warn("Invalid project") + if isValidProject(project) { + id, err := api.Database.SaveProject(project) + if err != nil { r.Json(CreateProjectResponse{ Ok: false, - Message: "Invalid project", - }, 400) + Message: err.Error(), + }, 500) + } else { + r.OkJson(CreateProjectResponse{ + Ok: true, + Id: id, + }) + logrus.WithFields(logrus.Fields{ + "project": project, + }).Debug("Created project") } + } else { + logrus.WithFields(logrus.Fields{ + "project": project, + }).Warn("Invalid project") + r.Json(CreateProjectResponse{ + Ok: false, + Message: "Invalid project", + }, 400) } } func (api *WebAPI) ProjectUpdate(r *Request) { id, err := strconv.ParseInt(r.Ctx.UserValue("id").(string), 10, 64) - handleErr(err, r) //todo handle invalid id + if err != nil || id <= 0 { + r.Json(CreateProjectResponse{ + Ok: false, + Message: "Invalid project id", + }, 400) + return + } updateReq := &UpdateProjectRequest{} - if r.GetJson(updateReq) { + err = json.Unmarshal(r.Ctx.Request.Body(), updateReq) + if err != nil { + r.Json(CreateProjectResponse{ + Ok: false, + Message: "Could not parse request", + }, 400) + return + } + project := &storage.Project{ + Id: id, + Name: updateReq.Name, + CloneUrl: updateReq.CloneUrl, + GitRepo: updateReq.GitRepo, + Priority: updateReq.Priority, + Motd: updateReq.Motd, + Public: updateReq.Public, + } - project := &storage.Project{ - Id: id, - Name: updateReq.Name, - CloneUrl: updateReq.CloneUrl, - GitRepo: updateReq.GitRepo, - Priority: updateReq.Priority, - Motd: updateReq.Motd, - Public: updateReq.Public, - } - - if isValidProject(project) { - err := api.Database.UpdateProject(project) - - if err != nil { - r.Json(CreateProjectResponse{ - Ok: false, - Message: err.Error(), - }, 500) - - logrus.WithError(err).WithFields(logrus.Fields{ - "project": project, - }).Warn("Error during project update") - } else { - r.OkJson(UpdateProjectResponse{ - Ok: true, - }) - - logrus.WithFields(logrus.Fields{ - "project": project, - }).Debug("Updated project") - } - - } else { - logrus.WithFields(logrus.Fields{ - "project": project, - }).Warn("Invalid project") + if isValidProject(project) { + err := api.Database.UpdateProject(project) + if err != nil { r.Json(CreateProjectResponse{ Ok: false, - Message: "Invalid project", - }, 400) + Message: err.Error(), + }, 500) + + logrus.WithError(err).WithFields(logrus.Fields{ + "project": project, + }).Warn("Error during project update") + } else { + r.OkJson(UpdateProjectResponse{ + Ok: true, + }) + + logrus.WithFields(logrus.Fields{ + "project": project, + }).Debug("Updated project") } + + } else { + logrus.WithFields(logrus.Fields{ + "project": project, + }).Warn("Invalid project") + + r.Json(CreateProjectResponse{ + Ok: false, + Message: "Invalid project", + }, 400) } } @@ -187,7 +203,13 @@ func (api *WebAPI) ProjectGet(r *Request) { func (api *WebAPI) ProjectGetStats(r *Request) { id, err := strconv.ParseInt(r.Ctx.UserValue("id").(string), 10, 64) - handleErr(err, r) + if err != nil { + r.Json(GetProjectStatsResponse{ + Ok: false, + Message: "Could not parse request", + }, 400) + return + } stats := api.Database.GetProjectStats(id) diff --git a/api/task.go b/api/task.go index d7fefab..2cf79b0 100644 --- a/api/task.go +++ b/api/task.go @@ -5,6 +5,7 @@ import ( "crypto" "crypto/hmac" "encoding/hex" + "encoding/json" "errors" "github.com/Sirupsen/logrus" "github.com/dchest/siphash" @@ -45,45 +46,50 @@ type GetTaskResponse struct { func (api *WebAPI) TaskCreate(r *Request) { - var createReq CreateTaskRequest - if r.GetJson(&createReq) { + createReq := &CreateTaskRequest{} + err := json.Unmarshal(r.Ctx.Request.Body(), createReq) + if err != nil { + r.Json(CreateProjectResponse{ + Ok: false, + Message: "Could not parse request", + }, 400) + return + } + task := &storage.Task{ + MaxRetries: createReq.MaxRetries, + Recipe: createReq.Recipe, + Priority: createReq.Priority, + AssignTime: 0, + MaxAssignTime: createReq.MaxAssignTime, + } - task := &storage.Task{ - MaxRetries: createReq.MaxRetries, - Recipe: createReq.Recipe, - Priority: createReq.Priority, - AssignTime: 0, - MaxAssignTime: createReq.MaxAssignTime, + if createReq.IsValid() && isTaskValid(task) { + + if createReq.UniqueString != "" { + //TODO: Load key from config + createReq.Hash64 = int64(siphash.Hash(1, 2, []byte(createReq.UniqueString))) } - if createReq.IsValid() && isTaskValid(task) { + err := api.Database.SaveTask(task, createReq.Project, createReq.Hash64) - if createReq.UniqueString != "" { - //TODO: Load key from config - createReq.Hash64 = int64(siphash.Hash(1, 2, []byte(createReq.UniqueString))) - } - - err := api.Database.SaveTask(task, createReq.Project, createReq.Hash64) - - if err != nil { - r.Json(CreateTaskResponse{ - Ok: false, - Message: err.Error(), //todo: hide sensitive error? - }, 500) - } else { - r.OkJson(CreateTaskResponse{ - Ok: true, - }) - } - } else { - logrus.WithFields(logrus.Fields{ - "task": task, - }).Warn("Invalid task") + if err != nil { r.Json(CreateTaskResponse{ Ok: false, - Message: "Invalid task", - }, 400) + Message: err.Error(), //todo: hide sensitive error? + }, 500) + } else { + r.OkJson(CreateTaskResponse{ + Ok: true, + }) } + } else { + logrus.WithFields(logrus.Fields{ + "task": task, + }).Warn("Invalid task") + r.Json(CreateTaskResponse{ + Ok: false, + Message: "Invalid task", + }, 400) } } @@ -215,30 +221,34 @@ func (api *WebAPI) TaskRelease(r *Request) { return } - var req ReleaseTaskRequest - if r.GetJson(&req) { - - res := api.Database.ReleaseTask(req.TaskId, worker.Id, req.Success) - - response := ReleaseTaskResponse{ - Ok: res, - } - - if !res { - response.Message = "Could not find a task with the specified Id assigned to this workerId" - - logrus.WithFields(logrus.Fields{ - "releaseTaskRequest": req, - "taskUpdated": res, - }).Warn("Release task: NOT FOUND") - } else { - - logrus.WithFields(logrus.Fields{ - "releaseTaskRequest": req, - "taskUpdated": res, - }).Trace("Release task") - } - - r.OkJson(response) + req := &ReleaseTaskRequest{} + err = json.Unmarshal(r.Ctx.Request.Body(), req) + if err != nil { + r.Json(CreateProjectResponse{ + Ok: false, + Message: "Could not parse request", + }, 400) } + res := api.Database.ReleaseTask(req.TaskId, worker.Id, req.Success) + + response := ReleaseTaskResponse{ + Ok: res, + } + + if !res { + response.Message = "Could not find a task with the specified Id assigned to this workerId" + + logrus.WithFields(logrus.Fields{ + "releaseTaskRequest": req, + "taskUpdated": res, + }).Warn("Release task: NOT FOUND") + } else { + + logrus.WithFields(logrus.Fields{ + "releaseTaskRequest": req, + "taskUpdated": res, + }).Trace("Release task") + } + + r.OkJson(response) } diff --git a/api/worker.go b/api/worker.go index 90b10c0..2df4f2f 100644 --- a/api/worker.go +++ b/api/worker.go @@ -1,6 +1,7 @@ package api import ( + "encoding/json" "github.com/Sirupsen/logrus" "math/rand" "src/task_tracker/storage" @@ -46,12 +47,12 @@ type WorkerAccessResponse struct { func (api *WebAPI) WorkerCreate(r *Request) { workerReq := &CreateWorkerRequest{} - if !r.GetJson(workerReq) { + err := json.Unmarshal(r.Ctx.Request.Body(), workerReq) + if err != nil { return } identity := getIdentity(r) - if !canCreateWorker(r, workerReq, identity) { logrus.WithFields(logrus.Fields{ @@ -123,40 +124,51 @@ func (api *WebAPI) WorkerGet(r *Request) { func (api *WebAPI) WorkerGrantAccess(r *Request) { req := &WorkerAccessRequest{} - if r.GetJson(req) { + err := json.Unmarshal(r.Ctx.Request.Body(), req) + if err != nil { + r.Json(GetTaskResponse{ + Ok: false, + Message: "Could not parse request", + }, 400) + return + } - ok := api.Database.GrantAccess(req.WorkerId, req.ProjectId) + ok := api.Database.GrantAccess(req.WorkerId, req.ProjectId) - if ok { - r.OkJson(WorkerAccessResponse{ - Ok: true, - }) - } else { - r.OkJson(WorkerAccessResponse{ - Ok: false, - Message: "Worker already has access to this project", - }) - } + if ok { + r.OkJson(WorkerAccessResponse{ + Ok: true, + }) + } else { + r.OkJson(WorkerAccessResponse{ + Ok: false, + Message: "Worker already has access to this project", + }) } } func (api *WebAPI) WorkerRemoveAccess(r *Request) { req := &WorkerAccessRequest{} - if r.GetJson(req) { + err := json.Unmarshal(r.Ctx.Request.Body(), req) + if err != nil { + r.Json(GetTaskResponse{ + Ok: false, + Message: "Could not parse request", + }, 400) + return + } + ok := api.Database.RemoveAccess(req.WorkerId, req.ProjectId) - ok := api.Database.RemoveAccess(req.WorkerId, req.ProjectId) - - if ok { - r.OkJson(WorkerAccessResponse{ - Ok: true, - }) - } else { - r.OkJson(WorkerAccessResponse{ - Ok: false, - Message: "Worker did not have access to this project", - }) - } + if ok { + r.OkJson(WorkerAccessResponse{ + Ok: true, + }) + } else { + r.OkJson(WorkerAccessResponse{ + Ok: false, + Message: "Worker did not have access to this project", + }) } } @@ -172,22 +184,27 @@ func (api *WebAPI) WorkerUpdate(r *Request) { } req := &UpdateWorkerRequest{} - if r.GetJson(req) { + err = json.Unmarshal(r.Ctx.Request.Body(), req) + if err != nil { + r.Json(GetTaskResponse{ + Ok: false, + Message: "Could not parse request", + }, 400) + return + } + worker.Alias = req.Alias - worker.Alias = req.Alias + ok := api.Database.UpdateWorker(worker) - ok := api.Database.UpdateWorker(worker) - - if ok { - r.OkJson(UpdateWorkerResponse{ - Ok: true, - }) - } else { - r.OkJson(UpdateWorkerResponse{ - Ok: false, - Message: "Could not update worker", - }) - } + if ok { + r.OkJson(UpdateWorkerResponse{ + Ok: true, + }) + } else { + r.OkJson(UpdateWorkerResponse{ + Ok: false, + Message: "Could not update worker", + }) } } diff --git a/test/api_log_test.go b/test/api_log_test.go index 29f3aec..373a85f 100644 --- a/test/api_log_test.go +++ b/test/api_log_test.go @@ -12,11 +12,12 @@ import ( func TestTraceValid(t *testing.T) { + w := genWid() r := Post("/log/trace", api.LogRequest{ Scope: "test", Message: "This is a test message", TimeStamp: time.Now().Unix(), - }, nil) + }, w) if r.StatusCode != 200 { t.Fail() @@ -24,64 +25,68 @@ func TestTraceValid(t *testing.T) { } func TestTraceInvalidScope(t *testing.T) { + w := genWid() r := Post("/log/trace", api.LogRequest{ Message: "this is a test message", TimeStamp: time.Now().Unix(), - }, nil) + }, w) - if r.StatusCode != 500 { - t.Fail() + if r.StatusCode == 200 { + t.Error() } r = Post("/log/trace", api.LogRequest{ Scope: "", Message: "this is a test message", TimeStamp: time.Now().Unix(), - }, nil) + }, w) - if r.StatusCode != 500 { - t.Fail() + if r.StatusCode == 200 { + t.Error() } - if GenericJson(r.Body)["message"] != "invalid scope" { - t.Fail() + if len(GenericJson(r.Body)["message"].(string)) <= 0 { + t.Error() } } func TestTraceInvalidMessage(t *testing.T) { + w := genWid() r := Post("/log/trace", api.LogRequest{ Scope: "test", Message: "", TimeStamp: time.Now().Unix(), - }, nil) + }, w) - if r.StatusCode != 500 { - t.Fail() + if r.StatusCode == 200 { + t.Error() } - if GenericJson(r.Body)["message"] != "invalid message" { - t.Fail() + if len(GenericJson(r.Body)["message"].(string)) <= 0 { + t.Error() } } func TestTraceInvalidTime(t *testing.T) { + w := genWid() r := Post("/log/trace", api.LogRequest{ Scope: "test", Message: "test", - }, nil) - if r.StatusCode != 500 { - t.Fail() + }, w) + if r.StatusCode == 200 { + t.Error() } - if GenericJson(r.Body)["message"] != "invalid timestamp" { - t.Fail() + if len(GenericJson(r.Body)["message"].(string)) <= 0 { + t.Error() } } func TestWarnValid(t *testing.T) { + w := genWid() r := Post("/log/warn", api.LogRequest{ Scope: "test", Message: "test", TimeStamp: time.Now().Unix(), - }, nil) + }, w) if r.StatusCode != 200 { t.Fail() } @@ -89,11 +94,12 @@ func TestWarnValid(t *testing.T) { func TestInfoValid(t *testing.T) { + w := genWid() r := Post("/log/info", api.LogRequest{ Scope: "test", Message: "test", TimeStamp: time.Now().Unix(), - }, nil) + }, w) if r.StatusCode != 200 { t.Fail() } @@ -101,11 +107,12 @@ func TestInfoValid(t *testing.T) { func TestErrorValid(t *testing.T) { + w := genWid() r := Post("/log/error", api.LogRequest{ Scope: "test", Message: "test", TimeStamp: time.Now().Unix(), - }, nil) + }, w) if r.StatusCode != 200 { t.Fail() }