mirror of
https://github.com/simon987/Architeuthis.git
synced 2025-04-04 08:02:59 +00:00
296 lines
6.3 KiB
Go
296 lines
6.3 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"fmt"
|
|
"github.com/pkg/errors"
|
|
"github.com/ryanuber/go-glob"
|
|
"github.com/sirupsen/logrus"
|
|
"io/ioutil"
|
|
"os"
|
|
"reflect"
|
|
"runtime"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
func (a HostRuleAction) String() string {
|
|
switch a {
|
|
case DontRetry:
|
|
return "dont_retry"
|
|
case ForceRetry:
|
|
return "force_retry"
|
|
case ShouldRetry:
|
|
return "should_retry"
|
|
}
|
|
return "???"
|
|
}
|
|
|
|
func parseRule(raw *RawHostRule) (*HostRule, error) {
|
|
|
|
rule := &HostRule{}
|
|
|
|
switch raw.Action {
|
|
case "should_retry":
|
|
rule.Action = ShouldRetry
|
|
case "dont_retry":
|
|
rule.Action = DontRetry
|
|
case "force_retry":
|
|
rule.Action = ForceRetry
|
|
default:
|
|
return nil, errors.Errorf("Invalid argument for action: %s", raw.Action)
|
|
}
|
|
|
|
switch {
|
|
case strings.Contains(raw.Condition, "!="):
|
|
op1Str, op2Str := split(raw.Condition, "!=")
|
|
op1Func := parseOperand1(op1Str)
|
|
if op1Func == nil {
|
|
return nil, errors.Errorf("Invalid rule: %s", raw.Condition)
|
|
}
|
|
|
|
if isGlob(op2Str) {
|
|
rule.Matches = func(ctx *ResponseCtx) bool {
|
|
return !glob.Glob(op2Str, op1Func(ctx))
|
|
}
|
|
} else {
|
|
op2Str = strings.Replace(op2Str, "\\*", "*", -1)
|
|
rule.Matches = func(ctx *ResponseCtx) bool {
|
|
return op1Func(ctx) != op2Str
|
|
}
|
|
}
|
|
case strings.Contains(raw.Condition, "="):
|
|
op1Str, op2Str := split(raw.Condition, "=")
|
|
op1Func := parseOperand1(op1Str)
|
|
if op1Func == nil {
|
|
return nil, errors.Errorf("Invalid rule: %s", raw.Condition)
|
|
}
|
|
|
|
if isGlob(op2Str) {
|
|
rule.Matches = func(ctx *ResponseCtx) bool {
|
|
return glob.Glob(op2Str, op1Func(ctx))
|
|
}
|
|
} else {
|
|
op2Str = strings.Replace(op2Str, "\\*", "*", -1)
|
|
rule.Matches = func(ctx *ResponseCtx) bool {
|
|
return op1Func(ctx) == op2Str
|
|
}
|
|
}
|
|
case strings.Contains(raw.Condition, ">"):
|
|
op1Str, op2Str := split(raw.Condition, ">")
|
|
op1Func := parseOperand1(op1Str)
|
|
if op1Func == nil {
|
|
return nil, errors.Errorf("Invalid rule: %s", raw.Condition)
|
|
}
|
|
op2Num, err := parseOperand2(op1Str, op2Str)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
rule.Matches = func(ctx *ResponseCtx) bool {
|
|
op1Num, err := strconv.ParseFloat(op1Func(ctx), 64)
|
|
handleRuleErr(err)
|
|
return op1Num > op2Num
|
|
}
|
|
case strings.Contains(raw.Condition, "<"):
|
|
op1Str, op2Str := split(raw.Condition, "<")
|
|
op1Func := parseOperand1(op1Str)
|
|
if op1Func == nil {
|
|
return nil, errors.Errorf("Invalid rule: %s", raw.Condition)
|
|
}
|
|
op2Num, err := parseOperand2(op1Str, op2Str)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
rule.Matches = func(ctx *ResponseCtx) bool {
|
|
op1Num, err := strconv.ParseFloat(op1Func(ctx), 64)
|
|
handleRuleErr(err)
|
|
return op1Num < op2Num
|
|
}
|
|
}
|
|
|
|
return rule, nil
|
|
}
|
|
|
|
func handleRuleErr(err error) {
|
|
if err != nil {
|
|
logrus.WithError(err).Warn("Error computing rule")
|
|
}
|
|
}
|
|
|
|
func split(str, subStr string) (string, string) {
|
|
|
|
str1 := str[:strings.Index(str, subStr)]
|
|
str2 := str[strings.Index(str, subStr)+len(subStr):]
|
|
|
|
return str1, str2
|
|
}
|
|
|
|
func parseOperand2(op1, op2 string) (float64, error) {
|
|
if op1 == "response_time" {
|
|
res, err := time.ParseDuration(op2)
|
|
if err != nil {
|
|
return -1, err
|
|
}
|
|
return res.Seconds(), nil
|
|
}
|
|
|
|
return strconv.ParseFloat(op2, 64)
|
|
}
|
|
|
|
func parseOperand1(op string) func(ctx *ResponseCtx) string {
|
|
switch {
|
|
case op == "body":
|
|
return func(ctx *ResponseCtx) string {
|
|
|
|
if ctx.Response == nil {
|
|
return ""
|
|
}
|
|
bodyBytes, err := ioutil.ReadAll(ctx.Response.Body)
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
err = ctx.Response.Body.Close()
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
ctx.Response.Body = ioutil.NopCloser(bytes.NewReader(bodyBytes))
|
|
|
|
return string(bodyBytes)
|
|
}
|
|
case op == "status":
|
|
return func(ctx *ResponseCtx) string {
|
|
if ctx.Response == nil {
|
|
return ""
|
|
}
|
|
return strconv.Itoa(ctx.Response.StatusCode)
|
|
}
|
|
case op == "response_time":
|
|
return func(ctx *ResponseCtx) string {
|
|
return strconv.FormatFloat(ctx.ResponseTime, 'f', 6, 64)
|
|
}
|
|
case strings.HasPrefix(op, "header:"):
|
|
header := op[strings.Index(op, ":")+1:]
|
|
return func(ctx *ResponseCtx) string {
|
|
if ctx.Response == nil {
|
|
return ""
|
|
}
|
|
return ctx.Response.Header.Get(header)
|
|
}
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func isGlob(op string) bool {
|
|
tmpStr := strings.Replace(op, "\\*", "_", -1)
|
|
|
|
return strings.Contains(tmpStr, "*")
|
|
}
|
|
|
|
func loadConfig() error {
|
|
|
|
configFile, err := os.Open("config.json")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
configBytes, err := ioutil.ReadAll(configFile)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = json.Unmarshal(configBytes, &config)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
validateConfig()
|
|
|
|
config.Timeout, err = time.ParseDuration(config.TimeoutStr)
|
|
wait, err := time.ParseDuration(config.WaitStr)
|
|
config.Wait = int64(wait)
|
|
|
|
for i, conf := range config.Hosts {
|
|
if conf.EveryStr == "" {
|
|
// Look 'upwards' for every
|
|
for _, prevConf := range config.Hosts[:i] {
|
|
if glob.Glob(prevConf.Host, conf.Host) {
|
|
conf.Every = prevConf.Every
|
|
}
|
|
}
|
|
} else {
|
|
conf.Every, err = time.ParseDuration(conf.EveryStr)
|
|
handleErr(err)
|
|
}
|
|
|
|
if conf.Burst == 0 {
|
|
// Look 'upwards' for burst
|
|
for _, prevConf := range config.Hosts[:i] {
|
|
if glob.Glob(prevConf.Host, conf.Host) {
|
|
conf.Burst = prevConf.Burst
|
|
}
|
|
}
|
|
}
|
|
if conf.Burst == 0 {
|
|
return errors.Errorf("Burst must be > 0 (Host: %s)", conf.Host)
|
|
}
|
|
|
|
for _, rawRule := range conf.RawRules {
|
|
r, err := parseRule(rawRule)
|
|
handleErr(err)
|
|
conf.Rules = append(conf.Rules, r)
|
|
|
|
logrus.WithFields(logrus.Fields{
|
|
"arg": r.Arg,
|
|
"action": r.Action,
|
|
"matchFunc": runtime.FuncForPC(reflect.ValueOf(r.Matches).Pointer()).Name(),
|
|
}).Info("Rule")
|
|
}
|
|
|
|
logrus.WithFields(logrus.Fields{
|
|
"every": conf.Every,
|
|
"burst": conf.Burst,
|
|
"headers": conf.Headers,
|
|
"host": conf.Host,
|
|
}).Info("Host")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func validateConfig() {
|
|
|
|
for _, conf := range config.Hosts {
|
|
|
|
if conf.Host == "*" {
|
|
config.DefaultConfig = conf
|
|
}
|
|
|
|
for k := range conf.Headers {
|
|
if strings.ToLower(k) == "accept-encoding" {
|
|
panic(fmt.Sprintf("headers config for '%s':"+
|
|
" Do not set the Accept-Encoding header, it breaks goproxy", conf.Host))
|
|
}
|
|
}
|
|
}
|
|
|
|
if config.DefaultConfig == nil {
|
|
panic("config.json: You must specify a default host ('*')")
|
|
}
|
|
}
|
|
|
|
func (a *Architeuthis) reloadConfig() {
|
|
_ = loadConfig()
|
|
logrus.Info("Reloaded config")
|
|
}
|
|
|
|
func handleErr(err error) {
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}
|