From df10b44fe7a45cc3270ba5fd14c775d5cb51b88a Mon Sep 17 00:00:00 2001 From: Mikhail Klementyev Date: Sun, 20 Nov 2016 13:22:00 +0300 Subject: [PATCH] Implements post query --- commands/commands.go | 183 ++++++++++++++++++++++++++++++++----------- main.go | 10 +-- storage/storage.go | 144 ++++++++++++++++++++++++++++++++++ 3 files changed, 287 insertions(+), 50 deletions(-) diff --git a/commands/commands.go b/commands/commands.go index f8fdb70..587edad 100644 --- a/commands/commands.go +++ b/commands/commands.go @@ -9,73 +9,88 @@ package commands import ( - "bytes" "database/sql" "strings" "fmt" - "io/ioutil" "log" "net/http" "net/url" "github.com/jollheef/wi/storage" + "github.com/PuerkitoBio/goquery" "github.com/jaytaylor/html2text" - "golang.org/x/net/html" "golang.org/x/net/html/charset" ) -func parseLink(db *sql.DB, oldPage, value string, lastUrl *url.URL) (htmlPage string, err error) { - linkUrl, err := lastUrl.Parse(value) - if err != nil { - return - } +func fixLinks(db *sql.DB, doc *goquery.Document, pageUrl *url.URL) (err error) { - linkNo, err := storage.GetLinkID(db, linkUrl.String()) - if err != nil { - linkNo, err = storage.AddLink(db, linkUrl.String()) + doc.Find("a").Each(func(i int, s *goquery.Selection) { + url, exists := s.Attr("href") + if !exists { + return + } + + linkUrl, err := pageUrl.Parse(url) if err != nil { return } - } - htmlPage = oldPage + linkNo, err := storage.GetLinkID(db, linkUrl.String()) + if err != nil { + linkNo, err = storage.AddLink(db, linkUrl.String()) + if err != nil { + log.Fatalln("Add link:", err) + } + } - for _, s := range []string{value, html.EscapeString(value)} { - htmlPage = strings.Replace(htmlPage, "\""+s+"\"", - "\""+fmt.Sprintf("%d", linkNo)+"\"", -1) - } + s.SetAttr("href", fmt.Sprintf("%d", linkNo)) + }) return } -func parseLinks(db *sql.DB, body []byte, lastUrl *url.URL) (htmlPage string, err error) { - htmlPage = string(body) +func fixForms(db *sql.DB, doc *goquery.Document, pageUrl *url.URL) (err error) { - z := html.NewTokenizer(bytes.NewReader(body)) - - for { - tt := z.Next() - if tt == html.ErrorToken { - break - } - - for { - key, value, moreAttr := z.TagAttr() - - if string(key) == "href" { - htmlPage, err = parseLink(db, htmlPage, string(value), lastUrl) - if err != nil { + doc.Find("form").Each(func(i int, s *goquery.Selection) { + var fields []storage.Field + s.Find("input").Map( + func(i int, s *goquery.Selection) (str string) { + f := storage.Field{} + var exists bool + f.Name, exists = s.Attr("name") + if !exists { return } - } + f.Value, _ = s.Attr("value") + hidden, _ := s.Attr("type") + if hidden == "hidden" { + f.Hidden = true + } + fields = append(fields, f) + return + }) - if !moreAttr { - break + action, _ := s.Attr("action") + + actionUrl, err := pageUrl.Parse(action) + if err != nil { + return + } + + method, _ := s.Attr("method") + + formNo, err := storage.GetFormID(db, fields, actionUrl.String(), method) + if err != nil { + formNo, err = storage.AddForm(db, fields, actionUrl.String(), method) + if err != nil { + log.Fatalln(err) } } - } + + s.AppendHtml(fmt.Sprintf("(%d %s)", formNo, strings.ToUpper(method))) + }) return } @@ -94,8 +109,12 @@ func Get(db *sql.DB, linkUrl string) { linkUrl = "https://" + linkUrl } - // TODO Full url encoding - req, err := http.NewRequest("GET", strings.Replace(linkUrl, " ", "%20", -1), nil) + u, err := url.Parse(linkUrl) + if err != nil { + log.Fatalln(err) + } + + req, err := http.NewRequest("GET", u.String(), nil) if err != nil { log.Fatalln(err) } @@ -120,14 +139,24 @@ func Get(db *sql.DB, linkUrl string) { log.Fatalln("Encoding error:", err) } - body, err := ioutil.ReadAll(utf8) + doc, err := goquery.NewDocumentFromReader(utf8) if err != nil { - log.Fatalln("IO error:", err) + log.Fatalln("Create document error:", err) } - htmlPage, err := parseLinks(db, body, lastUrl) + err = fixLinks(db, doc, lastUrl) if err != nil { - log.Fatalln("Parse links error:", err) + log.Fatalln("Fix links error:", err) + } + + err = fixForms(db, doc, lastUrl) + if err != nil { + log.Fatalln("Fix forms error", err) + } + + htmlPage, err := doc.Html() + if err != nil { + log.Fatalln("Convert to html error:", err) } text, err := html2text.FromString(htmlPage) @@ -139,8 +168,72 @@ func Get(db *sql.DB, linkUrl string) { fmt.Println(text) } -func Post(db *sql.DB, formID int64, formArgs []string) { - fmt.Println("Not implemented") +func Form(db *sql.DB, formID int64, formArgs []string) { + fields, formUrl, post, err := storage.GetForm(db, formID) + if err != nil { + log.Fatalln("Get form:", err) + } + + if len(formArgs) == 0 { + if post { + fmt.Print("POST ") + } + + fmt.Println(formUrl) + + fmt.Print("Values: ") + for i, f := range fields { + if i != 0 { + fmt.Print("\n\t") + } + fmt.Printf(`%s="%s"`, f.Name, f.Value) + } + fmt.Println() + + return + } + + urlData := url.Values{} + for _, f := range fields { + urlData.Set(f.Name, f.Value) + } + + for _, fa := range formArgs { + nameAndValue := strings.Split(fa, " ") + if len(nameAndValue) != 2 { + continue + } + name := nameAndValue[0] + value := nameAndValue[1] + urlData.Set(name, value) + } + + client := &http.Client{} + + var lastUrl *url.URL + + client.CheckRedirect = func(r *http.Request, via []*http.Request) (err error) { + lastUrl = r.URL + return + } + + resp, err := client.PostForm(formUrl, urlData) + if err != nil { + fmt.Println(err) + } + + if lastUrl == nil { + lastUrl, _ = resp.Location() + } + + log.Println(resp.Status) + + var status int64 + fmt.Sscanf(resp.Status, "%d", &status) + + if status >= 300 && status < 400 { + Get(db, lastUrl.String()) + } } func Link(db *sql.DB, linkID int64, fromHistory bool) { diff --git a/main.go b/main.go index 2bb6bb1..4ba479a 100644 --- a/main.go +++ b/main.go @@ -42,9 +42,9 @@ var ( get = kingpin.Command("get", "Get url") getUrl = get.Arg("url", "Url").Required().String() - post = kingpin.Command("post", "Fill post form") - postID = post.Arg("id", "Form ID").Required().Int64() - postArgs = SearchList(post.Arg("args", "Post form arguments")) + form = kingpin.Command("form", "Fill form") + formID = form.Arg("id", "Form ID").Required().Int64() + formArgs = SearchList(form.Arg("args", "Form arguments")) link = kingpin.Command("link", "Get link") linkNo = link.Arg("no", "Number").Required().Int64() @@ -68,8 +68,8 @@ func main() { switch kingpin.Parse() { case "get": commands.Get(db, *getUrl) - case "post": - commands.Post(db, *postID, *postArgs) + case "form": + commands.Form(db, *formID, *formArgs) case "link": commands.Link(db, *linkNo, *linkFromHistory) case "history": diff --git a/storage/storage.go b/storage/storage.go index f0d08f5..83536be 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -10,6 +10,9 @@ package storage import ( "database/sql" + "errors" + "reflect" + "strings" _ "github.com/mattn/go-sqlite3" ) @@ -28,6 +31,147 @@ func OpenDB(path string) (db *sql.DB, err error) { _, err = db.Exec("CREATE TABLE IF NOT EXISTS `history` " + "( `id` INTEGER PRIMARY KEY AUTOINCREMENT, `url` TEXT );") + if err != nil { + return + } + + _, err = db.Exec("CREATE TABLE IF NOT EXISTS `fields` " + + "( `id` INTEGER PRIMARY KEY AUTOINCREMENT, " + + " `form_id` INTEGER, " + + " `hidden` BOOLEAN, " + + " `value` TEXT, " + + " `name` TEXT );") + if err != nil { + return + } + + _, err = db.Exec("CREATE TABLE IF NOT EXISTS `forms` " + + "( `id` INTEGER PRIMARY KEY AUTOINCREMENT, " + + " `post` BOOLEAN, " + + " `url` TEXT );") + + return +} + +type Field struct { + Hidden bool + Value string + Name string +} + +func getFields(db *sql.DB, formNo int64) (fields []Field, err error) { + stmt, err := db.Prepare("SELECT `hidden`, `value`, `name` FROM `fields` WHERE form_id=$1;") + if err != nil { + return + } + defer stmt.Close() + + rows, err := stmt.Query(formNo) + if err != nil { + return + } + defer rows.Close() + + for rows.Next() { + var f Field + err = rows.Scan(&f.Hidden, &f.Value, &f.Name) + if err != nil { + return + } + + fields = append(fields, f) + } + + return +} + +func addField(db *sql.DB, name, value string, hidden bool, formNo int64) (err error) { + stmt, err := db.Prepare("INSERT INTO `fields` " + + "(`form_id`, `name`, `hidden`, `value`) VALUES ($1, $2, $3, $4);") + if err != nil { + return + } + defer stmt.Close() + + _, err = stmt.Exec(formNo, name, hidden, value) + return +} + +func AddForm(db *sql.DB, fields []Field, url, method string) (formNo int64, err error) { + stmt, err := db.Prepare("INSERT INTO `forms` (`url`, `post`) VALUES ($1, $2);") + if err != nil { + return + } + defer stmt.Close() + + post := false // GET + if strings.ToUpper(method) == "POST" { + post = true + } + + r, err := stmt.Exec(url, post) + if err != nil { + return + } + + formNo, err = r.LastInsertId() + if err != nil { + return + } + + for _, f := range fields { + addField(db, f.Name, f.Value, f.Hidden, formNo) + } + + return +} + +func GetFormID(db *sql.DB, fields []Field, url, method string) (formNo int64, err error) { + stmt, err := db.Prepare("SELECT `id` FROM `forms` WHERE url=$1 AND post=$2;") + if err != nil { + return + } + defer stmt.Close() + + post := false // GET + if strings.ToUpper(method) == "POST" { + post = true + } + + err = stmt.QueryRow(url, post).Scan(&formNo) + if err != nil { + return + } + + dbFields, err := getFields(db, formNo) + if err != nil { + return + } + + if !reflect.DeepEqual(fields, dbFields) { + err = errors.New("Fields not match") + return + } + + return +} + +func GetForm(db *sql.DB, formID int64) (fields []Field, url string, post bool, err error) { + stmt, err := db.Prepare("SELECT `post`, `url` FROM `forms` WHERE id=$1;") + if err != nil { + return + } + defer stmt.Close() + + err = stmt.QueryRow(formID).Scan(&post, &url) + if err != nil { + return + } + + fields, err = getFields(db, formID) + if err != nil { + return + } return }