From 2af37d47783aac8c650ffd1578e2297b5784c73d Mon Sep 17 00:00:00 2001 From: r Date: Tue, 28 Jan 2020 17:51:00 +0000 Subject: [PATCH] Refactor everything --- Makefile | 2 - main.go | 6 +- migrations/csrfToken/main.go | 4 +- model/app.go | 2 +- model/client.go | 7 +- model/{postContext.go => post.go} | 0 model/session.go | 2 +- renderer/model.go | 9 - renderer/renderer.go | 65 +- .../appRepository.go => repo/appRepo.go | 12 +- repo/sessionRepo.go | 42 + repository/sessionRepository.go | 64 -- service/auth.go | 314 ++--- service/logging.go | 222 ++-- service/service.go | 887 +++++++------- service/transport.go | 1016 +++++++++-------- static/custom.css | 3 - static/{main.css => style.css} | 0 templates/followers.tmpl | 2 +- templates/following.tmpl | 2 +- templates/header.tmpl | 2 +- templates/notification.tmpl | 2 +- templates/search.tmpl | 2 +- templates/timeline.tmpl | 4 +- templates/user.tmpl | 2 +- util/rand.go | 8 +- 26 files changed, 1320 insertions(+), 1361 deletions(-) rename model/{postContext.go => post.go} (100%) rename repository/appRepository.go => repo/appRepo.go (58%) create mode 100644 repo/sessionRepo.go delete mode 100644 repository/sessionRepository.go delete mode 100644 static/custom.css rename static/{main.css => style.css} (100%) diff --git a/Makefile b/Makefile index 8dcbd46..780f6e8 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,3 @@ -.POSIX: - GO=go all: bloat diff --git a/main.go b/main.go index 4f5851d..003fe5d 100644 --- a/main.go +++ b/main.go @@ -12,7 +12,7 @@ import ( "bloat/config" "bloat/kv" "bloat/renderer" - "bloat/repository" + "bloat/repo" "bloat/service" "bloat/util" ) @@ -67,8 +67,8 @@ func main() { log.Fatal(err) } - sessionRepo := repository.NewSessionRepository(sessionDB) - appRepo := repository.NewAppRepository(appDB) + sessionRepo := repo.NewSessionRepo(sessionDB) + appRepo := repo.NewAppRepo(appDB) customCSS := config.CustomCSS if !strings.HasPrefix(customCSS, "http://") && diff --git a/migrations/csrfToken/main.go b/migrations/csrfToken/main.go index 38ffbac..3d9526d 100644 --- a/migrations/csrfToken/main.go +++ b/migrations/csrfToken/main.go @@ -59,12 +59,12 @@ func main() { sessionRepo := repository.NewSessionRepository(sessionDB) - sessionIds, err := getKeys(sessionRepoPath) + sessionIDs, err := getKeys(sessionRepoPath) if err != nil { log.Fatal(err) } - for _, id := range sessionIds { + for _, id := range sessionIDs { s, err := sessionRepo.Get(id) if err != nil { log.Println(id, err) diff --git a/model/app.go b/model/app.go index 7abc8ec..8f172c8 100644 --- a/model/app.go +++ b/model/app.go @@ -15,7 +15,7 @@ type App struct { ClientSecret string `json:"client_secret"` } -type AppRepository interface { +type AppRepo interface { Add(app App) (err error) Get(instanceDomain string) (app App, err error) } diff --git a/model/client.go b/model/client.go index ae7270e..6123b38 100644 --- a/model/client.go +++ b/model/client.go @@ -1,8 +1,13 @@ package model -import "mastodon" +import ( + "io" + + "mastodon" +) type Client struct { *mastodon.Client + Writer io.Writer Session Session } diff --git a/model/postContext.go b/model/post.go similarity index 100% rename from model/postContext.go rename to model/post.go diff --git a/model/session.go b/model/session.go index 6bc8a63..10fca6f 100644 --- a/model/session.go +++ b/model/session.go @@ -16,7 +16,7 @@ type Session struct { Settings Settings `json:"settings"` } -type SessionRepository interface { +type SessionRepo interface { Add(session Session) (err error) Get(sessionID string) (session Session, err error) } diff --git a/renderer/model.go b/renderer/model.go index 25fa0c6..8df64ab 100644 --- a/renderer/model.go +++ b/renderer/model.go @@ -47,9 +47,7 @@ type TimelineData struct { *CommonData Title string Statuses []*mastodon.Status - HasNext bool NextLink string - HasPrev bool PrevLink string PostContext model.PostContext } @@ -64,7 +62,6 @@ type ThreadData struct { type NotificationData struct { *CommonData Notifications []*mastodon.Notification - HasNext bool NextLink string DarkMode bool } @@ -73,7 +70,6 @@ type UserData struct { *CommonData User *mastodon.Account Statuses []*mastodon.Status - HasNext bool NextLink string DarkMode bool } @@ -90,28 +86,24 @@ type EmojiData struct { type LikedByData struct { *CommonData Users []*mastodon.Account - HasNext bool NextLink string } type RetweetedByData struct { *CommonData Users []*mastodon.Account - HasNext bool NextLink string } type FollowingData struct { *CommonData Users []*mastodon.Account - HasNext bool NextLink string } type FollowersData struct { *CommonData Users []*mastodon.Account - HasNext bool NextLink string } @@ -121,7 +113,6 @@ type SearchData struct { Type string Users []*mastodon.Account Statuses []*mastodon.Status - HasNext bool NextLink string } diff --git a/renderer/renderer.go b/renderer/renderer.go index 4d2c74d..2d227c4 100644 --- a/renderer/renderer.go +++ b/renderer/renderer.go @@ -1,6 +1,7 @@ package renderer import ( + "fmt" "io" "strconv" "strings" @@ -89,78 +90,100 @@ func NewRenderer(templateGlobPattern string) (r *renderer, err error) { }, nil } -func (r *renderer) RenderSigninPage(ctx *Context, writer io.Writer, signinData *SigninData) (err error) { +func (r *renderer) RenderSigninPage(ctx *Context, writer io.Writer, + signinData *SigninData) (err error) { return r.template.ExecuteTemplate(writer, "signin.tmpl", WithContext(signinData, ctx)) } -func (r *renderer) RenderErrorPage(ctx *Context, writer io.Writer, errorData *ErrorData) { +func (r *renderer) RenderErrorPage(ctx *Context, writer io.Writer, + errorData *ErrorData) { r.template.ExecuteTemplate(writer, "error.tmpl", WithContext(errorData, ctx)) return } -func (r *renderer) RenderTimelinePage(ctx *Context, writer io.Writer, data *TimelineData) (err error) { +func (r *renderer) RenderTimelinePage(ctx *Context, writer io.Writer, + data *TimelineData) (err error) { return r.template.ExecuteTemplate(writer, "timeline.tmpl", WithContext(data, ctx)) } -func (r *renderer) RenderThreadPage(ctx *Context, writer io.Writer, data *ThreadData) (err error) { +func (r *renderer) RenderThreadPage(ctx *Context, writer io.Writer, + data *ThreadData) (err error) { return r.template.ExecuteTemplate(writer, "thread.tmpl", WithContext(data, ctx)) } -func (r *renderer) RenderNotificationPage(ctx *Context, writer io.Writer, data *NotificationData) (err error) { +func (r *renderer) RenderNotificationPage(ctx *Context, writer io.Writer, + data *NotificationData) (err error) { return r.template.ExecuteTemplate(writer, "notification.tmpl", WithContext(data, ctx)) } -func (r *renderer) RenderUserPage(ctx *Context, writer io.Writer, data *UserData) (err error) { +func (r *renderer) RenderUserPage(ctx *Context, writer io.Writer, + data *UserData) (err error) { return r.template.ExecuteTemplate(writer, "user.tmpl", WithContext(data, ctx)) } -func (r *renderer) RenderAboutPage(ctx *Context, writer io.Writer, data *AboutData) (err error) { +func (r *renderer) RenderAboutPage(ctx *Context, writer io.Writer, + data *AboutData) (err error) { return r.template.ExecuteTemplate(writer, "about.tmpl", WithContext(data, ctx)) } -func (r *renderer) RenderEmojiPage(ctx *Context, writer io.Writer, data *EmojiData) (err error) { +func (r *renderer) RenderEmojiPage(ctx *Context, writer io.Writer, + data *EmojiData) (err error) { return r.template.ExecuteTemplate(writer, "emoji.tmpl", WithContext(data, ctx)) } -func (r *renderer) RenderLikedByPage(ctx *Context, writer io.Writer, data *LikedByData) (err error) { +func (r *renderer) RenderLikedByPage(ctx *Context, writer io.Writer, + data *LikedByData) (err error) { return r.template.ExecuteTemplate(writer, "likedby.tmpl", WithContext(data, ctx)) } -func (r *renderer) RenderRetweetedByPage(ctx *Context, writer io.Writer, data *RetweetedByData) (err error) { +func (r *renderer) RenderRetweetedByPage(ctx *Context, writer io.Writer, + data *RetweetedByData) (err error) { return r.template.ExecuteTemplate(writer, "retweetedby.tmpl", WithContext(data, ctx)) } -func (r *renderer) RenderFollowingPage(ctx *Context, writer io.Writer, data *FollowingData) (err error) { +func (r *renderer) RenderFollowingPage(ctx *Context, writer io.Writer, + data *FollowingData) (err error) { return r.template.ExecuteTemplate(writer, "following.tmpl", WithContext(data, ctx)) } -func (r *renderer) RenderFollowersPage(ctx *Context, writer io.Writer, data *FollowersData) (err error) { +func (r *renderer) RenderFollowersPage(ctx *Context, writer io.Writer, + data *FollowersData) (err error) { return r.template.ExecuteTemplate(writer, "followers.tmpl", WithContext(data, ctx)) } -func (r *renderer) RenderSearchPage(ctx *Context, writer io.Writer, data *SearchData) (err error) { +func (r *renderer) RenderSearchPage(ctx *Context, writer io.Writer, + data *SearchData) (err error) { return r.template.ExecuteTemplate(writer, "search.tmpl", WithContext(data, ctx)) } -func (r *renderer) RenderSettingsPage(ctx *Context, writer io.Writer, data *SettingsData) (err error) { +func (r *renderer) RenderSettingsPage(ctx *Context, writer io.Writer, + data *SettingsData) (err error) { return r.template.ExecuteTemplate(writer, "settings.tmpl", WithContext(data, ctx)) } func EmojiFilter(content string, emojis []mastodon.Emoji) string { var replacements []string + var r string for _, e := range emojis { - replacements = append(replacements, ":"+e.ShortCode+":", "\""+e.ShortCode+"\"") + r = fmt.Sprintf("\"%s\"", + e.URL, e.ShortCode, e.ShortCode) + replacements = append(replacements, ":"+e.ShortCode+":", r) } return strings.NewReplacer(replacements...).Replace(content) } -func StatusContentFilter(spoiler string, content string, emojis []mastodon.Emoji, mentions []mastodon.Mention) string { +func StatusContentFilter(spoiler string, content string, + emojis []mastodon.Emoji, mentions []mastodon.Mention) string { + + var replacements []string + var r string if len(spoiler) > 0 { content = spoiler + "
" + content } - var replacements []string for _, e := range emojis { - replacements = append(replacements, ":"+e.ShortCode+":", "\""+e.ShortCode+"\"") + r = fmt.Sprintf("\"%s\"", + e.URL, e.ShortCode, e.ShortCode) + replacements = append(replacements, ":"+e.ShortCode+":", r) } for _, m := range mentions { replacements = append(replacements, "\""+m.URL+"\"", "\"/user/"+m.ID+"\"") @@ -177,32 +200,26 @@ func DisplayInteractionCount(c int64) string { func TimeSince(t time.Time) string { dur := time.Since(t) - s := dur.Seconds() if s < 60 { return strconv.Itoa(int(s)) + "s" } - m := dur.Minutes() if m < 60 { return strconv.Itoa(int(m)) + "m" } - h := dur.Hours() if h < 24 { return strconv.Itoa(int(h)) + "h" } - d := h / 24 if d < 30 { return strconv.Itoa(int(d)) + "d" } - mo := d / 30 if mo < 12 { return strconv.Itoa(int(mo)) + "mo" } - y := mo / 12 return strconv.Itoa(int(y)) + "y" } diff --git a/repository/appRepository.go b/repo/appRepo.go similarity index 58% rename from repository/appRepository.go rename to repo/appRepo.go index 9e57fb7..6338c4a 100644 --- a/repository/appRepository.go +++ b/repo/appRepo.go @@ -1,4 +1,4 @@ -package repository +package repo import ( "encoding/json" @@ -7,17 +7,17 @@ import ( "bloat/model" ) -type appRepository struct { +type appRepo struct { db *kv.Database } -func NewAppRepository(db *kv.Database) *appRepository { - return &appRepository{ +func NewAppRepo(db *kv.Database) *appRepo { + return &appRepo{ db: db, } } -func (repo *appRepository) Add(a model.App) (err error) { +func (repo *appRepo) Add(a model.App) (err error) { data, err := json.Marshal(a) if err != nil { return @@ -26,7 +26,7 @@ func (repo *appRepository) Add(a model.App) (err error) { return } -func (repo *appRepository) Get(instanceDomain string) (a model.App, err error) { +func (repo *appRepo) Get(instanceDomain string) (a model.App, err error) { data, err := repo.db.Get(instanceDomain) if err != nil { err = model.ErrAppNotFound diff --git a/repo/sessionRepo.go b/repo/sessionRepo.go new file mode 100644 index 0000000..ce923b1 --- /dev/null +++ b/repo/sessionRepo.go @@ -0,0 +1,42 @@ +package repo + +import ( + "encoding/json" + + "bloat/kv" + "bloat/model" +) + +type sessionRepo struct { + db *kv.Database +} + +func NewSessionRepo(db *kv.Database) *sessionRepo { + return &sessionRepo{ + db: db, + } +} + +func (repo *sessionRepo) Add(s model.Session) (err error) { + data, err := json.Marshal(s) + if err != nil { + return + } + err = repo.db.Set(s.ID, data) + return +} + +func (repo *sessionRepo) Get(id string) (s model.Session, err error) { + data, err := repo.db.Get(id) + if err != nil { + err = model.ErrSessionNotFound + return + } + + err = json.Unmarshal(data, &s) + if err != nil { + return + } + + return +} diff --git a/repository/sessionRepository.go b/repository/sessionRepository.go deleted file mode 100644 index cb19f68..0000000 --- a/repository/sessionRepository.go +++ /dev/null @@ -1,64 +0,0 @@ -package repository - -import ( - "encoding/json" - - "bloat/kv" - "bloat/model" -) - -type sessionRepository struct { - db *kv.Database -} - -func NewSessionRepository(db *kv.Database) *sessionRepository { - return &sessionRepository{ - db: db, - } -} - -func (repo *sessionRepository) Add(s model.Session) (err error) { - data, err := json.Marshal(s) - if err != nil { - return - } - err = repo.db.Set(s.ID, data) - return -} - -func (repo *sessionRepository) Update(id string, accessToken string) (err error) { - data, err := repo.db.Get(id) - if err != nil { - return - } - - var s model.Session - err = json.Unmarshal(data, &s) - if err != nil { - return - } - - s.AccessToken = accessToken - - data, err = json.Marshal(s) - if err != nil { - return - } - - return repo.db.Set(id, data) -} - -func (repo *sessionRepository) Get(id string) (s model.Session, err error) { - data, err := repo.db.Get(id) - if err != nil { - err = model.ErrSessionNotFound - return - } - - err = json.Unmarshal(data, &s) - if err != nil { - return - } - - return -} diff --git a/service/auth.go b/service/auth.go index 909a9a2..78934fd 100644 --- a/service/auth.go +++ b/service/auth.go @@ -3,7 +3,6 @@ package service import ( "context" "errors" - "io" "mime/multipart" "bloat/model" @@ -11,28 +10,28 @@ import ( ) var ( - ErrInvalidSession = errors.New("invalid session") - ErrInvalidCSRFToken = errors.New("invalid csrf token") + errInvalidSession = errors.New("invalid session") + errInvalidCSRFToken = errors.New("invalid csrf token") ) -type authService struct { - sessionRepo model.SessionRepository - appRepo model.AppRepository +type as struct { + sessionRepo model.SessionRepo + appRepo model.AppRepo Service } -func NewAuthService(sessionRepo model.SessionRepository, appRepo model.AppRepository, s Service) Service { - return &authService{sessionRepo, appRepo, s} +func NewAuthService(sessionRepo model.SessionRepo, appRepo model.AppRepo, s Service) Service { + return &as{sessionRepo, appRepo, s} } -func (s *authService) getClient(ctx context.Context) (c *model.Client, err error) { +func (s *as) authenticateClient(ctx context.Context, c *model.Client) (err error) { sessionID, ok := ctx.Value("session_id").(string) if !ok || len(sessionID) < 1 { - return nil, ErrInvalidSession + return errInvalidSession } session, err := s.sessionRepo.Get(sessionID) if err != nil { - return nil, ErrInvalidSession + return errInvalidSession } client, err := s.appRepo.Get(session.InstanceDomain) if err != nil { @@ -44,31 +43,146 @@ func (s *authService) getClient(ctx context.Context) (c *model.Client, err error ClientSecret: client.ClientSecret, AccessToken: session.AccessToken, }) - c = &model.Client{Client: mc, Session: session} - return c, nil + if c == nil { + c = &model.Client{} + } + c.Client = mc + c.Session = session + return nil } func checkCSRF(ctx context.Context, c *model.Client) (err error) { - csrfToken, ok := ctx.Value("csrf_token").(string) - if !ok || csrfToken != c.Session.CSRFToken { - return ErrInvalidCSRFToken + token, ok := ctx.Value("csrf_token").(string) + if !ok || token != c.Session.CSRFToken { + return errInvalidCSRFToken } return nil } -func (s *authService) GetAuthUrl(ctx context.Context, instance string) ( - redirectUrl string, sessionID string, err error) { - return s.Service.GetAuthUrl(ctx, instance) +func (s *as) ServeErrorPage(ctx context.Context, c *model.Client, err error) { + s.authenticateClient(ctx, c) + s.Service.ServeErrorPage(ctx, c, err) } -func (s *authService) GetUserToken(ctx context.Context, sessionID string, c *model.Client, +func (s *as) ServeSigninPage(ctx context.Context, c *model.Client) (err error) { + return s.Service.ServeSigninPage(ctx, c) +} + +func (s *as) ServeTimelinePage(ctx context.Context, c *model.Client, tType string, + maxID string, minID string) (err error) { + err = s.authenticateClient(ctx, c) + if err != nil { + return + } + return s.Service.ServeTimelinePage(ctx, c, tType, maxID, minID) +} + +func (s *as) ServeThreadPage(ctx context.Context, c *model.Client, id string, reply bool) (err error) { + err = s.authenticateClient(ctx, c) + if err != nil { + return + } + return s.Service.ServeThreadPage(ctx, c, id, reply) +} + +func (s *as) ServeLikedByPage(ctx context.Context, c *model.Client, id string) (err error) { + err = s.authenticateClient(ctx, c) + if err != nil { + return + } + return s.Service.ServeLikedByPage(ctx, c, id) +} + +func (s *as) ServeRetweetedByPage(ctx context.Context, c *model.Client, id string) (err error) { + err = s.authenticateClient(ctx, c) + if err != nil { + return + } + return s.Service.ServeRetweetedByPage(ctx, c, id) +} + +func (s *as) ServeFollowingPage(ctx context.Context, c *model.Client, id string, + maxID string, minID string) (err error) { + err = s.authenticateClient(ctx, c) + if err != nil { + return + } + return s.Service.ServeFollowingPage(ctx, c, id, maxID, minID) +} + +func (s *as) ServeFollowersPage(ctx context.Context, c *model.Client, id string, + maxID string, minID string) (err error) { + err = s.authenticateClient(ctx, c) + if err != nil { + return + } + return s.Service.ServeFollowersPage(ctx, c, id, maxID, minID) +} + +func (s *as) ServeNotificationPage(ctx context.Context, c *model.Client, + maxID string, minID string) (err error) { + err = s.authenticateClient(ctx, c) + if err != nil { + return + } + return s.Service.ServeNotificationPage(ctx, c, maxID, minID) +} + +func (s *as) ServeUserPage(ctx context.Context, c *model.Client, id string, + maxID string, minID string) (err error) { + err = s.authenticateClient(ctx, c) + if err != nil { + return + } + return s.Service.ServeUserPage(ctx, c, id, maxID, minID) +} + +func (s *as) ServeAboutPage(ctx context.Context, c *model.Client) (err error) { + err = s.authenticateClient(ctx, c) + if err != nil { + return + } + return s.Service.ServeAboutPage(ctx, c) +} + +func (s *as) ServeEmojiPage(ctx context.Context, c *model.Client) (err error) { + err = s.authenticateClient(ctx, c) + if err != nil { + return + } + return s.Service.ServeEmojiPage(ctx, c) +} + +func (s *as) ServeSearchPage(ctx context.Context, c *model.Client, q string, + qType string, offset int) (err error) { + err = s.authenticateClient(ctx, c) + if err != nil { + return + } + return s.Service.ServeSearchPage(ctx, c, q, qType, offset) +} + +func (s *as) ServeSettingsPage(ctx context.Context, c *model.Client) (err error) { + err = s.authenticateClient(ctx, c) + if err != nil { + return + } + return s.Service.ServeSettingsPage(ctx, c) +} + +func (s *as) NewSession(ctx context.Context, instance string) (redirectUrl string, + sessionID string, err error) { + return s.Service.NewSession(ctx, instance) +} + +func (s *as) Signin(ctx context.Context, c *model.Client, sessionID string, code string) (token string, err error) { - c, err = s.getClient(ctx) + err = s.authenticateClient(ctx, c) if err != nil { return } - token, err = s.Service.GetUserToken(ctx, c.Session.ID, c, code) + token, err = s.Service.Signin(ctx, c, c.Session.ID, code) if err != nil { return } @@ -82,114 +196,10 @@ func (s *authService) GetUserToken(ctx context.Context, sessionID string, c *mod return } -func (s *authService) ServeErrorPage(ctx context.Context, client io.Writer, c *model.Client, err error) { - c, _ = s.getClient(ctx) - s.Service.ServeErrorPage(ctx, client, c, err) -} - -func (s *authService) ServeSigninPage(ctx context.Context, client io.Writer) (err error) { - return s.Service.ServeSigninPage(ctx, client) -} - -func (s *authService) ServeTimelinePage(ctx context.Context, client io.Writer, - c *model.Client, timelineType string, maxID string, sinceID string, minID string) (err error) { - c, err = s.getClient(ctx) - if err != nil { - return - } - return s.Service.ServeTimelinePage(ctx, client, c, timelineType, maxID, sinceID, minID) -} - -func (s *authService) ServeThreadPage(ctx context.Context, client io.Writer, c *model.Client, id string, reply bool) (err error) { - c, err = s.getClient(ctx) - if err != nil { - return - } - return s.Service.ServeThreadPage(ctx, client, c, id, reply) -} - -func (s *authService) ServeNotificationPage(ctx context.Context, client io.Writer, c *model.Client, maxID string, minID string) (err error) { - c, err = s.getClient(ctx) - if err != nil { - return - } - return s.Service.ServeNotificationPage(ctx, client, c, maxID, minID) -} - -func (s *authService) ServeUserPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error) { - c, err = s.getClient(ctx) - if err != nil { - return - } - return s.Service.ServeUserPage(ctx, client, c, id, maxID, minID) -} - -func (s *authService) ServeAboutPage(ctx context.Context, client io.Writer, c *model.Client) (err error) { - c, err = s.getClient(ctx) - if err != nil { - return - } - return s.Service.ServeAboutPage(ctx, client, c) -} - -func (s *authService) ServeEmojiPage(ctx context.Context, client io.Writer, c *model.Client) (err error) { - c, err = s.getClient(ctx) - if err != nil { - return - } - return s.Service.ServeEmojiPage(ctx, client, c) -} - -func (s *authService) ServeLikedByPage(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) { - c, err = s.getClient(ctx) - if err != nil { - return - } - return s.Service.ServeLikedByPage(ctx, client, c, id) -} - -func (s *authService) ServeRetweetedByPage(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) { - c, err = s.getClient(ctx) - if err != nil { - return - } - return s.Service.ServeRetweetedByPage(ctx, client, c, id) -} - -func (s *authService) ServeFollowingPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error) { - c, err = s.getClient(ctx) - if err != nil { - return - } - return s.Service.ServeFollowingPage(ctx, client, c, id, maxID, minID) -} - -func (s *authService) ServeFollowersPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error) { - c, err = s.getClient(ctx) - if err != nil { - return - } - return s.Service.ServeFollowersPage(ctx, client, c, id, maxID, minID) -} - -func (s *authService) ServeSearchPage(ctx context.Context, client io.Writer, c *model.Client, q string, qType string, offset int) (err error) { - c, err = s.getClient(ctx) - if err != nil { - return - } - return s.Service.ServeSearchPage(ctx, client, c, q, qType, offset) -} - -func (s *authService) ServeSettingsPage(ctx context.Context, client io.Writer, c *model.Client) (err error) { - c, err = s.getClient(ctx) - if err != nil { - return - } - return s.Service.ServeSettingsPage(ctx, client, c) -} - -func (s *authService) SaveSettings(ctx context.Context, client io.Writer, c *model.Client, settings *model.Settings) (err error) { - c, err = s.getClient(ctx) +func (s *as) Post(ctx context.Context, c *model.Client, content string, + replyToID string, format string, visibility string, isNSFW bool, + files []*multipart.FileHeader) (id string, err error) { + err = s.authenticateClient(ctx, c) if err != nil { return } @@ -197,11 +207,11 @@ func (s *authService) SaveSettings(ctx context.Context, client io.Writer, c *mod if err != nil { return } - return s.Service.SaveSettings(ctx, client, c, settings) + return s.Service.Post(ctx, c, content, replyToID, format, visibility, isNSFW, files) } -func (s *authService) Like(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) { - c, err = s.getClient(ctx) +func (s *as) Like(ctx context.Context, c *model.Client, id string) (count int64, err error) { + err = s.authenticateClient(ctx, c) if err != nil { return } @@ -209,11 +219,11 @@ func (s *authService) Like(ctx context.Context, client io.Writer, c *model.Clien if err != nil { return } - return s.Service.Like(ctx, client, c, id) + return s.Service.Like(ctx, c, id) } -func (s *authService) UnLike(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) { - c, err = s.getClient(ctx) +func (s *as) UnLike(ctx context.Context, c *model.Client, id string) (count int64, err error) { + err = s.authenticateClient(ctx, c) if err != nil { return } @@ -221,11 +231,11 @@ func (s *authService) UnLike(ctx context.Context, client io.Writer, c *model.Cli if err != nil { return } - return s.Service.UnLike(ctx, client, c, id) + return s.Service.UnLike(ctx, c, id) } -func (s *authService) Retweet(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) { - c, err = s.getClient(ctx) +func (s *as) Retweet(ctx context.Context, c *model.Client, id string) (count int64, err error) { + err = s.authenticateClient(ctx, c) if err != nil { return } @@ -233,11 +243,11 @@ func (s *authService) Retweet(ctx context.Context, client io.Writer, c *model.Cl if err != nil { return } - return s.Service.Retweet(ctx, client, c, id) + return s.Service.Retweet(ctx, c, id) } -func (s *authService) UnRetweet(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) { - c, err = s.getClient(ctx) +func (s *as) UnRetweet(ctx context.Context, c *model.Client, id string) (count int64, err error) { + err = s.authenticateClient(ctx, c) if err != nil { return } @@ -245,11 +255,11 @@ func (s *authService) UnRetweet(ctx context.Context, client io.Writer, c *model. if err != nil { return } - return s.Service.UnRetweet(ctx, client, c, id) + return s.Service.UnRetweet(ctx, c, id) } -func (s *authService) PostTweet(ctx context.Context, client io.Writer, c *model.Client, content string, replyToID string, format string, visibility string, isNSFW bool, files []*multipart.FileHeader) (id string, err error) { - c, err = s.getClient(ctx) +func (s *as) Follow(ctx context.Context, c *model.Client, id string) (err error) { + err = s.authenticateClient(ctx, c) if err != nil { return } @@ -257,11 +267,11 @@ func (s *authService) PostTweet(ctx context.Context, client io.Writer, c *model. if err != nil { return } - return s.Service.PostTweet(ctx, client, c, content, replyToID, format, visibility, isNSFW, files) + return s.Service.Follow(ctx, c, id) } -func (s *authService) Follow(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) { - c, err = s.getClient(ctx) +func (s *as) UnFollow(ctx context.Context, c *model.Client, id string) (err error) { + err = s.authenticateClient(ctx, c) if err != nil { return } @@ -269,11 +279,11 @@ func (s *authService) Follow(ctx context.Context, client io.Writer, c *model.Cli if err != nil { return } - return s.Service.Follow(ctx, client, c, id) + return s.Service.UnFollow(ctx, c, id) } -func (s *authService) UnFollow(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) { - c, err = s.getClient(ctx) +func (s *as) SaveSettings(ctx context.Context, c *model.Client, settings *model.Settings) (err error) { + err = s.authenticateClient(ctx, c) if err != nil { return } @@ -281,5 +291,5 @@ func (s *authService) UnFollow(ctx context.Context, client io.Writer, c *model.C if err != nil { return } - return s.Service.UnFollow(ctx, client, c, id) + return s.Service.SaveSettings(ctx, c, settings) } diff --git a/service/logging.go b/service/logging.go index cafd815..e4f8985 100644 --- a/service/logging.go +++ b/service/logging.go @@ -2,7 +2,6 @@ package service import ( "context" - "io" "log" "mime/multipart" "time" @@ -10,206 +9,215 @@ import ( "bloat/model" ) -type loggingService struct { +type ls struct { logger *log.Logger Service } func NewLoggingService(logger *log.Logger, s Service) Service { - return &loggingService{logger, s} + return &ls{logger, s} } -func (s *loggingService) GetAuthUrl(ctx context.Context, instance string) ( - redirectUrl string, sessionID string, err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, instance=%v, took=%v, err=%v\n", - "GetAuthUrl", instance, time.Since(begin), err) - }(time.Now()) - return s.Service.GetAuthUrl(ctx, instance) -} - -func (s *loggingService) GetUserToken(ctx context.Context, sessionID string, c *model.Client, - code string) (token string, err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, session_id=%v, code=%v, took=%v, err=%v\n", - "GetUserToken", sessionID, code, time.Since(begin), err) - }(time.Now()) - return s.Service.GetUserToken(ctx, sessionID, c, code) -} - -func (s *loggingService) ServeErrorPage(ctx context.Context, client io.Writer, c *model.Client, err error) { +func (s *ls) ServeErrorPage(ctx context.Context, c *model.Client, err error) { defer func(begin time.Time) { s.logger.Printf("method=%v, err=%v, took=%v\n", "ServeErrorPage", err, time.Since(begin)) }(time.Now()) - s.Service.ServeErrorPage(ctx, client, c, err) + s.Service.ServeErrorPage(ctx, c, err) } -func (s *loggingService) ServeSigninPage(ctx context.Context, client io.Writer) (err error) { +func (s *ls) ServeSigninPage(ctx context.Context, c *model.Client) (err error) { defer func(begin time.Time) { s.logger.Printf("method=%v, took=%v, err=%v\n", "ServeSigninPage", time.Since(begin), err) }(time.Now()) - return s.Service.ServeSigninPage(ctx, client) + return s.Service.ServeSigninPage(ctx, c) } -func (s *loggingService) ServeTimelinePage(ctx context.Context, client io.Writer, - c *model.Client, timelineType string, maxID string, sinceID string, minID string) (err error) { +func (s *ls) ServeTimelinePage(ctx context.Context, c *model.Client, tType string, + maxID string, minID string) (err error) { defer func(begin time.Time) { - s.logger.Printf("method=%v, timeline_type=%v, max_id=%v, since_id=%v, min_id=%v, took=%v, err=%v\n", - "ServeTimelinePage", timelineType, maxID, sinceID, minID, time.Since(begin), err) + s.logger.Printf("method=%v, type=%v, took=%v, err=%v\n", + "ServeTimelinePage", tType, time.Since(begin), err) }(time.Now()) - return s.Service.ServeTimelinePage(ctx, client, c, timelineType, maxID, sinceID, minID) + return s.Service.ServeTimelinePage(ctx, c, tType, maxID, minID) } -func (s *loggingService) ServeThreadPage(ctx context.Context, client io.Writer, c *model.Client, id string, reply bool) (err error) { +func (s *ls) ServeThreadPage(ctx context.Context, c *model.Client, id string, + reply bool) (err error) { defer func(begin time.Time) { - s.logger.Printf("method=%v, id=%v, reply=%v, took=%v, err=%v\n", - "ServeThreadPage", id, reply, time.Since(begin), err) + s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", + "ServeThreadPage", id, time.Since(begin), err) }(time.Now()) - return s.Service.ServeThreadPage(ctx, client, c, id, reply) + return s.Service.ServeThreadPage(ctx, c, id, reply) } -func (s *loggingService) ServeNotificationPage(ctx context.Context, client io.Writer, c *model.Client, maxID string, minID string) (err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, max_id=%v, min_id=%v, took=%v, err=%v\n", - "ServeNotificationPage", maxID, minID, time.Since(begin), err) - }(time.Now()) - return s.Service.ServeNotificationPage(ctx, client, c, maxID, minID) -} - -func (s *loggingService) ServeUserPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, id=%v, max_id=%v, min_id=%v, took=%v, err=%v\n", - "ServeUserPage", id, maxID, minID, time.Since(begin), err) - }(time.Now()) - return s.Service.ServeUserPage(ctx, client, c, id, maxID, minID) -} - -func (s *loggingService) ServeAboutPage(ctx context.Context, client io.Writer, c *model.Client) (err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, took=%v, err=%v\n", - "ServeAboutPage", time.Since(begin), err) - }(time.Now()) - return s.Service.ServeAboutPage(ctx, client, c) -} - -func (s *loggingService) ServeEmojiPage(ctx context.Context, client io.Writer, c *model.Client) (err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, took=%v, err=%v\n", - "ServeEmojiPage", time.Since(begin), err) - }(time.Now()) - return s.Service.ServeEmojiPage(ctx, client, c) -} - -func (s *loggingService) ServeLikedByPage(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) { +func (s *ls) ServeLikedByPage(ctx context.Context, c *model.Client, id string) (err error) { defer func(begin time.Time) { s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", "ServeLikedByPage", id, time.Since(begin), err) }(time.Now()) - return s.Service.ServeLikedByPage(ctx, client, c, id) + return s.Service.ServeLikedByPage(ctx, c, id) } -func (s *loggingService) ServeRetweetedByPage(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) { +func (s *ls) ServeRetweetedByPage(ctx context.Context, c *model.Client, id string) (err error) { defer func(begin time.Time) { s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", "ServeRetweetedByPage", id, time.Since(begin), err) }(time.Now()) - return s.Service.ServeRetweetedByPage(ctx, client, c, id) + return s.Service.ServeRetweetedByPage(ctx, c, id) } -func (s *loggingService) ServeFollowingPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error) { +func (s *ls) ServeFollowingPage(ctx context.Context, c *model.Client, id string, + maxID string, minID string) (err error) { defer func(begin time.Time) { - s.logger.Printf("method=%v, id=%v, max_id=%v, min_id=%v, took=%v, err=%v\n", - "ServeFollowingPage", id, maxID, minID, time.Since(begin), err) + s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", + "ServeFollowingPage", id, time.Since(begin), err) }(time.Now()) - return s.Service.ServeFollowingPage(ctx, client, c, id, maxID, minID) + return s.Service.ServeFollowingPage(ctx, c, id, maxID, minID) } -func (s *loggingService) ServeFollowersPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error) { +func (s *ls) ServeFollowersPage(ctx context.Context, c *model.Client, id string, + maxID string, minID string) (err error) { defer func(begin time.Time) { - s.logger.Printf("method=%v, id=%v, max_id=%v, min_id=%v, took=%v, err=%v\n", - "ServeFollowersPage", id, maxID, minID, time.Since(begin), err) + s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", + "ServeFollowersPage", id, time.Since(begin), err) }(time.Now()) - return s.Service.ServeFollowersPage(ctx, client, c, id, maxID, minID) + return s.Service.ServeFollowersPage(ctx, c, id, maxID, minID) } -func (s *loggingService) ServeSearchPage(ctx context.Context, client io.Writer, c *model.Client, q string, qType string, offset int) (err error) { +func (s *ls) ServeNotificationPage(ctx context.Context, c *model.Client, + maxID string, minID string) (err error) { defer func(begin time.Time) { - s.logger.Printf("method=%v, q=%v, type=%v, offset=%v, took=%v, err=%v\n", - "ServeSearchPage", q, qType, offset, time.Since(begin), err) + s.logger.Printf("method=%v, took=%v, err=%v\n", + "ServeNotificationPage", time.Since(begin), err) }(time.Now()) - return s.Service.ServeSearchPage(ctx, client, c, q, qType, offset) + return s.Service.ServeNotificationPage(ctx, c, maxID, minID) } -func (s *loggingService) ServeSettingsPage(ctx context.Context, client io.Writer, c *model.Client) (err error) { +func (s *ls) ServeUserPage(ctx context.Context, c *model.Client, id string, + maxID string, minID string) (err error) { + defer func(begin time.Time) { + s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", + "ServeUserPage", id, time.Since(begin), err) + }(time.Now()) + return s.Service.ServeUserPage(ctx, c, id, maxID, minID) +} + +func (s *ls) ServeAboutPage(ctx context.Context, c *model.Client) (err error) { + defer func(begin time.Time) { + s.logger.Printf("method=%v, took=%v, err=%v\n", + "ServeAboutPage", time.Since(begin), err) + }(time.Now()) + return s.Service.ServeAboutPage(ctx, c) +} + +func (s *ls) ServeEmojiPage(ctx context.Context, c *model.Client) (err error) { + defer func(begin time.Time) { + s.logger.Printf("method=%v, took=%v, err=%v\n", + "ServeEmojiPage", time.Since(begin), err) + }(time.Now()) + return s.Service.ServeEmojiPage(ctx, c) +} + +func (s *ls) ServeSearchPage(ctx context.Context, c *model.Client, q string, + qType string, offset int) (err error) { + defer func(begin time.Time) { + s.logger.Printf("method=%v, took=%v, err=%v\n", + "ServeSearchPage", time.Since(begin), err) + }(time.Now()) + return s.Service.ServeSearchPage(ctx, c, q, qType, offset) +} + +func (s *ls) ServeSettingsPage(ctx context.Context, c *model.Client) (err error) { defer func(begin time.Time) { s.logger.Printf("method=%v, took=%v, err=%v\n", "ServeSettingsPage", time.Since(begin), err) }(time.Now()) - return s.Service.ServeSettingsPage(ctx, client, c) + return s.Service.ServeSettingsPage(ctx, c) } -func (s *loggingService) SaveSettings(ctx context.Context, client io.Writer, c *model.Client, settings *model.Settings) (err error) { +func (s *ls) NewSession(ctx context.Context, instance string) (redirectUrl string, + sessionID string, err error) { + defer func(begin time.Time) { + s.logger.Printf("method=%v, instance=%v, took=%v, err=%v\n", + "NewSession", instance, time.Since(begin), err) + }(time.Now()) + return s.Service.NewSession(ctx, instance) +} + +func (s *ls) Signin(ctx context.Context, c *model.Client, sessionID string, + code string) (token string, err error) { + defer func(begin time.Time) { + s.logger.Printf("method=%v, session_id=%v, took=%v, err=%v\n", + "Signin", sessionID, time.Since(begin), err) + }(time.Now()) + return s.Service.Signin(ctx, c, sessionID, code) +} + +func (s *ls) Post(ctx context.Context, c *model.Client, content string, + replyToID string, format string, visibility string, isNSFW bool, + files []*multipart.FileHeader) (id string, err error) { defer func(begin time.Time) { s.logger.Printf("method=%v, took=%v, err=%v\n", - "SaveSettings", time.Since(begin), err) + "Post", time.Since(begin), err) }(time.Now()) - return s.Service.SaveSettings(ctx, client, c, settings) + return s.Service.Post(ctx, c, content, replyToID, format, + visibility, isNSFW, files) } -func (s *loggingService) Like(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) { +func (s *ls) Like(ctx context.Context, c *model.Client, id string) (count int64, err error) { defer func(begin time.Time) { s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", "Like", id, time.Since(begin), err) }(time.Now()) - return s.Service.Like(ctx, client, c, id) + return s.Service.Like(ctx, c, id) } -func (s *loggingService) UnLike(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) { +func (s *ls) UnLike(ctx context.Context, c *model.Client, id string) (count int64, err error) { defer func(begin time.Time) { s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", "UnLike", id, time.Since(begin), err) }(time.Now()) - return s.Service.UnLike(ctx, client, c, id) + return s.Service.UnLike(ctx, c, id) } -func (s *loggingService) Retweet(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) { +func (s *ls) Retweet(ctx context.Context, c *model.Client, id string) (count int64, err error) { defer func(begin time.Time) { s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", "Retweet", id, time.Since(begin), err) }(time.Now()) - return s.Service.Retweet(ctx, client, c, id) + return s.Service.Retweet(ctx, c, id) } -func (s *loggingService) UnRetweet(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) { +func (s *ls) UnRetweet(ctx context.Context, c *model.Client, id string) (count int64, err error) { defer func(begin time.Time) { s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", "UnRetweet", id, time.Since(begin), err) }(time.Now()) - return s.Service.UnRetweet(ctx, client, c, id) + return s.Service.UnRetweet(ctx, c, id) } -func (s *loggingService) PostTweet(ctx context.Context, client io.Writer, c *model.Client, content string, replyToID string, format string, visibility string, isNSFW bool, files []*multipart.FileHeader) (id string, err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, content=%v, reply_to_id=%v, format=%v, visibility=%v, is_nsfw=%v, took=%v, err=%v\n", - "PostTweet", content, replyToID, format, visibility, isNSFW, time.Since(begin), err) - }(time.Now()) - return s.Service.PostTweet(ctx, client, c, content, replyToID, format, visibility, isNSFW, files) -} - -func (s *loggingService) Follow(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) { +func (s *ls) Follow(ctx context.Context, c *model.Client, id string) (err error) { defer func(begin time.Time) { s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", "Follow", id, time.Since(begin), err) }(time.Now()) - return s.Service.Follow(ctx, client, c, id) + return s.Service.Follow(ctx, c, id) } -func (s *loggingService) UnFollow(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) { +func (s *ls) UnFollow(ctx context.Context, c *model.Client, id string) (err error) { defer func(begin time.Time) { s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", "UnFollow", id, time.Since(begin), err) }(time.Now()) - return s.Service.UnFollow(ctx, client, c, id) + return s.Service.UnFollow(ctx, c, id) +} + +func (s *ls) SaveSettings(ctx context.Context, c *model.Client, settings *model.Settings) (err error) { + defer func(begin time.Time) { + s.logger.Printf("method=%v, took=%v, err=%v\n", + "SaveSettings", time.Since(begin), err) + }(time.Now()) + return s.Service.SaveSettings(ctx, c, settings) } diff --git a/service/service.go b/service/service.go index c9fccb4..7ad860f 100644 --- a/service/service.go +++ b/service/service.go @@ -1,14 +1,10 @@ package service import ( - "bytes" "context" - "encoding/json" "errors" "fmt" - "io" "mime/multipart" - "net/http" "net/url" "strings" @@ -19,37 +15,35 @@ import ( ) var ( - ErrInvalidArgument = errors.New("invalid argument") - ErrInvalidToken = errors.New("invalid token") - ErrInvalidClient = errors.New("invalid client") - ErrInvalidTimeline = errors.New("invalid timeline") + errInvalidArgument = errors.New("invalid argument") ) type Service interface { - GetAuthUrl(ctx context.Context, instance string) (url string, sessionID string, err error) - GetUserToken(ctx context.Context, sessionID string, c *model.Client, token string) (accessToken string, err error) - ServeErrorPage(ctx context.Context, client io.Writer, c *model.Client, err error) - ServeSigninPage(ctx context.Context, client io.Writer) (err error) - ServeTimelinePage(ctx context.Context, client io.Writer, c *model.Client, timelineType string, maxID string, sinceID string, minID string) (err error) - ServeThreadPage(ctx context.Context, client io.Writer, c *model.Client, id string, reply bool) (err error) - ServeNotificationPage(ctx context.Context, client io.Writer, c *model.Client, maxID string, minID string) (err error) - ServeUserPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error) - ServeAboutPage(ctx context.Context, client io.Writer, c *model.Client) (err error) - ServeEmojiPage(ctx context.Context, client io.Writer, c *model.Client) (err error) - ServeLikedByPage(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) - ServeRetweetedByPage(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) - ServeFollowingPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error) - ServeFollowersPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error) - ServeSearchPage(ctx context.Context, client io.Writer, c *model.Client, q string, qType string, offset int) (err error) - ServeSettingsPage(ctx context.Context, client io.Writer, c *model.Client) (err error) - SaveSettings(ctx context.Context, client io.Writer, c *model.Client, settings *model.Settings) (err error) - Like(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) - UnLike(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) - Retweet(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) - UnRetweet(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) - PostTweet(ctx context.Context, client io.Writer, c *model.Client, content string, replyToID string, format string, visibility string, isNSFW bool, files []*multipart.FileHeader) (id string, err error) - Follow(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) - UnFollow(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) + ServeErrorPage(ctx context.Context, c *model.Client, err error) + ServeSigninPage(ctx context.Context, c *model.Client) (err error) + ServeTimelinePage(ctx context.Context, c *model.Client, tType string, maxID string, minID string) (err error) + ServeThreadPage(ctx context.Context, c *model.Client, id string, reply bool) (err error) + ServeLikedByPage(ctx context.Context, c *model.Client, id string) (err error) + ServeRetweetedByPage(ctx context.Context, c *model.Client, id string) (err error) + ServeFollowingPage(ctx context.Context, c *model.Client, id string, maxID string, minID string) (err error) + ServeFollowersPage(ctx context.Context, c *model.Client, id string, maxID string, minID string) (err error) + ServeNotificationPage(ctx context.Context, c *model.Client, maxID string, minID string) (err error) + ServeUserPage(ctx context.Context, c *model.Client, id string, maxID string, minID string) (err error) + ServeAboutPage(ctx context.Context, c *model.Client) (err error) + ServeEmojiPage(ctx context.Context, c *model.Client) (err error) + ServeSearchPage(ctx context.Context, c *model.Client, q string, qType string, offset int) (err error) + ServeSettingsPage(ctx context.Context, c *model.Client) (err error) + NewSession(ctx context.Context, instance string) (redirectUrl string, sessionID string, err error) + Signin(ctx context.Context, c *model.Client, sessionID string, code string) (token string, err error) + Post(ctx context.Context, c *model.Client, content string, replyToID string, format string, + visibility string, isNSFW bool, files []*multipart.FileHeader) (id string, err error) + Like(ctx context.Context, c *model.Client, id string) (count int64, err error) + UnLike(ctx context.Context, c *model.Client, id string) (count int64, err error) + Retweet(ctx context.Context, c *model.Client, id string) (count int64, err error) + UnRetweet(ctx context.Context, c *model.Client, id string) (count int64, err error) + Follow(ctx context.Context, c *model.Client, id string) (err error) + UnFollow(ctx context.Context, c *model.Client, id string) (err error) + SaveSettings(ctx context.Context, c *model.Client, settings *model.Settings) (err error) } type service struct { @@ -59,13 +53,19 @@ type service struct { customCSS string postFormats []model.PostFormat renderer renderer.Renderer - sessionRepo model.SessionRepository - appRepo model.AppRepository + sessionRepo model.SessionRepo + appRepo model.AppRepo } -func NewService(clientName string, clientScope string, clientWebsite string, - customCSS string, postFormats []model.PostFormat, renderer renderer.Renderer, - sessionRepo model.SessionRepository, appRepo model.AppRepository) Service { +func NewService(clientName string, + clientScope string, + clientWebsite string, + customCSS string, + postFormats []model.PostFormat, + renderer renderer.Renderer, + sessionRepo model.SessionRepo, + appRepo model.AppRepo, +) Service { return &service{ clientName: clientName, clientScope: clientScope, @@ -96,137 +96,75 @@ func getRendererContext(c *model.Client) *renderer.Context { } } -func (svc *service) GetAuthUrl(ctx context.Context, instance string) ( - redirectUrl string, sessionID string, err error) { - var instanceURL string - if strings.HasPrefix(instance, "https://") { - instanceURL = instance - instance = strings.TrimPrefix(instance, "https://") - } else { - instanceURL = "https://" + instance - } - - sessionID, err = util.NewSessionId() - if err != nil { - return - } - csrfToken, err := util.NewCSRFToken() - if err != nil { - return - } - session := model.Session{ - ID: sessionID, - InstanceDomain: instance, - CSRFToken: csrfToken, - Settings: *model.NewSettings(), - } - err = svc.sessionRepo.Add(session) - if err != nil { +func addToReplyMap(m map[string][]mastodon.ReplyInfo, key interface{}, + val string, number int) { + if key == nil { return } - app, err := svc.appRepo.Get(instance) - if err != nil { - if err != model.ErrAppNotFound { - return - } - - var mastoApp *mastodon.Application - mastoApp, err = mastodon.RegisterApp(ctx, &mastodon.AppConfig{ - Server: instanceURL, - ClientName: svc.clientName, - Scopes: svc.clientScope, - Website: svc.clientWebsite, - RedirectURIs: svc.clientWebsite + "/oauth_callback", - }) - if err != nil { - return - } - - app = model.App{ - InstanceDomain: instance, - InstanceURL: instanceURL, - ClientID: mastoApp.ClientID, - ClientSecret: mastoApp.ClientSecret, - } - - err = svc.appRepo.Add(app) - if err != nil { - return - } - } - - u, err := url.Parse("/oauth/authorize") - if err != nil { + keyStr, ok := key.(string) + if !ok { return } - q := make(url.Values) - q.Set("scope", "read write follow") - q.Set("client_id", app.ClientID) - q.Set("response_type", "code") - q.Set("redirect_uri", svc.clientWebsite+"/oauth_callback") - u.RawQuery = q.Encode() + _, ok = m[keyStr] + if !ok { + m[keyStr] = []mastodon.ReplyInfo{} + } - redirectUrl = instanceURL + u.String() + m[keyStr] = append(m[keyStr], mastodon.ReplyInfo{val, number}) +} + +func (svc *service) getCommonData(ctx context.Context, c *model.Client, + title string) (data *renderer.CommonData, err error) { + + data = new(renderer.CommonData) + data.HeaderData = &renderer.HeaderData{ + Title: title + " - " + svc.clientName, + NotificationCount: 0, + CustomCSS: svc.customCSS, + } + + if c == nil || !c.Session.IsLoggedIn() { + return + } + + notifications, err := c.GetNotifications(ctx, nil) + if err != nil { + return nil, err + } + + var notificationCount int + for i := range notifications { + if notifications[i].Pleroma != nil && + !notifications[i].Pleroma.IsSeen { + notificationCount++ + } + } + + u, err := c.GetAccountCurrentUser(ctx) + if err != nil { + return nil, err + } + + data.NavbarData = &renderer.NavbarData{ + User: u, + NotificationCount: notificationCount, + } + + data.HeaderData.NotificationCount = notificationCount + data.HeaderData.CSRFToken = c.Session.CSRFToken return } -func (svc *service) GetUserToken(ctx context.Context, sessionID string, c *model.Client, - code string) (token string, err error) { - if len(code) < 1 { - err = ErrInvalidArgument - return - } - - session, err := svc.sessionRepo.Get(sessionID) - if err != nil { - return - } - - app, err := svc.appRepo.Get(session.InstanceDomain) - if err != nil { - return - } - - data := &bytes.Buffer{} - err = json.NewEncoder(data).Encode(map[string]string{ - "client_id": app.ClientID, - "client_secret": app.ClientSecret, - "grant_type": "authorization_code", - "code": code, - "redirect_uri": svc.clientWebsite + "/oauth_callback", - }) - if err != nil { - return - } - - resp, err := http.Post(app.InstanceURL+"/oauth/token", "application/json", data) - if err != nil { - return - } - defer resp.Body.Close() - - var res struct { - AccessToken string `json:"access_token"` - } - - err = json.NewDecoder(resp.Body).Decode(&res) - if err != nil { - return - } - - return res.AccessToken, nil -} - -func (svc *service) ServeErrorPage(ctx context.Context, client io.Writer, c *model.Client, err error) { +func (svc *service) ServeErrorPage(ctx context.Context, c *model.Client, err error) { var errStr string if err != nil { errStr = err.Error() } - commonData, err := svc.getCommonData(ctx, client, nil, "error") + commonData, err := svc.getCommonData(ctx, nil, "error") if err != nil { return } @@ -237,12 +175,13 @@ func (svc *service) ServeErrorPage(ctx context.Context, client io.Writer, c *mod } rCtx := getRendererContext(c) - - svc.renderer.RenderErrorPage(rCtx, client, data) + svc.renderer.RenderErrorPage(rCtx, c.Writer, data) } -func (svc *service) ServeSigninPage(ctx context.Context, client io.Writer) (err error) { - commonData, err := svc.getCommonData(ctx, client, nil, "signin") +func (svc *service) ServeSigninPage(ctx context.Context, c *model.Client) ( + err error) { + + commonData, err := svc.getCommonData(ctx, nil, "signin") if err != nil { return } @@ -252,26 +191,23 @@ func (svc *service) ServeSigninPage(ctx context.Context, client io.Writer) (err } rCtx := getRendererContext(nil) - return svc.renderer.RenderSigninPage(rCtx, client, data) + return svc.renderer.RenderSigninPage(rCtx, c.Writer, data) } -func (svc *service) ServeTimelinePage(ctx context.Context, client io.Writer, - c *model.Client, timelineType string, maxID string, sinceID string, minID string) (err error) { - - var hasNext, hasPrev bool - var nextLink, prevLink string +func (svc *service) ServeTimelinePage(ctx context.Context, c *model.Client, + tType string, maxID string, minID string) (err error) { + var nextLink, prevLink, title string + var statuses []*mastodon.Status var pg = mastodon.Pagination{ MaxID: maxID, MinID: minID, Limit: 20, } - var statuses []*mastodon.Status - var title string - switch timelineType { + switch tType { default: - return ErrInvalidTimeline + return errInvalidArgument case "home": statuses, err = c.GetTimelineHome(ctx, &pg) title = "Timeline" @@ -293,29 +229,31 @@ func (svc *service) ServeTimelinePage(ctx context.Context, client io.Writer, } if len(maxID) > 0 && len(statuses) > 0 { - hasPrev = true - prevLink = fmt.Sprintf("/timeline/$s?min_id=%s", timelineType, statuses[0].ID) + prevLink = fmt.Sprintf("/timeline/%s?min_id=%s", tType, + statuses[0].ID) } + if len(minID) > 0 && len(pg.MinID) > 0 { - newStatuses, err := c.GetTimelineHome(ctx, &mastodon.Pagination{MinID: pg.MinID, Limit: 20}) + newPg := &mastodon.Pagination{MinID: pg.MinID, Limit: 20} + newStatuses, err := c.GetTimelineHome(ctx, newPg) if err != nil { return err } - newStatusesLen := len(newStatuses) - if newStatusesLen == 20 { - hasPrev = true - prevLink = fmt.Sprintf("/timeline/%s?min_id=%s", timelineType, pg.MinID) + newLen := len(newStatuses) + if newLen == 20 { + prevLink = fmt.Sprintf("/timeline/%s?min_id=%s", + tType, pg.MinID) } else { - i := 20 - newStatusesLen - 1 + i := 20 - newLen - 1 if len(statuses) > i { - hasPrev = true - prevLink = fmt.Sprintf("/timeline/%s?min_id=%s", timelineType, statuses[i].ID) + prevLink = fmt.Sprintf("/timeline/%s?min_id=%s", + tType, statuses[i].ID) } } } + if len(pg.MaxID) > 0 { - hasNext = true - nextLink = fmt.Sprintf("/timeline/%s?max_id=%s", timelineType, pg.MaxID) + nextLink = fmt.Sprintf("/timeline/%s?max_id=%s", tType, pg.MaxID) } postContext := model.PostContext{ @@ -323,7 +261,7 @@ func (svc *service) ServeTimelinePage(ctx context.Context, client io.Writer, Formats: svc.postFormats, } - commonData, err := svc.getCommonData(ctx, client, c, timelineType+" timeline ") + commonData, err := svc.getCommonData(ctx, c, tType+" timeline ") if err != nil { return } @@ -331,24 +269,21 @@ func (svc *service) ServeTimelinePage(ctx context.Context, client io.Writer, data := &renderer.TimelineData{ Title: title, Statuses: statuses, - HasNext: hasNext, NextLink: nextLink, - HasPrev: hasPrev, PrevLink: prevLink, PostContext: postContext, CommonData: commonData, } + rCtx := getRendererContext(c) - - err = svc.renderer.RenderTimelinePage(rCtx, client, data) - if err != nil { - return - } - - return + return svc.renderer.RenderTimelinePage(rCtx, c.Writer, data) } -func (svc *service) ServeThreadPage(ctx context.Context, client io.Writer, c *model.Client, id string, reply bool) (err error) { +func (svc *service) ServeThreadPage(ctx context.Context, c *model.Client, + id string, reply bool) (err error) { + + var postContext model.PostContext + status, err := c.GetStatus(ctx, id) if err != nil { return @@ -359,19 +294,19 @@ func (svc *service) ServeThreadPage(ctx context.Context, client io.Writer, c *mo return } - var postContext model.PostContext if reply { var content string + var visibility string if u.ID != status.Account.ID { content += "@" + status.Account.Acct + " " } for i := range status.Mentions { - if status.Mentions[i].ID != u.ID && status.Mentions[i].ID != status.Account.ID { + if status.Mentions[i].ID != u.ID && + status.Mentions[i].ID != status.Account.ID { content += "@" + status.Mentions[i].Acct + " " } } - var visibility string if c.Session.Settings.CopyScope { s, err := c.GetStatus(ctx, id) if err != nil { @@ -400,16 +335,15 @@ func (svc *service) ServeThreadPage(ctx context.Context, client io.Writer, c *mo } statuses := append(append(context.Ancestors, status), context.Descendants...) - - replyMap := make(map[string][]mastodon.ReplyInfo) + replies := make(map[string][]mastodon.ReplyInfo) for i := range statuses { statuses[i].ShowReplies = true - statuses[i].ReplyMap = replyMap - addToReplyMap(replyMap, statuses[i].InReplyToID, statuses[i].ID, i+1) + statuses[i].ReplyMap = replies + addToReplyMap(replies, statuses[i].InReplyToID, statuses[i].ID, i+1) } - commonData, err := svc.getCommonData(ctx, client, c, "post by "+status.Account.DisplayName) + commonData, err := svc.getCommonData(ctx, c, "post by "+status.Account.DisplayName) if err != nil { return } @@ -417,23 +351,130 @@ func (svc *service) ServeThreadPage(ctx context.Context, client io.Writer, c *mo data := &renderer.ThreadData{ Statuses: statuses, PostContext: postContext, - ReplyMap: replyMap, + ReplyMap: replies, CommonData: commonData, } - rCtx := getRendererContext(c) - err = svc.renderer.RenderThreadPage(rCtx, client, data) + rCtx := getRendererContext(c) + return svc.renderer.RenderThreadPage(rCtx, c.Writer, data) +} + +func (svc *service) ServeLikedByPage(ctx context.Context, c *model.Client, + id string) (err error) { + + likers, err := c.GetFavouritedBy(ctx, id, nil) if err != nil { return } - return + commonData, err := svc.getCommonData(ctx, c, "likes") + if err != nil { + return + } + + data := &renderer.LikedByData{ + CommonData: commonData, + Users: likers, + } + + rCtx := getRendererContext(c) + return svc.renderer.RenderLikedByPage(rCtx, c.Writer, data) } -func (svc *service) ServeNotificationPage(ctx context.Context, client io.Writer, c *model.Client, maxID string, minID string) (err error) { - var hasNext bool - var nextLink string +func (svc *service) ServeRetweetedByPage(ctx context.Context, c *model.Client, + id string) (err error) { + retweeters, err := c.GetRebloggedBy(ctx, id, nil) + if err != nil { + return + } + + commonData, err := svc.getCommonData(ctx, c, "retweets") + if err != nil { + return + } + + data := &renderer.RetweetedByData{ + CommonData: commonData, + Users: retweeters, + } + + rCtx := getRendererContext(c) + return svc.renderer.RenderRetweetedByPage(rCtx, c.Writer, data) +} + +func (svc *service) ServeFollowingPage(ctx context.Context, c *model.Client, + id string, maxID string, minID string) (err error) { + + var nextLink string + var pg = mastodon.Pagination{ + MaxID: maxID, + MinID: minID, + Limit: 20, + } + + followings, err := c.GetAccountFollowing(ctx, id, &pg) + if err != nil { + return + } + + if len(followings) == 20 && len(pg.MaxID) > 0 { + nextLink = "/following/" + id + "?max_id=" + pg.MaxID + } + + commonData, err := svc.getCommonData(ctx, c, "following") + if err != nil { + return + } + + data := &renderer.FollowingData{ + CommonData: commonData, + Users: followings, + NextLink: nextLink, + } + + rCtx := getRendererContext(c) + return svc.renderer.RenderFollowingPage(rCtx, c.Writer, data) +} + +func (svc *service) ServeFollowersPage(ctx context.Context, c *model.Client, + id string, maxID string, minID string) (err error) { + + var nextLink string + var pg = mastodon.Pagination{ + MaxID: maxID, + MinID: minID, + Limit: 20, + } + + followers, err := c.GetAccountFollowers(ctx, id, &pg) + if err != nil { + return + } + + if len(followers) == 20 && len(pg.MaxID) > 0 { + nextLink = "/followers/" + id + "?max_id=" + pg.MaxID + } + + commonData, err := svc.getCommonData(ctx, c, "followers") + if err != nil { + return + } + + data := &renderer.FollowersData{ + CommonData: commonData, + Users: followers, + NextLink: nextLink, + } + rCtx := getRendererContext(c) + return svc.renderer.RenderFollowersPage(rCtx, c.Writer, data) +} + +func (svc *service) ServeNotificationPage(ctx context.Context, c *model.Client, + maxID string, minID string) (err error) { + + var nextLink string + var unreadCount int var pg = mastodon.Pagination{ MaxID: maxID, MinID: minID, @@ -445,7 +486,6 @@ func (svc *service) ServeNotificationPage(ctx context.Context, client io.Writer, return } - var unreadCount int for i := range notifications { if notifications[i].Status != nil { notifications[i].Status.CreatedAt = notifications[i].CreatedAt @@ -467,38 +507,26 @@ func (svc *service) ServeNotificationPage(ctx context.Context, client io.Writer, } if len(pg.MaxID) > 0 { - hasNext = true nextLink = "/notifications?max_id=" + pg.MaxID } - commonData, err := svc.getCommonData(ctx, client, c, "notifications") + commonData, err := svc.getCommonData(ctx, c, "notifications") if err != nil { return } data := &renderer.NotificationData{ Notifications: notifications, - HasNext: hasNext, NextLink: nextLink, CommonData: commonData, } rCtx := getRendererContext(c) - - err = svc.renderer.RenderNotificationPage(rCtx, client, data) - if err != nil { - return - } - - return + return svc.renderer.RenderNotificationPage(rCtx, c.Writer, data) } -func (svc *service) ServeUserPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error) { - user, err := c.GetAccount(ctx, id) - if err != nil { - return - } +func (svc *service) ServeUserPage(ctx context.Context, c *model.Client, + id string, maxID string, minID string) (err error) { - var hasNext bool var nextLink string var pg = mastodon.Pagination{ @@ -507,17 +535,21 @@ func (svc *service) ServeUserPage(ctx context.Context, client io.Writer, c *mode Limit: 20, } + user, err := c.GetAccount(ctx, id) + if err != nil { + return + } + statuses, err := c.GetAccountStatuses(ctx, id, &pg) if err != nil { return } if len(pg.MaxID) > 0 { - hasNext = true nextLink = "/user/" + id + "?max_id=" + pg.MaxID } - commonData, err := svc.getCommonData(ctx, client, c, user.DisplayName) + commonData, err := svc.getCommonData(ctx, c, user.DisplayName) if err != nil { return } @@ -525,22 +557,15 @@ func (svc *service) ServeUserPage(ctx context.Context, client io.Writer, c *mode data := &renderer.UserData{ User: user, Statuses: statuses, - HasNext: hasNext, NextLink: nextLink, CommonData: commonData, } rCtx := getRendererContext(c) - - err = svc.renderer.RenderUserPage(rCtx, client, data) - if err != nil { - return - } - - return + return svc.renderer.RenderUserPage(rCtx, c.Writer, data) } -func (svc *service) ServeAboutPage(ctx context.Context, client io.Writer, c *model.Client) (err error) { - commonData, err := svc.getCommonData(ctx, client, c, "about") +func (svc *service) ServeAboutPage(ctx context.Context, c *model.Client) (err error) { + commonData, err := svc.getCommonData(ctx, c, "about") if err != nil { return } @@ -548,18 +573,13 @@ func (svc *service) ServeAboutPage(ctx context.Context, client io.Writer, c *mod data := &renderer.AboutData{ CommonData: commonData, } + rCtx := getRendererContext(c) - - err = svc.renderer.RenderAboutPage(rCtx, client, data) - if err != nil { - return - } - - return + return svc.renderer.RenderAboutPage(rCtx, c.Writer, data) } -func (svc *service) ServeEmojiPage(ctx context.Context, client io.Writer, c *model.Client) (err error) { - commonData, err := svc.getCommonData(ctx, client, c, "emojis") +func (svc *service) ServeEmojiPage(ctx context.Context, c *model.Client) (err error) { + commonData, err := svc.getCommonData(ctx, c, "emojis") if err != nil { return } @@ -573,174 +593,33 @@ func (svc *service) ServeEmojiPage(ctx context.Context, client io.Writer, c *mod Emojis: emojis, CommonData: commonData, } + rCtx := getRendererContext(c) - - err = svc.renderer.RenderEmojiPage(rCtx, client, data) - if err != nil { - return - } - - return + return svc.renderer.RenderEmojiPage(rCtx, c.Writer, data) } -func (svc *service) ServeLikedByPage(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) { - likers, err := c.GetFavouritedBy(ctx, id, nil) - if err != nil { - return - } +func (svc *service) ServeSearchPage(ctx context.Context, c *model.Client, + q string, qType string, offset int) (err error) { - commonData, err := svc.getCommonData(ctx, client, c, "likes") - if err != nil { - return - } - - data := &renderer.LikedByData{ - CommonData: commonData, - Users: likers, - } - rCtx := getRendererContext(c) - - err = svc.renderer.RenderLikedByPage(rCtx, client, data) - if err != nil { - return - } - - return -} - -func (svc *service) ServeRetweetedByPage(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) { - retweeters, err := c.GetRebloggedBy(ctx, id, nil) - if err != nil { - return - } - - commonData, err := svc.getCommonData(ctx, client, c, "retweets") - if err != nil { - return - } - - data := &renderer.RetweetedByData{ - CommonData: commonData, - Users: retweeters, - } - rCtx := getRendererContext(c) - - err = svc.renderer.RenderRetweetedByPage(rCtx, client, data) - if err != nil { - return - } - - return -} - -func (svc *service) ServeFollowingPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error) { - var hasNext bool - var nextLink string - - var pg = mastodon.Pagination{ - MaxID: maxID, - MinID: minID, - Limit: 20, - } - - followings, err := c.GetAccountFollowing(ctx, id, &pg) - if err != nil { - return - } - - if len(followings) == 20 && len(pg.MaxID) > 0 { - hasNext = true - nextLink = "/following/" + id + "?max_id=" + pg.MaxID - } - - commonData, err := svc.getCommonData(ctx, client, c, "following") - if err != nil { - return - } - - data := &renderer.FollowingData{ - CommonData: commonData, - Users: followings, - HasNext: hasNext, - NextLink: nextLink, - } - rCtx := getRendererContext(c) - - err = svc.renderer.RenderFollowingPage(rCtx, client, data) - if err != nil { - return - } - - return -} - -func (svc *service) ServeFollowersPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error) { - var hasNext bool - var nextLink string - - var pg = mastodon.Pagination{ - MaxID: maxID, - MinID: minID, - Limit: 20, - } - - followers, err := c.GetAccountFollowers(ctx, id, &pg) - if err != nil { - return - } - - if len(followers) == 20 && len(pg.MaxID) > 0 { - hasNext = true - nextLink = "/followers/" + id + "?max_id=" + pg.MaxID - } - - commonData, err := svc.getCommonData(ctx, client, c, "followers") - if err != nil { - return - } - - data := &renderer.FollowersData{ - CommonData: commonData, - Users: followers, - HasNext: hasNext, - NextLink: nextLink, - } - rCtx := getRendererContext(c) - - err = svc.renderer.RenderFollowersPage(rCtx, client, data) - if err != nil { - return - } - - return -} - -func (svc *service) ServeSearchPage(ctx context.Context, client io.Writer, c *model.Client, q string, qType string, offset int) (err error) { - var hasNext bool var nextLink string + var title = "search" results, err := c.Search(ctx, q, qType, 20, true, offset) if err != nil { return } - switch qType { - case "accounts": - hasNext = len(results.Accounts) == 20 - case "statuses": - hasNext = len(results.Statuses) == 20 - } - - if hasNext { + if (qType == "accounts" && len(results.Accounts) == 20) || + (qType == "statuses" && len(results.Statuses) == 20) { offset += 20 nextLink = fmt.Sprintf("/search?q=%s&type=%s&offset=%d", q, qType, offset) } - var title = "search" if len(q) > 0 { title += " \"" + q + "\"" } - commonData, err := svc.getCommonData(ctx, client, c, title) + + commonData, err := svc.getCommonData(ctx, c, title) if err != nil { return } @@ -751,21 +630,15 @@ func (svc *service) ServeSearchPage(ctx context.Context, client io.Writer, c *mo Type: qType, Users: results.Accounts, Statuses: results.Statuses, - HasNext: hasNext, NextLink: nextLink, } + rCtx := getRendererContext(c) - - err = svc.renderer.RenderSearchPage(rCtx, client, data) - if err != nil { - return - } - - return + return svc.renderer.RenderSearchPage(rCtx, c.Writer, data) } -func (svc *service) ServeSettingsPage(ctx context.Context, client io.Writer, c *model.Client) (err error) { - commonData, err := svc.getCommonData(ctx, client, c, "settings") +func (svc *service) ServeSettingsPage(ctx context.Context, c *model.Client) (err error) { + commonData, err := svc.getCommonData(ctx, c, "settings") if err != nil { return } @@ -774,122 +647,125 @@ func (svc *service) ServeSettingsPage(ctx context.Context, client io.Writer, c * CommonData: commonData, Settings: &c.Session.Settings, } + rCtx := getRendererContext(c) - - err = svc.renderer.RenderSettingsPage(rCtx, client, data) - if err != nil { - return - } - - return + return svc.renderer.RenderSettingsPage(rCtx, c.Writer, data) } -func (svc *service) SaveSettings(ctx context.Context, client io.Writer, c *model.Client, settings *model.Settings) (err error) { - session, err := svc.sessionRepo.Get(c.Session.ID) +func (svc *service) NewSession(ctx context.Context, instance string) ( + redirectUrl string, sessionID string, err error) { + + var instanceURL string + if strings.HasPrefix(instance, "https://") { + instanceURL = instance + instance = strings.TrimPrefix(instance, "https://") + } else { + instanceURL = "https://" + instance + } + + sessionID, err = util.NewSessionID() if err != nil { return } - session.Settings = *settings + csrfToken, err := util.NewCSRFToken() + if err != nil { + return + } + + session := model.Session{ + ID: sessionID, + InstanceDomain: instance, + CSRFToken: csrfToken, + Settings: *model.NewSettings(), + } + err = svc.sessionRepo.Add(session) if err != nil { return } - return -} + app, err := svc.appRepo.Get(instance) + if err != nil { + if err != model.ErrAppNotFound { + return + } -func (svc *service) getCommonData(ctx context.Context, client io.Writer, c *model.Client, title string) (data *renderer.CommonData, err error) { - data = new(renderer.CommonData) - - data.HeaderData = &renderer.HeaderData{ - Title: title + " - " + svc.clientName, - NotificationCount: 0, - CustomCSS: svc.customCSS, - } - - if c != nil && c.Session.IsLoggedIn() { - notifications, err := c.GetNotifications(ctx, nil) + mastoApp, err := mastodon.RegisterApp(ctx, &mastodon.AppConfig{ + Server: instanceURL, + ClientName: svc.clientName, + Scopes: svc.clientScope, + Website: svc.clientWebsite, + RedirectURIs: svc.clientWebsite + "/oauth_callback", + }) if err != nil { - return nil, err + return "", "", err } - var notificationCount int - for i := range notifications { - if notifications[i].Pleroma != nil && !notifications[i].Pleroma.IsSeen { - notificationCount++ - } + app = model.App{ + InstanceDomain: instance, + InstanceURL: instanceURL, + ClientID: mastoApp.ClientID, + ClientSecret: mastoApp.ClientSecret, } - u, err := c.GetAccountCurrentUser(ctx) + err = svc.appRepo.Add(app) if err != nil { - return nil, err + return "", "", err } - - data.NavbarData = &renderer.NavbarData{ - User: u, - NotificationCount: notificationCount, - } - - data.HeaderData.NotificationCount = notificationCount - data.HeaderData.CSRFToken = c.Session.CSRFToken } - return -} - -func (svc *service) Like(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) { - s, err := c.Favourite(ctx, id) + u, err := url.Parse("/oauth/authorize") if err != nil { return } - count = s.FavouritesCount + + q := make(url.Values) + q.Set("scope", "read write follow") + q.Set("client_id", app.ClientID) + q.Set("response_type", "code") + q.Set("redirect_uri", svc.clientWebsite+"/oauth_callback") + u.RawQuery = q.Encode() + + redirectUrl = instanceURL + u.String() + return } -func (svc *service) UnLike(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) { - s, err := c.Unfavourite(ctx, id) +func (svc *service) Signin(ctx context.Context, c *model.Client, + sessionID string, code string) (token string, err error) { + + if len(code) < 1 { + err = errInvalidArgument + return + } + + err = c.AuthenticateToken(ctx, code, svc.clientWebsite+"/oauth_callback") if err != nil { return } - count = s.FavouritesCount + token = c.GetAccessToken(ctx) + return } -func (svc *service) Retweet(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) { - s, err := c.Reblog(ctx, id) - if err != nil { - return - } - if s.Reblog != nil { - count = s.Reblog.ReblogsCount - } - return -} +func (svc *service) Post(ctx context.Context, c *model.Client, content string, + replyToID string, format string, visibility string, isNSFW bool, + files []*multipart.FileHeader) (id string, err error) { -func (svc *service) UnRetweet(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) { - s, err := c.Unreblog(ctx, id) - if err != nil { - return - } - count = s.ReblogsCount - return -} - -func (svc *service) PostTweet(ctx context.Context, client io.Writer, c *model.Client, content string, replyToID string, format string, visibility string, isNSFW bool, files []*multipart.FileHeader) (id string, err error) { - var mediaIds []string + var mediaIDs []string for _, f := range files { a, err := c.UploadMediaFromMultipartFileHeader(ctx, f) if err != nil { return "", err } - mediaIds = append(mediaIds, a.ID) + mediaIDs = append(mediaIDs, a.ID) } tweet := &mastodon.Toot{ Status: content, InReplyToID: replyToID, - MediaIDs: mediaIds, + MediaIDs: mediaIDs, ContentType: format, Visibility: visibility, Sensitive: isNSFW, @@ -903,29 +779,66 @@ func (svc *service) PostTweet(ctx context.Context, client io.Writer, c *model.Cl return s.ID, nil } -func (svc *service) Follow(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) { +func (svc *service) Like(ctx context.Context, c *model.Client, id string) ( + count int64, err error) { + s, err := c.Favourite(ctx, id) + if err != nil { + return + } + count = s.FavouritesCount + return +} + +func (svc *service) UnLike(ctx context.Context, c *model.Client, id string) ( + count int64, err error) { + s, err := c.Unfavourite(ctx, id) + if err != nil { + return + } + count = s.FavouritesCount + return +} + +func (svc *service) Retweet(ctx context.Context, c *model.Client, id string) ( + count int64, err error) { + s, err := c.Reblog(ctx, id) + if err != nil { + return + } + if s.Reblog != nil { + count = s.Reblog.ReblogsCount + } + return +} + +func (svc *service) UnRetweet(ctx context.Context, c *model.Client, id string) ( + count int64, err error) { + s, err := c.Unreblog(ctx, id) + if err != nil { + return + } + count = s.ReblogsCount + return +} + +func (svc *service) Follow(ctx context.Context, c *model.Client, id string) (err error) { _, err = c.AccountFollow(ctx, id) return } -func (svc *service) UnFollow(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) { +func (svc *service) UnFollow(ctx context.Context, c *model.Client, id string) (err error) { _, err = c.AccountUnfollow(ctx, id) return } -func addToReplyMap(m map[string][]mastodon.ReplyInfo, key interface{}, val string, number int) { - if key == nil { +func (svc *service) SaveSettings(ctx context.Context, c *model.Client, + settings *model.Settings) (err error) { + + session, err := svc.sessionRepo.Get(c.Session.ID) + if err != nil { return } - keyStr, ok := key.(string) - if !ok { - return - } - _, ok = m[keyStr] - if !ok { - m[keyStr] = []mastodon.ReplyInfo{} - } - - m[keyStr] = append(m[keyStr], mastodon.ReplyInfo{val, number}) + session.Settings = *settings + return svc.sessionRepo.Add(session) } diff --git a/service/transport.go b/service/transport.go index e878f8d..fbab2e5 100644 --- a/service/transport.go +++ b/service/transport.go @@ -15,495 +15,14 @@ import ( "github.com/gorilla/mux" ) -var ( - ctx = context.Background() - cookieAge = "31536000" -) - -func NewHandler(s Service, staticDir string) http.Handler { - r := mux.NewRouter() - - r.PathPrefix("/static").Handler(http.StripPrefix("/static", - http.FileServer(http.Dir(path.Join(".", staticDir))))) - - r.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { - location := "/signin" - - sessionID, _ := req.Cookie("session_id") - if sessionID != nil && len(sessionID.Value) > 0 { - location = "/timeline/home" - } - - w.Header().Add("Location", location) - w.WriteHeader(http.StatusFound) - }).Methods(http.MethodGet) - - r.HandleFunc("/signin", func(w http.ResponseWriter, req *http.Request) { - err := s.ServeSigninPage(ctx, w) - if err != nil { - s.ServeErrorPage(ctx, w, nil, err) - return - } - }).Methods(http.MethodGet) - - r.HandleFunc("/signin", func(w http.ResponseWriter, req *http.Request) { - instance := req.FormValue("instance") - url, sessionID, err := s.GetAuthUrl(ctx, instance) - if err != nil { - s.ServeErrorPage(ctx, w, nil, err) - return - } - - http.SetCookie(w, &http.Cookie{ - Name: "session_id", - Value: sessionID, - Expires: time.Now().Add(365 * 24 * time.Hour), - }) - - w.Header().Add("Location", url) - w.WriteHeader(http.StatusFound) - }).Methods(http.MethodPost) - - r.HandleFunc("/oauth_callback", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) - token := req.URL.Query().Get("code") - _, err := s.GetUserToken(ctx, "", nil, token) - if err != nil { - s.ServeErrorPage(ctx, w, nil, err) - return - } - - w.Header().Add("Location", "/timeline/home") - w.WriteHeader(http.StatusFound) - }).Methods(http.MethodGet) - - r.HandleFunc("/timeline", func(w http.ResponseWriter, req *http.Request) { - w.Header().Add("Location", "/timeline/home") - w.WriteHeader(http.StatusFound) - }).Methods(http.MethodGet) - - r.HandleFunc("/timeline/{type}", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) - - timelineType, _ := mux.Vars(req)["type"] - maxID := req.URL.Query().Get("max_id") - sinceID := req.URL.Query().Get("since_id") - minID := req.URL.Query().Get("min_id") - - err := s.ServeTimelinePage(ctx, w, nil, timelineType, maxID, sinceID, minID) - if err != nil { - s.ServeErrorPage(ctx, w, nil, err) - return - } - }).Methods(http.MethodGet) - - r.HandleFunc("/thread/{id}", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) - id, _ := mux.Vars(req)["id"] - reply := req.URL.Query().Get("reply") - err := s.ServeThreadPage(ctx, w, nil, id, len(reply) > 1) - if err != nil { - s.ServeErrorPage(ctx, w, nil, err) - return - } - }).Methods(http.MethodGet) - - r.HandleFunc("/likedby/{id}", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) - id, _ := mux.Vars(req)["id"] - - err := s.ServeLikedByPage(ctx, w, nil, id) - if err != nil { - s.ServeErrorPage(ctx, w, nil, err) - return - } - }).Methods(http.MethodGet) - - r.HandleFunc("/retweetedby/{id}", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) - id, _ := mux.Vars(req)["id"] - - err := s.ServeRetweetedByPage(ctx, w, nil, id) - if err != nil { - s.ServeErrorPage(ctx, w, nil, err) - return - } - }).Methods(http.MethodGet) - - r.HandleFunc("/following/{id}", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) - - id, _ := mux.Vars(req)["id"] - maxID := req.URL.Query().Get("max_id") - minID := req.URL.Query().Get("min_id") - - err := s.ServeFollowingPage(ctx, w, nil, id, maxID, minID) - if err != nil { - s.ServeErrorPage(ctx, w, nil, err) - return - } - }).Methods(http.MethodGet) - - r.HandleFunc("/followers/{id}", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) - - id, _ := mux.Vars(req)["id"] - maxID := req.URL.Query().Get("max_id") - minID := req.URL.Query().Get("min_id") - - err := s.ServeFollowersPage(ctx, w, nil, id, maxID, minID) - if err != nil { - s.ServeErrorPage(ctx, w, nil, err) - return - } - }).Methods(http.MethodGet) - - r.HandleFunc("/like/{id}", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) - ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) - - id, _ := mux.Vars(req)["id"] - retweetedByID := req.FormValue("retweeted_by_id") - - _, err := s.Like(ctx, w, nil, id) - if err != nil { - s.ServeErrorPage(ctx, w, nil, err) - return - } - - rID := id - if len(retweetedByID) > 0 { - rID = retweetedByID - } - w.Header().Add("Location", req.Header.Get("Referer")+"#status-"+rID) - w.WriteHeader(http.StatusFound) - }).Methods(http.MethodPost) - - r.HandleFunc("/unlike/{id}", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) - ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) - - id, _ := mux.Vars(req)["id"] - retweetedByID := req.FormValue("retweeted_by_id") - - _, err := s.UnLike(ctx, w, nil, id) - if err != nil { - s.ServeErrorPage(ctx, w, nil, err) - return - } - - rID := id - if len(retweetedByID) > 0 { - rID = retweetedByID - } - w.Header().Add("Location", req.Header.Get("Referer")+"#status-"+rID) - w.WriteHeader(http.StatusFound) - }).Methods(http.MethodPost) - - r.HandleFunc("/retweet/{id}", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) - ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) - - id, _ := mux.Vars(req)["id"] - retweetedByID := req.FormValue("retweeted_by_id") - - _, err := s.Retweet(ctx, w, nil, id) - if err != nil { - s.ServeErrorPage(ctx, w, nil, err) - return - } - - rID := id - if len(retweetedByID) > 0 { - rID = retweetedByID - } - w.Header().Add("Location", req.Header.Get("Referer")+"#status-"+rID) - w.WriteHeader(http.StatusFound) - }).Methods(http.MethodPost) - - r.HandleFunc("/unretweet/{id}", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) - ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) - - id, _ := mux.Vars(req)["id"] - retweetedByID := req.FormValue("retweeted_by_id") - - _, err := s.UnRetweet(ctx, w, nil, id) - if err != nil { - s.ServeErrorPage(ctx, w, nil, err) - return - } - - rID := id - if len(retweetedByID) > 0 { - rID = retweetedByID - } - w.Header().Add("Location", req.Header.Get("Referer")+"#status-"+rID) - w.WriteHeader(http.StatusFound) - }).Methods(http.MethodPost) - - r.HandleFunc("/fluoride/like/{id}", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) - ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) - - id, _ := mux.Vars(req)["id"] - count, err := s.Like(ctx, w, nil, id) - if err != nil { - s.ServeErrorPage(ctx, w, nil, err) - return - } - - err = serveJson(w, count) - if err != nil { - s.ServeErrorPage(ctx, w, nil, err) - return - } - }).Methods(http.MethodPost) - - r.HandleFunc("/fluoride/unlike/{id}", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) - ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) - - id, _ := mux.Vars(req)["id"] - count, err := s.UnLike(ctx, w, nil, id) - if err != nil { - s.ServeErrorPage(ctx, w, nil, err) - return - } - - err = serveJson(w, count) - if err != nil { - s.ServeErrorPage(ctx, w, nil, err) - return - } - }).Methods(http.MethodPost) - - r.HandleFunc("/fluoride/retweet/{id}", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) - ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) - - id, _ := mux.Vars(req)["id"] - count, err := s.Retweet(ctx, w, nil, id) - if err != nil { - s.ServeErrorPage(ctx, w, nil, err) - return - } - - err = serveJson(w, count) - if err != nil { - s.ServeErrorPage(ctx, w, nil, err) - return - } - }).Methods(http.MethodPost) - - r.HandleFunc("/fluoride/unretweet/{id}", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) - ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) - - id, _ := mux.Vars(req)["id"] - count, err := s.UnRetweet(ctx, w, nil, id) - if err != nil { - s.ServeErrorPage(ctx, w, nil, err) - return - } - - err = serveJson(w, count) - if err != nil { - s.ServeErrorPage(ctx, w, nil, err) - return - } - }).Methods(http.MethodPost) - - r.HandleFunc("/post", func(w http.ResponseWriter, req *http.Request) { - err := req.ParseMultipartForm(4 << 20) - if err != nil { - s.ServeErrorPage(ctx, w, nil, err) - return - } - - ctx := getContextWithSession(context.Background(), req) - ctx = context.WithValue(ctx, "csrf_token", - getMultipartFormValue(req.MultipartForm, "csrf_token")) - - content := getMultipartFormValue(req.MultipartForm, "content") - replyToID := getMultipartFormValue(req.MultipartForm, "reply_to_id") - format := getMultipartFormValue(req.MultipartForm, "format") - visibility := getMultipartFormValue(req.MultipartForm, "visibility") - isNSFW := "on" == getMultipartFormValue(req.MultipartForm, "is_nsfw") - - files := req.MultipartForm.File["attachments"] - - id, err := s.PostTweet(ctx, w, nil, content, replyToID, format, visibility, isNSFW, files) - if err != nil { - s.ServeErrorPage(ctx, w, nil, err) - return - } - - location := "/timeline/home" + "#status-" + id - if len(replyToID) > 0 { - location = "/thread/" + replyToID + "#status-" + id - } - w.Header().Add("Location", location) - w.WriteHeader(http.StatusFound) - }).Methods(http.MethodPost) - - r.HandleFunc("/notifications", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) - - maxID := req.URL.Query().Get("max_id") - minID := req.URL.Query().Get("min_id") - - err := s.ServeNotificationPage(ctx, w, nil, maxID, minID) - if err != nil { - s.ServeErrorPage(ctx, w, nil, err) - return - } - }).Methods(http.MethodGet) - - r.HandleFunc("/user/{id}", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) - - id, _ := mux.Vars(req)["id"] - maxID := req.URL.Query().Get("max_id") - minID := req.URL.Query().Get("min_id") - - err := s.ServeUserPage(ctx, w, nil, id, maxID, minID) - if err != nil { - s.ServeErrorPage(ctx, w, nil, err) - return - } - }).Methods(http.MethodGet) - - r.HandleFunc("/follow/{id}", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) - ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) - - id, _ := mux.Vars(req)["id"] - - err := s.Follow(ctx, w, nil, id) - if err != nil { - s.ServeErrorPage(ctx, w, nil, err) - return - } - - w.Header().Add("Location", req.Header.Get("Referer")) - w.WriteHeader(http.StatusFound) - }).Methods(http.MethodPost) - - r.HandleFunc("/unfollow/{id}", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) - ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) - - id, _ := mux.Vars(req)["id"] - - err := s.UnFollow(ctx, w, nil, id) - if err != nil { - s.ServeErrorPage(ctx, w, nil, err) - return - } - - w.Header().Add("Location", req.Header.Get("Referer")) - w.WriteHeader(http.StatusFound) - }).Methods(http.MethodPost) - - r.HandleFunc("/about", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) - - err := s.ServeAboutPage(ctx, w, nil) - if err != nil { - s.ServeErrorPage(ctx, w, nil, err) - return - } - }).Methods(http.MethodGet) - - r.HandleFunc("/emojis", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) - - err := s.ServeEmojiPage(ctx, w, nil) - if err != nil { - s.ServeErrorPage(ctx, w, nil, err) - return - } - }).Methods(http.MethodGet) - - r.HandleFunc("/search", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) - - q := req.URL.Query().Get("q") - qType := req.URL.Query().Get("type") - offsetStr := req.URL.Query().Get("offset") - - var offset int - var err error - if len(offsetStr) > 1 { - offset, err = strconv.Atoi(offsetStr) - if err != nil { - s.ServeErrorPage(ctx, w, nil, err) - return - } - } - - err = s.ServeSearchPage(ctx, w, nil, q, qType, offset) - if err != nil { - s.ServeErrorPage(ctx, w, nil, err) - return - } - }).Methods(http.MethodGet) - - r.HandleFunc("/settings", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) - - err := s.ServeSettingsPage(ctx, w, nil) - if err != nil { - s.ServeErrorPage(ctx, w, nil, err) - return - } - }).Methods(http.MethodGet) - - r.HandleFunc("/settings", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) - ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) - - visibility := req.FormValue("visibility") - copyScope := req.FormValue("copy_scope") == "true" - threadInNewTab := req.FormValue("thread_in_new_tab") == "true" - maskNSFW := req.FormValue("mask_nsfw") == "true" - fluorideMode := req.FormValue("fluoride_mode") == "true" - darkMode := req.FormValue("dark_mode") == "true" - settings := &model.Settings{ - DefaultVisibility: visibility, - CopyScope: copyScope, - ThreadInNewTab: threadInNewTab, - MaskNSFW: maskNSFW, - FluorideMode: fluorideMode, - DarkMode: darkMode, - } - - err := s.SaveSettings(ctx, w, nil, settings) - if err != nil { - s.ServeErrorPage(ctx, w, nil, err) - return - } - - w.Header().Add("Location", req.Header.Get("Referer")) - w.WriteHeader(http.StatusFound) - }).Methods(http.MethodPost) - - r.HandleFunc("/signout", func(w http.ResponseWriter, req *http.Request) { - // TODO remove session from database - http.SetCookie(w, &http.Cookie{ - Name: "session_id", - Value: "", - Expires: time.Now(), - }) - w.Header().Add("Location", "/") - w.WriteHeader(http.StatusFound) - }).Methods(http.MethodGet) - - return r +func newClient(w io.Writer) *model.Client { + return &model.Client{ + Writer: w, + } } -func getContextWithSession(ctx context.Context, req *http.Request) context.Context { +func newCtxWithSesion(req *http.Request) context.Context { + ctx := context.Background() sessionID, err := req.Cookie("session_id") if err != nil { return ctx @@ -511,6 +30,11 @@ func getContextWithSession(ctx context.Context, req *http.Request) context.Conte return context.WithValue(ctx, "session_id", sessionID.Value) } +func newCtxWithSesionCSRF(req *http.Request, csrfToken string) context.Context { + ctx := newCtxWithSesion(req) + return context.WithValue(ctx, "csrf_token", csrfToken) +} + func getMultipartFormValue(mf *multipart.Form, key string) (val string) { vals, ok := mf.Value[key] if !ok { @@ -527,3 +51,521 @@ func serveJson(w io.Writer, data interface{}) (err error) { d["data"] = data return json.NewEncoder(w).Encode(d) } + +func NewHandler(s Service, staticDir string) http.Handler { + r := mux.NewRouter() + + rootPage := func(w http.ResponseWriter, req *http.Request) { + sessionID, _ := req.Cookie("session_id") + + location := "/signin" + if sessionID != nil && len(sessionID.Value) > 0 { + location = "/timeline/home" + } + + w.Header().Add("Location", location) + w.WriteHeader(http.StatusFound) + } + + signinPage := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := context.Background() + err := s.ServeSigninPage(ctx, c) + if err != nil { + s.ServeErrorPage(ctx, c, err) + return + } + } + + timelinePage := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesion(req) + tType, _ := mux.Vars(req)["type"] + maxID := req.URL.Query().Get("max_id") + minID := req.URL.Query().Get("min_id") + + err := s.ServeTimelinePage(ctx, c, tType, maxID, minID) + if err != nil { + s.ServeErrorPage(ctx, c, err) + return + } + } + + timelineOldPage := func(w http.ResponseWriter, req *http.Request) { + w.Header().Add("Location", "/timeline/home") + w.WriteHeader(http.StatusFound) + } + + threadPage := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesion(req) + id, _ := mux.Vars(req)["id"] + reply := req.URL.Query().Get("reply") + + err := s.ServeThreadPage(ctx, c, id, len(reply) > 1) + if err != nil { + s.ServeErrorPage(ctx, c, err) + return + } + } + + likedByPage := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesion(req) + id, _ := mux.Vars(req)["id"] + + err := s.ServeLikedByPage(ctx, c, id) + if err != nil { + s.ServeErrorPage(ctx, c, err) + return + } + } + + retweetedByPage := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesion(req) + id, _ := mux.Vars(req)["id"] + + err := s.ServeRetweetedByPage(ctx, c, id) + if err != nil { + s.ServeErrorPage(ctx, c, err) + return + } + } + + followingPage := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesion(req) + id, _ := mux.Vars(req)["id"] + maxID := req.URL.Query().Get("max_id") + minID := req.URL.Query().Get("min_id") + + err := s.ServeFollowingPage(ctx, c, id, maxID, minID) + if err != nil { + s.ServeErrorPage(ctx, c, err) + return + } + } + + followersPage := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesion(req) + id, _ := mux.Vars(req)["id"] + maxID := req.URL.Query().Get("max_id") + minID := req.URL.Query().Get("min_id") + + err := s.ServeFollowersPage(ctx, c, id, maxID, minID) + if err != nil { + s.ServeErrorPage(ctx, c, err) + return + } + } + + notificationsPage := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesion(req) + maxID := req.URL.Query().Get("max_id") + minID := req.URL.Query().Get("min_id") + + err := s.ServeNotificationPage(ctx, c, maxID, minID) + if err != nil { + s.ServeErrorPage(ctx, c, err) + return + } + } + + userPage := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesion(req) + id, _ := mux.Vars(req)["id"] + maxID := req.URL.Query().Get("max_id") + minID := req.URL.Query().Get("min_id") + + err := s.ServeUserPage(ctx, c, id, maxID, minID) + if err != nil { + s.ServeErrorPage(ctx, c, err) + return + } + } + + aboutPage := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesion(req) + + err := s.ServeAboutPage(ctx, c) + if err != nil { + s.ServeErrorPage(ctx, c, err) + return + } + } + + emojisPage := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesion(req) + + err := s.ServeEmojiPage(ctx, c) + if err != nil { + s.ServeErrorPage(ctx, c, err) + return + } + } + + searchPage := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesion(req) + q := req.URL.Query().Get("q") + qType := req.URL.Query().Get("type") + offsetStr := req.URL.Query().Get("offset") + + var offset int + var err error + if len(offsetStr) > 1 { + offset, err = strconv.Atoi(offsetStr) + if err != nil { + s.ServeErrorPage(ctx, c, err) + return + } + } + + err = s.ServeSearchPage(ctx, c, q, qType, offset) + if err != nil { + s.ServeErrorPage(ctx, c, err) + return + } + } + + settingsPage := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesion(req) + + err := s.ServeSettingsPage(ctx, c) + if err != nil { + s.ServeErrorPage(ctx, c, err) + return + } + } + + signin := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := context.Background() + instance := req.FormValue("instance") + + url, sessionID, err := s.NewSession(ctx, instance) + if err != nil { + s.ServeErrorPage(ctx, c, err) + return + } + + http.SetCookie(w, &http.Cookie{ + Name: "session_id", + Value: sessionID, + Expires: time.Now().Add(365 * 24 * time.Hour), + }) + + w.Header().Add("Location", url) + w.WriteHeader(http.StatusFound) + } + + oauthCallback := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesion(req) + token := req.URL.Query().Get("code") + + _, err := s.Signin(ctx, c, "", token) + if err != nil { + s.ServeErrorPage(ctx, c, err) + return + } + + w.Header().Add("Location", "/timeline/home") + w.WriteHeader(http.StatusFound) + } + + post := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + err := req.ParseMultipartForm(4 << 20) + if err != nil { + s.ServeErrorPage(context.Background(), c, err) + return + } + + ctx := newCtxWithSesionCSRF(req, + getMultipartFormValue(req.MultipartForm, "csrf_token")) + content := getMultipartFormValue(req.MultipartForm, "content") + replyToID := getMultipartFormValue(req.MultipartForm, "reply_to_id") + format := getMultipartFormValue(req.MultipartForm, "format") + visibility := getMultipartFormValue(req.MultipartForm, "visibility") + isNSFW := "on" == getMultipartFormValue(req.MultipartForm, "is_nsfw") + files := req.MultipartForm.File["attachments"] + + id, err := s.Post(ctx, c, content, replyToID, format, visibility, isNSFW, files) + if err != nil { + s.ServeErrorPage(ctx, c, err) + return + } + + location := "/timeline/home" + "#status-" + id + if len(replyToID) > 0 { + location = "/thread/" + replyToID + "#status-" + id + } + w.Header().Add("Location", location) + w.WriteHeader(http.StatusFound) + } + + like := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesionCSRF(req, req.FormValue("csrf_token")) + id, _ := mux.Vars(req)["id"] + retweetedByID := req.FormValue("retweeted_by_id") + + _, err := s.Like(ctx, c, id) + if err != nil { + s.ServeErrorPage(ctx, c, err) + return + } + + rID := id + if len(retweetedByID) > 0 { + rID = retweetedByID + } + w.Header().Add("Location", req.Header.Get("Referer")+"#status-"+rID) + w.WriteHeader(http.StatusFound) + } + + unlike := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesionCSRF(req, req.FormValue("csrf_token")) + id, _ := mux.Vars(req)["id"] + retweetedByID := req.FormValue("retweeted_by_id") + + _, err := s.UnLike(ctx, c, id) + if err != nil { + s.ServeErrorPage(ctx, c, err) + return + } + + rID := id + if len(retweetedByID) > 0 { + rID = retweetedByID + } + w.Header().Add("Location", req.Header.Get("Referer")+"#status-"+rID) + w.WriteHeader(http.StatusFound) + } + + retweet := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesionCSRF(req, req.FormValue("csrf_token")) + id, _ := mux.Vars(req)["id"] + retweetedByID := req.FormValue("retweeted_by_id") + + _, err := s.Retweet(ctx, c, id) + if err != nil { + s.ServeErrorPage(ctx, c, err) + return + } + + rID := id + if len(retweetedByID) > 0 { + rID = retweetedByID + } + w.Header().Add("Location", req.Header.Get("Referer")+"#status-"+rID) + w.WriteHeader(http.StatusFound) + } + + unretweet := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesionCSRF(req, req.FormValue("csrf_token")) + id, _ := mux.Vars(req)["id"] + retweetedByID := req.FormValue("retweeted_by_id") + + _, err := s.UnRetweet(ctx, c, id) + if err != nil { + s.ServeErrorPage(ctx, c, err) + return + } + + rID := id + if len(retweetedByID) > 0 { + rID = retweetedByID + } + + w.Header().Add("Location", req.Header.Get("Referer")+"#status-"+rID) + w.WriteHeader(http.StatusFound) + } + + follow := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesionCSRF(req, req.FormValue("csrf_token")) + id, _ := mux.Vars(req)["id"] + + err := s.Follow(ctx, c, id) + if err != nil { + s.ServeErrorPage(ctx, c, err) + return + } + + w.Header().Add("Location", req.Header.Get("Referer")) + w.WriteHeader(http.StatusFound) + } + + unfollow := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesionCSRF(req, req.FormValue("csrf_token")) + id, _ := mux.Vars(req)["id"] + + err := s.UnFollow(ctx, c, id) + if err != nil { + s.ServeErrorPage(ctx, c, err) + return + } + + w.Header().Add("Location", req.Header.Get("Referer")) + w.WriteHeader(http.StatusFound) + } + + settings := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesionCSRF(req, req.FormValue("csrf_token")) + visibility := req.FormValue("visibility") + copyScope := req.FormValue("copy_scope") == "true" + threadInNewTab := req.FormValue("thread_in_new_tab") == "true" + maskNSFW := req.FormValue("mask_nsfw") == "true" + fluorideMode := req.FormValue("fluoride_mode") == "true" + darkMode := req.FormValue("dark_mode") == "true" + + settings := &model.Settings{ + DefaultVisibility: visibility, + CopyScope: copyScope, + ThreadInNewTab: threadInNewTab, + MaskNSFW: maskNSFW, + FluorideMode: fluorideMode, + DarkMode: darkMode, + } + + err := s.SaveSettings(ctx, c, settings) + if err != nil { + s.ServeErrorPage(ctx, c, err) + return + } + + w.Header().Add("Location", req.Header.Get("Referer")) + w.WriteHeader(http.StatusFound) + } + + signout := func(w http.ResponseWriter, req *http.Request) { + // TODO remove session from database + http.SetCookie(w, &http.Cookie{ + Name: "session_id", + Value: "", + Expires: time.Now(), + }) + w.Header().Add("Location", "/") + w.WriteHeader(http.StatusFound) + } + + fLike := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesionCSRF(req, req.FormValue("csrf_token")) + id, _ := mux.Vars(req)["id"] + + count, err := s.Like(ctx, c, id) + if err != nil { + s.ServeErrorPage(ctx, c, err) + return + } + + err = serveJson(w, count) + if err != nil { + s.ServeErrorPage(ctx, c, err) + return + } + } + + fUnlike := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesionCSRF(req, req.FormValue("csrf_token")) + id, _ := mux.Vars(req)["id"] + count, err := s.UnLike(ctx, c, id) + if err != nil { + s.ServeErrorPage(ctx, c, err) + return + } + + err = serveJson(w, count) + if err != nil { + s.ServeErrorPage(ctx, c, err) + return + } + } + + fRetweet := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesionCSRF(req, req.FormValue("csrf_token")) + id, _ := mux.Vars(req)["id"] + + count, err := s.Retweet(ctx, c, id) + if err != nil { + s.ServeErrorPage(ctx, c, err) + return + } + + err = serveJson(w, count) + if err != nil { + s.ServeErrorPage(ctx, c, err) + return + } + } + + fUnretweet := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesionCSRF(req, req.FormValue("csrf_token")) + id, _ := mux.Vars(req)["id"] + + count, err := s.UnRetweet(ctx, c, id) + if err != nil { + s.ServeErrorPage(ctx, c, err) + return + } + + err = serveJson(w, count) + if err != nil { + s.ServeErrorPage(ctx, c, err) + return + } + } + + r.HandleFunc("/", rootPage).Methods(http.MethodGet) + r.HandleFunc("/signin", signinPage).Methods(http.MethodGet) + r.HandleFunc("/timeline/{type}", timelinePage).Methods(http.MethodGet) + r.HandleFunc("/timeline", timelineOldPage).Methods(http.MethodGet) + r.HandleFunc("/thread/{id}", threadPage).Methods(http.MethodGet) + r.HandleFunc("/likedby/{id}", likedByPage).Methods(http.MethodGet) + r.HandleFunc("/retweetedby/{id}", retweetedByPage).Methods(http.MethodGet) + r.HandleFunc("/following/{id}", followingPage).Methods(http.MethodGet) + r.HandleFunc("/followers/{id}", followersPage).Methods(http.MethodGet) + r.HandleFunc("/notifications", notificationsPage).Methods(http.MethodGet) + r.HandleFunc("/user/{id}", userPage).Methods(http.MethodGet) + r.HandleFunc("/about", aboutPage).Methods(http.MethodGet) + r.HandleFunc("/emojis", emojisPage).Methods(http.MethodGet) + r.HandleFunc("/search", searchPage).Methods(http.MethodGet) + r.HandleFunc("/settings", settingsPage).Methods(http.MethodGet) + r.HandleFunc("/signin", signin).Methods(http.MethodPost) + r.HandleFunc("/oauth_callback", oauthCallback).Methods(http.MethodGet) + r.HandleFunc("/post", post).Methods(http.MethodPost) + r.HandleFunc("/like/{id}", like).Methods(http.MethodPost) + r.HandleFunc("/unlike/{id}", unlike).Methods(http.MethodPost) + r.HandleFunc("/retweet/{id}", retweet).Methods(http.MethodPost) + r.HandleFunc("/unretweet/{id}", unretweet).Methods(http.MethodPost) + r.HandleFunc("/follow/{id}", follow).Methods(http.MethodPost) + r.HandleFunc("/unfollow/{id}", unfollow).Methods(http.MethodPost) + r.HandleFunc("/settings", settings).Methods(http.MethodPost) + r.HandleFunc("/signout", signout).Methods(http.MethodGet) + r.HandleFunc("/fluoride/like/{id}", fLike).Methods(http.MethodPost) + r.HandleFunc("/fluoride/unlike/{id}", fUnlike).Methods(http.MethodPost) + r.HandleFunc("/fluoride/retweet/{id}", fRetweet).Methods(http.MethodPost) + r.HandleFunc("/fluoride/unretweet/{id}", fUnretweet).Methods(http.MethodPost) + r.PathPrefix("/static").Handler(http.StripPrefix("/static", + http.FileServer(http.Dir(path.Join(".", staticDir))))) + + return r +} diff --git a/static/custom.css b/static/custom.css deleted file mode 100644 index a1c192a..0000000 --- a/static/custom.css +++ /dev/null @@ -1,3 +0,0 @@ -html { - background: #000000; -} diff --git a/static/main.css b/static/style.css similarity index 100% rename from static/main.css rename to static/style.css diff --git a/templates/followers.tmpl b/templates/followers.tmpl index 8102b26..44a303b 100644 --- a/templates/followers.tmpl +++ b/templates/followers.tmpl @@ -5,7 +5,7 @@ {{template "userlist.tmpl" (WithContext .Users $.Ctx)}} diff --git a/templates/following.tmpl b/templates/following.tmpl index 43a60dc..50413d5 100644 --- a/templates/following.tmpl +++ b/templates/following.tmpl @@ -5,7 +5,7 @@ {{template "userlist.tmpl" (WithContext .Users $.Ctx)}} diff --git a/templates/header.tmpl b/templates/header.tmpl index e6e7f0d..2889ead 100644 --- a/templates/header.tmpl +++ b/templates/header.tmpl @@ -8,7 +8,7 @@ {{end}} {{if gt .NotificationCount 0}}({{.NotificationCount}}) {{end}}{{.Title}} - + {{if .CustomCSS}} {{end}} diff --git a/templates/notification.tmpl b/templates/notification.tmpl index 86134ac..7d0e67c 100644 --- a/templates/notification.tmpl +++ b/templates/notification.tmpl @@ -65,7 +65,7 @@ {{end}} diff --git a/templates/search.tmpl b/templates/search.tmpl index b4cd744..acbfbdd 100644 --- a/templates/search.tmpl +++ b/templates/search.tmpl @@ -31,7 +31,7 @@ {{end}} diff --git a/templates/timeline.tmpl b/templates/timeline.tmpl index aa951fc..0321c7f 100644 --- a/templates/timeline.tmpl +++ b/templates/timeline.tmpl @@ -10,10 +10,10 @@ {{end}} diff --git a/templates/user.tmpl b/templates/user.tmpl index abf22ec..bab24b2 100644 --- a/templates/user.tmpl +++ b/templates/user.tmpl @@ -56,7 +56,7 @@ {{end}} diff --git a/util/rand.go b/util/rand.go index ffe97a0..1e4ec95 100644 --- a/util/rand.go +++ b/util/rand.go @@ -10,7 +10,7 @@ var ( runes_length = len(runes) ) -func NewRandId(n int) (string, error) { +func NewRandID(n int) (string, error) { data := make([]rune, n) for i := range data { num, err := rand.Int(rand.Reader, big.NewInt(int64(runes_length))) @@ -22,10 +22,10 @@ func NewRandId(n int) (string, error) { return string(data), nil } -func NewSessionId() (string, error) { - return NewRandId(24) +func NewSessionID() (string, error) { + return NewRandID(24) } func NewCSRFToken() (string, error) { - return NewRandId(24) + return NewRandID(24) }