package web import ( "compress/flate" "context" "crypto/tls" "errors" "fmt" "net/http" "os" "os/signal" "strconv" "time" "github.com/creasty/defaults" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/go-chi/cors" "github.com/go-chi/render" "scm.yoorie.de/go-lib/certs" log "scm.yoorie.de/go-lib/gelf" ) type WebServerConfiguration struct { Host string `yaml:"host"` Port int `default:"7080" yaml:"port"` SslPort int `yaml:"sslport"` SslCert string `yaml:"sslcert"` SslKey string `yaml:"sslkey"` HealthCheckPeriod int `default:"30" yaml:"healthcheckperiod"` CertificateGenerationParams TLSCertificateGenerationParams `yaml:"tlsgenerationparams"` } type TLSCertificateGenerationParams struct { Organization string `default:"acme.com" yaml:"organization"` Host string `default:"localhost" yaml:"host"` ValidForDays time.Duration `default:"3650" yaml:"validfor"` // Value in days } type WebServer struct { serviceConfig *WebServerConfiguration sslsrv *http.Server srv *http.Server healthChecker *healthChecker router *chi.Mux healthRouter *chi.Mux mounts map[string]http.Handler HealthCheck func() (bool, string) } func NewWebServer(config *WebServerConfiguration) (*WebServer, error) { if config == nil { return nil, errors.New("config may not be null") } defaults.Set(config) return &WebServer{ serviceConfig: config, mounts: make(map[string]http.Handler), }, nil } func (server *WebServer) healthRoutes() *chi.Mux { compressor := middleware.NewCompressor(flate.DefaultCompression) router := chi.NewRouter() router.Use( render.SetContentType(render.ContentTypeJSON), compressor.Handler, middleware.Recoverer, ) router.Get("/healthz", server.healthChecker.healthyEndpoint) router.Get("/readyz", server.healthChecker.readinessEndpoint) return router } func (server *WebServer) isSsl() bool { return server.serviceConfig.SslPort > 0 } func (server *WebServer) Mount(pattern string, handler http.Handler) { server.mounts[pattern] = handler } func (server *WebServer) routes() *chi.Mux { compressor := middleware.NewCompressor(flate.DefaultCompression) router := chi.NewRouter() router.Use( compressor.Handler, middleware.Recoverer, cors.Handler(cors.Options{ AllowedOrigins: []string{"https://*", "http://*"}, AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"}, ExposedHeaders: []string{"Link"}, AllowCredentials: false, MaxAge: 300, // Maximum value not ignored by any of major browsers }), ) router.Mount("/health", server.healthRoutes()) // Custom routes for pattern, handler := range server.mounts { router.Mount(pattern, handler) } // clean map server.mounts = make(map[string]http.Handler) return router } func (server *WebServer) getTLSConfig() (*tls.Config, error) { cert := server.serviceConfig.SslCert key := server.serviceConfig.SslKey if cert != "" && key != "" { log.Infof("Loading SSL certificate %s %s", cert, key) cer, err := tls.LoadX509KeyPair(cert, key) if err != nil { return nil, fmt.Errorf("Cannot load SSL certificat: %v", err) } return &tls.Config{Certificates: []tls.Certificate{cer}}, nil } log.Info("Generating SSL certificate") params := server.serviceConfig.CertificateGenerationParams gc := certs.GenerateCertificate{ Organization: params.Organization, Host: params.Host, ValidFor: params.ValidForDays * 24 * time.Hour, IsCA: false, EcdsaCurve: "P256", Ed25519Key: true, } return gc.GenerateTLSConfig() } func (server *WebServer) DebugRoutes(routerName string, router *chi.Mux) { walkFunc := func(method string, route string, handler http.Handler, middlewares ...func(http.Handler) http.Handler) error { log.Infof("%s %s", method, route) return nil } log.Infof("Routes of scope %s \n", routerName) if err := chi.Walk(router, walkFunc); err != nil { log.Alertf("Logging err: %s", err.Error()) } } // BuildHostAddress Builds a host address from configuration func (c *WebServerConfiguration) BuildHostAddress(ssl bool) string { var address string if "" == c.Host { address += "0.0.0.0" } else { address += c.Host } address += ":" if ssl { address += strconv.Itoa(c.SslPort) } else { address += strconv.Itoa(c.Port) } return address } func (c *WebServerConfiguration) GetBindings() string { bindingStrings := "" if c.SslPort > 0 { bindingStrings += "https://" + c.BuildHostAddress(true) + "," } bindingStrings += "http://" + c.BuildHostAddress(false) return bindingStrings } func (server *WebServer) performHealthCheck() (bool, string) { if server.HealthCheck != nil { return server.HealthCheck() } return true, "" } func (server *WebServer) Start() error { if server.serviceConfig == nil { return errors.New("use NewWebServer(config) for initialising the web server") } if len(server.mounts) == 0 { return errors.New("No mounts points added") } server.healthChecker = newHealthChecker(server.performHealthCheck) server.healthChecker.start(server.serviceConfig.HealthCheckPeriod) ssl := server.isSsl() if ssl { log.Debugf("Running in SSL mode") } server.healthRouter = server.healthRoutes() server.router = server.routes() server.DebugRoutes("Main", server.router) server.DebugRoutes("Health", server.healthRouter) if ssl { err := server.setupSsl() if err != nil { return err } } else { // own http server for the healthchecks httpAddress := server.serviceConfig.BuildHostAddress(false) server.srv = &http.Server{ Addr: httpAddress, WriteTimeout: time.Second * 15, ReadTimeout: time.Second * 15, IdleTimeout: time.Second * 60, Handler: server.router, } go func() { log.Infof("Starting http server on address: %s", server.srv.Addr) if err := server.srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { log.Alertf("error starting server: %s", err.Error()) } }() } return nil } func (server *WebServer) setupHttp() error { // own http server for the healthchecks httpAddress := server.serviceConfig.BuildHostAddress(false) server.srv = &http.Server{ Addr: httpAddress, WriteTimeout: time.Second * 15, ReadTimeout: time.Second * 15, IdleTimeout: time.Second * 60, Handler: server.router, } go func() { log.Infof("Starting http server on address: %s", server.srv.Addr) if err := server.srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { log.Alertf("error starting server: %s", err.Error()) } }() return nil } func (server *WebServer) setupSsl() error { tlsConfig, err := server.getTLSConfig() if err != nil { return err } sslAddress := server.serviceConfig.BuildHostAddress(true) server.sslsrv = &http.Server{ Addr: sslAddress, WriteTimeout: time.Second * 15, ReadTimeout: time.Second * 15, IdleTimeout: time.Second * 60, Handler: server.router, TLSConfig: tlsConfig, } go func() { log.Infof("Starting https server on address: %s", server.sslsrv.Addr) if err := server.sslsrv.ListenAndServeTLS("", ""); err != nil && err != http.ErrServerClosed { log.Alertf("error starting server: %s", err.Error()) } }() httpAddress := server.serviceConfig.BuildHostAddress(false) server.srv = &http.Server{ Addr: httpAddress, WriteTimeout: time.Second * 15, ReadTimeout: time.Second * 15, IdleTimeout: time.Second * 60, Handler: server.healthRouter, } go func() { log.Infof("Starting http server on address: %s", server.srv.Addr) if err := server.srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { log.Alertf("error starting server: %s", err.Error()) } }() return nil } func (server *WebServer) Join() { c := make(chan os.Signal, 1) signal.Notify(c, os.Interrupt) <-c } func (server *WebServer) Stop() { server.healthChecker.stop() log.Info("Waiting for clients ...") ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) defer cancel() log.Info("Shutting down server ...") server.srv.Shutdown(ctx) if server.isSsl() { server.sslsrv.Shutdown(ctx) } log.Info("Server has been shutted down") }