Added middleware types and support; refactored names and route execution

This commit is contained in:
mitchelljfs 2018-10-16 22:08:23 -07:00
parent ecdb0bed84
commit 9626cf65d4
2 changed files with 116 additions and 92 deletions

136
router.go
View File

@ -3,7 +3,6 @@ package lambdarouter
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"log"
"net/http" "net/http"
"strings" "strings"
@ -34,18 +33,23 @@ type APIGContext struct {
Err error Err error
} }
// APIGHandler is the interface a handler function must implement to be used // APIGHandler is the type a handler function must implement to be used
// with Get, Post, Put, Patch, and Delete. // with Get, Post, Put, Patch, and Delete.
type APIGHandler func(ctx *APIGContext) type APIGHandler func(ctx *APIGContext)
// APIGMiddleware is the function type that must me implemented to be appended
// to a route or to the APIGRouterConfig.Middleware attribute.
type APIGMiddleware func(APIGHandler) APIGHandler
// APIGRouter is the object that handlers build upon and is used in the end to respond. // APIGRouter is the object that handlers build upon and is used in the end to respond.
type APIGRouter struct { type APIGRouter struct {
request *events.APIGatewayProxyRequest request *events.APIGatewayProxyRequest
endpoints map[string]*radix.Tree routes map[string]*radix.Tree
params map[string]interface{} params map[string]interface{}
prefix string prefix string
headers map[string]string headers map[string]string
context context.Context context context.Context
middleware []APIGMiddleware
} }
// APIGRouterConfig is used as the input to NewAPIGRouter, request is your incoming // APIGRouterConfig is used as the input to NewAPIGRouter, request is your incoming
@ -56,6 +60,7 @@ type APIGRouterConfig struct {
Request *events.APIGatewayProxyRequest Request *events.APIGatewayProxyRequest
Prefix string Prefix string
Headers map[string]string Headers map[string]string
Middleware []APIGMiddleware
} }
// NOTE: Begin router methods. // NOTE: Begin router methods.
@ -64,7 +69,7 @@ type APIGRouterConfig struct {
func NewAPIGRouter(cfg *APIGRouterConfig) *APIGRouter { func NewAPIGRouter(cfg *APIGRouterConfig) *APIGRouter {
return &APIGRouter{ return &APIGRouter{
request: cfg.Request, request: cfg.Request,
endpoints: map[string]*radix.Tree{ routes: map[string]*radix.Tree{
post: radix.New(), post: radix.New(),
get: radix.New(), get: radix.New(),
put: radix.New(), put: radix.New(),
@ -75,49 +80,77 @@ func NewAPIGRouter(cfg *APIGRouterConfig) *APIGRouter {
prefix: cfg.Prefix, prefix: cfg.Prefix,
headers: cfg.Headers, headers: cfg.Headers,
context: cfg.Context, context: cfg.Context,
middleware: cfg.Middleware,
} }
} }
// Get creates a new get endpoint. // Get creates a new get endpoint.
func (r *APIGRouter) Get(route string, handlers ...APIGHandler) { func (r *APIGRouter) Get(path string, handler APIGHandler, middleware ...APIGMiddleware) {
r.addEndpoint(get, route, handlers) functions := routeFunctions{
handler: handler,
middleware: middleware,
}
r.addRoute(get, path, functions)
} }
// Post creates a new post endpoint. // Post creates a new post endpoint.
func (r *APIGRouter) Post(route string, handlers ...APIGHandler) { func (r *APIGRouter) Post(path string, handler APIGHandler, middleware ...APIGMiddleware) {
r.addEndpoint(post, route, handlers) functions := routeFunctions{
handler: handler,
middleware: middleware,
}
r.addRoute(post, path, functions)
} }
// Put creates a new put endpoint. // Put creates a new put endpoint.
func (r *APIGRouter) Put(route string, handlers ...APIGHandler) { func (r *APIGRouter) Put(path string, handler APIGHandler, middleware ...APIGMiddleware) {
r.addEndpoint(put, route, handlers) functions := routeFunctions{
handler: handler,
middleware: middleware,
}
r.addRoute(put, path, functions)
} }
// Patch creates a new patch endpoint // Patch creates a new patch endpoint
func (r *APIGRouter) Patch(route string, handlers ...APIGHandler) { func (r *APIGRouter) Patch(path string, handler APIGHandler, middleware ...APIGMiddleware) {
r.addEndpoint(patch, route, handlers) functions := routeFunctions{
handler: handler,
middleware: middleware,
}
r.addRoute(patch, path, functions)
} }
// Delete creates a new delete endpoint. // Delete creates a new delete endpoint.
func (r *APIGRouter) Delete(route string, handlers ...APIGHandler) { func (r *APIGRouter) Delete(path string, handler APIGHandler, middleware ...APIGMiddleware) {
r.addEndpoint(delete, route, handlers) functions := routeFunctions{
handler: handler,
middleware: middleware,
}
r.addRoute(delete, path, functions)
} }
// Respond returns an APIGatewayProxyResponse to respond to the lambda request. // Respond returns an APIGatewayProxyResponse to respond to the lambda request.
func (r *APIGRouter) Respond() events.APIGatewayProxyResponse { func (r *APIGRouter) Respond() events.APIGatewayProxyResponse {
var ( var (
handlersInterface interface{}
ok bool ok bool
status int respbytes []byte
respbody []byte response events.APIGatewayProxyResponse
err error routeInterface interface{}
endpointTree = r.endpoints[r.request.HTTPMethod] routeTrie = r.routes[r.request.HTTPMethod]
path = strings.TrimPrefix(r.request.Path, r.prefix) path = strings.TrimPrefix(r.request.Path, r.prefix)
inPath = path
response = events.APIGatewayProxyResponse{}
splitPath = stripSlashesAndSplit(path) splitPath = stripSlashesAndSplit(path)
ctx = &APIGContext{
Body: []byte(r.request.Body),
Path: r.request.PathParameters,
QryStr: r.request.QueryStringParameters,
Request: r.request,
Context: r.context,
}
) )
if r.request.RequestContext.Authorizer["claims"] != nil {
ctx.Claims = r.request.RequestContext.Authorizer["claims"].(map[string]interface{})
}
for p := range r.params { for p := range r.params {
if r.request.PathParameters[p] != "" { if r.request.PathParameters[p] != "" {
@ -132,72 +165,49 @@ func (r *APIGRouter) Respond() events.APIGatewayProxyResponse {
} }
path = "/" + strings.Join(splitPath, "/") path = "/" + strings.Join(splitPath, "/")
if handlersInterface, ok = endpointTree.Get(path); !ok { if routeInterface, ok = routeTrie.Get(path); !ok {
respbody, _ = json.Marshal(map[string]string{"error": "no route matching path found"}) respbytes, _ = json.Marshal(map[string]string{"error": "no route matching path found"})
response.StatusCode = http.StatusNotFound response.StatusCode = http.StatusNotFound
response.Body = string(respbody) response.Body = string(respbytes)
response.Headers = r.headers response.Headers = r.headers
return response return response
} }
handlers := handlersInterface.([]APIGHandler) functions := routeInterface.(routeFunctions)
for _, handler := range handlers { for _, m := range functions.middleware {
ctx := &APIGContext{ functions.handler = m(functions.handler)
Body: []byte(r.request.Body),
Path: r.request.PathParameters,
QryStr: r.request.QueryStringParameters,
Request: r.request,
Context: r.context,
} }
if r.request.RequestContext.Authorizer["claims"] != nil { for _, m := range r.middleware {
ctx.Claims = r.request.RequestContext.Authorizer["claims"].(map[string]interface{}) functions.handler = m(functions.handler)
} }
functions.handler(ctx)
handler(ctx) response.StatusCode = ctx.Status
status, respbody, err = ctx.respDeconstruct() response.Body = string(ctx.Body)
if err != nil {
respbody, _ = json.Marshal(map[string]string{"error": err.Error()})
if strings.Contains(err.Error(), "record not found") {
status = 204
} else if status != 204 && status < 400 {
status = 400
}
log.Printf("%v %v %v error: %v \n", r.request.HTTPMethod, inPath, status, err.Error())
log.Println("error causing body: " + r.request.Body)
response.StatusCode = status
response.Body = string(respbody)
response.Headers = r.headers
return response
}
}
response.StatusCode = status
response.Body = string(respbody)
response.Headers = r.headers response.Headers = r.headers
return response return response
} }
// NOTE: Begin helper functions. // NOTE: Begin helper functions.
type routeFunctions struct {
handler APIGHandler
middleware []APIGMiddleware
}
func stripSlashesAndSplit(s string) []string { func stripSlashesAndSplit(s string) []string {
s = strings.TrimPrefix(s, "/") s = strings.TrimPrefix(s, "/")
s = strings.TrimSuffix(s, "/") s = strings.TrimSuffix(s, "/")
return strings.Split(s, "/") return strings.Split(s, "/")
} }
func (ctx *APIGContext) respDeconstruct() (int, []byte, error) { func (r *APIGRouter) addRoute(method string, path string, functions routeFunctions) {
return ctx.Status, ctx.Body, ctx.Err if _, overwrite := r.routes[method].Insert(path, functions); overwrite {
}
func (r *APIGRouter) addEndpoint(method string, route string, handlers []APIGHandler) {
if _, overwrite := r.endpoints[method].Insert(route, handlers); overwrite {
panic("endpoint already existent") panic("endpoint already existent")
} }
rtearr := stripSlashesAndSplit(route) rtearr := stripSlashesAndSplit(path)
for _, v := range rtearr { for _, v := range rtearr {
if strings.HasPrefix(v, "{") { if strings.HasPrefix(v, "{") {
v = strings.TrimPrefix(v, "{") v = strings.TrimPrefix(v, "{")

View File

@ -1,8 +1,10 @@
package lambdarouter package lambdarouter
import ( import (
"encoding/json"
"errors" "errors"
"net/http" "net/http"
"strings"
"testing" "testing"
"github.com/aws/aws-lambda-go/events" "github.com/aws/aws-lambda-go/events"
@ -11,8 +13,21 @@ import (
func TestRouterSpec(t *testing.T) { func TestRouterSpec(t *testing.T) {
Convey("Given an instantiated router", t, func() { Convey("Given an instantiated router with an error reporting middleware, headers, and prefix", t, func() {
request := events.APIGatewayProxyRequest{} request := events.APIGatewayProxyRequest{}
errorReporter := func(handler APIGHandler) APIGHandler {
return func(ctx *APIGContext) {
handler(ctx)
if ctx.Err != nil {
ctx.Body, _ = json.Marshal(map[string]string{"error": ctx.Err.Error()})
if strings.Contains(ctx.Err.Error(), "record not found") {
ctx.Status = 204
} else if ctx.Status != 204 && ctx.Status < 400 {
ctx.Status = 400
}
}
}
}
rtr := NewAPIGRouter(&APIGRouterConfig{ rtr := NewAPIGRouter(&APIGRouterConfig{
Request: &request, Request: &request,
Prefix: "/shipping", Prefix: "/shipping",
@ -20,6 +35,9 @@ func TestRouterSpec(t *testing.T) {
"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Credentials": "true", "Access-Control-Allow-Credentials": "true",
}, },
Middleware: []APIGMiddleware{
errorReporter,
},
}) })
Convey("When the handler func does NOT return an error", func() { Convey("When the handler func does NOT return an error", func() {
@ -144,24 +162,20 @@ func TestRouterSpec(t *testing.T) {
}) })
Convey("When the handler func does return a status < 400", func() { Convey("When the handler func does return a status < 400", func() {
middlefunc1 := func(ctx *APIGContext) { middlefunc1 := func(handler APIGHandler) APIGHandler {
return func(ctx *APIGContext) {
ctx.Status = http.StatusOK ctx.Status = http.StatusOK
ctx.Body = []byte("hello")
ctx.Err = nil
}
middlefunc2 := func(ctx *APIGContext) {
ctx.Status = http.StatusOK
ctx.Body = []byte("hello")
ctx.Err = errors.New("bad request") ctx.Err = errors.New("bad request")
handler(ctx)
}
} }
hdlrfunc := func(ctx *APIGContext) { hdlrfunc := func(ctx *APIGContext) {
ctx.Status = http.StatusOK ctx.Status = http.StatusOK
ctx.Body = []byte("hello") ctx.Body = []byte("hello")
ctx.Err = nil
} }
Convey("And a Get handler expecting the pattern /listings/{id}/state/{event} is defined", func() { Convey("And a Get handler expecting the pattern /listings/{id}/state/{event} is defined", func() {
rtr.Get("/listings/{id}/state/{event}", middlefunc1, middlefunc2, hdlrfunc) rtr.Get("/listings/{id}/state/{event}", hdlrfunc, middlefunc1)
Convey("And the request matches the pattern and the path params are filled", func() { Convey("And the request matches the pattern and the path params are filled", func() {
request.HTTPMethod = http.MethodGet request.HTTPMethod = http.MethodGet