Преглед изворни кода

Refactoring work and tests.

Kelly Norton пре 10 година
родитељ
комит
334286dcab
4 измењених фајлова са 241 додато и 113 уклоњено
  1. 1 1
      Makefile
  2. 152 0
      context/context.go
  3. 77 0
      context/context_test.go
  4. 11 112
      main.go

+ 1 - 1
Makefile

@@ -3,5 +3,5 @@ ALL: bindata.go
 .build/bin/go-bindata:
 	GOPATH=$(shell pwd)/.build go get github.com/jteeuwen/go-bindata/...
 
-bindata.go: .build/bin/go-bindata $(wildcard pub/**)
+bindata.go: .build/bin/go-bindata $(wildcard pub/**/*)
 	$< -o $@ -pkg main -prefix pub -nomemcopy pub/...

+ 152 - 0
context/context.go

@@ -0,0 +1,152 @@
+package context
+
+import (
+	"bytes"
+	"encoding/binary"
+	"io"
+	"io/ioutil"
+	"os"
+	"path/filepath"
+	"sync"
+	"time"
+
+	"github.com/syndtr/goleveldb/leveldb"
+	"github.com/syndtr/goleveldb/leveldb/opt"
+)
+
+const (
+	routesDbFilename        = "routes.db"
+	idLogFilename           = "id"
+	idBatchSize      uint64 = 1000
+)
+
+// Route ...
+type Route struct {
+	URL  string
+	Time time.Time
+}
+
+//
+func (o *Route) write(w io.Writer) error {
+	if err := binary.Write(w, binary.LittleEndian, o.Time.UnixNano()); err != nil {
+		return err
+	}
+
+	if _, err := w.Write([]byte(o.URL)); err != nil {
+		return err
+	}
+
+	return nil
+}
+
+//
+func (o *Route) read(r io.Reader) error {
+	var t int64
+	if err := binary.Read(r, binary.LittleEndian, &t); err != nil {
+		return err
+	}
+
+	b, err := ioutil.ReadAll(r)
+	if err != nil {
+		return err
+	}
+
+	o.URL = string(b)
+	o.Time = time.Unix(0, t)
+	return nil
+}
+
+// Context ...
+type Context struct {
+	path string
+	db   *leveldb.DB
+	lck  sync.Mutex
+	id   uint64
+}
+
+// Open ...
+func Open(path string) (*Context, error) {
+	if _, err := os.Stat(path); err != nil {
+		if err := os.MkdirAll(path, os.ModePerm); err != nil {
+			return nil, err
+		}
+	}
+
+	// open the database
+	db, err := leveldb.OpenFile(filepath.Join(path, routesDbFilename), nil)
+	if err != nil {
+		return nil, err
+	}
+
+	c := &Context{
+		path: path,
+		db:   db,
+	}
+
+	// make sure we have an id log file
+	if _, err := os.Stat(filepath.Join(path, idLogFilename)); err != nil {
+		if err := c.commit(idBatchSize); err != nil {
+			return nil, err
+		}
+	}
+
+	return c, nil
+}
+
+// Get ...
+func (c *Context) Get(name string) (*Route, error) {
+	val, err := c.db.Get([]byte(name), nil)
+	if err != nil {
+		return nil, err
+	}
+
+	rt := &Route{}
+	if err := rt.read(bytes.NewBuffer(val)); err != nil {
+		return nil, err
+	}
+
+	return rt, nil
+}
+
+// Put ...
+func (c *Context) Put(key string, rt *Route) error {
+	var buf bytes.Buffer
+	if err := rt.write(&buf); err != nil {
+		return err
+	}
+
+	return c.db.Put([]byte(key), buf.Bytes(), &opt.WriteOptions{Sync: true})
+}
+
+func (c *Context) commit(id uint64) error {
+	w, err := os.Create(filepath.Join(c.path, idLogFilename))
+	if err != nil {
+		return err
+	}
+	defer w.Close()
+
+	if err := binary.Write(w, binary.LittleEndian, id); err != nil {
+		return err
+	}
+
+	return w.Sync()
+}
+
+// NextID ...
+func (c *Context) NextID() (uint64, error) {
+	c.lck.Lock()
+	defer c.lck.Unlock()
+
+	// when we hit a batch boundary, we will commit all ids until the next
+	// boundary. If we crash, we'll just throw away a batch of ids in the worst
+	// case.
+	if c.id%idBatchSize == 0 {
+		if err := c.commit(c.id + idBatchSize); err != nil {
+			return 0, err
+		}
+	}
+
+	c.id++
+
+	return c.id, nil
+}

+ 77 - 0
context/context_test.go

@@ -0,0 +1,77 @@
+package context
+
+import (
+	"io/ioutil"
+	"os"
+	"path/filepath"
+	"testing"
+	"time"
+
+	"github.com/syndtr/goleveldb/leveldb"
+)
+
+func TestGetPut(t *testing.T) {
+	tmp, err := ioutil.TempDir("", "")
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer os.RemoveAll(tmp)
+
+	ctx, err := Open(filepath.Join(tmp, "data"))
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if _, err := ctx.Get("not_found"); err != leveldb.ErrNotFound {
+		t.Fatalf("expected ErrNotFound, got \"%v\"", err)
+	}
+
+	a := &Route{
+		URL:  "http://www.kellegous.com/",
+		Time: time.Now(),
+	}
+
+	if err := ctx.Put("key", a); err != nil {
+		t.Fatal(err)
+	}
+
+	b, err := ctx.Get("key")
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if b.URL != a.URL {
+		t.Fatalf("expected URL of %s, got %s", a.URL, b.URL)
+	}
+
+	if !b.Time.Equal(a.Time) {
+		t.Fatalf("expected Time of %s, got %s", a.Time, b.Time)
+	}
+}
+
+func TestNextID(t *testing.T) {
+	tmp, err := ioutil.TempDir("", "")
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer os.RemoveAll(tmp)
+
+	ctx, err := Open(filepath.Join(tmp, "data"))
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	var e uint64 = 1
+	for i, n := 0, 2*int(idBatchSize)+5; i < n; i++ {
+		r, err := ctx.NextID()
+		if err != nil {
+			t.Fatal(err)
+		}
+
+		if r != e {
+			t.Fatalf("expected %d, got %d", e, r)
+		}
+
+		e++
+	}
+}

+ 11 - 112
main.go

@@ -7,115 +7,17 @@ import (
 	"encoding/json"
 	"flag"
 	"fmt"
-	"io"
-	"io/ioutil"
 	"log"
 	"math/rand"
 	"net/http"
-	"os"
-	"path/filepath"
 	"strings"
 	"time"
 
 	"github.com/syndtr/goleveldb/leveldb"
-)
 
-const (
-	dbFilename = "keys.db"
+	"github.com/kellegous/go/context"
 )
 
-type Route struct {
-	Url  string
-	Time time.Time
-}
-
-func (r *Route) Write(w io.Writer) error {
-	if err := binary.Write(w, binary.LittleEndian, r.Time.UnixNano()); err != nil {
-		return err
-	}
-
-	if _, err := w.Write([]byte(r.Url)); err != nil {
-		return err
-	}
-
-	return nil
-}
-
-func (o *Route) Read(r io.Reader) error {
-	var t int64
-	if err := binary.Read(r, binary.LittleEndian, &t); err != nil {
-		return err
-	}
-
-	b, err := ioutil.ReadAll(r)
-	if err != nil {
-		return err
-	}
-
-	o.Url = string(b)
-	o.Time = time.Unix(0, t)
-	return nil
-}
-
-type Context struct {
-	path string
-}
-
-func (c *Context) Init() error {
-	if _, err := os.Stat(c.path); err != nil {
-		if err := os.MkdirAll(c.path, os.ModePerm); err != nil {
-			return err
-		}
-	}
-
-	db, err := openDb(c.path)
-	if err != nil {
-		return err
-	}
-
-	return db.Close()
-}
-
-func openDb(path string) (*leveldb.DB, error) {
-	return leveldb.OpenFile(filepath.Join(path, dbFilename), nil)
-}
-
-func (c *Context) Get(key string) (*Route, error) {
-	db, err := openDb(c.path)
-	if err != nil {
-		return nil, err
-	}
-	defer db.Close()
-
-	val, err := db.Get([]byte(key), nil)
-	if err != nil {
-		return nil, err
-	}
-
-	r := &Route{}
-	if err := r.Read(bytes.NewBuffer(val)); err != nil {
-		return nil, err
-	}
-
-	return r, nil
-}
-
-func (c *Context) Put(key string, r *Route) error {
-	db, err := openDb(c.path)
-	if err != nil {
-		return err
-	}
-	defer db.Close()
-
-	var buf bytes.Buffer
-
-	if err := r.Write(&buf); err != nil {
-		return err
-	}
-
-	return db.Put([]byte(key), buf.Bytes(), nil)
-}
-
 func MakeName() string {
 	var buf bytes.Buffer
 	binary.Write(&buf, binary.LittleEndian, rand.Int63())
@@ -145,14 +47,14 @@ func WriteJsonError(w http.ResponseWriter, error string, status int) {
 	}, status)
 }
 
-func WriteJsonRoute(w http.ResponseWriter, name string, rt *Route) {
+func WriteJsonRoute(w http.ResponseWriter, name string, rt *context.Route) {
 	res := struct {
 		Name string    `json:"name"`
 		URL  string    `json:"url"`
 		Time time.Time `json:"time"`
 	}{
 		name,
-		rt.Url,
+		rt.URL,
 		rt.Time,
 	}
 
@@ -176,7 +78,7 @@ func ServeAsset(w http.ResponseWriter, r *http.Request, name string) {
 }
 
 type DefaultHandler struct {
-	ctx *Context
+	ctx *context.Context
 }
 
 func (h *DefaultHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
@@ -200,12 +102,12 @@ func (h *DefaultHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 	}
 
 	http.Redirect(w, r,
-		rt.Url,
+		rt.URL,
 		http.StatusTemporaryRedirect)
 }
 
 type EditHandler struct {
-	ctx *Context
+	ctx *context.Context
 }
 
 func (h *EditHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
@@ -222,7 +124,7 @@ func (h *EditHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 }
 
 type ApiHandler struct {
-	ctx *Context
+	ctx *context.Context
 }
 
 func (h *ApiHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
@@ -249,8 +151,8 @@ func (h *ApiHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 			return
 		}
 
-		rt := Route{
-			Url:  req.URL,
+		rt := context.Route{
+			URL:  req.URL,
 			Time: time.Now(),
 		}
 
@@ -281,11 +183,8 @@ func main() {
 	flagAddr := flag.String("addr", ":8067", "addr")
 	flag.Parse()
 
-	ctx := &Context{
-		path: *flagData,
-	}
-
-	if err := ctx.Init(); err != nil {
+	ctx, err := context.Open(*flagData)
+	if err != nil {
 		log.Panic(err)
 	}