diff --git a/api/main.go b/api/main.go index 4cbe2bf..97c3bf3 100644 --- a/api/main.go +++ b/api/main.go @@ -109,6 +109,7 @@ func New() *WebAPI { api.router.POST("/project/reclaim_assigned_tasks/:id", LogRequestMiddleware(api.ReclaimAssignedTasks)) api.router.POST("/task/submit", LogRequestMiddleware(api.SubmitTask)) + api.router.POST("/task/bulk_submit", LogRequestMiddleware(api.BulkSubmitTask)) api.router.GET("/task/get/:project", LogRequestMiddleware(api.GetTaskFromProject)) api.router.POST("/task/release", LogRequestMiddleware(api.ReleaseTask)) diff --git a/api/models.go b/api/models.go index ea372c6..d41041c 100644 --- a/api/models.go +++ b/api/models.go @@ -222,6 +222,28 @@ func (req *SubmitTaskRequest) IsValid() bool { return true } +type BulkSubmitTaskRequest struct { + Requests []SubmitTaskRequest `json:"requests"` +} + +func (reqs BulkSubmitTaskRequest) IsValid() bool { + + if reqs.Requests == nil { + return false + } + + if len(reqs.Requests) == 0 { + return false + } + + for _, req := range reqs.Requests { + if !req.IsValid() { + return false + } + } + return true +} + type ReleaseTaskRequest struct { TaskId int64 `json:"task_id"` Result storage.TaskResult `json:"result"` diff --git a/api/rate.go b/api/rate.go index c74a9da..66fc12b 100644 --- a/api/rate.go +++ b/api/rate.go @@ -5,7 +5,7 @@ import ( "time" ) -func (api *WebAPI) ReserveSubmit(pid int64) *rate.Reservation { +func (api *WebAPI) ReserveSubmit(pid int64, count int) *rate.Reservation { limiter, ok := api.SubmitLimiters.Load(pid) if !ok { @@ -18,7 +18,7 @@ func (api *WebAPI) ReserveSubmit(pid int64) *rate.Reservation { api.SubmitLimiters.Store(pid, limiter) } - return limiter.(*rate.Limiter).ReserveN(time.Now(), 1) + return limiter.(*rate.Limiter).ReserveN(time.Now(), count) } func (api *WebAPI) ReserveAssign(pid int64) *rate.Reservation { diff --git a/api/task.go b/api/task.go index 48694c2..d481e1f 100644 --- a/api/task.go +++ b/api/task.go @@ -32,6 +32,18 @@ func (api *WebAPI) SubmitTask(r *Request) { }, 400) return } + + if !createReq.IsValid() { + logrus.WithFields(logrus.Fields{ + "req": createReq, + }).Warn("Invalid task") + r.Json(JsonResponse{ + Ok: false, + Message: "Invalid task", + }, 400) + return + } + task := &storage.Task{ MaxRetries: createReq.MaxRetries, Recipe: createReq.Recipe, @@ -41,18 +53,7 @@ func (api *WebAPI) SubmitTask(r *Request) { VerificationCount: createReq.VerificationCount, } - if !createReq.IsValid() { - logrus.WithFields(logrus.Fields{ - "task": task, - }).Warn("Invalid task") - r.Json(JsonResponse{ - Ok: false, - Message: "Invalid task", - }, 400) - return - } - - reservation := api.ReserveSubmit(createReq.Project) + reservation := api.ReserveSubmit(createReq.Project, 1) if reservation == nil { r.Json(JsonResponse{ Ok: false, @@ -91,6 +92,103 @@ func (api *WebAPI) SubmitTask(r *Request) { }) } +func (api *WebAPI) BulkSubmitTask(r *Request) { + + worker, err := api.validateSecret(r) + if err != nil { + r.Json(JsonResponse{ + Ok: false, + Message: err.Error(), + }, 401) + return + } + + createReq := &BulkSubmitTaskRequest{} + err = json.Unmarshal(r.Ctx.Request.Body(), createReq) + if err != nil || createReq.Requests == nil || len(createReq.Requests) == 0 { + r.Json(JsonResponse{ + Ok: false, + Message: "Could not parse request", + }, 400) + return + } + if !createReq.IsValid() { + logrus.WithFields(logrus.Fields{ + "req": createReq, + }).Warn("Invalid request") + r.Json(JsonResponse{ + Ok: false, + Message: "Invalid request", + }, 400) + return + } + + saveRequests := make([]storage.SaveTaskRequest, len(createReq.Requests)) + projectId := createReq.Requests[0].Project + for i, req := range createReq.Requests { + + if req.Project != projectId { + r.Json(JsonResponse{ + Ok: false, + Message: "All the tasks in a bulk submit must be of the same project", + }, 400) + return + } + + if req.UniqueString != "" { + req.Hash64 = int64(siphash.Hash(1, 2, []byte(req.UniqueString))) + } + + saveRequests[i] = storage.SaveTaskRequest{ + Task: &storage.Task{ + MaxRetries: req.MaxRetries, + Recipe: req.Recipe, + Priority: req.Priority, + AssignTime: 0, + MaxAssignTime: req.MaxAssignTime, + VerificationCount: req.VerificationCount, + }, + Project: projectId, + WorkerId: worker.Id, + Hash64: req.Hash64, + } + } + + reservation := api.ReserveSubmit(projectId, len(saveRequests)) + if reservation == nil { + r.Json(JsonResponse{ + Ok: false, + Message: "Project not found", + }, 404) + return + } + delay := reservation.DelayFrom(time.Now()).Seconds() + if delay > 0 { + r.Json(JsonResponse{ + Ok: false, + Message: "Too many requests", + RateLimitDelay: delay, + }, 429) + reservation.Cancel() + return + } + + saveErrors := api.Database.BulkSaveTask(saveRequests) + + if saveErrors == nil { + r.Json(JsonResponse{ + Ok: false, + Message: "Fatal error during bulk insert, see server logs", + }, 400) + reservation.Cancel() + return + } + + r.OkJson(JsonResponse{ + Ok: true, + }) +} + func (api *WebAPI) GetTaskFromProject(r *Request) { worker, err := api.validateSecret(r) diff --git a/storage/task.go b/storage/task.go index ff2ad7a..fdee8ad 100644 --- a/storage/task.go +++ b/storage/task.go @@ -35,6 +35,13 @@ const ( TR_SKIP TaskResult = 2 ) +type SaveTaskRequest struct { + Task *Task + Project int64 + Hash64 int64 + WorkerId int64 +} + func (database *Database) SaveTask(task *Task, project int64, hash64 int64, wid int64) error { db := database.getDB() @@ -67,6 +74,38 @@ func (database *Database) SaveTask(task *Task, project int64, hash64 int64, wid return nil } +func (database Database) BulkSaveTask(bulkSaveTaskReqs []SaveTaskRequest) []error { + + db := database.getDB() + + tx, err := db.Begin() + if err != nil { + handleErr(err) + return nil + } + + errs := make([]error, len(bulkSaveTaskReqs)) + + for i, req := range bulkSaveTaskReqs { + res, err := tx.Exec(fmt.Sprintf(` + INSERT INTO task (project, max_retries, recipe, priority, max_assign_time, hash64,verification_count) + SELECT $1,$2,$3,$4,$5,NULLIF(%d, 0),$6 FROM worker_access + WHERE role_submit AND NOT request AND worker=$7 AND project=$1`, req.Hash64), + req.Project, req.Task.MaxRetries, req.Task.Recipe, req.Task.Priority, + req.Task.MaxAssignTime, req.Task.VerificationCount, + req.WorkerId) + errs[i] = err + + rowsAffected, _ := res.RowsAffected() + if rowsAffected == 0 { + errs[i] = errors.New("unauthorized task submit") + } + } + _ = tx.Commit() + + return errs +} + func (database Database) ReleaseTask(id int64, workerId int64, result TaskResult, verification int64) bool { db := database.getDB() diff --git a/test/api_task_test.go b/test/api_task_test.go index 51f209d..9287cb6 100644 --- a/test/api_task_test.go +++ b/test/api_task_test.go @@ -920,6 +920,107 @@ func TestTaskSubmitInvalidDoesntGiveRateLimit(t *testing.T) { t.Error() } } + +func TestBulkTaskSubmitValid(t *testing.T) { + + proj := createProjectAsAdmin(api.CreateProjectRequest{ + Name: "testbulksubmit", + CloneUrl: "testbulkprojectsubmit", + GitRepo: "testbulkprojectsubmit", + }).Content.Id + + r := bulkSubmitTask(api.BulkSubmitTaskRequest{ + Requests: []api.SubmitTaskRequest{ + { + Recipe: "1234", + Project: proj, + }, + { + Recipe: "1234", + Project: proj, + }, + { + Recipe: "1234", + Project: proj, + }, + }, + }, testWorker) + + if r.Ok != true { + t.Error() + } +} + +func TestBulkTaskSubmitNotTheSameProject(t *testing.T) { + + proj := createProjectAsAdmin(api.CreateProjectRequest{ + Name: "testbulksubmitnotprj", + CloneUrl: "testbulkprojectsubmitnotprj", + GitRepo: "testbulkprojectsubmitnotprj", + }).Content.Id + + r := bulkSubmitTask(api.BulkSubmitTaskRequest{ + Requests: []api.SubmitTaskRequest{ + { + Recipe: "1234", + Project: proj, + }, + { + Recipe: "1234", + Project: 348729, + }, + }, + }, testWorker) + + if r.Ok != false { + t.Error() + } +} + +func TestBulkTaskSubmitInvalid(t *testing.T) { + + proj := createProjectAsAdmin(api.CreateProjectRequest{ + Name: "testbulksubmitinvalid", + CloneUrl: "testbulkprojectsubmitinvalid", + GitRepo: "testbulkprojectsubmitinvalid", + }).Content.Id + + r := bulkSubmitTask(api.BulkSubmitTaskRequest{ + Requests: []api.SubmitTaskRequest{ + { + Recipe: "1234", + Project: proj, + }, + { + + Recipe: "", + Project: proj, + }, + }, + }, testWorker) + + if r.Ok != false { + t.Error() + } +} + +func TestBulkTaskSubmitInvalid2(t *testing.T) { + + r := bulkSubmitTask(api.BulkSubmitTaskRequest{ + Requests: []api.SubmitTaskRequest{}, + }, testWorker) + + if r.Ok != false { + t.Error() + } +} + +func bulkSubmitTask(request api.BulkSubmitTaskRequest, worker *storage.Worker) (ar api.JsonResponse) { + r := Post("/task/bulk_submit", request, worker, nil) + UnmarshalResponse(r, &ar) + return +} + func createTask(request api.SubmitTaskRequest, worker *storage.Worker) (ar api.JsonResponse) { r := Post("/task/submit", request, worker, nil) UnmarshalResponse(r, &ar)