Compare commits

..

No commits in common. "5147897c6c8ba3428ea6998f77241182ee8caa24" and "5f688c6318191627775ae44221eb75cb6ed2e905" have entirely different histories.

20 changed files with 548 additions and 292 deletions

10
INSTALL
View File

@ -23,8 +23,16 @@ most cases, you only need to change the value of "client_website".
# cp bloat.gen.conf /etc/bloat.conf # cp bloat.gen.conf /etc/bloat.conf
# $EDITOR /etc/bloat.conf # $EDITOR /etc/bloat.conf
4. Create database directory
Create a directory to store session information. Optionally, create a user
to run bloat and change the ownership of the database directory accordingly.
# mkdir /var/bloat
# useradd _bloat
# chown -R _bloat:_bloat /var/bloat
Replace /var/bloat with the value you specified in the config file.
5. Run the binary 5. Run the binary
$ bloat # su _bloat -c bloat
Now you should create an init script to automatically start bloat at system Now you should create an init script to automatically start bloat at system
startup. startup.

View File

@ -10,6 +10,7 @@ SRC=main.go \
mastodon/*.go \ mastodon/*.go \
model/*.go \ model/*.go \
renderer/*.go \ renderer/*.go \
repo/*.go \
service/*.go \ service/*.go \
util/*.go \ util/*.go \
@ -17,7 +18,8 @@ all: bloat
bloat: $(SRC) $(TMPL) bloat: $(SRC) $(TMPL)
$(GO) build $(GOFLAGS) -o bloat main.go $(GO) build $(GOFLAGS) -o bloat main.go
sed -e "s%=templates%=$(SHAREPATH)/templates%g" \ sed -e "s%=database%=/var/bloat%g" \
-e "s%=templates%=$(SHAREPATH)/templates%g" \
-e "s%=static%=$(SHAREPATH)/static%g" \ -e "s%=static%=$(SHAREPATH)/static%g" \
< bloat.conf > bloat.gen.conf < bloat.conf > bloat.gen.conf

View File

@ -3,6 +3,10 @@
# - Key and Value are separated by a single '=' # - Key and Value are separated by a single '='
# - Leading and trailing white spaces in Key and Value are ignored # - Leading and trailing white spaces in Key and Value are ignored
# - Quoting and multi-line values are not supported # - Quoting and multi-line values are not supported
#
# Changing values of client_name, client_scope or client_website will cause
# previously generated access tokens and client tokens to be invalid. Issue the
# `rm -r database_path/*` command to clean the database afterwards.
# Address to listen to. Value can be of "HOSTNAME:PORT" or "IP:PORT" form. In # Address to listen to. Value can be of "HOSTNAME:PORT" or "IP:PORT" form. In
# case of empty HOSTNAME or IP, "0.0.0.0:PORT" is used. # case of empty HOSTNAME or IP, "0.0.0.0:PORT" is used.
@ -21,6 +25,9 @@ client_name=bloat
# See https://docs.joinmastodon.org/api/oauth-scopes/ # See https://docs.joinmastodon.org/api/oauth-scopes/
client_scope=read write follow client_scope=read write follow
# Path of database directory. It's used to store session information.
database_path=database
# Path of directory containing template files. # Path of directory containing template files.
templates_path=templates templates_path=templates

View File

@ -18,6 +18,7 @@ type config struct {
SingleInstance string SingleInstance string
StaticDirectory string StaticDirectory string
TemplatesPath string TemplatesPath string
DatabasePath string
CustomCSS string CustomCSS string
PostFormats []model.PostFormat PostFormats []model.PostFormat
LogFile string LogFile string
@ -29,7 +30,8 @@ func (c *config) IsValid() bool {
len(c.ClientScope) < 1 || len(c.ClientScope) < 1 ||
len(c.ClientWebsite) < 1 || len(c.ClientWebsite) < 1 ||
len(c.StaticDirectory) < 1 || len(c.StaticDirectory) < 1 ||
len(c.TemplatesPath) < 1 { len(c.TemplatesPath) < 1 ||
len(c.DatabasePath) < 1 {
return false return false
} }
return true return true
@ -73,10 +75,10 @@ func Parse(r io.Reader) (c *config, err error) {
c.StaticDirectory = val c.StaticDirectory = val
case "templates_path": case "templates_path":
c.TemplatesPath = val c.TemplatesPath = val
case "database_path":
c.DatabasePath = val
case "custom_css": case "custom_css":
c.CustomCSS = val c.CustomCSS = val
case "database_path":
// ignore
case "post_formats": case "post_formats":
vals := strings.Split(val, ",") vals := strings.Split(val, ",")
var formats []model.PostFormat var formats []model.PostFormat

24
main.go
View File

@ -12,7 +12,9 @@ import (
"bloat/config" "bloat/config"
"bloat/renderer" "bloat/renderer"
"bloat/repo"
"bloat/service" "bloat/service"
"bloat/util"
) )
var ( var (
@ -46,6 +48,26 @@ func main() {
errExit(err) errExit(err)
} }
err = os.Mkdir(config.DatabasePath, 0755)
if err != nil && !os.IsExist(err) {
errExit(err)
}
sessionDBPath := filepath.Join(config.DatabasePath, "session")
sessionDB, err := util.NewDatabse(sessionDBPath)
if err != nil {
errExit(err)
}
appDBPath := filepath.Join(config.DatabasePath, "app")
appDB, err := util.NewDatabse(appDBPath)
if err != nil {
errExit(err)
}
sessionRepo := repo.NewSessionRepo(sessionDB)
appRepo := repo.NewAppRepo(appDB)
customCSS := config.CustomCSS customCSS := config.CustomCSS
if len(customCSS) > 0 && !strings.HasPrefix(customCSS, "http://") && if len(customCSS) > 0 && !strings.HasPrefix(customCSS, "http://") &&
!strings.HasPrefix(customCSS, "https://") { !strings.HasPrefix(customCSS, "https://") {
@ -67,7 +89,7 @@ func main() {
s := service.NewService(config.ClientName, config.ClientScope, s := service.NewService(config.ClientName, config.ClientScope,
config.ClientWebsite, customCSS, config.SingleInstance, config.ClientWebsite, customCSS, config.SingleInstance,
config.PostFormats, renderer) config.PostFormats, renderer, sessionRepo, appRepo)
handler := service.NewHandler(s, logger, config.StaticDirectory) handler := service.NewHandler(s, logger, config.StaticDirectory)
logger.Println("listening on", config.ListenAddress) logger.Println("listening on", config.ListenAddress)

View File

@ -56,9 +56,7 @@ type AccountSource struct {
// GetAccount return Account. // GetAccount return Account.
func (c *Client) GetAccount(ctx context.Context, id string) (*Account, error) { func (c *Client) GetAccount(ctx context.Context, id string) (*Account, error) {
var account Account var account Account
params := url.Values{} err := c.doAPI(ctx, http.MethodGet, fmt.Sprintf("/api/v1/accounts/%s", url.PathEscape(string(id))), nil, &account, nil)
params.Set("with_relationships", strconv.FormatBool(true))
err := c.doAPI(ctx, http.MethodGet, fmt.Sprintf("/api/v1/accounts/%s", url.PathEscape(string(id))), params, &account, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -246,10 +244,11 @@ func (c *Client) AccountUnblock(ctx context.Context, id string) (*Relationship,
} }
// AccountMute mute the account. // AccountMute mute the account.
func (c *Client) AccountMute(ctx context.Context, id string, notifications bool, duration int) (*Relationship, error) { func (c *Client) AccountMute(ctx context.Context, id string, notifications *bool) (*Relationship, error) {
params := url.Values{} params := url.Values{}
params.Set("notifications", strconv.FormatBool(notifications)) if notifications != nil {
params.Set("duration", strconv.Itoa(duration)) params.Set("notifications", strconv.FormatBool(*notifications))
}
var relationship Relationship var relationship Relationship
err := c.doAPI(ctx, http.MethodPost, fmt.Sprintf("/api/v1/accounts/%s/mute", url.PathEscape(string(id))), params, &relationship, nil) err := c.doAPI(ctx, http.MethodPost, fmt.Sprintf("/api/v1/accounts/%s/mute", url.PathEscape(string(id))), params, &relationship, nil)
if err != nil { if err != nil {

View File

@ -56,6 +56,7 @@ type Status struct {
MediaAttachments []Attachment `json:"media_attachments"` MediaAttachments []Attachment `json:"media_attachments"`
Mentions []Mention `json:"mentions"` Mentions []Mention `json:"mentions"`
Tags []Tag `json:"tags"` Tags []Tag `json:"tags"`
Card *Card `json:"card"`
Application Application `json:"application"` Application Application `json:"application"`
Language string `json:"language"` Language string `json:"language"`
Pinned interface{} `json:"pinned"` Pinned interface{} `json:"pinned"`
@ -76,6 +77,22 @@ type Context struct {
Descendants []*Status `json:"descendants"` Descendants []*Status `json:"descendants"`
} }
// Card hold information for mastodon card.
type Card struct {
URL string `json:"url"`
Title string `json:"title"`
Description string `json:"description"`
Image string `json:"image"`
Type string `json:"type"`
AuthorName string `json:"author_name"`
AuthorURL string `json:"author_url"`
ProviderName string `json:"provider_name"`
ProviderURL string `json:"provider_url"`
HTML string `json:"html"`
Width int64 `json:"width"`
Height int64 `json:"height"`
}
// GetFavourites return the favorite list of the current user. // GetFavourites return the favorite list of the current user.
func (c *Client) GetFavourites(ctx context.Context, pg *Pagination) ([]*Status, error) { func (c *Client) GetFavourites(ctx context.Context, pg *Pagination) ([]*Status, error) {
var statuses []*Status var statuses []*Status
@ -106,6 +123,16 @@ func (c *Client) GetStatusContext(ctx context.Context, id string) (*Context, err
return &context, nil return &context, nil
} }
// GetStatusCard return status specified by id.
func (c *Client) GetStatusCard(ctx context.Context, id string) (*Card, error) {
var card Card
err := c.doAPI(ctx, http.MethodGet, fmt.Sprintf("/api/v1/statuses/%s/card", id), nil, &card, nil)
if err != nil {
return nil, err
}
return &card, nil
}
// GetRebloggedBy returns the account list of the user who reblogged the toot of id. // GetRebloggedBy returns the account list of the user who reblogged the toot of id.
func (c *Client) GetRebloggedBy(ctx context.Context, id string, pg *Pagination) ([]*Account, error) { func (c *Client) GetRebloggedBy(ctx context.Context, id string, pg *Pagination) ([]*Account, error) {
var accounts []*Account var accounts []*Account

21
model/app.go Normal file
View File

@ -0,0 +1,21 @@
package model
import (
"errors"
)
var (
ErrAppNotFound = errors.New("app not found")
)
type App struct {
InstanceDomain string `json:"instance_domain"`
InstanceURL string `json:"instance_url"`
ClientID string `json:"client_id"`
ClientSecret string `json:"client_secret"`
}
type AppRepo interface {
Add(app App) (err error)
Get(instanceDomain string) (app App, err error)
}

View File

@ -1,48 +1,28 @@
package model package model
import (
"errors"
)
var (
ErrSessionNotFound = errors.New("session not found")
)
type Session struct { type Session struct {
ID string `json:"id,omitempty"` ID string `json:"id"`
UserID string `json:"uid,omitempty"` UserID string `json:"user_id"`
Instance string `json:"ins,omitempty"` InstanceDomain string `json:"instance_domain"`
ClientID string `json:"cid,omitempty"` AccessToken string `json:"access_token"`
ClientSecret string `json:"cs,omitempty"` CSRFToken string `json:"csrf_token"`
AccessToken string `json:"at,omitempty"` Settings Settings `json:"settings"`
CSRFToken string `json:"csrf,omitempty"` }
Settings Settings `json:"sett,omitempty"`
type SessionRepo interface {
Add(session Session) (err error)
Get(sessionID string) (session Session, err error)
Remove(sessionID string)
} }
func (s Session) IsLoggedIn() bool { func (s Session) IsLoggedIn() bool {
return len(s.AccessToken) > 0 return len(s.AccessToken) > 0
} }
type Settings struct {
DefaultVisibility string `json:"dv,omitempty"`
DefaultFormat string `json:"df,omitempty"`
CopyScope bool `json:"cs,omitempty"`
ThreadInNewTab bool `json:"tnt,omitempty"`
HideAttachments bool `json:"ha,omitempty"`
MaskNSFW bool `json:"mn,omitempty"`
NotificationInterval int `json:"ni,omitempty"`
FluorideMode bool `json:"fm,omitempty"`
DarkMode bool `json:"dm,omitempty"`
AntiDopamineMode bool `json:"adm,omitempty"`
HideUnsupportedNotifs bool `json:"hun,omitempty"`
CSS string `json:"css,omitempty"`
}
func NewSettings() *Settings {
return &Settings{
DefaultVisibility: "public",
DefaultFormat: "",
CopyScope: true,
ThreadInNewTab: false,
HideAttachments: false,
MaskNSFW: true,
NotificationInterval: 0,
FluorideMode: false,
DarkMode: false,
AntiDopamineMode: false,
HideUnsupportedNotifs: false,
CSS: "",
}
}

33
model/settings.go Normal file
View File

@ -0,0 +1,33 @@
package model
type Settings struct {
DefaultVisibility string `json:"default_visibility"`
DefaultFormat string `json:"default_format"`
CopyScope bool `json:"copy_scope"`
ThreadInNewTab bool `json:"thread_in_new_tab"`
HideAttachments bool `json:"hide_attachments"`
MaskNSFW bool `json:"mask_nfsw"`
NotificationInterval int `json:"notifications_interval"`
FluorideMode bool `json:"fluoride_mode"`
DarkMode bool `json:"dark_mode"`
AntiDopamineMode bool `json:"anti_dopamine_mode"`
HideUnsupportedNotifs bool `json:"hide_unsupported_notifs"`
CSS string `json:"css"`
}
func NewSettings() *Settings {
return &Settings{
DefaultVisibility: "public",
DefaultFormat: "",
CopyScope: true,
ThreadInNewTab: false,
HideAttachments: false,
MaskNSFW: true,
NotificationInterval: 0,
FluorideMode: false,
DarkMode: false,
AntiDopamineMode: false,
HideUnsupportedNotifs: false,
CSS: "",
}
}

View File

@ -155,8 +155,3 @@ type FiltersData struct {
*CommonData *CommonData
Filters []*mastodon.Filter Filters []*mastodon.Filter
} }
type MuteData struct {
*CommonData
User *mastodon.Account
}

View File

@ -33,7 +33,6 @@ const (
SearchPage = "search.tmpl" SearchPage = "search.tmpl"
SettingsPage = "settings.tmpl" SettingsPage = "settings.tmpl"
FiltersPage = "filters.tmpl" FiltersPage = "filters.tmpl"
MutePage = "mute.tmpl"
) )
type TemplateData struct { type TemplateData struct {

42
repo/appRepo.go Normal file
View File

@ -0,0 +1,42 @@
package repo
import (
"encoding/json"
"bloat/util"
"bloat/model"
)
type appRepo struct {
db *util.Database
}
func NewAppRepo(db *util.Database) *appRepo {
return &appRepo{
db: db,
}
}
func (repo *appRepo) Add(a model.App) (err error) {
data, err := json.Marshal(a)
if err != nil {
return
}
err = repo.db.Set(a.InstanceDomain, data)
return
}
func (repo *appRepo) Get(instanceDomain string) (a model.App, err error) {
data, err := repo.db.Get(instanceDomain)
if err != nil {
err = model.ErrAppNotFound
return
}
err = json.Unmarshal(data, &a)
if err != nil {
return
}
return
}

47
repo/sessionRepo.go Normal file
View File

@ -0,0 +1,47 @@
package repo
import (
"encoding/json"
"bloat/util"
"bloat/model"
)
type sessionRepo struct {
db *util.Database
}
func NewSessionRepo(db *util.Database) *sessionRepo {
return &sessionRepo{
db: db,
}
}
func (repo *sessionRepo) Add(s model.Session) (err error) {
data, err := json.Marshal(s)
if err != nil {
return
}
err = repo.db.Set(s.ID, data)
return
}
func (repo *sessionRepo) Get(id string) (s model.Session, err error) {
data, err := repo.db.Get(id)
if err != nil {
err = model.ErrSessionNotFound
return
}
err = json.Unmarshal(data, &s)
if err != nil {
return
}
return
}
func (repo *sessionRepo) Remove(id string) {
repo.db.Remove(id)
return
}

View File

@ -1,111 +0,0 @@
package service
import (
"context"
"encoding/base64"
"encoding/json"
"net/http"
"strings"
"time"
"bloat/mastodon"
"bloat/model"
"bloat/renderer"
)
type client struct {
*mastodon.Client
w http.ResponseWriter
r *http.Request
s *model.Session
csrf string
ctx context.Context
rctx *renderer.Context
}
func (c *client) setSession(sess *model.Session) error {
var sb strings.Builder
bw := base64.NewEncoder(base64.URLEncoding, &sb)
err := json.NewEncoder(bw).Encode(sess)
bw.Close()
if err != nil {
return err
}
http.SetCookie(c.w, &http.Cookie{
Name: "session",
Value: sb.String(),
Expires: time.Now().Add(365 * 24 * time.Hour),
})
return nil
}
func (c *client) getSession() (sess *model.Session, err error) {
cookie, _ := c.r.Cookie("session")
if cookie == nil {
return nil, errInvalidSession
}
br := base64.NewDecoder(base64.URLEncoding, strings.NewReader(cookie.Value))
err = json.NewDecoder(br).Decode(&sess)
return
}
func (c *client) unsetSession() {
http.SetCookie(c.w, &http.Cookie{
Name: "session",
Value: "",
Expires: time.Now(),
})
}
func (c *client) writeJson(data interface{}) error {
return json.NewEncoder(c.w).Encode(map[string]interface{}{
"data": data,
})
}
func (c *client) redirect(url string) {
c.w.Header().Add("Location", url)
c.w.WriteHeader(http.StatusFound)
}
func (c *client) authenticate(t int) (err error) {
csrf := c.r.FormValue("csrf_token")
ref := c.r.URL.RequestURI()
defer func() {
if c.s == nil {
c.s = &model.Session{
Settings: *model.NewSettings(),
}
}
c.rctx = &renderer.Context{
HideAttachments: c.s.Settings.HideAttachments,
MaskNSFW: c.s.Settings.MaskNSFW,
ThreadInNewTab: c.s.Settings.ThreadInNewTab,
FluorideMode: c.s.Settings.FluorideMode,
DarkMode: c.s.Settings.DarkMode,
CSRFToken: c.s.CSRFToken,
UserID: c.s.UserID,
AntiDopamineMode: c.s.Settings.AntiDopamineMode,
UserCSS: c.s.Settings.CSS,
Referrer: ref,
}
}()
if t < SESSION {
return
}
sess, err := c.getSession()
if err != nil {
return err
}
c.s = sess
c.Client = mastodon.NewClient(&mastodon.Config{
Server: "https://" + c.s.Instance,
ClientID: c.s.ClientID,
ClientSecret: c.s.ClientSecret,
AccessToken: c.s.AccessToken,
})
if t >= CSRF && (len(csrf) < 1 || csrf != c.s.CSRFToken) {
return errInvalidCSRFToken
}
return
}

View File

@ -27,11 +27,14 @@ type service struct {
instance string instance string
postFormats []model.PostFormat postFormats []model.PostFormat
renderer renderer.Renderer renderer renderer.Renderer
sessionRepo model.SessionRepo
appRepo model.AppRepo
} }
func NewService(cname string, cscope string, cwebsite string, func NewService(cname string, cscope string, cwebsite string,
css string, instance string, postFormats []model.PostFormat, css string, instance string, postFormats []model.PostFormat,
renderer renderer.Renderer) *service { renderer renderer.Renderer, sessionRepo model.SessionRepo,
appRepo model.AppRepo) *service {
return &service{ return &service{
cname: cname, cname: cname,
cscope: cscope, cscope: cscope,
@ -40,9 +43,57 @@ func NewService(cname string, cscope string, cwebsite string,
instance: instance, instance: instance,
postFormats: postFormats, postFormats: postFormats,
renderer: renderer, renderer: renderer,
sessionRepo: sessionRepo,
appRepo: appRepo,
} }
} }
func (s *service) authenticate(c *client, sid string, csrf string, ref string, t int) (err error) {
var sett *model.Settings
defer func() {
if sett == nil {
sett = model.NewSettings()
}
c.rctx = &renderer.Context{
HideAttachments: sett.HideAttachments,
MaskNSFW: sett.MaskNSFW,
ThreadInNewTab: sett.ThreadInNewTab,
FluorideMode: sett.FluorideMode,
DarkMode: sett.DarkMode,
CSRFToken: c.s.CSRFToken,
UserID: c.s.UserID,
AntiDopamineMode: sett.AntiDopamineMode,
UserCSS: sett.CSS,
Referrer: ref,
}
}()
if t < SESSION {
return
}
if len(sid) < 1 {
return errInvalidSession
}
c.s, err = s.sessionRepo.Get(sid)
if err != nil {
return errInvalidSession
}
sett = &c.s.Settings
app, err := s.appRepo.Get(c.s.InstanceDomain)
if err != nil {
return err
}
c.Client = mastodon.NewClient(&mastodon.Config{
Server: app.InstanceURL,
ClientID: app.ClientID,
ClientSecret: app.ClientSecret,
AccessToken: c.s.AccessToken,
})
if t >= CSRF && (len(csrf) < 1 || csrf != c.s.CSRFToken) {
return errInvalidCSRFToken
}
return
}
func (s *service) cdata(c *client, title string, count int, rinterval int, func (s *service) cdata(c *client, title string, count int, rinterval int,
target string) (data *renderer.CommonData) { target string) (data *renderer.CommonData) {
data = &renderer.CommonData{ data = &renderer.CommonData{
@ -678,19 +729,6 @@ func (s *service) UserSearchPage(c *client,
return s.renderer.Render(c.rctx, c.w, renderer.UserSearchPage, data) return s.renderer.Render(c.rctx, c.w, renderer.UserSearchPage, data)
} }
func (s *service) MutePage(c *client, id string) (err error) {
user, err := c.GetAccount(c.ctx, id)
if err != nil {
return
}
cdata := s.cdata(c, "Mute"+user.DisplayName+" @"+user.Acct, 0, 0, "")
data := &renderer.UserData{
User: user,
CommonData: cdata,
}
return s.renderer.Render(c.rctx, c.w, renderer.MutePage, data)
}
func (s *service) AboutPage(c *client) (err error) { func (s *service) AboutPage(c *client) (err error) {
cdata := s.cdata(c, "about", 0, 0, "") cdata := s.cdata(c, "about", 0, 0, "")
data := &renderer.AboutData{ data := &renderer.AboutData{
@ -782,7 +820,7 @@ func (s *service) SingleInstance() (instance string, ok bool) {
return return
} }
func (s *service) NewSession(c *client, instance string) (rurl string, sess *model.Session, err error) { func (s *service) NewSession(c *client, instance string) (rurl string, sid string, err error) {
var instanceURL string var instanceURL string
if strings.HasPrefix(instance, "https://") { if strings.HasPrefix(instance, "https://") {
instanceURL = instance instanceURL = instance
@ -791,7 +829,7 @@ func (s *service) NewSession(c *client, instance string) (rurl string, sess *mod
instanceURL = "https://" + instance instanceURL = "https://" + instance
} }
sid, err := util.NewSessionID() sid, err = util.NewSessionID()
if err != nil { if err != nil {
return return
} }
@ -800,7 +838,23 @@ func (s *service) NewSession(c *client, instance string) (rurl string, sess *mod
return return
} }
app, err := mastodon.RegisterApp(c.ctx, &mastodon.AppConfig{ sess := model.Session{
ID: sid,
InstanceDomain: instance,
CSRFToken: csrf,
Settings: *model.NewSettings(),
}
err = s.sessionRepo.Add(sess)
if err != nil {
return
}
app, err := s.appRepo.Get(instance)
if err != nil {
if err != model.ErrAppNotFound {
return
}
mastoApp, err := mastodon.RegisterApp(c.ctx, &mastodon.AppConfig{
Server: instanceURL, Server: instanceURL,
ClientName: s.cname, ClientName: s.cname,
Scopes: s.cscope, Scopes: s.cscope,
@ -808,15 +862,18 @@ func (s *service) NewSession(c *client, instance string) (rurl string, sess *mod
RedirectURIs: s.cwebsite + "/oauth_callback", RedirectURIs: s.cwebsite + "/oauth_callback",
}) })
if err != nil { if err != nil {
return return "", "", err
}
app = model.App{
InstanceDomain: instance,
InstanceURL: instanceURL,
ClientID: mastoApp.ClientID,
ClientSecret: mastoApp.ClientSecret,
}
err = s.appRepo.Add(app)
if err != nil {
return "", "", err
} }
sess = &model.Session{
ID: sid,
Instance: instance,
ClientID: app.ClientID,
ClientSecret: app.ClientSecret,
CSRFToken: csrf,
Settings: *model.NewSettings(),
} }
u, err := url.Parse("/oauth/authorize") u, err := url.Parse("/oauth/authorize")
@ -850,7 +907,12 @@ func (s *service) Signin(c *client, code string) (err error) {
} }
c.s.AccessToken = c.GetAccessToken(c.ctx) c.s.AccessToken = c.GetAccessToken(c.ctx)
c.s.UserID = u.ID c.s.UserID = u.ID
return c.setSession(c.s) return s.sessionRepo.Add(c.s)
}
func (s *service) Signout(c *client) (err error) {
s.sessionRepo.Remove(c.s.ID)
return
} }
func (s *service) Post(c *client, content string, replyToID string, func (s *service) Post(c *client, content string, replyToID string,
@ -943,8 +1005,8 @@ func (s *service) Reject(c *client, id string) (err error) {
return c.FollowRequestReject(c.ctx, id) return c.FollowRequestReject(c.ctx, id)
} }
func (s *service) Mute(c *client, id string, notifications bool, duration int) (err error) { func (s *service) Mute(c *client, id string, notifications *bool) (err error) {
_, err = c.AccountMute(c.ctx, id, notifications, duration) _, err = c.AccountMute(c.ctx, id, notifications)
return return
} }
@ -982,8 +1044,12 @@ func (s *service) SaveSettings(c *client, settings *model.Settings) (err error)
if len(settings.CSS) > 1<<20 { if len(settings.CSS) > 1<<20 {
return errInvalidArgument return errInvalidArgument
} }
c.s.Settings = *settings sess, err := s.sessionRepo.Get(c.s.ID)
return c.setSession(c.s) if err != nil {
return
}
sess.Settings = *settings
return s.sessionRepo.Add(sess)
} }
func (s *service) MuteConversation(c *client, id string) (err error) { func (s *service) MuteConversation(c *client, id string) (err error) {

View File

@ -1,17 +1,24 @@
package service package service
import ( import (
"context"
"encoding/json" "encoding/json"
"log" "log"
"net/http" "net/http"
"strconv" "strconv"
"time" "time"
"bloat/mastodon"
"bloat/model" "bloat/model"
"bloat/renderer"
"github.com/gorilla/mux" "github.com/gorilla/mux"
) )
const (
sessionExp = 365 * 24 * time.Hour
)
const ( const (
HTML int = iota HTML int = iota
JSON JSON
@ -23,6 +30,35 @@ const (
CSRF CSRF
) )
type client struct {
*mastodon.Client
w http.ResponseWriter
r *http.Request
s model.Session
csrf string
ctx context.Context
rctx *renderer.Context
}
func setSessionCookie(w http.ResponseWriter, sid string, exp time.Duration) {
http.SetCookie(w, &http.Cookie{
Name: "session_id",
Value: sid,
Expires: time.Now().Add(exp),
})
}
func writeJson(c *client, data interface{}) error {
return json.NewEncoder(c.w).Encode(map[string]interface{}{
"data": data,
})
}
func redirect(c *client, url string) {
c.w.Header().Add("Location", url)
c.w.WriteHeader(http.StatusFound)
}
func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
r := mux.NewRouter() r := mux.NewRouter()
@ -39,6 +75,16 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
} }
} }
authenticate := func(c *client, t int) error {
var sid string
if cookie, _ := c.r.Cookie("session_id"); cookie != nil {
sid = cookie.Value
}
csrf := c.r.FormValue("csrf_token")
ref := c.r.URL.RequestURI()
return s.authenticate(c, sid, csrf, ref, t)
}
handle := func(f func(c *client) error, at int, rt int) http.HandlerFunc { handle := func(f func(c *client) error, at int, rt int) http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) { return func(w http.ResponseWriter, req *http.Request) {
var err error var err error
@ -62,7 +108,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
} }
c.w.Header().Add("Content-Type", ct) c.w.Header().Add("Content-Type", ct)
err = c.authenticate(at) err = authenticate(c, at)
if err != nil { if err != nil {
writeError(c, err, rt, req.Method == http.MethodGet) writeError(c, err, rt, req.Method == http.MethodGet)
return return
@ -77,16 +123,16 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
} }
rootPage := handle(func(c *client) error { rootPage := handle(func(c *client) error {
err := c.authenticate(SESSION) err := authenticate(c, SESSION)
if err != nil { if err != nil {
if err == errInvalidSession { if err == errInvalidSession {
c.redirect("/signin") redirect(c, "/signin")
return nil return nil
} }
return err return err
} }
if !c.s.IsLoggedIn() { if !c.s.IsLoggedIn() {
c.redirect("/signin") redirect(c, "/signin")
return nil return nil
} }
return s.RootPage(c) return s.RootPage(c)
@ -101,12 +147,12 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if !ok { if !ok {
return s.SigninPage(c) return s.SigninPage(c)
} }
url, sess, err := s.NewSession(c, instance) url, sid, err := s.NewSession(c, instance)
if err != nil { if err != nil {
return err return err
} }
c.setSession(sess) setSessionCookie(c.w, sid, sessionExp)
c.redirect(url) redirect(c, url)
return nil return nil
}, NOAUTH, HTML) }, NOAUTH, HTML)
@ -121,7 +167,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
}, SESSION, HTML) }, SESSION, HTML)
defaultTimelinePage := handle(func(c *client) error { defaultTimelinePage := handle(func(c *client) error {
c.redirect("/timeline/home") redirect(c, "/timeline/home")
return nil return nil
}, SESSION, HTML) }, SESSION, HTML)
@ -171,11 +217,6 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
return s.UserSearchPage(c, id, sq, offset) return s.UserSearchPage(c, id, sq, offset)
}, SESSION, HTML) }, SESSION, HTML)
mutePage := handle(func(c *client) error {
id, _ := mux.Vars(c.r)["id"]
return s.MutePage(c, id)
}, SESSION, HTML)
aboutPage := handle(func(c *client) error { aboutPage := handle(func(c *client) error {
return s.AboutPage(c) return s.AboutPage(c)
}, SESSION, HTML) }, SESSION, HTML)
@ -202,12 +243,12 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
signin := handle(func(c *client) error { signin := handle(func(c *client) error {
instance := c.r.FormValue("instance") instance := c.r.FormValue("instance")
url, sess, err := s.NewSession(c, instance) url, sid, err := s.NewSession(c, instance)
if err != nil { if err != nil {
return err return err
} }
c.setSession(sess) setSessionCookie(c.w, sid, sessionExp)
c.redirect(url) redirect(c, url)
return nil return nil
}, NOAUTH, HTML) }, NOAUTH, HTML)
@ -218,7 +259,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil { if err != nil {
return err return err
} }
c.redirect("/") redirect(c, "/")
return nil return nil
}, SESSION, HTML) }, SESSION, HTML)
@ -246,7 +287,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
} else { } else {
location = c.r.FormValue("referrer") location = c.r.FormValue("referrer")
} }
c.redirect(location) redirect(c, location)
return nil return nil
}, CSRF, HTML) }, CSRF, HTML)
@ -260,7 +301,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if len(rid) > 0 { if len(rid) > 0 {
id = rid id = rid
} }
c.redirect(c.r.FormValue("referrer") + "#status-" + id) redirect(c, c.r.FormValue("referrer")+"#status-"+id)
return nil return nil
}, CSRF, HTML) }, CSRF, HTML)
@ -274,7 +315,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if len(rid) > 0 { if len(rid) > 0 {
id = rid id = rid
} }
c.redirect(c.r.FormValue("referrer") + "#status-" + id) redirect(c, c.r.FormValue("referrer")+"#status-"+id)
return nil return nil
}, CSRF, HTML) }, CSRF, HTML)
@ -288,7 +329,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if len(rid) > 0 { if len(rid) > 0 {
id = rid id = rid
} }
c.redirect(c.r.FormValue("referrer") + "#status-" + id) redirect(c, c.r.FormValue("referrer")+"#status-"+id)
return nil return nil
}, CSRF, HTML) }, CSRF, HTML)
@ -302,7 +343,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if len(rid) > 0 { if len(rid) > 0 {
id = rid id = rid
} }
c.redirect(c.r.FormValue("referrer") + "#status-" + id) redirect(c, c.r.FormValue("referrer")+"#status-"+id)
return nil return nil
}, CSRF, HTML) }, CSRF, HTML)
@ -314,7 +355,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil { if err != nil {
return err return err
} }
c.redirect(c.r.FormValue("referrer") + "#status-" + statusID) redirect(c, c.r.FormValue("referrer")+"#status-"+statusID)
return nil return nil
}, CSRF, HTML) }, CSRF, HTML)
@ -330,7 +371,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil { if err != nil {
return err return err
} }
c.redirect(c.r.FormValue("referrer")) redirect(c, c.r.FormValue("referrer"))
return nil return nil
}, CSRF, HTML) }, CSRF, HTML)
@ -340,7 +381,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil { if err != nil {
return err return err
} }
c.redirect(c.r.FormValue("referrer")) redirect(c, c.r.FormValue("referrer"))
return nil return nil
}, CSRF, HTML) }, CSRF, HTML)
@ -350,7 +391,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil { if err != nil {
return err return err
} }
c.redirect(c.r.FormValue("referrer")) redirect(c, c.r.FormValue("referrer"))
return nil return nil
}, CSRF, HTML) }, CSRF, HTML)
@ -360,19 +401,23 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil { if err != nil {
return err return err
} }
c.redirect(c.r.FormValue("referrer")) redirect(c, c.r.FormValue("referrer"))
return nil return nil
}, CSRF, HTML) }, CSRF, HTML)
mute := handle(func(c *client) error { mute := handle(func(c *client) error {
id, _ := mux.Vars(c.r)["id"] id, _ := mux.Vars(c.r)["id"]
notifications, _ := strconv.ParseBool(c.r.FormValue("notifications")) q := c.r.URL.Query()
duration, _ := strconv.Atoi(c.r.FormValue("duration")) var notifications *bool
err := s.Mute(c, id, notifications, duration) if r, ok := q["notifications"]; ok && len(r) > 0 {
notifications = new(bool)
*notifications = r[0] == "true"
}
err := s.Mute(c, id, notifications)
if err != nil { if err != nil {
return err return err
} }
c.redirect("/user/" + id) redirect(c, c.r.FormValue("referrer"))
return nil return nil
}, CSRF, HTML) }, CSRF, HTML)
@ -382,7 +427,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil { if err != nil {
return err return err
} }
c.redirect(c.r.FormValue("referrer")) redirect(c, c.r.FormValue("referrer"))
return nil return nil
}, CSRF, HTML) }, CSRF, HTML)
@ -392,7 +437,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil { if err != nil {
return err return err
} }
c.redirect(c.r.FormValue("referrer")) redirect(c, c.r.FormValue("referrer"))
return nil return nil
}, CSRF, HTML) }, CSRF, HTML)
@ -402,7 +447,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil { if err != nil {
return err return err
} }
c.redirect(c.r.FormValue("referrer")) redirect(c, c.r.FormValue("referrer"))
return nil return nil
}, CSRF, HTML) }, CSRF, HTML)
@ -412,7 +457,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil { if err != nil {
return err return err
} }
c.redirect(c.r.FormValue("referrer")) redirect(c, c.r.FormValue("referrer"))
return nil return nil
}, CSRF, HTML) }, CSRF, HTML)
@ -422,7 +467,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil { if err != nil {
return err return err
} }
c.redirect(c.r.FormValue("referrer")) redirect(c, c.r.FormValue("referrer"))
return nil return nil
}, CSRF, HTML) }, CSRF, HTML)
@ -459,7 +504,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil { if err != nil {
return err return err
} }
c.redirect("/") redirect(c, "/")
return nil return nil
}, CSRF, HTML) }, CSRF, HTML)
@ -469,7 +514,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil { if err != nil {
return err return err
} }
c.redirect(c.r.FormValue("referrer")) redirect(c, c.r.FormValue("referrer"))
return nil return nil
}, CSRF, HTML) }, CSRF, HTML)
@ -479,7 +524,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil { if err != nil {
return err return err
} }
c.redirect(c.r.FormValue("referrer")) redirect(c, c.r.FormValue("referrer"))
return nil return nil
}, CSRF, HTML) }, CSRF, HTML)
@ -489,7 +534,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil { if err != nil {
return err return err
} }
c.redirect(c.r.FormValue("referrer")) redirect(c, c.r.FormValue("referrer"))
return nil return nil
}, CSRF, HTML) }, CSRF, HTML)
@ -500,7 +545,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil { if err != nil {
return err return err
} }
c.redirect(c.r.FormValue("referrer")) redirect(c, c.r.FormValue("referrer"))
return nil return nil
}, CSRF, HTML) }, CSRF, HTML)
@ -514,7 +559,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if len(rid) > 0 { if len(rid) > 0 {
id = rid id = rid
} }
c.redirect(c.r.FormValue("referrer") + "#status-" + id) redirect(c, c.r.FormValue("referrer")+"#status-"+id)
return nil return nil
}, CSRF, HTML) }, CSRF, HTML)
@ -528,7 +573,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if len(rid) > 0 { if len(rid) > 0 {
id = rid id = rid
} }
c.redirect(c.r.FormValue("referrer") + "#status-" + id) redirect(c, c.r.FormValue("referrer")+"#status-"+id)
return nil return nil
}, CSRF, HTML) }, CSRF, HTML)
@ -539,7 +584,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil { if err != nil {
return err return err
} }
c.redirect(c.r.FormValue("referrer")) redirect(c, c.r.FormValue("referrer"))
return nil return nil
}, CSRF, HTML) }, CSRF, HTML)
@ -549,7 +594,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil { if err != nil {
return err return err
} }
c.redirect(c.r.FormValue("referrer")) redirect(c, c.r.FormValue("referrer"))
return nil return nil
}, CSRF, HTML) }, CSRF, HTML)
@ -563,7 +608,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil { if err != nil {
return err return err
} }
c.redirect(c.r.FormValue("referrer")) redirect(c, c.r.FormValue("referrer"))
return nil return nil
}, CSRF, HTML) }, CSRF, HTML)
@ -573,7 +618,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil { if err != nil {
return err return err
} }
c.redirect(c.r.FormValue("referrer")) redirect(c, c.r.FormValue("referrer"))
return nil return nil
}, CSRF, HTML) }, CSRF, HTML)
@ -584,7 +629,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil { if err != nil {
return err return err
} }
c.redirect(c.r.FormValue("referrer")) redirect(c, c.r.FormValue("referrer"))
return nil return nil
}, CSRF, HTML) }, CSRF, HTML)
@ -603,7 +648,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil { if err != nil {
return err return err
} }
c.redirect(c.r.FormValue("referrer")) redirect(c, c.r.FormValue("referrer"))
return nil return nil
}, CSRF, HTML) }, CSRF, HTML)
@ -615,13 +660,14 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil { if err != nil {
return err return err
} }
c.redirect(c.r.FormValue("referrer")) redirect(c, c.r.FormValue("referrer"))
return nil return nil
}, CSRF, HTML) }, CSRF, HTML)
signout := handle(func(c *client) error { signout := handle(func(c *client) error {
c.unsetSession() s.Signout(c)
c.redirect("/") setSessionCookie(c.w, "", 0)
redirect(c, "/")
return nil return nil
}, CSRF, HTML) }, CSRF, HTML)
@ -631,7 +677,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil { if err != nil {
return err return err
} }
return c.writeJson(count) return writeJson(c, count)
}, CSRF, JSON) }, CSRF, JSON)
fUnlike := handle(func(c *client) error { fUnlike := handle(func(c *client) error {
@ -640,7 +686,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil { if err != nil {
return err return err
} }
return c.writeJson(count) return writeJson(c, count)
}, CSRF, JSON) }, CSRF, JSON)
fRetweet := handle(func(c *client) error { fRetweet := handle(func(c *client) error {
@ -649,7 +695,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil { if err != nil {
return err return err
} }
return c.writeJson(count) return writeJson(c, count)
}, CSRF, JSON) }, CSRF, JSON)
fUnretweet := handle(func(c *client) error { fUnretweet := handle(func(c *client) error {
@ -658,7 +704,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil { if err != nil {
return err return err
} }
return c.writeJson(count) return writeJson(c, count)
}, CSRF, JSON) }, CSRF, JSON)
r.HandleFunc("/", rootPage).Methods(http.MethodGet) r.HandleFunc("/", rootPage).Methods(http.MethodGet)
@ -674,7 +720,6 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
r.HandleFunc("/user/{id}", userPage).Methods(http.MethodGet) r.HandleFunc("/user/{id}", userPage).Methods(http.MethodGet)
r.HandleFunc("/user/{id}/{type}", userPage).Methods(http.MethodGet) r.HandleFunc("/user/{id}/{type}", userPage).Methods(http.MethodGet)
r.HandleFunc("/usersearch/{id}", userSearchPage).Methods(http.MethodGet) r.HandleFunc("/usersearch/{id}", userSearchPage).Methods(http.MethodGet)
r.HandleFunc("/mute/{id}", mutePage).Methods(http.MethodGet)
r.HandleFunc("/about", aboutPage).Methods(http.MethodGet) r.HandleFunc("/about", aboutPage).Methods(http.MethodGet)
r.HandleFunc("/emojis", emojisPage).Methods(http.MethodGet) r.HandleFunc("/emojis", emojisPage).Methods(http.MethodGet)
r.HandleFunc("/search", searchPage).Methods(http.MethodGet) r.HandleFunc("/search", searchPage).Methods(http.MethodGet)

View File

@ -1,29 +0,0 @@
{{with .Data}}
{{template "header.tmpl" (WithContext .CommonData $.Ctx)}}
<div class="page-title"> Mute {{.User.Acct}} </div>
<form action="/mute/{{.User.ID}}" method="POST">
<input type="hidden" name="csrf_token" value="{{$.Ctx.CSRFToken}}">
<input type="hidden" name="referrer" value="{{$.Ctx.Referrer}}">
<div class="settings-form-field">
<input id="notifications" name="notifications" type="checkbox" value="true" checked>
<label for="notifications"> Mute notifications </label>
</div>
<div class="settings-form-field">
<label for="duration"> Auto unmute </label>
<select id="duration" name="duration">
<option value="0" selected>Disabled</option>
<option value="300">After 5m</option>
<option value="1800">After 30m</option>
<option value="3600">After 1h</option>
<option value="21600">After 6h</option>
<option value="86400">After 1d</option>
<option value="259200">After 3d</option>
<option value="604800">After 7d</option>
</select>
</div>
<button type="submit"> Mute </button>
</form>
{{template "footer.tmpl"}}
{{end}}

View File

@ -79,7 +79,17 @@
<input type="submit" value="unmute" class="btn-link"> <input type="submit" value="unmute" class="btn-link">
</form> </form>
{{else}} {{else}}
<a href="/mute/{{.User.ID}}"> mute </a> <form class="d-inline" action="/mute/{{.User.ID}}" method="post">
<input type="hidden" name="csrf_token" value="{{$.Ctx.CSRFToken}}">
<input type="hidden" name="referrer" value="{{$.Ctx.Referrer}}">
<input type="submit" value="mute" class="btn-link">
</form>
-
<form class="d-inline" action="/mute/{{.User.ID}}?notifications=false" method="post">
<input type="hidden" name="csrf_token" value="{{$.Ctx.CSRFToken}}">
<input type="hidden" name="referrer" value="{{$.Ctx.Referrer}}">
<input type="submit" value="mute (keep notifications)" class="btn-link">
</form>
{{end}} {{end}}
{{if .User.Pleroma.Relationship.Following}} {{if .User.Pleroma.Relationship.Following}}
- -
@ -125,7 +135,7 @@
{{if .User.Fields}} {{if .User.Fields}}
<div class="user-fields"> <div class="user-fields">
{{range .User.Fields}} {{range .User.Fields}}
<div>{{EmojiFilter .Name $.Data.User.Emojis | Raw}} - {{EmojiFilter .Value $.Data.User.Emojis | Raw}}</div> <div>{{.Name}} - {{.Value | Raw}}</div>
{{end}} {{end}}
</div> </div>
{{end}} {{end}}

91
util/kv.go Normal file
View File

@ -0,0 +1,91 @@
package util
import (
"errors"
"io/ioutil"
"os"
"path/filepath"
"strings"
"sync"
)
var (
errInvalidKey = errors.New("invalid key")
errNoSuchKey = errors.New("no such key")
)
type Database struct {
cache map[string][]byte
basedir string
m sync.RWMutex
}
func NewDatabse(basedir string) (db *Database, err error) {
err = os.Mkdir(basedir, 0755)
if err != nil && !os.IsExist(err) {
return
}
return &Database{
cache: make(map[string][]byte),
basedir: basedir,
}, nil
}
func (db *Database) Set(key string, val []byte) (err error) {
if len(key) < 1 || strings.ContainsRune(key, os.PathSeparator) {
return errInvalidKey
}
err = ioutil.WriteFile(filepath.Join(db.basedir, key), val, 0644)
if err != nil {
return
}
db.m.Lock()
db.cache[key] = val
db.m.Unlock()
return
}
func (db *Database) Get(key string) (val []byte, err error) {
if len(key) < 1 || strings.ContainsRune(key, os.PathSeparator) {
return nil, errInvalidKey
}
db.m.RLock()
data, ok := db.cache[key]
db.m.RUnlock()
if !ok {
data, err = ioutil.ReadFile(filepath.Join(db.basedir, key))
if err != nil {
err = errNoSuchKey
return nil, err
}
db.m.Lock()
db.cache[key] = data
db.m.Unlock()
}
val = make([]byte, len(data))
copy(val, data)
return
}
func (db *Database) Remove(key string) {
if len(key) < 1 || strings.ContainsRune(key, os.PathSeparator) {
return
}
os.Remove(filepath.Join(db.basedir, key))
db.m.Lock()
delete(db.cache, key)
db.m.Unlock()
return
}