From e919323169fe2f5105bb7edb70347d7ffb551e36 Mon Sep 17 00:00:00 2001 From: Richard Patel Date: Sun, 3 Feb 2019 16:24:18 +0100 Subject: [PATCH] Resume tests --- resume.go | 154 ++++++++++++++++++++++++++++--------------------- resume_test.go | 48 +++++++++++++++ 2 files changed, 137 insertions(+), 65 deletions(-) create mode 100644 resume_test.go diff --git a/resume.go b/resume.go index 3926896..2fdf6e6 100644 --- a/resume.go +++ b/resume.go @@ -89,33 +89,7 @@ func SaveTask(od *OD) (err error) { if err != nil { return err } defer pausedF.Close() - // Write pause file version - _, err = pausedF.Write([]byte("ODPAUSE-")) - if err != nil { return err } - - // Create save state - paused := PausedOD { - Task: &od.Task, - Result: &od.Result, - BaseUri: &od.BaseUri, - InProgress: atomic.LoadInt64(&od.InProgress), - } - - // Write pause settings - pauseEnc := gob.NewEncoder(pausedF) - err = pauseEnc.Encode(&paused) - if err != nil { return err } - - // Save mark - _, err = pausedF.Write([]byte("--------")) - if err != nil { return err } - - // Write pause scan state - err = od.Scanned.Marshal(pausedF) - if err != nil { return err } - - // Save mark - _, err = pausedF.Write([]byte("--------")) + err = writePauseFile(od, pausedF) if err != nil { return err } return nil @@ -141,45 +115,8 @@ func resumeQueue(id uint64) (od *OD, err error) { od = new(OD) od.WCtx.OD = od - // Make the paused struct point to OD fields - // So gob loads values into the OD struct - paused := PausedOD { - Task: &od.Task, - Result: &od.Result, - BaseUri: &od.BaseUri, - } - - var version [8]byte - _, err = io.ReadFull(pausedF, version[:]) + err = readPauseFile(od, pausedF) if err != nil { return nil, err } - if !bytes.Equal(version[:], []byte("ODPAUSE-")) { - return nil, fmt.Errorf("unsupported pause file") - } - - // Read pause settings - pauseDec := gob.NewDecoder(pausedF) - err = pauseDec.Decode(&paused) - if err != nil { return nil, err } - atomic.StoreInt64(&od.InProgress, paused.InProgress) - - // Check mark - var mark [8]byte - _, err = io.ReadFull(pausedF, mark[:]) - if err != nil { return nil, err } - if !bytes.Equal(mark[:], []byte("--------")) { - return nil, fmt.Errorf("corrupt pause file") - } - - // Read pause scan state - err = od.Scanned.Unmarshal(pausedF) - if err != nil { return nil, err } - - // Check mark - _, err = io.ReadFull(pausedF, mark[:]) - if err != nil { return nil, err } - if !bytes.Equal(mark[:], []byte("--------")) { - return nil, fmt.Errorf("corrupt pause file") - } // Open queue bq, err := OpenQueue(fPath) @@ -238,3 +175,90 @@ func removeOldQueue(id uint64) { return } } + +func writePauseFile(od *OD, w io.Writer) (err error) { + // Write pause file version + _, err = w.Write([]byte("ODPAUSE-")) + if err != nil { return err } + + // Create save state + paused := PausedOD { + Task: &od.Task, + Result: &od.Result, + BaseUri: &od.BaseUri, + InProgress: atomic.LoadInt64(&od.InProgress), + } + + // Write pause settings + pauseEnc := gob.NewEncoder(w) + err = pauseEnc.Encode(&paused) + if err != nil { return err } + + // Save mark + _, err = w.Write([]byte("--------")) + if err != nil { return err } + + // Write pause scan state + err = od.Scanned.Marshal(w) + if err != nil { return err } + + // Save mark + _, err = w.Write([]byte("--------")) + if err != nil { return err } + + return nil +} + +func readPauseFile(od *OD, r io.Reader) (err error) { + // Make the paused struct point to OD fields + // So gob loads values into the OD struct + paused := PausedOD { + Task: &od.Task, + Result: &od.Result, + BaseUri: &od.BaseUri, + } + + var version [8]byte + _, err = io.ReadFull(r, version[:]) + if err != nil { return err } + if !bytes.Equal(version[:], []byte("ODPAUSE-")) { + return fmt.Errorf("unsupported pause file") + } + + // Read pause settings + pauseDec := gob.NewDecoder(r) + err = pauseDec.Decode(&paused) + if err != nil { return err } + atomic.StoreInt64(&od.InProgress, paused.InProgress) + + // Check mark + var mark [8]byte + _, err = io.ReadFull(r, mark[:]) + if err != nil { return err } + if !bytes.Equal(mark[:], []byte("--------")) { + return fmt.Errorf("corrupt pause file") + } + + err = readPauseStateTree(od, r) + if err != nil { + return fmt.Errorf("failed to read state tree: %s", err) + } + + return nil +} + +func readPauseStateTree(od *OD, r io.Reader) (err error) { + // Read pause scan state + err = od.Scanned.Unmarshal(r) + if err != nil { return err } + + // Check mark + var mark [8]byte + _, err = io.ReadFull(r, mark[:]) + if err != nil { return err } + if !bytes.Equal(mark[:], []byte("--------")) { + return fmt.Errorf("corrupt pause file") + } + + return nil +} diff --git a/resume_test.go b/resume_test.go new file mode 100644 index 0000000..ce737bb --- /dev/null +++ b/resume_test.go @@ -0,0 +1,48 @@ +package main + +import ( + "bytes" + "github.com/terorie/od-database-crawler/fasturl" + "testing" + "time" +) + +func TestResumeTasks_Empty(t *testing.T) { + start := time.Now().Add(-1 * time.Minute) + od := OD { + Task: Task { + WebsiteId: 213, + Url: "https://the-eye.eu/public/", + }, + Result: TaskResult { + StartTime: start, + StartTimeUnix: start.Unix(), + EndTimeUnix: time.Now().Unix(), + WebsiteId: 213, + }, + InProgress: 0, + BaseUri: fasturl.URL { + Scheme: fasturl.SchemeHTTPS, + Host: "the-eye.eu", + Path: "/public/", + }, + } + od.WCtx.OD = &od + + var b bytes.Buffer + var err error + err = writePauseFile(&od, &b) + if err != nil { + t.Fatal(err) + } + + buf := b.Bytes() + + var od2 OD + + b2 := bytes.NewBuffer(buf) + err = readPauseFile(&od2, b2) + if err != nil { + t.Fatal(err) + } +}