Changed handlers to veriadic parameter inputs, allowing for multiple

middleware functions
This commit is contained in:
mitchelljfs 2018-07-16 16:44:38 -07:00
parent 2005f453c6
commit 0a73e3f069
2 changed files with 55 additions and 42 deletions

View File

@ -66,35 +66,38 @@ func NewAPIGRouter(r *events.APIGatewayProxyRequest, svcprefix string) *APIGRout
} }
// Get creates a new get endpoint. // Get creates a new get endpoint.
func (r *APIGRouter) Get(route string, handler APIGHandler) { func (r *APIGRouter) Get(route string, handlers ...APIGHandler) {
r.addEndpoint(get, route, handler) r.addEndpoint(get, route, handlers)
} }
// Post creates a new post endpoint. // Post creates a new post endpoint.
func (r *APIGRouter) Post(route string, handler APIGHandler) { func (r *APIGRouter) Post(route string, handlers ...APIGHandler) {
r.addEndpoint(post, route, handler) r.addEndpoint(post, route, handlers)
} }
// Put creates a new put endpoint. // Put creates a new put endpoint.
func (r *APIGRouter) Put(route string, handler APIGHandler) { func (r *APIGRouter) Put(route string, handlers ...APIGHandler) {
r.addEndpoint(put, route, handler) r.addEndpoint(put, route, handlers)
} }
// Patch creates a new patch endpoint // Patch creates a new patch endpoint
func (r *APIGRouter) Patch(route string, handler APIGHandler) { func (r *APIGRouter) Patch(route string, handlers ...APIGHandler) {
r.addEndpoint(patch, route, handler) r.addEndpoint(patch, route, handlers)
} }
// Delete creates a new delete endpoint. // Delete creates a new delete endpoint.
func (r *APIGRouter) Delete(route string, handler APIGHandler) { func (r *APIGRouter) Delete(route string, handlers ...APIGHandler) {
r.addEndpoint(delete, route, handler) r.addEndpoint(delete, route, handlers)
} }
// Respond returns an APIGatewayProxyResponse to respond to the lambda request. // Respond returns an APIGatewayProxyResponse to respond to the lambda request.
func (r *APIGRouter) Respond() events.APIGatewayProxyResponse { func (r *APIGRouter) Respond() events.APIGatewayProxyResponse {
var ( var (
handlerInterface interface{} handlersInterface interface{}
ok bool ok bool
status int
respbody []byte
err error
endpointTree = r.endpoints[r.request.HTTPMethod] endpointTree = r.endpoints[r.request.HTTPMethod]
path = strings.TrimPrefix(r.request.Path, "/"+r.svcprefix) path = strings.TrimPrefix(r.request.Path, "/"+r.svcprefix)
@ -109,7 +112,7 @@ func (r *APIGRouter) Respond() events.APIGatewayProxyResponse {
} }
} }
if handlerInterface, ok = endpointTree.Get(path); !ok { if handlersInterface, ok = endpointTree.Get(path); !ok {
respbody, _ := json.Marshal(map[string]string{"error": "no route matching path found"}) respbody, _ := json.Marshal(map[string]string{"error": "no route matching path found"})
response.StatusCode = http.StatusNotFound response.StatusCode = http.StatusNotFound
@ -117,32 +120,34 @@ func (r *APIGRouter) Respond() events.APIGatewayProxyResponse {
return response return response
} }
handler := handlerInterface.(APIGHandler) handlers := handlersInterface.([]APIGHandler)
req := &APIGRequest{ for _, handler := range handlers {
Path: r.request.PathParameters, req := &APIGRequest{
QryStr: r.request.QueryStringParameters, Path: r.request.PathParameters,
Request: r.request, QryStr: r.request.QueryStringParameters,
} Request: r.request,
if r.request.RequestContext.Authorizer["claims"] != nil {
req.Claims = r.request.RequestContext.Authorizer["claims"].(map[string]interface{})
}
res := &APIGResponse{}
handler(req, res)
status, respbody, err := res.deconstruct()
if err != nil {
respbody, _ := json.Marshal(map[string]string{"error": err.Error()})
if strings.Contains(err.Error(), "record not found") {
status = 204
} else if status < 400 {
status = 400
} }
if r.request.RequestContext.Authorizer["claims"] != nil {
req.Claims = r.request.RequestContext.Authorizer["claims"].(map[string]interface{})
}
res := &APIGResponse{}
response.StatusCode = status handler(req, res)
response.Body = string(respbody) status, respbody, err = res.deconstruct()
return response
if err != nil {
respbody, _ := json.Marshal(map[string]string{"error": err.Error()})
if strings.Contains(err.Error(), "record not found") {
status = 204
} else if status < 400 {
status = 400
}
response.StatusCode = status
response.Body = string(respbody)
return response
}
} }
response.StatusCode = status response.StatusCode = status
@ -161,8 +166,8 @@ func (res *APIGResponse) deconstruct() (int, []byte, error) {
return res.Status, res.Body, res.Err return res.Status, res.Body, res.Err
} }
func (r *APIGRouter) addEndpoint(method string, route string, handler APIGHandler) { func (r *APIGRouter) addEndpoint(method string, route string, handlers []APIGHandler) {
if _, overwrite := r.endpoints[method].Insert(route, handler); overwrite { if _, overwrite := r.endpoints[method].Insert(route, handlers); overwrite {
panic("endpoint already existent") panic("endpoint already existent")
} }

View File

@ -20,7 +20,6 @@ func TestRouterSpec(t *testing.T) {
res.Status = http.StatusOK res.Status = http.StatusOK
res.Body = []byte("hello") res.Body = []byte("hello")
res.Err = nil res.Err = nil
} }
Convey("And a Get handler expecting the pattern /orders/filter/by_user/{id} is defined", func() { Convey("And a Get handler expecting the pattern /orders/filter/by_user/{id} is defined", func() {
@ -128,15 +127,24 @@ func TestRouterSpec(t *testing.T) {
}) })
Convey("When the handler func does return a status < 400", func() { Convey("When the handler func does return a status < 400", func() {
hdlrfunc := func(req *APIGRequest, res *APIGResponse) { middlefunc1 := func(req *APIGRequest, res *APIGResponse) {
res.Status = http.StatusOK
res.Body = []byte("hello")
res.Err = nil
}
middlefunc2 := func(req *APIGRequest, res *APIGResponse) {
res.Status = http.StatusOK res.Status = http.StatusOK
res.Body = []byte("hello") res.Body = []byte("hello")
res.Err = errors.New("bad request") res.Err = errors.New("bad request")
}
hdlrfunc := func(req *APIGRequest, res *APIGResponse) {
res.Status = http.StatusOK
res.Body = []byte("hello")
res.Err = nil
} }
Convey("And a Get handler expecting the pattern /orders/filter/by_user/{id} is defined", func() { Convey("And a Get handler expecting the pattern /orders/filter/by_user/{id} is defined", func() {
rtr.Get("/orders/filter/by_user/{id}", hdlrfunc) rtr.Get("/orders/filter/by_user/{id}", middlefunc1, middlefunc2, hdlrfunc)
Convey("And the request matches the pattern and the path params are filled", func() { Convey("And the request matches the pattern and the path params are filled", func() {
request.HTTPMethod = http.MethodGet request.HTTPMethod = http.MethodGet