diff --git a/services/credentials/repositories/bolt.go b/services/credentials/repositories/bolt.go index 66dc230..7ed090d 100644 --- a/services/credentials/repositories/bolt.go +++ b/services/credentials/repositories/bolt.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/gob" + "fmt" "os" "sync" @@ -32,35 +33,25 @@ func (db BoltDB) GetAllMetadata(ctx context.Context, sourceHost string, errch ch defer close(mdch) err := db.bolt.View(func(tx *bbolt.Tx) error { - bkt := tx.Bucket([]byte(credentialsBkt)) - if bkt == nil { + bkt := getCredentialsBucket(tx) + if bkt.isEmpty { return nil } var wg sync.WaitGroup - err := bkt.ForEach(func(_, value []byte) error { - wg.Add(1) + c := bkt.hostPrimaryIndex.Cursor() - go func(value []byte) { - defer wg.Done() - - var cred types.Credential - - err := gobUnmarshal(value, &cred) - if err != nil { - errch <- err - return - } - - if sourceHost == "" || sourceHost == cred.SourceHost { - mdch <- cred.Metadata - } - }(value) - - return nil - }) - if err != nil { - return err + if sourceHost == "" { + for key, value := c.First(); key != nil; key, value = c.Next() { + wg.Add(1) + unmarshalAndSendCred(value, mdch, errch, &wg) + } + } else { + hostBytes := []byte(sourceHost) + for key, value := c.Seek(hostBytes); bytes.HasPrefix(key, hostBytes); key, value = c.Next() { + wg.Add(1) + unmarshalAndSendCred(value, mdch, errch, &wg) + } } wg.Wait() @@ -76,10 +67,24 @@ func (db BoltDB) GetAllMetadata(ctx context.Context, sourceHost string, errch ch return mdch } +func unmarshalAndSendCred(value []byte, mdch chan<- types.Metadata, errch chan<- error, wg *sync.WaitGroup) { + defer wg.Done() + + var cred types.Credential + + err := gobUnmarshal(value, &cred) + if err != nil { + errch <- err + return + } + + mdch <- cred.Metadata +} + func (db BoltDB) Get(ctx context.Context, id string) (output types.Credential, err error) { err = db.bolt.View(func(tx *bbolt.Tx) error { - bkt := tx.Bucket([]byte(credentialsBkt)) - if bkt == nil { + bkt := getCredentialsBucket(tx) + if bkt.isEmpty { return nil } @@ -95,10 +100,23 @@ func (db BoltDB) Get(ctx context.Context, id string) (output types.Credential, e } func (db BoltDB) Put(ctx context.Context, c types.Credential) (err error) { - err = db.bolt.Update(func(tx *bbolt.Tx) error { - bkt, err := tx.CreateBucketIfNotExists([]byte(credentialsBkt)) - if err != nil { - return err + return db.bolt.Update(func(tx *bbolt.Tx) error { + bkt := getCredentialsBucket(tx) + bkt.createIfNotExists() + + value := bkt.Get([]byte(c.ID)) + if value != nil { + var cred types.Credential + if err = gobUnmarshal(value, &cred); err != nil { + return err + } + + if err = bkt.Delete([]byte(c.ID)); err != nil { + return err + } + if err = bkt.hostPrimaryIndex.Delete([]byte(genHostPrimaryIdxKey(cred))); err != nil { + return err + } } value, err := gobMarshal(c) @@ -106,26 +124,74 @@ func (db BoltDB) Put(ctx context.Context, c types.Credential) (err error) { return err } + if err = bkt.hostPrimaryIndex.Put([]byte(genHostPrimaryIdxKey(c)), value); err != nil { + return err + } + return bkt.Put([]byte(c.ID), value) }) - - return err } func (db BoltDB) Delete(ctx context.Context, id string) (err error) { - err = db.bolt.Update(func(tx *bbolt.Tx) error { - bkt := tx.Bucket([]byte(credentialsBkt)) - if bkt == nil { + return db.bolt.Update(func(tx *bbolt.Tx) error { + bkt := getCredentialsBucket(tx) + if bkt.isEmpty { return nil } + value := bkt.Get([]byte(id)) + if value == nil { + return nil + } + + var cred types.Credential + if err = gobUnmarshal(value, &cred); err != nil { + return err + } + + if err = bkt.hostPrimaryIndex.Delete([]byte(genHostPrimaryIdxKey(cred))); err != nil { + return err + } + return bkt.Delete([]byte(id)) }) - - return err } -const credentialsBkt = "credentials" +const keyCredentialsBkt = "credentials" +const keyHostAndPrimaryIdx = "sourceHost-primary" + +func getCredentialsBucket(tx *bbolt.Tx) credentialsBucket { + bkt := credentialsBucket{ + Bucket: tx.Bucket([]byte(keyCredentialsBkt)), + tx: tx, + } + bkt.isEmpty = bkt.Bucket == nil + + if !bkt.isEmpty { + bkt.hostPrimaryIndex = bkt.Bucket.Bucket([]byte(keyHostAndPrimaryIdx)) + } + + return bkt +} + +type credentialsBucket struct { + *bbolt.Bucket + tx *bbolt.Tx + hostPrimaryIndex *bbolt.Bucket + isEmpty bool +} + +func (bkt *credentialsBucket) createIfNotExists() { + if bkt.isEmpty { + bkt.Bucket, _ = bkt.tx.CreateBucket([]byte(keyCredentialsBkt)) + bkt.hostPrimaryIndex, _ = bkt.CreateBucket([]byte(keyHostAndPrimaryIdx)) + bkt.isEmpty = false + } +} + +func genHostPrimaryIdxKey(cred types.Credential) string { + return fmt.Sprintf("%s-%s-%s", cred.SourceHost, cred.Primary, cred.ID) +} func gobMarshal(v interface{}) (bs []byte, err error) { buf := bytes.NewBuffer(nil) diff --git a/services/credentials/transport/grpc_server.go b/services/credentials/transport/grpc_server.go index 7d64647..56e354a 100644 --- a/services/credentials/transport/grpc_server.go +++ b/services/credentials/transport/grpc_server.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/go-kit/kit/log" + "github.com/go-kit/kit/transport" "github.com/go-kit/kit/transport/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -20,31 +21,31 @@ func NewGRPCServer(svc types.Service, logger log.Logger) GRPCServer { endpoints.MakeGetAllMetadataEndpoint(svc), decodeSourceHostRequest, encodeMetadataStreamResponse, - grpc.ServerErrorLogger(logger), + grpc.ServerErrorHandler(transport.NewLogErrorHandler(logger)), ), get: grpc.NewServer( endpoints.MakeGetEndpoint(svc), decodeIdRequest, encodeCredentialResponse, - grpc.ServerErrorLogger(logger), + grpc.ServerErrorHandler(transport.NewLogErrorHandler(logger)), ), create: grpc.NewServer( endpoints.MakeCreateEndpoint(svc), decodeCredentialRequest, encodeCredentialResponse, - grpc.ServerErrorLogger(logger), + grpc.ServerErrorHandler(transport.NewLogErrorHandler(logger)), ), update: grpc.NewServer( endpoints.MakeUpdateEndpoint(svc), decodeUpdateRequest, encodeCredentialResponse, - grpc.ServerErrorLogger(logger), + grpc.ServerErrorHandler(transport.NewLogErrorHandler(logger)), ), delete: grpc.NewServer( endpoints.MakeDeleteEndpoint(svc), decodeIdRequest, noOp, - grpc.ServerErrorLogger(logger), + grpc.ServerErrorHandler(transport.NewLogErrorHandler(logger)), ), } } @@ -58,13 +59,12 @@ type GRPCServer struct { } func (s GRPCServer) GetAllMetadata(r *protobuf.SourceHostRequest, srv protobuf.Credentials_GetAllMetadataServer) (err error) { - defer func() { err = handlerGRPCError(err) }() - var i interface{} ctx := srv.Context() ctx, i, err = s.getAllMetadata.ServeGRPC(ctx, *r) if err != nil { + err = handlerGRPCError(err) return err } @@ -74,14 +74,19 @@ receiveLoop: for { select { case <-ctx.Done(): + err = ctx.Err() break receiveLoop case err = <-mds.Errors: - break receiveLoop + if err != nil { + err = handlerGRPCError(err) + break receiveLoop + } case md, ok := <-mds.Metadata: if !ok { break receiveLoop } if err = srv.Send(&md); err != nil { + err = handlerGRPCError(err) break receiveLoop } } diff --git a/services/migrations/201907191_redis_to_bolt.go b/services/migrations/201907191_redis_to_bolt.go index a3e3416..718211b 100644 --- a/services/migrations/201907191_redis_to_bolt.go +++ b/services/migrations/201907191_redis_to_bolt.go @@ -18,7 +18,7 @@ const keyCredentials = "credentials" func main() { redisHost := pflag.StringP("redis-host", "r", "127.0.0.1:6379", "specify the redis host") - boltFile := pflag.StringP("bolt-file", "b", "./data/bolt.db", "specify the bolt DB file") + boltFile := pflag.StringP("bolt-file", "f", "./data/bolt.db", "specify the bolt DB file") help := pflag.BoolP("help", "h", false, "see help") pflag.Parse() diff --git a/services/migrations/201907230_index_host_and_primary.go b/services/migrations/201907230_index_host_and_primary.go new file mode 100644 index 0000000..163d2d9 --- /dev/null +++ b/services/migrations/201907230_index_host_and_primary.go @@ -0,0 +1,110 @@ +package main + +import ( + "bytes" + "encoding/gob" + "errors" + "fmt" + "sync" + + "github.com/spf13/pflag" + "go.etcd.io/bbolt" + + "github.com/mitchell/selfpass/services/credentials/types" + "github.com/mitchell/selfpass/services/migrations/migration" +) + +const keyCredentialsBkt = "credentials" +const keyHostAndPrimaryIdx = "sourceHost-primary" + +func main() { + file := pflag.StringP("file", "f", "./data/bolt.db", "specify the bolt db file") + help := pflag.BoolP("help", "h", false, "see help") + pflag.Parse() + + if *help { + pflag.PrintDefaults() + return + } + + db, err := bbolt.Open(*file, 0600, nil) + migration.Check(err) + + fmt.Println("Beginning migration...") + + creds := make(chan types.Credential, 1) + errs := make(chan error, 1) + + go func() { + defer close(creds) + + var wg sync.WaitGroup + + errs <- db.View(func(tx *bbolt.Tx) error { + bkt := tx.Bucket([]byte(keyCredentialsBkt)) + if bkt == nil { + return errors.New("no credentials bucket") + } + + return bkt.ForEach(func(_, value []byte) error { + wg.Add(1) + + go func(value []byte) { + defer wg.Done() + + reader := bytes.NewReader(value) + + var cred types.Credential + errs <- gob.NewDecoder(reader).Decode(&cred) + + creds <- cred + }(value) + + return nil + }) + }) + + wg.Wait() + }() + + go func() { + defer close(errs) + + var wg sync.WaitGroup + + for cred := range creds { + key := fmt.Sprintf("%s-%s-%s", cred.SourceHost, cred.Primary, cred.ID) + + fmt.Printf("Adding credential %s to index as %s.\n", cred.ID, key) + + wg.Add(1) + go func(key string, cred types.Credential) { + defer wg.Done() + + buf := bytes.NewBuffer(nil) + migration.Check(gob.NewEncoder(buf).Encode(cred)) + + value := buf.Bytes() + + errs <- db.Batch(func(tx *bbolt.Tx) error { + credBkt := tx.Bucket([]byte(keyCredentialsBkt)) + + bkt, err := credBkt.CreateBucketIfNotExists([]byte(keyHostAndPrimaryIdx)) + if err != nil { + return err + } + + return bkt.Put([]byte(key), value) + }) + }(key, cred) + } + + wg.Wait() + }() + + for err = range errs { + migration.Check(err) + } + + fmt.Println("Migration done.") +}