Correct work with cookies

This commit is contained in:
Mikhail Klementyev 2016-11-20 17:57:27 +03:00
parent f01134eaa4
commit dc8b2f1b26
3 changed files with 23 additions and 111 deletions

View file

@ -21,6 +21,7 @@ import (
"github.com/PuerkitoBio/goquery" "github.com/PuerkitoBio/goquery"
"github.com/jaytaylor/html2text" "github.com/jaytaylor/html2text"
cookiejar "github.com/juju/persistent-cookiejar"
"golang.org/x/net/html/charset" "golang.org/x/net/html/charset"
) )
@ -95,8 +96,8 @@ func fixForms(db *sql.DB, doc *goquery.Document, pageUrl *url.URL) (err error) {
return return
} }
func Get(db *sql.DB, linkUrl string) { func Get(db *sql.DB, jar *cookiejar.Jar, linkUrl string) {
client := &http.Client{} client := &http.Client{Jar: jar}
var lastUrl *url.URL var lastUrl *url.URL
@ -132,13 +133,6 @@ func Get(db *sql.DB, linkUrl string) {
storage.AddHistoryURL(db, linkUrl) storage.AddHistoryURL(db, linkUrl)
if len(resp.Cookies()) != 0 {
err = storage.AddCookies(db, lastUrl.Host, resp.Cookies())
if err != nil {
log.Fatalln("Add cookies:", err)
}
}
defer resp.Body.Close() defer resp.Body.Close()
utf8, err := charset.NewReader(resp.Body, resp.Header.Get("Content-Type")) utf8, err := charset.NewReader(resp.Body, resp.Header.Get("Content-Type"))
@ -175,7 +169,7 @@ func Get(db *sql.DB, linkUrl string) {
fmt.Println(text) fmt.Println(text)
} }
func Form(db *sql.DB, formID int64, formArgs []string) { func Form(db *sql.DB, jar *cookiejar.Jar, formID int64, formArgs []string) {
fields, formUrl, post, err := storage.GetForm(db, formID) fields, formUrl, post, err := storage.GetForm(db, formID)
if err != nil { if err != nil {
log.Fatalln("Get form:", err) log.Fatalln("Get form:", err)
@ -215,7 +209,7 @@ func Form(db *sql.DB, formID int64, formArgs []string) {
urlData.Set(name, value) urlData.Set(name, value)
} }
client := &http.Client{} client := &http.Client{Jar: jar}
var lastUrl *url.URL var lastUrl *url.URL
@ -238,19 +232,12 @@ func Form(db *sql.DB, formID int64, formArgs []string) {
var status int64 var status int64
fmt.Sscanf(resp.Status, "%d", &status) fmt.Sscanf(resp.Status, "%d", &status)
if status < 400 && len(resp.Cookies()) != 0 {
err = storage.AddCookies(db, lastUrl.Host, resp.Cookies())
if err != nil {
log.Fatalln("Add cookies:", err)
}
}
if status >= 300 && status < 400 { if status >= 300 && status < 400 {
Get(db, lastUrl.String()) Get(db, jar, lastUrl.String())
} }
} }
func Link(db *sql.DB, linkID int64, fromHistory bool) { func Link(db *sql.DB, jar *cookiejar.Jar, linkID int64, fromHistory bool) {
var linkUrl string var linkUrl string
var err error var err error
@ -265,7 +252,7 @@ func Link(db *sql.DB, linkID int64, fromHistory bool) {
log.Fatalln("Get link/history url error:", err) log.Fatalln("Get link/history url error:", err)
} }
Get(db, linkUrl) Get(db, jar, linkUrl)
} }
func History(db *sql.DB, argAmount, defaultAmount int64, all bool) { func History(db *sql.DB, argAmount, defaultAmount int64, all bool) {

19
main.go
View file

@ -9,11 +9,13 @@
package main package main
import ( import (
"os"
"strings" "strings"
"github.com/jollheef/wi/commands" "github.com/jollheef/wi/commands"
"github.com/jollheef/wi/storage" "github.com/jollheef/wi/storage"
cookiejar "github.com/juju/persistent-cookiejar"
kingpin "gopkg.in/alecthomas/kingpin.v2" kingpin "gopkg.in/alecthomas/kingpin.v2"
) )
@ -65,17 +67,26 @@ func main() {
} }
defer db.Close() defer db.Close()
os.Setenv("GOCOOKIES", "/tmp/wi.jar")
jar, err := cookiejar.New(nil)
if err != nil {
panic(err)
}
defer jar.Save()
switch kingpin.Parse() { switch kingpin.Parse() {
case "get": case "get":
commands.Get(db, *getUrl) commands.Get(db, jar, *getUrl)
case "form": case "form":
commands.Form(db, *formID, *formArgs) commands.Form(db, jar, *formID, *formArgs)
case "link": case "link":
commands.Link(db, *linkNo, *linkFromHistory) commands.Link(db, jar, *linkNo, *linkFromHistory)
case "history": case "history":
commands.History(db, *historyListItems, 20, *historyListAll) commands.History(db, *historyListItems, 20, *historyListAll)
case "search": case "search":
// FIXME: currenlty supports only Google // FIXME: currenlty supports only Google
commands.Get(db, "https://google.com/search?q="+strings.Join(*searchArgs, "+")) commands.Get(db, jar, "https://google.com/search?q="+strings.Join(*searchArgs, "+"))
} }
} }

View file

@ -9,61 +9,14 @@
package storage package storage
import ( import (
"bytes"
"database/sql" "database/sql"
"encoding/base64"
"encoding/gob"
"errors" "errors"
"log"
"net/http"
"reflect" "reflect"
"strings" "strings"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
) )
func toGOB64(m http.Cookie) string {
b := bytes.Buffer{}
e := gob.NewEncoder(&b)
err := e.Encode(m)
if err != nil {
panic(err)
}
return base64.StdEncoding.EncodeToString(b.Bytes())
}
func fromGOB64(str string) http.Cookie {
m := http.Cookie{}
by, err := base64.StdEncoding.DecodeString(str)
if err != nil {
panic(err)
}
b := bytes.Buffer{}
b.Write(by)
d := gob.NewDecoder(&b)
err = d.Decode(&m)
if err != nil {
panic(err)
}
return m
}
func serializeCookies(cookies []*http.Cookie) (s string) {
for _, c := range cookies {
s += toGOB64(*c) + " "
}
return
}
func deserializeCookies(s string) (cookies []*http.Cookie) {
gob64Objects := strings.Split(s, " ")
for _, g := range gob64Objects {
c := fromGOB64(g)
cookies = append(cookies, &c)
}
return
}
func OpenDB(path string) (db *sql.DB, err error) { func OpenDB(path string) (db *sql.DB, err error) {
db, err = sql.Open("sqlite3", path) db, err = sql.Open("sqlite3", path)
if err != nil { if err != nil {
@ -96,45 +49,6 @@ func OpenDB(path string) (db *sql.DB, err error) {
"( `id` INTEGER PRIMARY KEY AUTOINCREMENT, " + "( `id` INTEGER PRIMARY KEY AUTOINCREMENT, " +
" `post` BOOLEAN, " + " `post` BOOLEAN, " +
" `url` TEXT );") " `url` TEXT );")
if err != nil {
return
}
_, err = db.Exec("CREATE TABLE IF NOT EXISTS `cookies` " +
"( `id` INTEGER PRIMARY KEY AUTOINCREMENT, " +
" `url` BOOLEAN, " +
" `cookies` TEXT );")
return
}
func AddCookies(db *sql.DB, url string, cookies []*http.Cookie) (err error) {
log.Println("Add cookies", url, cookies)
stmt, err := db.Prepare("INSERT INTO `cookies` " +
"(`url`, `cookies`) VALUES ($1, $2);")
if err != nil {
return
}
defer stmt.Close()
_, err = stmt.Exec(url, serializeCookies(cookies))
return
}
func GetCookies(db *sql.DB, url string) (cookies []*http.Cookie, err error) {
stmt, err := db.Prepare("SELECT `cookies` FROM `cookies` WHERE url=$1;")
if err != nil {
return
}
defer stmt.Close()
var rawCookies string
err = stmt.QueryRow(url).Scan(&rawCookies)
if err != nil {
return
}
cookies = deserializeCookies(rawCookies)
return return
} }