diff --git a/config.go b/config.go index b8c15ef..495a45d 100644 --- a/config.go +++ b/config.go @@ -4,6 +4,7 @@ import ( "bufio" "fmt" "github.com/sirupsen/logrus" + "github.com/spf13/pflag" "github.com/spf13/viper" "io" "os" @@ -26,6 +27,8 @@ var config struct { JobBufferSize int } +var onlineMode bool + const ( ConfServerUrl = "server.url" ConfToken = "server.token" @@ -43,6 +46,7 @@ const ( ConfDialTimeout = "crawl.dial_timeout" ConfTimeout = "crawl.timeout" ConfJobBufferSize = "crawl.job_buffer" + ConfResume = "crawl.resume" ConfCrawlStats = "output.crawl_stats" ConfAllocStats = "output.resource_stats" @@ -54,8 +58,58 @@ const ( func prepareConfig() { pf := rootCmd.PersistentFlags() - bind := func(s string) { - if err := viper.BindPFlag(s, pf.Lookup(s)); err != nil { + pf.SortFlags = false + pf.StringVar(&configFile, "config", "", "Config file") + configFile = os.Getenv("OD_CONFIG") + + pf.String(ConfServerUrl, "http://od-db.the-eye.eu/api", "OD-DB server URL") + + pf.String(ConfToken, "", "OD-DB access token (env OD_SERVER_TOKEN)") + + pf.Duration(ConfServerTimeout, 60 * time.Second, "OD-DB request timeout") + + pf.Duration(ConfRecheck, 1 * time.Second, "OD-DB: Poll interval for new jobs") + + pf.Duration(ConfCooldown, 30 * time.Second, "OD-DB: Time to wait after a server-side error") + + pf.String(ConfChunkSize, "1 MB", "OD-DB: Result upload chunk size") + + pf.Uint(ConfUploadRetries, 10, "OD-DB: Max upload retries") + + pf.Duration(ConfUploadRetryInterval, 30 * time.Second, "OD-DB: Time to wait between upload retries") + + pf.Uint(ConfTasks, 100, "Crawler: Max concurrent tasks") + + pf.Uint(ConfWorkers, 4, "Crawler: Connections per server") + + pf.Uint(ConfRetries, 5, "Crawler: Request retries") + + pf.Duration(ConfDialTimeout, 10 * time.Second, "Crawler: Handshake timeout") + + pf.Duration(ConfTimeout, 30 * time.Second, "Crawler: Request timeout") + + pf.String(ConfUserAgent, "Mozilla/5.0 (X11; od-database-crawler) Gecko/20100101 Firefox/52.0", "Crawler: User-Agent") + + pf.Uint(ConfJobBufferSize, 5000, "Crawler: Task queue cache size") + + pf.Duration(ConfResume, 72 * time.Hour, "Crawler: Resume tasks not older than x") + + pf.Duration(ConfCrawlStats, time.Second, "Log: Crawl stats interval") + + pf.Duration(ConfAllocStats, 10 * time.Second, "Log: Resource stats interval") + + pf.Bool(ConfVerbose, false, "Log: Print every listed dir") + + pf.Bool(ConfPrintHTTP, false, "Log: Print HTTP client errors") + + pf.String(ConfLogFile, "crawler.log", "Log file") + + // Bind all flags to Viper + pf.VisitAll(func(flag *pflag.Flag) { + s := flag.Name + s = strings.TrimLeft(s, "-") + + if err := viper.BindPFlag(s, flag); err != nil { panic(err) } var envKey string @@ -65,71 +119,7 @@ func prepareConfig() { if err := viper.BindEnv(s, envKey); err != nil { panic(err) } - } - - pf.SortFlags = false - pf.StringVar(&configFile, "config", "", "Config file") - configFile = os.Getenv("OD_CONFIG") - - pf.String(ConfServerUrl, "http://od-db.the-eye.eu/api", "OD-DB server URL") - bind(ConfServerUrl) - - pf.String(ConfToken, "", "OD-DB access token (env OD_SERVER_TOKEN)") - bind(ConfToken) - - pf.Duration(ConfServerTimeout, 60 * time.Second, "OD-DB request timeout") - bind(ConfServerTimeout) - - pf.Duration(ConfRecheck, 1 * time.Second, "OD-DB: Poll interval for new jobs") - bind(ConfRecheck) - - pf.Duration(ConfCooldown, 30 * time.Second, "OD-DB: Time to wait after a server-side error") - bind(ConfCooldown) - - pf.String(ConfChunkSize, "1 MB", "OD-DB: Result upload chunk size") - bind(ConfChunkSize) - - pf.Uint(ConfUploadRetries, 10, "OD-DB: Max upload retries") - bind(ConfUploadRetries) - - pf.Duration(ConfUploadRetryInterval, 30 * time.Second, "OD-DB: Time to wait between upload retries") - bind(ConfUploadRetryInterval) - - pf.Uint(ConfTasks, 100, "Crawler: Max concurrent tasks") - bind(ConfTasks) - - pf.Uint(ConfWorkers, 4, "Crawler: Connections per server") - bind(ConfWorkers) - - pf.Uint(ConfRetries, 5, "Crawler: Request retries") - bind(ConfRetries) - - pf.Duration(ConfDialTimeout, 10 * time.Second, "Crawler: Handshake timeout") - bind(ConfDialTimeout) - - pf.Duration(ConfTimeout, 30 * time.Second, "Crawler: Request timeout") - bind(ConfTimeout) - - pf.String(ConfUserAgent, "Mozilla/5.0 (X11; od-database-crawler) Gecko/20100101 Firefox/52.0", "Crawler: User-Agent") - bind(ConfUserAgent) - - pf.Uint(ConfJobBufferSize, 5000, "Crawler: Task queue cache size") - bind(ConfJobBufferSize) - - pf.Duration(ConfCrawlStats, time.Second, "Log: Crawl stats interval") - bind(ConfCrawlStats) - - pf.Duration(ConfAllocStats, 10 * time.Second, "Log: Resource stats interval") - bind(ConfAllocStats) - - pf.Bool(ConfVerbose, false, "Log: Print every listed dir") - bind(ConfVerbose) - - pf.Bool(ConfPrintHTTP, false, "Log: Print HTTP client errors") - bind(ConfPrintHTTP) - - pf.String(ConfLogFile, "crawler.log", "Log file") - bind(ConfLogFile) + }) } func readConfig() { @@ -157,15 +147,17 @@ func readConfig() { } } - config.ServerUrl = viper.GetString(ConfServerUrl) - if config.ServerUrl == "" { - configMissing(ConfServerUrl) - } - config.ServerUrl = strings.TrimRight(config.ServerUrl, "/") + if onlineMode { + config.ServerUrl = viper.GetString(ConfServerUrl) + if config.ServerUrl == "" { + configMissing(ConfServerUrl) + } + config.ServerUrl = strings.TrimRight(config.ServerUrl, "/") - config.Token = viper.GetString(ConfToken) - if config.Token == "" { - configMissing(ConfToken) + config.Token = viper.GetString(ConfToken) + if config.Token == "" { + configMissing(ConfToken) + } } config.ServerTimeout = viper.GetDuration(ConfServerTimeout) diff --git a/ds/redblackhash/redblack.go b/ds/redblackhash/redblack.go index 95084c2..6b5b1ed 100644 --- a/ds/redblackhash/redblack.go +++ b/ds/redblackhash/redblack.go @@ -15,7 +15,10 @@ package redblackhash import ( "bytes" + "encoding/binary" + "encoding/hex" "fmt" + "io" "sync" ) @@ -43,6 +46,13 @@ type Node struct { Parent *Node } +type nodeHeader struct { + Key *Key + Color color +} + +var o = binary.BigEndian + func (k *Key) Compare(o *Key) int { return bytes.Compare(k[:], o[:]) } @@ -233,7 +243,7 @@ func (tree *Tree) String() string { } func (node *Node) String() string { - return fmt.Sprintf("%v", node.Key) + return hex.EncodeToString(node.Key[:16]) + "..." } func output(node *Node, prefix string, isTail bool, str *string) { @@ -481,6 +491,119 @@ func (tree *Tree) deleteCase6(node *Node) { } } +func (tree *Tree) Marshal(w io.Writer) (err error) { + tree.Lock() + defer tree.Unlock() + + err = binary.Write(w, o, uint64(0x617979797979790A)) + if err != nil { return err } + + err = marshal(tree.Root, w) + if err != nil { return err } + + err = binary.Write(w, o, uint64(0x6C6D616F6F6F6F0A)) + if err != nil { return err } + + return nil +} + +func marshal(n *Node, w io.Writer) (err error) { + if n == nil { + err = binary.Write(w, o, uint64(0x796565656565740A)) + return err + } + + err = binary.Write(w, o, uint64(0xF09F85B1EFB88F0A)) + if err != nil { return err } + + _, err = w.Write(n.Key[:]) + if err != nil { return err } + + var colorI uint64 + if n.color { + colorI = 0x7468652D6579657C + } else { + colorI = 0x6865782B7465727C + } + + err = binary.Write(w, o, colorI) + if err != nil { return err } + + err = marshal(n.Left, w) + if err != nil { return err } + + err = marshal(n.Right, w) + if err != nil { return err } + + return nil +} + +func (tree *Tree) Unmarshal(r io.Reader) (err error) { + tree.Lock() + defer tree.Unlock() + + var sof uint64 + err = binary.Read(r, o, &sof) + if err != nil { return err } + if sof != 0x617979797979790A { + return fmt.Errorf("redblack: wrong format") + } + + tree.Root, tree.size, err = unmarshal(r) + if err != nil { return err } + + var eof uint64 + err = binary.Read(r, o, &eof) + if err != nil { return err } + if eof != 0x6C6D616F6F6F6F0A { + return fmt.Errorf("redblack: end of file missing") + } + + return nil +} + +func unmarshal(r io.Reader) (n *Node, size int, err error) { + var head uint64 + err = binary.Read(r, o, &head) + if err != nil { return nil, 0, err } + + size = 1 + + switch head { + case 0x796565656565740A: + return nil, 0, nil + case 0xF09F85B1EFB88F0A: + n = new(Node) + + _, err = io.ReadFull(r, n.Key[:]) + if err != nil { return nil, 0, err } + + var colorInt uint64 + err = binary.Read(r, o, &colorInt) + if err != nil { return nil, 0, err } + switch colorInt { + case 0x7468652D6579657C: + n.color = true + case 0x6865782B7465727C: + n.color = false + default: + return nil, 0, fmt.Errorf("redblack: corrupt node color") + } + default: + return nil, 0, fmt.Errorf("redblack: corrupt node info") + } + + var s2 int + n.Left, s2, err = unmarshal(r) + size += s2 + if err != nil { return nil, 0, err } + n.Right, s2, err = unmarshal(r) + size += s2 + if err != nil { return nil, 0, err } + + return n, size, nil +} + func nodeColor(node *Node) color { if node == nil { return black diff --git a/ds/redblackhash/redblack_test.go b/ds/redblackhash/redblack_test.go new file mode 100644 index 0000000..9944880 --- /dev/null +++ b/ds/redblackhash/redblack_test.go @@ -0,0 +1,47 @@ +package redblackhash + +import ( + "bytes" + "math/rand" + "testing" +) + +func TestTree_Marshal(t *testing.T) { + var t1, t2 Tree + + // Generate 1000 random values to insert + for i := 0; i < 1000; i++ { + var key Key + rand.Read(key[:]) + t1.Put(&key) + } + + // Marshal tree + var wr bytes.Buffer + err := t1.Marshal(&wr) + if err != nil { + t.Error(err) + t.FailNow() + } + + buf := wr.Bytes() + rd := bytes.NewBuffer(buf) + + // Unmarshal tree + err = t2.Unmarshal(rd) + if err != nil { + t.Error(err) + t.FailNow() + } + + if !compare(t1.Root, t2.Root) { + t.Error("trees are not equal") + t.FailNow() + } +} + +func compare(n1, n2 *Node) bool { + return n1.Key.Compare(&n2.Key) == 0 && + (n1.Left == nil || compare(n1.Left, n2.Left)) && + (n1.Right == nil || compare(n1.Right, n2.Right)) +} diff --git a/main.go b/main.go index e2d67ac..c13cef1 100644 --- a/main.go +++ b/main.go @@ -8,6 +8,7 @@ import ( "github.com/spf13/viper" "github.com/terorie/od-database-crawler/fasturl" "os" + "os/signal" "strings" "sync/atomic" "time" @@ -61,8 +62,6 @@ func preRun(cmd *cobra.Command, args []string) error { if err := os.MkdirAll("queue", 0755); err != nil { panic(err) } - readConfig() - return nil } @@ -75,25 +74,32 @@ func main() { } func cmdBase(_ *cobra.Command, _ []string) { - // TODO Graceful shutdown - appCtx := context.Background() - forceCtx := context.Background() + onlineMode = true + readConfig() + + appCtx, soft := context.WithCancel(context.Background()) + forceCtx, hard := context.WithCancel(context.Background()) + go hardShutdown(forceCtx) + go listenCtrlC(soft, hard) inRemotes := make(chan *OD) - go Schedule(forceCtx, inRemotes) + go LoadResumeTasks(inRemotes) + go Schedule(appCtx, inRemotes) ticker := time.NewTicker(config.Recheck) defer ticker.Stop() for { select { case <-appCtx.Done(): - return + goto shutdown case <-ticker.C: t, err := FetchTask() if err != nil { logrus.WithError(err). Error("Failed to get new task") - time.Sleep(viper.GetDuration(ConfCooldown)) + if !sleep(viper.GetDuration(ConfCooldown), appCtx) { + goto shutdown + } continue } if t == nil { @@ -126,9 +132,15 @@ func cmdBase(_ *cobra.Command, _ []string) { ScheduleTask(inRemotes, t, &baseUri) } } + + shutdown: + globalWait.Wait() } func cmdCrawler(_ *cobra.Command, args []string) error { + onlineMode = false + readConfig() + arg := args[0] // https://github.com/golang/go/issues/19779 if !strings.Contains(arg, "://") { @@ -161,3 +173,30 @@ func cmdCrawler(_ *cobra.Command, args []string) error { return nil } + +func listenCtrlC(soft, hard context.CancelFunc) { + c := make(chan os.Signal) + signal.Notify(c, os.Interrupt) + + <-c + logrus.Info(">>> Shutting down crawler... <<<") + soft() + + <-c + logrus.Warning(">>> Force shutdown! <<<") + hard() +} + +func hardShutdown(c context.Context) { + <-c.Done() + os.Exit(1) +} + +func sleep(d time.Duration, c context.Context) bool { + select { + case <-time.After(d): + return true + case <-c.Done(): + return false + } +} diff --git a/model.go b/model.go index 0b24e91..7a0a907 100644 --- a/model.go +++ b/model.go @@ -3,7 +3,6 @@ package main import ( "github.com/terorie/od-database-crawler/ds/redblackhash" "github.com/terorie/od-database-crawler/fasturl" - "sync" "time" ) @@ -30,12 +29,19 @@ type Job struct { } type OD struct { - Task Task - Result TaskResult - Wait sync.WaitGroup - BaseUri fasturl.URL - WCtx WorkerContext - Scanned redblackhash.Tree + Task Task + Result TaskResult + InProgress int64 + BaseUri fasturl.URL + WCtx WorkerContext + Scanned redblackhash.Tree +} + +type PausedOD struct { + Task *Task + Result *TaskResult + BaseUri *fasturl.URL + InProgress int64 } type File struct { diff --git a/resume.go b/resume.go new file mode 100644 index 0000000..e5da8bf --- /dev/null +++ b/resume.go @@ -0,0 +1,270 @@ +package main + +import ( + "bytes" + "encoding/binary" + "encoding/gob" + "fmt" + "github.com/beeker1121/goque" + "github.com/sirupsen/logrus" + "github.com/spf13/viper" + "io" + "os" + "path/filepath" + "strconv" + "sync/atomic" + "time" +) + +func init() { + gob.Register(&PausedOD{}) +} + +func LoadResumeTasks(inRemotes chan<- *OD) { + resumed, err := ResumeTasks() + if err != nil { + logrus.WithError(err). + Error("Failed to resume queued tasks. " + + "/queue is probably corrupt") + err = nil + } + + for _, remote := range resumed { + inRemotes <- remote + } +} + +func ResumeTasks() (tasks []*OD, err error) { + // Get files in /queue + var queueF *os.File + var entries []os.FileInfo + queueF, err = os.Open("queue") + if err != nil { return nil, err } + defer queueF.Close() + entries, err = queueF.Readdir(-1) + if err != nil { return nil, err } + + resumeDur := viper.GetDuration(ConfResume) + + for _, entry := range entries { + if !entry.IsDir() { continue } + + // Check if name is a number + var id uint64 + if id, err = strconv.ParseUint(entry.Name(), 10, 64); err != nil { + continue + } + + // Too old to be resumed + timeDelta := time.Since(entry.ModTime()) + if resumeDur >= 0 && timeDelta > resumeDur { + removeOldQueue(id) + continue + } + + // Load queue + var od *OD + if od, err = resumeQueue(id); err != nil { + logrus.WithError(err). + WithField("id", id). + Warning("Failed to load paused task") + continue + } else if od == nil { + removeOldQueue(id) + continue + } + + tasks = append(tasks, od) + } + + return tasks, nil +} + +func SaveTask(od *OD) (err error) { + dir := filepath.Join("queue", + strconv.FormatUint(od.Task.WebsiteId, 10)) + + fPath := filepath.Join(dir, "PAUSED") + + err = os.Mkdir(dir, 0777) + if err != nil { return err } + + // Open pause file + pausedF, err := os.OpenFile(fPath, os.O_CREATE | os.O_WRONLY | os.O_TRUNC, 0666) + if err != nil { return err } + defer pausedF.Close() + + err = writePauseFile(od, pausedF) + if err != nil { return err } + + return nil +} + +func resumeQueue(id uint64) (od *OD, err error) { + logrus.WithField("id", id). + Info("Found unfinished") + + fPath := filepath.Join("queue", strconv.FormatUint(id, 10)) + + // Try to find pause file + pausedF, err := os.Open(filepath.Join(fPath, "PAUSED")) + if os.IsNotExist(err) { + // No PAUSED file => not paused + // not paused => no error + return nil, nil + } else if err != nil { + return nil, err + } + defer pausedF.Close() + + od = new(OD) + od.WCtx.OD = od + + err = readPauseFile(od, pausedF) + if err != nil { return nil, err } + + // Open queue + bq, err := OpenQueue(fPath) + if err != nil { return nil, err } + + od.WCtx.Queue = bq + + logrus.WithField("id", id). + Info("Resuming task") + + return od, nil +} + +func removeOldQueue(id uint64) { + if id == 0 { + // TODO Make custom crawl less of an ugly hack + return + } + + logrus.WithField("id", id). + Warning("Deleting & returning old task") + + name := strconv.FormatUint(id, 10) + + fPath := filepath.Join("queue", name) + + // Acquire old queue + q, err := goque.OpenQueue(fPath) + if err != nil { + // Queue lock exists, don't delete + logrus.WithField("err", err). + WithField("path", fPath). + Error("Failed to acquire old task") + return + } + + // Delete old queue from disk + err = q.Drop() + if err != nil { + // Queue lock exists, don't delete + logrus.WithField("err", err). + WithField("path", fPath). + Error("Failed to delete old task") + return + } + + // Delete old crawl result from disk + _ = os.Remove(filepath.Join("crawled", name + ".json")) + + // Return task to server + if err := CancelTask(id); err != nil { + // Queue lock exists, don't delete + logrus.WithField("err", err). + WithField("id", id). + Warning("Failed to return unfinished task to server") + return + } +} + +func writePauseFile(od *OD, w io.Writer) (err error) { + // Write pause file version + _, err = w.Write([]byte("ODPAUSE-")) + if err != nil { return err } + + // Create save state + paused := PausedOD { + Task: &od.Task, + Result: &od.Result, + BaseUri: &od.BaseUri, + InProgress: atomic.LoadInt64(&od.InProgress), + } + + // Prepare pause settings + var b bytes.Buffer + pauseEnc := gob.NewEncoder(&b) + err = pauseEnc.Encode(&paused) + if err != nil { return err } + + // Write length of pause settings + err = binary.Write(w, binary.LittleEndian, uint64(b.Len())) + if err != nil { return err } + + // Write pause settings + _, err = w.Write(b.Bytes()) + if err != nil { return err } + + // Write pause scan state + err = od.Scanned.Marshal(w) + if err != nil { return err } + + // Save mark + _, err = w.Write([]byte("--------")) + if err != nil { return err } + + return nil +} + +func readPauseFile(od *OD, r io.Reader) (err error) { + // Make the paused struct point to OD fields + // So gob loads values into the OD struct + paused := PausedOD { + Task: &od.Task, + Result: &od.Result, + BaseUri: &od.BaseUri, + } + + var version [8]byte + _, err = io.ReadFull(r, version[:]) + if err != nil { return err } + if !bytes.Equal(version[:], []byte("ODPAUSE-")) { + return fmt.Errorf("unsupported pause file") + } + + // Read pause settings len + var pauseSettingsLen uint64 + err = binary.Read(r, binary.LittleEndian, &pauseSettingsLen) + + // Read pause settings + pauseDec := gob.NewDecoder(io.LimitReader(r, int64(pauseSettingsLen))) + err = pauseDec.Decode(&paused) + if err != nil { return err } + atomic.StoreInt64(&od.InProgress, paused.InProgress) + + err = readPauseStateTree(od, r) + if err != nil { + return fmt.Errorf("failed to read state tree: %s", err) + } + + return nil +} + +func readPauseStateTree(od *OD, r io.Reader) (err error) { + // Read pause scan state + err = od.Scanned.Unmarshal(r) + if err != nil { return err } + + // Check mark + var mark [8]byte + _, err = io.ReadFull(r, mark[:]) + if err != nil { return err } + if !bytes.Equal(mark[:], []byte("--------")) { + return fmt.Errorf("corrupt pause file") + } + + return nil +} diff --git a/resume_test.go b/resume_test.go new file mode 100644 index 0000000..ce737bb --- /dev/null +++ b/resume_test.go @@ -0,0 +1,48 @@ +package main + +import ( + "bytes" + "github.com/terorie/od-database-crawler/fasturl" + "testing" + "time" +) + +func TestResumeTasks_Empty(t *testing.T) { + start := time.Now().Add(-1 * time.Minute) + od := OD { + Task: Task { + WebsiteId: 213, + Url: "https://the-eye.eu/public/", + }, + Result: TaskResult { + StartTime: start, + StartTimeUnix: start.Unix(), + EndTimeUnix: time.Now().Unix(), + WebsiteId: 213, + }, + InProgress: 0, + BaseUri: fasturl.URL { + Scheme: fasturl.SchemeHTTPS, + Host: "the-eye.eu", + Path: "/public/", + }, + } + od.WCtx.OD = &od + + var b bytes.Buffer + var err error + err = writePauseFile(&od, &b) + if err != nil { + t.Fatal(err) + } + + buf := b.Bytes() + + var od2 OD + + b2 := bytes.NewBuffer(buf) + err = readPauseFile(&od2, b2) + if err != nil { + t.Fatal(err) + } +} diff --git a/scheduler.go b/scheduler.go index 9abe491..6a00420 100644 --- a/scheduler.go +++ b/scheduler.go @@ -22,54 +22,57 @@ func Schedule(c context.Context, remotes <-chan *OD) { go Stats(c) for remote := range remotes { - logrus.WithField("url", remote.BaseUri.String()). - Info("Starting crawler") - - // Collect results - results := make(chan File) - - remote.WCtx.OD = remote - - // Get queue path - queuePath := path.Join("queue", fmt.Sprintf("%d", remote.Task.WebsiteId)) - - // Delete existing queue - if err := os.RemoveAll(queuePath); - err != nil { panic(err) } - - // Start new queue - var err error - remote.WCtx.Queue, err = OpenQueue(queuePath) - if err != nil { panic(err) } - - // Spawn workers - for i := 0; i < config.Workers; i++ { - go remote.WCtx.Worker(results) - } - - // Enqueue initial job - atomic.AddInt32(&numActiveTasks, 1) - remote.WCtx.queueJob(Job{ - Uri: remote.BaseUri, - UriStr: remote.BaseUri.String(), - Fails: 0, - }) - - // Upload result when ready - go remote.Watch(results) - - // Sleep if max number of tasks are active - for atomic.LoadInt32(&numActiveTasks) > config.Tasks { - select { - case <-c.Done(): - return - case <-time.After(time.Second): - continue - } + if !scheduleNewTask(c, remote) { + return } } } +func scheduleNewTask(c context.Context, remote *OD) bool { + logrus.WithField("url", remote.BaseUri.String()). + Info("Starting crawler") + + // Collect results + results := make(chan File) + + remote.WCtx.OD = remote + + // Get queue path + queuePath := path.Join("queue", fmt.Sprintf("%d", remote.Task.WebsiteId)) + + // Delete existing queue + if err := os.RemoveAll(queuePath); + err != nil { panic(err) } + + // Start new queue + var err error + remote.WCtx.Queue, err = OpenQueue(queuePath) + if err != nil { panic(err) } + + // Spawn workers + remote.WCtx.SpawnWorkers(c, results, config.Workers) + + // Enqueue initial job + atomic.AddInt32(&numActiveTasks, 1) + remote.WCtx.queueJob(Job{ + Uri: remote.BaseUri, + UriStr: remote.BaseUri.String(), + Fails: 0, + }) + + // Upload result when ready + go remote.Watch(results) + + // Sleep if max number of tasks are active + for atomic.LoadInt32(&numActiveTasks) > config.Tasks { + if !sleep(time.Second, c) { + break + } + } + + return true +} + func ScheduleTask(remotes chan<- *OD, t *Task, u *fasturl.URL) { if !t.register() { return @@ -117,7 +120,7 @@ func (o *OD) Watch(results chan File) { // Open crawl results file f, err := os.OpenFile( filePath, - os.O_CREATE | os.O_RDWR | os.O_TRUNC, + os.O_CREATE | os.O_RDWR | os.O_APPEND, 0644, ) if err != nil { @@ -159,16 +162,33 @@ func (o *OD) handleCollect(results chan File, f *os.File, collectErrC chan error defer close(results) // Wait for all jobs on remote to finish - o.Wait.Wait() + for { + // Natural finish + if atomic.LoadInt64(&o.InProgress) == 0 { + o.onTaskFinished() + return + } + // Abort + if atomic.LoadInt32(&o.WCtx.aborted) != 0 { + // Wait for all workers to finish + o.WCtx.workers.Wait() + o.onTaskPaused() + return + } + + time.Sleep(500 * time.Millisecond) + } +} + +func (o *OD) onTaskFinished() { + defer atomic.AddInt32(&numActiveTasks, -1) // Close queue if err := o.WCtx.Queue.Close(); err != nil { panic(err) } - atomic.AddInt32(&numActiveTasks, -1) // Log finish - logrus.WithFields(logrus.Fields{ "id": o.Task.WebsiteId, "url": o.BaseUri.String(), @@ -191,6 +211,37 @@ func (o *OD) handleCollect(results chan File, f *os.File, collectErrC chan error } } +func (o *OD) onTaskPaused() { + defer atomic.AddInt32(&numActiveTasks, -1) + + // Close queue + if err := o.WCtx.Queue.Close(); err != nil { + panic(err) + } + + // Set current end time + o.Result.EndTimeUnix = time.Now().Unix() + + // Save task metadata + err := SaveTask(o) + if err != nil { + // Log finish + logrus.WithFields(logrus.Fields{ + "err": err.Error(), + "id": o.Task.WebsiteId, + "url": o.BaseUri.String(), + }).Error("Failed to save crawler state") + return + } + + // Log finish + logrus.WithFields(logrus.Fields{ + "id": o.Task.WebsiteId, + "url": o.BaseUri.String(), + "duration": time.Since(o.Result.StartTime), + }).Info("Crawler paused") +} + func (t *Task) Collect(results chan File, f *os.File, errC chan<- error) { err := t.collect(results, f) if err != nil { diff --git a/worker.go b/worker.go index 118e28f..5c07a15 100644 --- a/worker.go +++ b/worker.go @@ -1,6 +1,7 @@ package main import ( + "context" "github.com/beeker1121/goque" "github.com/sirupsen/logrus" "math" @@ -18,10 +19,29 @@ type WorkerContext struct { Queue *BufferedQueue lastRateLimit time.Time numRateLimits int + workers sync.WaitGroup + aborted int32 } -func (w *WorkerContext) Worker(results chan<- File) { +func (w *WorkerContext) SpawnWorkers(c context.Context, results chan<- File, n int) { + w.workers.Add(n) + for i := 0; i < n; i++ { + go w.Worker(c, results) + } +} + +func (w *WorkerContext) Worker(c context.Context, results chan<- File) { + defer w.workers.Done() + for { + select { + case <-c.Done(): + // Not yet done + atomic.StoreInt32(&w.aborted, 1) + return + default: + } + job, err := w.Queue.Dequeue() switch err { case goque.ErrEmpty: @@ -156,7 +176,7 @@ func (w *WorkerContext) DoJob(job *Job, f *File) (newJobs []Job, err error) { } func (w *WorkerContext) queueJob(job Job) { - w.OD.Wait.Add(1) + atomic.AddInt64(&w.OD.InProgress, 1) if w.numRateLimits > 0 { if time.Since(w.lastRateLimit) > 5 * time.Second { @@ -173,7 +193,7 @@ func (w *WorkerContext) queueJob(job Job) { } func (w *WorkerContext) finishJob() { - w.OD.Wait.Done() + atomic.AddInt64(&w.OD.InProgress, -1) } func isErrSilent(err error) bool {