Architeuthis/config.go

296 lines
6.3 KiB
Go

package main
import (
"bytes"
"encoding/json"
"fmt"
"github.com/pkg/errors"
"github.com/ryanuber/go-glob"
"github.com/sirupsen/logrus"
"io/ioutil"
"os"
"reflect"
"runtime"
"strconv"
"strings"
"time"
)
func (a HostRuleAction) String() string {
switch a {
case DontRetry:
return "dont_retry"
case ForceRetry:
return "force_retry"
case ShouldRetry:
return "should_retry"
}
return "???"
}
func parseRule(raw *RawHostRule) (*HostRule, error) {
rule := &HostRule{}
switch raw.Action {
case "should_retry":
rule.Action = ShouldRetry
case "dont_retry":
rule.Action = DontRetry
case "force_retry":
rule.Action = ForceRetry
default:
return nil, errors.Errorf("Invalid argument for action: %s", raw.Action)
}
switch {
case strings.Contains(raw.Condition, "!="):
op1Str, op2Str := split(raw.Condition, "!=")
op1Func := parseOperand1(op1Str)
if op1Func == nil {
return nil, errors.Errorf("Invalid rule: %s", raw.Condition)
}
if isGlob(op2Str) {
rule.Matches = func(ctx *ResponseCtx) bool {
return !glob.Glob(op2Str, op1Func(ctx))
}
} else {
op2Str = strings.Replace(op2Str, "\\*", "*", -1)
rule.Matches = func(ctx *ResponseCtx) bool {
return op1Func(ctx) != op2Str
}
}
case strings.Contains(raw.Condition, "="):
op1Str, op2Str := split(raw.Condition, "=")
op1Func := parseOperand1(op1Str)
if op1Func == nil {
return nil, errors.Errorf("Invalid rule: %s", raw.Condition)
}
if isGlob(op2Str) {
rule.Matches = func(ctx *ResponseCtx) bool {
return glob.Glob(op2Str, op1Func(ctx))
}
} else {
op2Str = strings.Replace(op2Str, "\\*", "*", -1)
rule.Matches = func(ctx *ResponseCtx) bool {
return op1Func(ctx) == op2Str
}
}
case strings.Contains(raw.Condition, ">"):
op1Str, op2Str := split(raw.Condition, ">")
op1Func := parseOperand1(op1Str)
if op1Func == nil {
return nil, errors.Errorf("Invalid rule: %s", raw.Condition)
}
op2Num, err := parseOperand2(op1Str, op2Str)
if err != nil {
return nil, err
}
rule.Matches = func(ctx *ResponseCtx) bool {
op1Num, err := strconv.ParseFloat(op1Func(ctx), 64)
handleRuleErr(err)
return op1Num > op2Num
}
case strings.Contains(raw.Condition, "<"):
op1Str, op2Str := split(raw.Condition, "<")
op1Func := parseOperand1(op1Str)
if op1Func == nil {
return nil, errors.Errorf("Invalid rule: %s", raw.Condition)
}
op2Num, err := parseOperand2(op1Str, op2Str)
if err != nil {
return nil, err
}
rule.Matches = func(ctx *ResponseCtx) bool {
op1Num, err := strconv.ParseFloat(op1Func(ctx), 64)
handleRuleErr(err)
return op1Num < op2Num
}
}
return rule, nil
}
func handleRuleErr(err error) {
if err != nil {
logrus.WithError(err).Warn("Error computing rule")
}
}
func split(str, subStr string) (string, string) {
str1 := str[:strings.Index(str, subStr)]
str2 := str[strings.Index(str, subStr)+len(subStr):]
return str1, str2
}
func parseOperand2(op1, op2 string) (float64, error) {
if op1 == "response_time" {
res, err := time.ParseDuration(op2)
if err != nil {
return -1, err
}
return res.Seconds(), nil
}
return strconv.ParseFloat(op2, 64)
}
func parseOperand1(op string) func(ctx *ResponseCtx) string {
switch {
case op == "body":
return func(ctx *ResponseCtx) string {
if ctx.Response == nil {
return ""
}
bodyBytes, err := ioutil.ReadAll(ctx.Response.Body)
if err != nil {
return ""
}
err = ctx.Response.Body.Close()
if err != nil {
return ""
}
ctx.Response.Body = ioutil.NopCloser(bytes.NewReader(bodyBytes))
return string(bodyBytes)
}
case op == "status":
return func(ctx *ResponseCtx) string {
if ctx.Response == nil {
return ""
}
return strconv.Itoa(ctx.Response.StatusCode)
}
case op == "response_time":
return func(ctx *ResponseCtx) string {
return strconv.FormatFloat(ctx.ResponseTime, 'f', 6, 64)
}
case strings.HasPrefix(op, "header:"):
header := op[strings.Index(op, ":")+1:]
return func(ctx *ResponseCtx) string {
if ctx.Response == nil {
return ""
}
return ctx.Response.Header.Get(header)
}
default:
return nil
}
}
func isGlob(op string) bool {
tmpStr := strings.Replace(op, "\\*", "_", -1)
return strings.Contains(tmpStr, "*")
}
func loadConfig() error {
configFile, err := os.Open("config.json")
if err != nil {
return err
}
configBytes, err := ioutil.ReadAll(configFile)
if err != nil {
return err
}
err = json.Unmarshal(configBytes, &config)
if err != nil {
return err
}
validateConfig()
config.Timeout, err = time.ParseDuration(config.TimeoutStr)
wait, err := time.ParseDuration(config.WaitStr)
config.Wait = int64(wait)
for i, conf := range config.Hosts {
if conf.EveryStr == "" {
// Look 'upwards' for every
for _, prevConf := range config.Hosts[:i] {
if glob.Glob(prevConf.Host, conf.Host) {
conf.Every = prevConf.Every
}
}
} else {
conf.Every, err = time.ParseDuration(conf.EveryStr)
handleErr(err)
}
if conf.Burst == 0 {
// Look 'upwards' for burst
for _, prevConf := range config.Hosts[:i] {
if glob.Glob(prevConf.Host, conf.Host) {
conf.Burst = prevConf.Burst
}
}
}
if conf.Burst == 0 {
return errors.Errorf("Burst must be > 0 (Host: %s)", conf.Host)
}
for _, rawRule := range conf.RawRules {
r, err := parseRule(rawRule)
handleErr(err)
conf.Rules = append(conf.Rules, r)
logrus.WithFields(logrus.Fields{
"arg": r.Arg,
"action": r.Action,
"matchFunc": runtime.FuncForPC(reflect.ValueOf(r.Matches).Pointer()).Name(),
}).Info("Rule")
}
logrus.WithFields(logrus.Fields{
"every": conf.Every,
"burst": conf.Burst,
"headers": conf.Headers,
"host": conf.Host,
}).Info("Host")
}
return nil
}
func validateConfig() {
for _, conf := range config.Hosts {
if conf.Host == "*" {
config.DefaultConfig = conf
}
for k := range conf.Headers {
if strings.ToLower(k) == "accept-encoding" {
panic(fmt.Sprintf("headers config for '%s':"+
" Do not set the Accept-Encoding header, it breaks goproxy", conf.Host))
}
}
}
if config.DefaultConfig == nil {
panic("config.json: You must specify a default host ('*')")
}
}
func (a *Architeuthis) reloadConfig() {
_ = loadConfig()
logrus.Info("Reloaded config")
}
func handleErr(err error) {
if err != nil {
panic(err)
}
}