Architeuthis/main.go
2019-06-04 22:02:50 -04:00

397 lines
8.2 KiB
Go

package main
import (
"fmt"
"github.com/elazarl/goproxy"
"github.com/pkg/errors"
"github.com/ryanuber/go-glob"
"github.com/sirupsen/logrus"
"golang.org/x/time/rate"
"math/rand"
"net/http"
"net/url"
"sort"
"strings"
"sync"
"time"
)
type Balancer struct {
server *goproxy.ProxyHttpServer
proxies []*Proxy
proxyMutex *sync.RWMutex
}
type ExpiringLimiter struct {
HostGlob string
IsGlob bool
CanDelete bool
Limiter *rate.Limiter
LastRead time.Time
}
type Proxy struct {
Name string
Url *url.URL
Limiters []*ExpiringLimiter
HttpClient *http.Client
Connections int
}
type RequestCtx struct {
RequestTime time.Time
Response *http.Response
}
type ByConnectionCount []*Proxy
func (a ByConnectionCount) Len() int {
return len(a)
}
func (a ByConnectionCount) Swap(i, j int) {
a[i], a[j] = a[j], a[i]
}
func (a ByConnectionCount) Less(i, j int) bool {
return a[i].Connections < a[j].Connections
}
func (p *Proxy) getLimiter(host string) *rate.Limiter {
for _, limiter := range p.Limiters {
if limiter.IsGlob {
if glob.Glob(limiter.HostGlob, host) {
limiter.LastRead = time.Now()
return limiter.Limiter
}
} else if limiter.HostGlob == host {
limiter.LastRead = time.Now()
return limiter.Limiter
}
}
newExpiringLimiter := p.makeNewLimiter(host)
return newExpiringLimiter.Limiter
}
func (p *Proxy) makeNewLimiter(host string) *ExpiringLimiter {
newExpiringLimiter := &ExpiringLimiter{
CanDelete: false,
HostGlob: host,
IsGlob: false,
LastRead: time.Now(),
Limiter: rate.NewLimiter(rate.Every(config.DefaultConfig.Every), config.DefaultConfig.Burst),
}
p.Limiters = append([]*ExpiringLimiter{newExpiringLimiter}, p.Limiters...)
logrus.WithFields(logrus.Fields{
"host": host,
"every": config.DefaultConfig.Every,
}).Trace("New limiter")
return newExpiringLimiter
}
func simplifyHost(host string) string {
col := strings.LastIndex(host, ":")
if col > 0 {
host = host[:col]
}
return "." + host
}
func (b *Balancer) chooseProxy() *Proxy {
if len(b.proxies) == 0 {
return b.proxies[0]
}
sort.Sort(ByConnectionCount(b.proxies))
proxyWithLeastConns := b.proxies[0]
proxiesWithSameConnCount := b.getProxiesWithSameConnCountAs(proxyWithLeastConns)
if len(proxiesWithSameConnCount) > 1 {
return proxiesWithSameConnCount[rand.Intn(len(proxiesWithSameConnCount))]
} else {
return proxyWithLeastConns
}
}
func (b *Balancer) getProxiesWithSameConnCountAs(p0 *Proxy) []*Proxy {
proxiesWithSameConnCount := make([]*Proxy, 0)
for _, p := range b.proxies {
if p.Connections != p0.Connections {
break
}
proxiesWithSameConnCount = append(proxiesWithSameConnCount, p)
}
return proxiesWithSameConnCount
}
func New() *Balancer {
balancer := new(Balancer)
balancer.proxyMutex = &sync.RWMutex{}
balancer.server = goproxy.NewProxyHttpServer()
balancer.server.OnRequest().HandleConnect(goproxy.AlwaysMitm)
balancer.server.OnRequest().DoFunc(
func(r *http.Request, ctx *goproxy.ProxyCtx) (*http.Request, *http.Response) {
balancer.proxyMutex.RLock()
p := balancer.chooseProxy()
logrus.WithFields(logrus.Fields{
"proxy": p.Name,
"conns": p.Connections,
"url": r.URL,
}).Trace("Routing request")
resp, err := p.processRequest(r)
balancer.proxyMutex.RUnlock()
if err != nil {
logrus.WithError(err).Trace("Could not complete request")
return nil, goproxy.NewResponse(r, "text/plain", 500, err.Error())
}
return nil, resp
})
balancer.server.NonproxyHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/reload" {
balancer.reloadConfig()
_, _ = fmt.Fprint(w, "Reloaded\n")
} else {
w.Header().Set("Content-Type", "application/json")
_, _ = fmt.Fprint(w, "{\"name\":\"Architeuthis\",\"version\":1.0}")
}
})
return balancer
}
func getConfsMatchingRequest(r *http.Request) []*HostConfig {
sHost := simplifyHost(r.Host)
configs := make([]*HostConfig, 0)
for _, conf := range config.Hosts {
if glob.Glob(conf.Host, sHost) {
configs = append(configs, conf)
}
}
return configs
}
func applyHeaders(r *http.Request, configs []*HostConfig) *http.Request {
for _, conf := range configs {
for k, v := range conf.Headers {
r.Header.Set(k, v)
}
}
return r
}
func computeRules(ctx *RequestCtx, configs []*HostConfig) (dontRetry, forceRetry bool,
limitMultiplier, newLimit float64, shouldRetry bool) {
dontRetry = false
forceRetry = false
shouldRetry = false
limitMultiplier = 1
for _, conf := range configs {
for _, rule := range conf.Rules {
if rule.Matches(ctx) {
switch rule.Action {
case DontRetry:
dontRetry = true
case MultiplyEvery:
limitMultiplier = rule.Arg
case SetEvery:
newLimit = rule.Arg
case ForceRetry:
forceRetry = true
case ShouldRetry:
shouldRetry = true
}
}
}
}
return
}
func (p *Proxy) processRequest(r *http.Request) (*http.Response, error) {
p.Connections += 1
defer func() {
p.Connections -= 1
}()
retries := 0
additionalRetries := 0
configs := getConfsMatchingRequest(r)
sHost := simplifyHost(r.Host)
limiter := p.getLimiter(sHost)
proxyReq := applyHeaders(cloneRequest(r), configs)
for {
p.waitRateLimit(limiter)
if retries >= config.Retries+additionalRetries || retries > config.RetriesHard {
return nil, errors.Errorf("giving up after %d retries", retries)
}
ctx := &RequestCtx{
RequestTime: time.Now(),
}
var err error
ctx.Response, err = p.HttpClient.Do(proxyReq)
if err != nil {
if isPermanentError(err) {
return nil, err
}
dontRetry, forceRetry, limitMultiplier, newLimit, _ := computeRules(ctx, configs)
if forceRetry {
additionalRetries += 1
} else if dontRetry {
return nil, errors.Errorf("Applied dont_retry rule for (%s)", err)
}
p.applyLimiterRules(newLimit, limiter, limitMultiplier)
wait := waitTime(retries)
logrus.WithError(err).WithFields(logrus.Fields{
"wait": wait,
}).Trace("Temporary error during request")
time.Sleep(wait)
retries += 1
continue
}
// Compute rules
dontRetry, forceRetry, limitMultiplier, newLimit, shouldRetry := computeRules(ctx, configs)
if forceRetry {
additionalRetries += 1
} else if dontRetry {
return nil, errors.Errorf("Applied dont_retry rule")
}
p.applyLimiterRules(newLimit, limiter, limitMultiplier)
if isHttpSuccessCode(ctx.Response.StatusCode) {
return ctx.Response, nil
} else if forceRetry || shouldRetry || shouldRetryHttpCode(ctx.Response.StatusCode) {
wait := waitTime(retries)
logrus.WithFields(logrus.Fields{
"wait": wait,
"status": ctx.Response.StatusCode,
}).Trace("HTTP error during request")
time.Sleep(wait)
retries += 1
continue
} else {
return nil, errors.Errorf("HTTP error: %d", ctx.Response.StatusCode)
}
}
}
func (p *Proxy) applyLimiterRules(newLimit float64, limiter *rate.Limiter, limitMultiplier float64) {
if newLimit != 0 {
limiter.SetLimit(rate.Limit(newLimit))
} else if limitMultiplier != 1 {
limiter.SetLimit(limiter.Limit() * rate.Limit(1/limitMultiplier))
}
}
func (b *Balancer) Run() {
//b.Verbose = true
logrus.WithFields(logrus.Fields{
"addr": config.Addr,
}).Info("Listening")
err := http.ListenAndServe(config.Addr, b.server)
logrus.Fatal(err)
}
func cloneRequest(r *http.Request) *http.Request {
proxyReq := &http.Request{
Method: r.Method,
URL: r.URL,
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: r.Header,
Body: r.Body,
Host: r.URL.Host,
}
return proxyReq
}
func NewProxy(name, stringUrl string) (*Proxy, error) {
var parsedUrl *url.URL
var err error
if stringUrl != "" {
parsedUrl, err = url.Parse(stringUrl)
if err != nil {
return nil, err
}
} else {
parsedUrl = nil
}
var httpClient *http.Client
if parsedUrl == nil {
httpClient = &http.Client{}
} else {
httpClient = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(parsedUrl),
},
}
}
httpClient.Timeout = config.Timeout
return &Proxy{
Name: name,
Url: parsedUrl,
HttpClient: httpClient,
}, nil
}
func main() {
logrus.SetLevel(logrus.TraceLevel)
balancer := New()
balancer.reloadConfig()
balancer.setupGarbageCollector()
balancer.Run()
}