diff --git a/README.md b/README.md index 4fa7afe..6640501 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,7 @@ and error handling. Built for automated web scraping. * Strictly obeys configured rate-limiting for each IP & Host * Seamless exponential backoff retries on timeout or error HTTP codes * Requires no additional configuration for integration into existing programs +* Configurable per-host behavior ### Typical use case ![user_case](use_case.png) @@ -58,6 +59,51 @@ level=trace msg=Sleeping wait=433.394361ms ./reload.sh ``` +### Rules + + +Conditions + +| Left operand | Description | Allowed operators | Right operand +| :--- | :--- | :--- | :--- +| body | Contents of the response | `=`, `!=` | String w/ wildcard +| body | Contents of the response | `<`, `>` | float +| status | HTTP response code | `=`, `!=` | String w/ wildcard +| status | HTTP response code | `<`, `>` | float +| response_time | HTTP response code | `<`, `>` | duration (e.g. `20s`) +| header:`
` | Response header | `=`, `!=` | String w/ wildcard +| header:`
` | Response header | `<`, `>` | float + +Note that `response_time` can never be higher than the configured `timeout` value. + +Examples: + +```json +[ + {"condition": "header:X-Test>10", "action": "..."}, + {"condition": "body=*Try again in a few minutes*", "action": "..."}, + {"condition": "response_time>10s", "action": "..."}, + {"condition": "status>500", "action": "..."}, + {"condition": "status=404", "action": "..."}, + {"condition": "status=40*", "action": "..."} +] +``` + +Actions + +| Action | Description | `arg` value | +| :--- | :--- | :--- | +| should_retry | Override default retry behavior for http errors (by default it retries on 403,408,429,444,499,>500) +| force_retry | Always retry (Up to retries_hard times) +| dont_retry | Immediately stop retrying +| multiply_every | Multiply the current limiter's 'every' value by `arg` | `1.5`(float) +| set_every | Set the current limiter's 'every' value to `arg` | `10s`(duration) + +In the event of a temporary network error, `should_retry` is ignored (it will always retry unless `dont_retry` is set) + +Note that having too many rules for one host might negatively impact performance (especially the `body` condition for large requests) + + ### Sample configuration ```json @@ -67,6 +113,7 @@ level=trace msg=Sleeping wait=433.394361ms "wait": "4s", "multiplier": 2.5, "retries": 3, + "retries_hard": 6, "proxies": [ { "name": "squid_P0", @@ -83,7 +130,7 @@ level=trace msg=Sleeping wait=433.394361ms "every": "500ms", "burst": 25, "headers": { - "User-Agent": "Some user agent", + "User-Agent": "Some user agent for all requests", "X-Test": "Will be overwritten" } }, @@ -94,6 +141,22 @@ level=trace msg=Sleeping wait=433.394361ms "headers": { "X-Test": "Will overwrite default" } + }, + { + "host": ".s3.amazonaws.com", + "every": "2s", + "burst": 30, + "rules": [ + {"condition": "status=403", "action": "dont_retry"} + ] + }, + { + "host": ".www.instagram.com", + "every": "4500ms", + "burst": 3, + "rules": [ + {"condition": "body=*please try again in a few minutes*", "action": "multiply_every", "arg": "2"} + ] } ] } diff --git a/config.go b/config.go index d0e518d..7938e7d 100644 --- a/config.go +++ b/config.go @@ -1,12 +1,18 @@ package main import ( + "bytes" "encoding/json" "fmt" + "github.com/pkg/errors" + "github.com/ryanuber/go-glob" "github.com/sirupsen/logrus" "golang.org/x/time/rate" "io/ioutil" "os" + "reflect" + "runtime" + "strconv" "strings" "time" ) @@ -16,7 +22,47 @@ type HostConfig struct { EveryStr string `json:"every"` Burst int `json:"burst"` Headers map[string]string `json:"headers"` + RawRules []*RawHostRule `json:"rules"` Every time.Duration + Rules []*HostRule +} + +type RawHostRule struct { + Condition string `json:"condition"` + Action string `json:"action"` + Arg string `json:"arg"` +} + +type HostRuleAction int + +const ( + DontRetry HostRuleAction = 0 + MultiplyEvery HostRuleAction = 1 + SetEvery HostRuleAction = 2 + ForceRetry HostRuleAction = 3 + ShouldRetry HostRuleAction = 4 +) + +func (a HostRuleAction) String() string { + switch a { + case DontRetry: + return "dont_retry" + case MultiplyEvery: + return "multiply_every" + case SetEvery: + return "set_every" + case ForceRetry: + return "force_retry" + case ShouldRetry: + return "should_retry" + } + return "???" +} + +type HostRule struct { + Matches func(r *RequestCtx) bool + Action HostRuleAction + Arg float64 } type ProxyConfig struct { @@ -30,6 +76,7 @@ var config struct { WaitStr string `json:"wait"` Multiplier float64 `json:"multiplier"` Retries int `json:"retries"` + RetriesHard int `json:"retries_hard"` Hosts []*HostConfig `json:"hosts"` Proxies []ProxyConfig `json:"proxies"` Wait int64 @@ -37,16 +84,202 @@ var config struct { DefaultConfig *HostConfig } -func loadConfig() { +func parseRule(raw *RawHostRule) (*HostRule, error) { + //TODO: for the love of god someone please refactor this func + + rule := &HostRule{} + var err error + + switch raw.Action { + case "should_retry": + rule.Action = ShouldRetry + case "dont_retry": + rule.Action = DontRetry + case "multiply_every": + rule.Action = MultiplyEvery + rule.Arg, err = strconv.ParseFloat(raw.Arg, 64) + case "set_every": + rule.Action = SetEvery + var duration time.Duration + duration, err = time.ParseDuration(raw.Arg) + if err != nil { + return nil, err + } + rule.Arg = 1 / duration.Seconds() + case "force_retry": + rule.Action = ForceRetry + default: + return nil, errors.Errorf("Invalid argument for action: %s", raw.Action) + } + + if err != nil { + return nil, err + } + + switch { + case strings.Contains(raw.Condition, "!="): + op1Str, op2Str := split(raw.Condition, "!=") + op1Func := parseOperand1(op1Str) + if op1Func == nil { + return nil, errors.Errorf("Invalid rule: %s", raw.Condition) + } + + if isGlob(op2Str) { + rule.Matches = func(ctx *RequestCtx) bool { + return !glob.Glob(op2Str, op1Func(ctx)) + } + } else { + op2Str = strings.Replace(op2Str, "\\*", "*", -1) + rule.Matches = func(ctx *RequestCtx) bool { + return op1Func(ctx) != op2Str + } + } + case strings.Contains(raw.Condition, "="): + op1Str, op2Str := split(raw.Condition, "=") + op1Func := parseOperand1(op1Str) + if op1Func == nil { + return nil, errors.Errorf("Invalid rule: %s", raw.Condition) + } + + if isGlob(op2Str) { + rule.Matches = func(ctx *RequestCtx) bool { + return glob.Glob(op2Str, op1Func(ctx)) + } + } else { + op2Str = strings.Replace(op2Str, "\\*", "*", -1) + rule.Matches = func(ctx *RequestCtx) bool { + return op1Func(ctx) == op2Str + } + } + case strings.Contains(raw.Condition, ">"): + op1Str, op2Str := split(raw.Condition, ">") + op1Func := parseOperand1(op1Str) + if op1Func == nil { + return nil, errors.Errorf("Invalid rule: %s", raw.Condition) + } + op2Num, err := parseOperand2(op1Str, op2Str) + if err != nil { + return nil, err + } + + rule.Matches = func(ctx *RequestCtx) bool { + op1Num, err := strconv.ParseFloat(op1Func(ctx), 64) + handleRuleErr(err) + return op1Num > op2Num + } + case strings.Contains(raw.Condition, "<"): + op1Str, op2Str := split(raw.Condition, "<") + op1Func := parseOperand1(op1Str) + if op1Func == nil { + return nil, errors.Errorf("Invalid rule: %s", raw.Condition) + } + op2Num, err := parseOperand2(op1Str, op2Str) + if err != nil { + return nil, err + } + + rule.Matches = func(ctx *RequestCtx) bool { + op1Num, err := strconv.ParseFloat(op1Func(ctx), 64) + handleRuleErr(err) + return op1Num < op2Num + } + } + + return rule, nil +} + +func handleRuleErr(err error) { + if err != nil { + logrus.WithError(err).Warn("Error computing rule") + } +} + +func split(str, subStr string) (string, string) { + + str1 := str[:strings.Index(str, subStr)] + str2 := str[strings.Index(str, subStr)+len(subStr):] + + return str1, str2 +} + +func parseOperand2(op1, op2 string) (float64, error) { + if op1 == "response_time" { + res, err := time.ParseDuration(op2) + if err != nil { + return -1, err + } + return res.Seconds(), nil + } + + return strconv.ParseFloat(op2, 64) +} + +func parseOperand1(op string) func(ctx *RequestCtx) string { + switch { + case op == "body": + return func(ctx *RequestCtx) string { + + if ctx.Response == nil { + return "" + } + bodyBytes, err := ioutil.ReadAll(ctx.Response.Body) + if err != nil { + return "" + } + err = ctx.Response.Body.Close() + if err != nil { + return "" + } + ctx.Response.Body = ioutil.NopCloser(bytes.NewReader(bodyBytes)) + + return string(bodyBytes) + } + case op == "status": + return func(ctx *RequestCtx) string { + if ctx.Response == nil { + return "" + } + return strconv.Itoa(ctx.Response.StatusCode) + } + case op == "response_time": + return func(ctx *RequestCtx) string { + return strconv.FormatFloat(time.Now().Sub(ctx.RequestTime).Seconds(), 'f', 6, 64) + } + case strings.HasPrefix(op, "header:"): + header := op[strings.Index(op, ":")+1:] + return func(ctx *RequestCtx) string { + if ctx.Response == nil { + return "" + } + return ctx.Response.Header.Get(header) + } + default: + return nil + } +} + +func isGlob(op string) bool { + tmpStr := strings.Replace(op, "\\*", "_", -1) + + return strings.Contains(tmpStr, "*") +} + +func loadConfig() error { configFile, err := os.Open("config.json") - handleErr(err) + if err != nil { + return err + } configBytes, err := ioutil.ReadAll(configFile) - handleErr(err) + if err != nil { + return err + } err = json.Unmarshal(configBytes, &config) - handleErr(err) + if err != nil { + return err + } validateConfig() @@ -54,18 +287,52 @@ func loadConfig() { wait, err := time.ParseDuration(config.WaitStr) config.Wait = int64(wait) - for _, conf := range config.Hosts { + for i, conf := range config.Hosts { if conf.EveryStr == "" { - conf.Every = config.DefaultConfig.Every + // Look 'upwards' for every + for _, prevConf := range config.Hosts[:i] { + if glob.Glob(prevConf.Host, conf.Host) { + conf.Every = prevConf.Every + } + } } else { conf.Every, err = time.ParseDuration(conf.EveryStr) handleErr(err) } - if config.DefaultConfig != nil && conf.Burst == 0 { - conf.Burst = config.DefaultConfig.Burst + if conf.Burst == 0 { + // Look 'upwards' for burst + for _, prevConf := range config.Hosts[:i] { + if glob.Glob(prevConf.Host, conf.Host) { + conf.Burst = prevConf.Burst + } + } } + if conf.Burst == 0 { + return errors.Errorf("Burst must be > 0 (Host: %s)", conf.Host) + } + + for _, rawRule := range conf.RawRules { + r, err := parseRule(rawRule) + handleErr(err) + conf.Rules = append(conf.Rules, r) + + logrus.WithFields(logrus.Fields{ + "arg": r.Arg, + "action": r.Action, + "matchFunc": runtime.FuncForPC(reflect.ValueOf(r.Matches).Pointer()).Name(), + }).Info("Rule") + } + + logrus.WithFields(logrus.Fields{ + "every": conf.Every, + "burst": conf.Burst, + "headers": conf.Headers, + "host": conf.Host, + }).Info("Host") } + + return nil } func validateConfig() { @@ -91,18 +358,28 @@ func validateConfig() { func applyConfig(proxy *Proxy) { - for _, conf := range config.Hosts { - proxy.Limiters[conf.Host] = &ExpiringLimiter{ - rate.NewLimiter(rate.Every(conf.Every), conf.Burst), - time.Now(), - } + //Reverse order + for i := len(config.Hosts) - 1; i >= 0; i-- { + + conf := config.Hosts[i] + + proxy.Limiters = append(proxy.Limiters, &ExpiringLimiter{ + HostGlob: conf.Host, + IsGlob: isGlob(conf.Host), + Limiter: rate.NewLimiter(rate.Every(conf.Every), conf.Burst), + LastRead: time.Now(), + CanDelete: false, + }) } } func (b *Balancer) reloadConfig() { b.proxyMutex.Lock() - loadConfig() + err := loadConfig() + if err != nil { + panic(err) + } if b.proxies != nil { b.proxies = b.proxies[:0] diff --git a/config.json b/config.json index c671284..947b6b7 100644 --- a/config.json +++ b/config.json @@ -4,6 +4,7 @@ "wait": "4s", "multiplier": 2.5, "retries": 3, + "retries_hard": 6, "proxies": [ { "name": "p0", @@ -20,7 +21,10 @@ "Cache-Control": "max-age=0", "Connection": "keep-alive", "User-Agent": "Mozilla/5.0 (X11; Linux x86_64; rv:67.0) Gecko/20100101 Firefox/67.0" - } + }, + "rules": [ + {"condition": "response_time>10s", "action": "dont_retry"} + ] }, { "host": "*.reddit.com", @@ -36,13 +40,9 @@ "host": ".pbs.twimg.com", "every": "125ms" }, - { - "host": "*.cdninstagram", - "every": "250ms" - }, { "host": ".www.instagram.com", - "every": "30s", + "every": "4500ms", "burst": 3 }, { @@ -53,7 +53,10 @@ { "host": ".s3.amazonaws.com", "every": "10s", - "burst": 3 + "burst": 1, + "rules": [ + {"condition": "status=403", "action": "dont_retry"} + ] } ] } \ No newline at end of file diff --git a/gc.go b/gc.go index f66d082..041a876 100644 --- a/gc.go +++ b/gc.go @@ -42,32 +42,20 @@ func cleanExpiredLimits(proxy *Proxy) { const ttl = time.Hour - limits := make(map[string]*ExpiringLimiter, 0) + var limits []*ExpiringLimiter now := time.Now() for host, limiter := range proxy.Limiters { - if now.Sub(limiter.LastRead) > ttl && shouldPruneLimiter(host) { + if now.Sub(limiter.LastRead) > ttl && limiter.CanDelete { logrus.WithFields(logrus.Fields{ "proxy": proxy.Name, "limiter": host, "last_read": now.Sub(limiter.LastRead), }).Trace("Pruning limiter") } else { - limits[host] = limiter + limits = append(limits, limiter) } } proxy.Limiters = limits } - -func shouldPruneLimiter(host string) bool { - - // Don't remove hosts that are coming from the config - for _, conf := range config.Hosts { - if conf.Host == host { - return false - } - } - - return true -} diff --git a/main.go b/main.go index d0d4e8c..af0ad14 100644 --- a/main.go +++ b/main.go @@ -23,18 +23,26 @@ type Balancer struct { } type ExpiringLimiter struct { - Limiter *rate.Limiter - LastRead time.Time + HostGlob string + IsGlob bool + CanDelete bool + Limiter *rate.Limiter + LastRead time.Time } type Proxy struct { Name string Url *url.URL - Limiters map[string]*ExpiringLimiter + Limiters []*ExpiringLimiter HttpClient *http.Client Connections int } +type RequestCtx struct { + RequestTime time.Time + Response *http.Response +} + type ByConnectionCount []*Proxy func (a ByConnectionCount) Len() int { @@ -51,8 +59,13 @@ func (a ByConnectionCount) Less(i, j int) bool { func (p *Proxy) getLimiter(host string) *rate.Limiter { - for hostGlob, limiter := range p.Limiters { - if glob.Glob(hostGlob, host) { + for _, limiter := range p.Limiters { + if limiter.IsGlob { + if glob.Glob(limiter.HostGlob, host) { + limiter.LastRead = time.Now() + return limiter.Limiter + } + } else if limiter.HostGlob == host { limiter.LastRead = time.Now() return limiter.Limiter } @@ -65,14 +78,18 @@ func (p *Proxy) getLimiter(host string) *rate.Limiter { func (p *Proxy) makeNewLimiter(host string) *ExpiringLimiter { newExpiringLimiter := &ExpiringLimiter{ - LastRead: time.Now(), - Limiter: rate.NewLimiter(rate.Every(config.DefaultConfig.Every), config.DefaultConfig.Burst), + CanDelete: false, + HostGlob: host, + IsGlob: false, + LastRead: time.Now(), + Limiter: rate.NewLimiter(rate.Every(config.DefaultConfig.Every), config.DefaultConfig.Burst), } - p.Limiters[host] = newExpiringLimiter + p.Limiters = append([]*ExpiringLimiter{newExpiringLimiter}, p.Limiters...) logrus.WithFields(logrus.Fields{ - "host": host, + "host": host, + "every": config.DefaultConfig.Every, }).Trace("New limiter") return newExpiringLimiter @@ -96,7 +113,18 @@ func (b *Balancer) chooseProxy() *Proxy { sort.Sort(ByConnectionCount(b.proxies)) - p0 := b.proxies[0] + proxyWithLeastConns := b.proxies[0] + proxiesWithSameConnCount := b.getProxiesWithSameConnCountAs(proxyWithLeastConns) + + if len(proxiesWithSameConnCount) > 1 { + return proxiesWithSameConnCount[rand.Intn(len(proxiesWithSameConnCount))] + } else { + return proxyWithLeastConns + } +} + +func (b *Balancer) getProxiesWithSameConnCountAs(p0 *Proxy) []*Proxy { + proxiesWithSameConnCount := make([]*Proxy, 0) for _, p := range b.proxies { if p.Connections != p0.Connections { @@ -104,12 +132,7 @@ func (b *Balancer) chooseProxy() *Proxy { } proxiesWithSameConnCount = append(proxiesWithSameConnCount, p) } - - if len(proxiesWithSameConnCount) > 1 { - return proxiesWithSameConnCount[rand.Intn(len(proxiesWithSameConnCount))] - } else { - return p0 - } + return proxiesWithSameConnCount } func New() *Balancer { @@ -157,21 +180,61 @@ func New() *Balancer { return balancer } -func applyHeaders(r *http.Request) *http.Request { +func getConfsMatchingRequest(r *http.Request) []*HostConfig { sHost := simplifyHost(r.Host) + configs := make([]*HostConfig, 0) + for _, conf := range config.Hosts { if glob.Glob(conf.Host, sHost) { - for k, v := range conf.Headers { - r.Header.Set(k, v) - } + configs = append(configs, conf) + } + } + + return configs +} + +func applyHeaders(r *http.Request, configs []*HostConfig) *http.Request { + + for _, conf := range configs { + for k, v := range conf.Headers { + r.Header.Set(k, v) } } return r } +func computeRules(ctx *RequestCtx, configs []*HostConfig) (dontRetry, forceRetry bool, + limitMultiplier, newLimit float64, shouldRetry bool) { + dontRetry = false + forceRetry = false + shouldRetry = false + limitMultiplier = 1 + + for _, conf := range configs { + for _, rule := range conf.Rules { + if rule.Matches(ctx) { + switch rule.Action { + case DontRetry: + dontRetry = true + case MultiplyEvery: + limitMultiplier = rule.Arg + case SetEvery: + newLimit = rule.Arg + case ForceRetry: + forceRetry = true + case ShouldRetry: + shouldRetry = true + } + } + } + } + + return +} + func (p *Proxy) processRequest(r *http.Request) (*http.Response, error) { p.Connections += 1 @@ -179,25 +242,41 @@ func (p *Proxy) processRequest(r *http.Request) (*http.Response, error) { p.Connections -= 1 }() retries := 0 + additionalRetries := 0 - p.waitRateLimit(r) - proxyReq := applyHeaders(cloneRequest(r)) + configs := getConfsMatchingRequest(r) + sHost := simplifyHost(r.Host) + limiter := p.getLimiter(sHost) + + proxyReq := applyHeaders(cloneRequest(r), configs) for { + p.waitRateLimit(limiter) - if retries >= config.Retries { - return nil, errors.Errorf("giving up after %d retries", config.Retries) + if retries >= config.Retries+additionalRetries || retries > config.RetriesHard { + return nil, errors.Errorf("giving up after %d retries", retries) } - resp, err := p.HttpClient.Do(proxyReq) + ctx := &RequestCtx{ + RequestTime: time.Now(), + } + var err error + ctx.Response, err = p.HttpClient.Do(proxyReq) if err != nil { if isPermanentError(err) { return nil, err } - wait := waitTime(retries) + dontRetry, forceRetry, limitMultiplier, newLimit, _ := computeRules(ctx, configs) + if forceRetry { + additionalRetries += 1 + } else if dontRetry { + return nil, errors.Errorf("Applied dont_retry rule for (%s)", err) + } + p.applyLimiterRules(newLimit, limiter, limitMultiplier) + wait := waitTime(retries) logrus.WithError(err).WithFields(logrus.Fields{ "wait": wait, }).Trace("Temporary error during request") @@ -207,27 +286,45 @@ func (p *Proxy) processRequest(r *http.Request) (*http.Response, error) { continue } - if isHttpSuccessCode(resp.StatusCode) { + // Compute rules + dontRetry, forceRetry, limitMultiplier, newLimit, shouldRetry := computeRules(ctx, configs) - return resp, nil - } else if shouldRetryHttpCode(resp.StatusCode) { + if forceRetry { + additionalRetries += 1 + } else if dontRetry { + return nil, errors.Errorf("Applied dont_retry rule") + } + p.applyLimiterRules(newLimit, limiter, limitMultiplier) + + if isHttpSuccessCode(ctx.Response.StatusCode) { + return ctx.Response, nil + + } else if forceRetry || shouldRetry || shouldRetryHttpCode(ctx.Response.StatusCode) { wait := waitTime(retries) logrus.WithFields(logrus.Fields{ "wait": wait, - "status": resp.StatusCode, + "status": ctx.Response.StatusCode, }).Trace("HTTP error during request") time.Sleep(wait) retries += 1 continue } else { - return nil, errors.Errorf("HTTP error: %d", resp.StatusCode) + return nil, errors.Errorf("HTTP error: %d", ctx.Response.StatusCode) } } } +func (p *Proxy) applyLimiterRules(newLimit float64, limiter *rate.Limiter, limitMultiplier float64) { + if newLimit != 0 { + limiter.SetLimit(rate.Limit(newLimit)) + } else if limitMultiplier != 1 { + limiter.SetLimit(limiter.Limit() * rate.Limit(1/limitMultiplier)) + } +} + func (b *Balancer) Run() { //b.Verbose = true @@ -285,7 +382,6 @@ func NewProxy(name, stringUrl string) (*Proxy, error) { Name: name, Url: parsedUrl, HttpClient: httpClient, - Limiters: make(map[string]*ExpiringLimiter), }, nil } diff --git a/retry.go b/retry.go index ed3528b..e68f9fa 100644 --- a/retry.go +++ b/retry.go @@ -3,10 +3,10 @@ package main import ( "fmt" "github.com/sirupsen/logrus" + "golang.org/x/time/rate" "log" "math" "net" - "net/http" "net/url" "os" "syscall" @@ -80,17 +80,9 @@ func waitTime(retries int) time.Duration { return time.Duration(config.Wait * int64(math.Pow(config.Multiplier, float64(retries)))) } -func (p *Proxy) waitRateLimit(r *http.Request) { +func (p *Proxy) waitRateLimit(limiter *rate.Limiter) { - sHost := simplifyHost(r.Host) - - limiter := p.getLimiter(sHost) reservation := limiter.Reserve() - if !reservation.OK() { - logrus.WithFields(logrus.Fields{ - "host": sHost, - }).Warn("Could not get reservation, make sure that burst is > 0") - } delay := reservation.Delay() if delay > 0 { diff --git a/test/web.py b/test/web.py index e009956..5a8751e 100644 --- a/test/web.py +++ b/test/web.py @@ -1,6 +1,7 @@ -from flask import Flask, Response import time +from flask import Flask, Response + app = Flask(__name__) @@ -10,6 +11,18 @@ def slow(): return "Hello World!" +@app.route("/echo/") +def echo(text): + return text + + +@app.route("/echoh/") +def echoh(text): + return Response(response="see X-Test header", status=404, headers={ + "X-Test": text, + }) + + @app.route("/500") def e500(): return Response(status=500) @@ -22,7 +35,7 @@ def e404(): @app.route("/403") def e403(): - return Response(status=404) + return Response(status=403) if __name__ == "__main__":