// A URL Shortener called Link. // Copyright (C) 2021 i@fsh.ee // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package main import ( "crypto/md5" _ "embed" "errors" "flag" "fmt" "hash/maphash" "html/template" "io/ioutil" "log" "net/http" "net/url" "os" "strconv" "strings" "time" "gorm.io/driver/sqlite" "gorm.io/gorm" "gorm.io/gorm/logger" ) //go:embed index.html var indexTemplate string type Retry struct { retryAttemptCount int } func NewRetry(retryAttemptCount int) (Retry, error) { if retryAttemptCount < 1 { return Retry{}, errors.New("retry attempt count must be greater than zero") } return Retry{retryAttemptCount}, nil } func (r Retry) Do(f func() error) (err error) { for i := 0; i < r.retryAttemptCount; i++ { err = f() if err == nil { return nil } } return err } type DB struct { *gorm.DB log *log.Logger hashSeed string retry Retry } func NewDB(l *log.Logger, dbFilePath, hashSeed string, retry Retry) (DB, error) { _, err := os.Stat(dbFilePath) if os.IsNotExist(err) { err := ioutil.WriteFile(dbFilePath, []byte{}, 0600) if err != nil { return DB{}, err } } db, err := gorm.Open(sqlite.Open(dbFilePath), &gorm.Config{ NowFunc: func() time.Time { return time.Now().UTC() }, Logger: logger.Default.LogMode(logger.Silent), }) if err != nil { return DB{}, err } return DB{db, l, hashSeed, retry}, db.AutoMigrate(&Link{}) } type Link struct { gorm.Model Big string Smol string `gorm:"unique"` Del string `gorm:"unique"` } func (db DB) getHashShortLink(s fmt.Stringer) (string, error) { var ( h = maphash.Hash{} _, err = h.WriteString(s.String()) ) if err != nil { return "", err } return strings.TrimSpace(strings.TrimLeft(fmt.Sprintf("%#x\n", h.Sum64()), "0x")), nil } func (db DB) getHashDeleteKey(s fmt.Stringer) string { return strings.TrimSpace(fmt.Sprintf("%x", md5.Sum([]byte(db.hashSeed+s.String()+strconv.FormatInt(time.Now().Unix(), 10))))) } func (db DB) NewLink(u *url.URL) (Link, error) { h, err := db.getHashShortLink(u) if err != nil { return Link{}, err } return db.NewLinkWithShortLink(u, h) } func (db DB) NewLinkWithShortLink(u *url.URL, hash string) (link Link, err error) { // Retry for unique errors. err = db.retry.Do(func() error { link = Link{Big: u.String(), Smol: hash, Del: db.getHashDeleteKey(u)} return db.Create(&link).Error }) return } func (db DB) GetLink(smol string) (l Link, e error) { res := db.Where(&Link{Smol: smol}).First(&l) return l, res.Error } func (db DB) DelLink(smol, del string) error { link, err := db.GetLink(smol) if err != nil { return err } res := db.Where(&Link{Del: del}).Delete(&link) if res.RowsAffected < 1 { return gorm.ErrRecordNotFound } return res.Error } type controller struct { log *log.Logger db DB demo bool url, copy string tmpl *template.Template } func NewController(logger *log.Logger, db DB, demo bool, url, copy string, tmpl *template.Template) controller { return controller{logger, db, demo, strings.TrimRight(url, "/"), copy, tmpl} } func (c controller) Err(rw http.ResponseWriter, r *http.Request, err error) { if errors.Is(err, gorm.ErrRecordNotFound) { rw.WriteHeader(http.StatusNotFound) fmt.Fprintf(rw, "%s", err) return } c.log.Println(err) rw.WriteHeader(http.StatusInternalServerError) fmt.Fprintf(rw, "%s", err) } func (c controller) ServeHTTP(rw http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: st := strings.TrimRight(r.URL.Path, "/") rq := r.URL.RawQuery if rq != "" { u, err := url.Parse(rq) if err != nil { c.Err(rw, r, err) return } if u.Scheme != "http" && u.Scheme != "https" { rw.WriteHeader(http.StatusBadRequest) fmt.Fprintf(rw, "URL must contain scheme, e.g. `http://` or `https://`.") return } var ( link Link h = strings.Trim(r.URL.Path, "/") ) if h != "" { link, err = c.db.NewLinkWithShortLink(u, h) } else { link, err = c.db.NewLink(u) } if err != nil { c.Err(rw, r, err) return } rw.Header().Set("X-Delete-With", link.Del) rw.WriteHeader(http.StatusFound) fmt.Fprintf(rw, "%s/%s", c.url, link.Smol) return } else { switch st { case "": data := map[string]interface{}{ "URL": c.url, "Demo": c.demo, "Copy": c.copy, } if err := c.tmpl.Execute(rw, data); err != nil { c.Err(rw, r, err) return } return case "/favicon.ico": http.NotFound(rw, r) return default: link, err := c.db.GetLink(strings.TrimLeft(r.URL.Path, "/")) if err != nil { c.Err(rw, r, err) return } http.Redirect(rw, r, link.Big, http.StatusPermanentRedirect) return } } case http.MethodPost: b, err := ioutil.ReadAll(r.Body) if err != nil { c.Err(rw, r, err) return } u, err := url.Parse(string(b)) if err != nil { c.Err(rw, r, err) return } if u.Scheme != "http" && u.Scheme != "https" { rw.WriteHeader(http.StatusBadRequest) fmt.Fprintf(rw, "URL must contain scheme, e.g. `http://` or `https://`.") return } var ( link Link h = strings.Trim(r.URL.Path, "/") ) if h != "" { link, err = c.db.NewLinkWithShortLink(u, h) } else { link, err = c.db.NewLink(u) } if err != nil { c.Err(rw, r, err) return } rw.Header().Set("X-Delete-With", link.Del) rw.WriteHeader(http.StatusFound) fmt.Fprintf(rw, "%s/%s", c.url, link.Smol) return case http.MethodDelete: b, err := ioutil.ReadAll(r.Body) if err != nil { c.Err(rw, r, err) return } if len(b) < 1 { rw.WriteHeader(http.StatusBadRequest) fmt.Fprintf(rw, "Must include deletion key in DELETE body.") return } var ( smol = strings.TrimSpace(strings.TrimLeft(r.URL.Path, "/")) del = strings.TrimSpace(string(b)) ) if err := c.db.DelLink(smol, del); err != nil { c.Err(rw, r, err) return } rw.WriteHeader(http.StatusNoContent) return } http.NotFound(rw, r) } func main() { var ( logPrefix = "link: " startupLogger = log.New(os.Stdout, logPrefix, 0) applicationLogger = log.New(ioutil.Discard, logPrefix, 0) v = flag.Bool("v", false, "verbose logging") demo = flag.Bool("demo", false, "turn on demo mode") port = flag.Uint("port", 8080, "port to listen on") dbFilePath = flag.String("db", "", "sqlite database filepath: required") url = flag.String("url", "", "URL which the server will be running on: required") hashSeed = flag.String("seed", "", "hash seed: required") copy = flag.String("copy", "", "copyright information") ) flag.Parse() if *dbFilePath == "" || *url == "" || *hashSeed == "" { flag.Usage() return } if *v { applicationLogger = log.New(os.Stdout, logPrefix, 0) } retry, err := NewRetry(3) if err != nil { startupLogger.Fatal(err) return } db, err := NewDB(applicationLogger, *dbFilePath, *hashSeed, retry) if err != nil { startupLogger.Fatal(err) return } tmpl, err := template.New("").Parse(indexTemplate) if err != nil { startupLogger.Fatal(err) return } http.Handle("/", NewController(applicationLogger, db, *demo, *url, *copy, tmpl)) startupLogger.Println("listening on port", *port) startupLogger.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", *port), nil)) }