package web

import (
	"compress/flate"
	"context"
	"crypto/tls"
	"errors"
	"fmt"
	"net/http"
	"os"
	"os/signal"
	"strconv"
	"time"

	"github.com/creasty/defaults"

	"github.com/go-chi/chi/v5"
	"github.com/go-chi/chi/v5/middleware"
	"github.com/go-chi/cors"
	"github.com/go-chi/render"
	"scm.yoorie.de/go-lib/certs"
	log "scm.yoorie.de/go-lib/gelf"
)

type WebServerConfiguration struct {
	Host                        string                         `yaml:"host"`
	Port                        int                            `default:"7080" yaml:"port"`
	SslPort                     int                            `yaml:"sslport"`
	SslCert                     string                         `yaml:"sslcert"`
	SslKey                      string                         `yaml:"sslkey"`
	HealthCheckPeriod           int                            `default:"30" yaml:"healthcheckperiod"`
	CertificateGenerationParams TLSCertificateGenerationParams `yaml:"tlsgenerationparams"`
}

type TLSCertificateGenerationParams struct {
	Organization string        `default:"acme.com" yaml:"organization"`
	Host         string        `default:"localhost" yaml:"host"`
	ValidForDays time.Duration `default:"3650" yaml:"validfor"` // Value in days
}

type WebServer struct {
	serviceConfig *WebServerConfiguration
	sslsrv        *http.Server
	srv           *http.Server
	healthChecker *healthChecker
	router        *chi.Mux
	healthRouter  *chi.Mux
	mounts        map[string]http.Handler
	HealthCheck   func() (bool, string)
}

func NewWebServer(config *WebServerConfiguration) (*WebServer, error) {
	if config == nil {
		return nil, errors.New("config may not be null")
	}
	defaults.Set(config)
	return &WebServer{
		serviceConfig: config,
		mounts:        make(map[string]http.Handler),
	}, nil
}

func (server *WebServer) healthRoutes() *chi.Mux {
	compressor := middleware.NewCompressor(flate.DefaultCompression)
	router := chi.NewRouter()
	router.Use(
		render.SetContentType(render.ContentTypeJSON),
		compressor.Handler,
		middleware.Recoverer,
	)
	router.Get("/healthz", server.healthChecker.healthyEndpoint)
	router.Get("/readyz", server.healthChecker.readinessEndpoint)
	return router
}

func (server *WebServer) isSsl() bool {
	return server.serviceConfig.SslPort > 0
}

func (server *WebServer) Mount(pattern string, handler http.Handler) {
	server.mounts[pattern] = handler
}

func (server *WebServer) routes() *chi.Mux {
	compressor := middleware.NewCompressor(flate.DefaultCompression)
	router := chi.NewRouter()
	router.Use(
		compressor.Handler,
		middleware.Recoverer,
		cors.Handler(cors.Options{
			AllowedOrigins:   []string{"https://*", "http://*"},
			AllowedMethods:   []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
			AllowedHeaders:   []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
			ExposedHeaders:   []string{"Link"},
			AllowCredentials: false,
			MaxAge:           300, // Maximum value not ignored by any of major browsers
		}),
	)
	router.Mount("/health", server.healthRoutes())
	// Custom routes
	for pattern, handler := range server.mounts {
		router.Mount(pattern, handler)
	}
	// clean map
	server.mounts = make(map[string]http.Handler)
	return router
}

func (server *WebServer) getTLSConfig() (*tls.Config, error) {
	cert := server.serviceConfig.SslCert
	key := server.serviceConfig.SslKey
	if cert != "" && key != "" {
		log.Infof("Loading SSL certificate %s %s", cert, key)
		cer, err := tls.LoadX509KeyPair(cert, key)
		if err != nil {
			return nil, fmt.Errorf("Cannot load SSL certificat: %v", err)
		}
		return &tls.Config{Certificates: []tls.Certificate{cer}}, nil
	}
	log.Info("Generating SSL certificate")

	params := server.serviceConfig.CertificateGenerationParams

	gc := certs.GenerateCertificate{
		Organization: params.Organization,
		Host:         params.Host,
		ValidFor:     params.ValidForDays * 24 * time.Hour,
		IsCA:         false,
		EcdsaCurve:   "P256",
		Ed25519Key:   true,
	}
	return gc.GenerateTLSConfig()
}

func (server *WebServer) DebugRoutes(routerName string, router *chi.Mux) {
	walkFunc := func(method string, route string, handler http.Handler, middlewares ...func(http.Handler) http.Handler) error {
		log.Infof("%s %s", method, route)
		return nil
	}
	log.Infof("Routes of scope %s \n", routerName)
	if err := chi.Walk(router, walkFunc); err != nil {
		log.Alertf("Logging err: %s", err.Error())
	}
}

// BuildHostAddress Builds a host address from configuration
func (c *WebServerConfiguration) BuildHostAddress(ssl bool) string {
	var address string
	if "" == c.Host {
		address += "0.0.0.0"
	} else {
		address += c.Host
	}
	address += ":"
	if ssl {
		address += strconv.Itoa(c.SslPort)
	} else {
		address += strconv.Itoa(c.Port)
	}
	return address
}

func (c *WebServerConfiguration) GetBindings() string {

	bindingStrings := ""
	if c.SslPort > 0 {
		bindingStrings += "https://" + c.BuildHostAddress(true) + ","
	}
	bindingStrings += "http://" + c.BuildHostAddress(false)
	return bindingStrings
}

func (server *WebServer) performHealthCheck() (bool, string) {
	if server.HealthCheck != nil {
		return server.HealthCheck()
	}
	return true, ""
}

func (server *WebServer) Start() error {
	if server.serviceConfig == nil {
		return errors.New("use NewWebServer(config) for initialising the web server")
	}
	if len(server.mounts) == 0 {
		return errors.New("No mounts points added")
	}

	server.healthChecker = newHealthChecker(server.performHealthCheck)
	server.healthChecker.start(server.serviceConfig.HealthCheckPeriod)

	ssl := server.isSsl()
	if ssl {
		log.Debugf("Running in SSL mode")
	}

	server.healthRouter = server.healthRoutes()
	server.router = server.routes()
	server.DebugRoutes("Main", server.router)
	server.DebugRoutes("Health", server.healthRouter)
	if ssl {
		err := server.setupSsl()
		if err != nil {
			return err
		}
	} else {
		// own http server for the healthchecks
		httpAddress := server.serviceConfig.BuildHostAddress(false)
		server.srv = &http.Server{
			Addr:         httpAddress,
			WriteTimeout: time.Second * 15,
			ReadTimeout:  time.Second * 15,
			IdleTimeout:  time.Second * 60,
			Handler:      server.router,
		}
		go func() {
			log.Infof("Starting http server on address: %s", server.srv.Addr)
			if err := server.srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
				log.Alertf("error starting server: %s", err.Error())
			}
		}()
	}
	return nil
}

func (server *WebServer) setupHttp() error {
	// own http server for the healthchecks
	httpAddress := server.serviceConfig.BuildHostAddress(false)
	server.srv = &http.Server{
		Addr:         httpAddress,
		WriteTimeout: time.Second * 15,
		ReadTimeout:  time.Second * 15,
		IdleTimeout:  time.Second * 60,
		Handler:      server.router,
	}
	go func() {
		log.Infof("Starting http server on address: %s", server.srv.Addr)
		if err := server.srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
			log.Alertf("error starting server: %s", err.Error())
		}
	}()
	return nil
}

func (server *WebServer) setupSsl() error {
	tlsConfig, err := server.getTLSConfig()
	if err != nil {
		return err
	}
	sslAddress := server.serviceConfig.BuildHostAddress(true)

	server.sslsrv = &http.Server{
		Addr:         sslAddress,
		WriteTimeout: time.Second * 15,
		ReadTimeout:  time.Second * 15,
		IdleTimeout:  time.Second * 60,
		Handler:      server.router,
		TLSConfig:    tlsConfig,
	}
	go func() {
		log.Infof("Starting https server on address: %s", server.sslsrv.Addr)
		if err := server.sslsrv.ListenAndServeTLS("", ""); err != nil && err != http.ErrServerClosed {
			log.Alertf("error starting server: %s", err.Error())
		}
	}()
	httpAddress := server.serviceConfig.BuildHostAddress(false)
	server.srv = &http.Server{
		Addr:         httpAddress,
		WriteTimeout: time.Second * 15,
		ReadTimeout:  time.Second * 15,
		IdleTimeout:  time.Second * 60,
		Handler:      server.healthRouter,
	}
	go func() {
		log.Infof("Starting http server on address: %s", server.srv.Addr)
		if err := server.srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
			log.Alertf("error starting server: %s", err.Error())
		}
	}()
	return nil
}

func (server *WebServer) Join() {
	c := make(chan os.Signal, 1)
	signal.Notify(c, os.Interrupt)
	<-c
}

func (server *WebServer) Stop() {
	server.healthChecker.stop()
	log.Info("Waiting for clients ...")
	ctx, cancel := context.WithTimeout(context.Background(), time.Second*15)
	defer cancel()

	log.Info("Shutting down server ...")
	server.srv.Shutdown(ctx)
	if server.isSsl() {
		server.sslsrv.Shutdown(ctx)
	}
	log.Info("Server has been shutted down")
}