diff --git a/router.go b/router.go index 7ef6a31..b2e1a08 100644 --- a/router.go +++ b/router.go @@ -3,7 +3,6 @@ package lambdarouter import ( "context" "encoding/json" - "log" "net/http" "strings" @@ -34,28 +33,34 @@ type APIGContext struct { Err error } -// APIGHandler is the interface a handler function must implement to be used +// APIGHandler is the type a handler function must implement to be used // with Get, Post, Put, Patch, and Delete. type APIGHandler func(ctx *APIGContext) +// APIGMiddleware is the function type that must me implemented to be appended +// to a route or to the APIGRouterConfig.Middleware attribute. +type APIGMiddleware func(APIGHandler) APIGHandler + // APIGRouter is the object that handlers build upon and is used in the end to respond. type APIGRouter struct { - request *events.APIGatewayProxyRequest - endpoints map[string]*radix.Tree - params map[string]interface{} - prefix string - headers map[string]string - context context.Context + request *events.APIGatewayProxyRequest + routes map[string]*radix.Tree + params map[string]interface{} + prefix string + headers map[string]string + context context.Context + middleware []APIGMiddleware } // APIGRouterConfig is used as the input to NewAPIGRouter, request is your incoming // apig request and prefix will be stripped of all incoming request paths. Headers // will be sent with all responses. type APIGRouterConfig struct { - Context context.Context - Request *events.APIGatewayProxyRequest - Prefix string - Headers map[string]string + Context context.Context + Request *events.APIGatewayProxyRequest + Prefix string + Headers map[string]string + Middleware []APIGMiddleware } // NOTE: Begin router methods. @@ -64,60 +69,88 @@ type APIGRouterConfig struct { func NewAPIGRouter(cfg *APIGRouterConfig) *APIGRouter { return &APIGRouter{ request: cfg.Request, - endpoints: map[string]*radix.Tree{ + routes: map[string]*radix.Tree{ post: radix.New(), get: radix.New(), put: radix.New(), patch: radix.New(), delete: radix.New(), }, - params: map[string]interface{}{}, - prefix: cfg.Prefix, - headers: cfg.Headers, - context: cfg.Context, + params: map[string]interface{}{}, + prefix: cfg.Prefix, + headers: cfg.Headers, + context: cfg.Context, + middleware: cfg.Middleware, } } // Get creates a new get endpoint. -func (r *APIGRouter) Get(route string, handlers ...APIGHandler) { - r.addEndpoint(get, route, handlers) +func (r *APIGRouter) Get(path string, handler APIGHandler, middleware ...APIGMiddleware) { + functions := routeFunctions{ + handler: handler, + middleware: middleware, + } + r.addRoute(get, path, functions) } // Post creates a new post endpoint. -func (r *APIGRouter) Post(route string, handlers ...APIGHandler) { - r.addEndpoint(post, route, handlers) +func (r *APIGRouter) Post(path string, handler APIGHandler, middleware ...APIGMiddleware) { + functions := routeFunctions{ + handler: handler, + middleware: middleware, + } + r.addRoute(post, path, functions) } // Put creates a new put endpoint. -func (r *APIGRouter) Put(route string, handlers ...APIGHandler) { - r.addEndpoint(put, route, handlers) +func (r *APIGRouter) Put(path string, handler APIGHandler, middleware ...APIGMiddleware) { + functions := routeFunctions{ + handler: handler, + middleware: middleware, + } + r.addRoute(put, path, functions) } // Patch creates a new patch endpoint -func (r *APIGRouter) Patch(route string, handlers ...APIGHandler) { - r.addEndpoint(patch, route, handlers) +func (r *APIGRouter) Patch(path string, handler APIGHandler, middleware ...APIGMiddleware) { + functions := routeFunctions{ + handler: handler, + middleware: middleware, + } + r.addRoute(patch, path, functions) } // Delete creates a new delete endpoint. -func (r *APIGRouter) Delete(route string, handlers ...APIGHandler) { - r.addEndpoint(delete, route, handlers) +func (r *APIGRouter) Delete(path string, handler APIGHandler, middleware ...APIGMiddleware) { + functions := routeFunctions{ + handler: handler, + middleware: middleware, + } + r.addRoute(delete, path, functions) } // Respond returns an APIGatewayProxyResponse to respond to the lambda request. func (r *APIGRouter) Respond() events.APIGatewayProxyResponse { var ( - handlersInterface interface{} - ok bool - status int - respbody []byte - err error + ok bool + respbytes []byte + response events.APIGatewayProxyResponse + routeInterface interface{} - endpointTree = r.endpoints[r.request.HTTPMethod] - path = strings.TrimPrefix(r.request.Path, r.prefix) - inPath = path - response = events.APIGatewayProxyResponse{} - splitPath = stripSlashesAndSplit(path) + routeTrie = r.routes[r.request.HTTPMethod] + path = strings.TrimPrefix(r.request.Path, r.prefix) + splitPath = stripSlashesAndSplit(path) + ctx = &APIGContext{ + Body: []byte(r.request.Body), + Path: r.request.PathParameters, + QryStr: r.request.QueryStringParameters, + Request: r.request, + Context: r.context, + } ) + if r.request.RequestContext.Authorizer["claims"] != nil { + ctx.Claims = r.request.RequestContext.Authorizer["claims"].(map[string]interface{}) + } for p := range r.params { if r.request.PathParameters[p] != "" { @@ -132,72 +165,49 @@ func (r *APIGRouter) Respond() events.APIGatewayProxyResponse { } path = "/" + strings.Join(splitPath, "/") - if handlersInterface, ok = endpointTree.Get(path); !ok { - respbody, _ = json.Marshal(map[string]string{"error": "no route matching path found"}) + if routeInterface, ok = routeTrie.Get(path); !ok { + respbytes, _ = json.Marshal(map[string]string{"error": "no route matching path found"}) response.StatusCode = http.StatusNotFound - response.Body = string(respbody) + response.Body = string(respbytes) response.Headers = r.headers return response } - handlers := handlersInterface.([]APIGHandler) + functions := routeInterface.(routeFunctions) - for _, handler := range handlers { - ctx := &APIGContext{ - Body: []byte(r.request.Body), - Path: r.request.PathParameters, - QryStr: r.request.QueryStringParameters, - Request: r.request, - Context: r.context, - } - if r.request.RequestContext.Authorizer["claims"] != nil { - ctx.Claims = r.request.RequestContext.Authorizer["claims"].(map[string]interface{}) - } - - handler(ctx) - status, respbody, err = ctx.respDeconstruct() - - if err != nil { - respbody, _ = json.Marshal(map[string]string{"error": err.Error()}) - if strings.Contains(err.Error(), "record not found") { - status = 204 - } else if status != 204 && status < 400 { - status = 400 - } - - log.Printf("%v %v %v error: %v \n", r.request.HTTPMethod, inPath, status, err.Error()) - log.Println("error causing body: " + r.request.Body) - response.StatusCode = status - response.Body = string(respbody) - response.Headers = r.headers - return response - } + for _, m := range functions.middleware { + functions.handler = m(functions.handler) } + for _, m := range r.middleware { + functions.handler = m(functions.handler) + } + functions.handler(ctx) - response.StatusCode = status - response.Body = string(respbody) + response.StatusCode = ctx.Status + response.Body = string(ctx.Body) response.Headers = r.headers return response } // NOTE: Begin helper functions. +type routeFunctions struct { + handler APIGHandler + middleware []APIGMiddleware +} + func stripSlashesAndSplit(s string) []string { s = strings.TrimPrefix(s, "/") s = strings.TrimSuffix(s, "/") return strings.Split(s, "/") } -func (ctx *APIGContext) respDeconstruct() (int, []byte, error) { - return ctx.Status, ctx.Body, ctx.Err -} - -func (r *APIGRouter) addEndpoint(method string, route string, handlers []APIGHandler) { - if _, overwrite := r.endpoints[method].Insert(route, handlers); overwrite { +func (r *APIGRouter) addRoute(method string, path string, functions routeFunctions) { + if _, overwrite := r.routes[method].Insert(path, functions); overwrite { panic("endpoint already existent") } - rtearr := stripSlashesAndSplit(route) + rtearr := stripSlashesAndSplit(path) for _, v := range rtearr { if strings.HasPrefix(v, "{") { v = strings.TrimPrefix(v, "{") diff --git a/router_test.go b/router_test.go index fe4e567..e75227c 100644 --- a/router_test.go +++ b/router_test.go @@ -1,8 +1,10 @@ package lambdarouter import ( + "encoding/json" "errors" "net/http" + "strings" "testing" "github.com/aws/aws-lambda-go/events" @@ -11,8 +13,21 @@ import ( func TestRouterSpec(t *testing.T) { - Convey("Given an instantiated router", t, func() { + Convey("Given an instantiated router with an error reporting middleware, headers, and prefix", t, func() { request := events.APIGatewayProxyRequest{} + errorReporter := func(handler APIGHandler) APIGHandler { + return func(ctx *APIGContext) { + handler(ctx) + if ctx.Err != nil { + ctx.Body, _ = json.Marshal(map[string]string{"error": ctx.Err.Error()}) + if strings.Contains(ctx.Err.Error(), "record not found") { + ctx.Status = 204 + } else if ctx.Status != 204 && ctx.Status < 400 { + ctx.Status = 400 + } + } + } + } rtr := NewAPIGRouter(&APIGRouterConfig{ Request: &request, Prefix: "/shipping", @@ -20,6 +35,9 @@ func TestRouterSpec(t *testing.T) { "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Credentials": "true", }, + Middleware: []APIGMiddleware{ + errorReporter, + }, }) Convey("When the handler func does NOT return an error", func() { @@ -144,24 +162,20 @@ func TestRouterSpec(t *testing.T) { }) Convey("When the handler func does return a status < 400", func() { - middlefunc1 := func(ctx *APIGContext) { - ctx.Status = http.StatusOK - ctx.Body = []byte("hello") - ctx.Err = nil - } - middlefunc2 := func(ctx *APIGContext) { - ctx.Status = http.StatusOK - ctx.Body = []byte("hello") - ctx.Err = errors.New("bad request") + middlefunc1 := func(handler APIGHandler) APIGHandler { + return func(ctx *APIGContext) { + ctx.Status = http.StatusOK + ctx.Err = errors.New("bad request") + handler(ctx) + } } hdlrfunc := func(ctx *APIGContext) { ctx.Status = http.StatusOK ctx.Body = []byte("hello") - ctx.Err = nil } Convey("And a Get handler expecting the pattern /listings/{id}/state/{event} is defined", func() { - rtr.Get("/listings/{id}/state/{event}", middlefunc1, middlefunc2, hdlrfunc) + rtr.Get("/listings/{id}/state/{event}", hdlrfunc, middlefunc1) Convey("And the request matches the pattern and the path params are filled", func() { request.HTTPMethod = http.MethodGet