diff --git a/agent/main.go b/agent/main.go index 1f1a876..a51939a 100644 --- a/agent/main.go +++ b/agent/main.go @@ -1,7 +1,6 @@ package main import ( - "bytes" "errors" "io" "io/ioutil" @@ -10,9 +9,13 @@ import ( "os" "os/exec" "strings" + "sync" + "github.com/jaevor/go-nanoid" "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" + gommon "github.com/labstack/gommon/log" + "golang.org/x/net/websocket" ) func main() { @@ -24,11 +27,15 @@ func main() { e := echo.New() e.Use(middleware.Logger()) - e.Use(middleware.Recover()) - e.Use(auth{secret: sec}.middleware) + e.Use(middleware.RecoverWithConfig(middleware.RecoverConfig{ + Skipper: middleware.DefaultSkipper, + LogLevel: gommon.ERROR, + })) + au := auth{secret: sec} e.GET("/hello", hello) - e.POST("/run", run) + e.POST("/run", runHandler, au.headerMiddleware) + e.GET("/run/:id", runLogHandler, au.queryMiddleware) e.Logger.Fatal(e.Start(":8080")) } @@ -38,56 +45,190 @@ func hello(c echo.Context) error { } type runResp struct { - Stdout string - Stderr string - Error string + RunId string } -func run(c echo.Context) error { +type run struct { + mutex sync.Mutex + clients []client + prev []byte + msgChan chan []byte +} + +func newRun() *run { + r := &run{ + msgChan: make(chan []byte), + } + go r.pump() + return r +} +func (r *run) newClient(ws *websocket.Conn) { + cl := client{conn: ws, msgChan: make(chan []byte)} + prev := func() []byte { + r.mutex.Lock() + defer r.mutex.Unlock() + r.clients = append(r.clients, cl) + return r.prev + }() + cl.pump(prev) +} +func (r *run) pump() { + for { + select { + case msg, ok := <-r.msgChan: + func() { + r.mutex.Lock() + defer r.mutex.Unlock() + if !ok { + for _, c := range r.clients { + close(c.msgChan) + } + return + } + r.prev = append(r.prev, msg...) + for _, c := range r.clients { + c.msgChan <- msg + } + }() + if !ok { + return + } + + } + } +} + +// Write implements io.Writer +func (r *run) Write(p []byte) (n int, err error) { + r.msgChan <- p + return len(p), nil +} + +// Close implements io.Closer +func (r *run) Close() error { + close(r.msgChan) + return nil +} + +type client struct { + conn *websocket.Conn + msgChan chan []byte +} + +func (c *client) pump(prev []byte) { + _, err := c.conn.Write(prev) + if err != nil { + close(c.msgChan) + return + } + for { + select { + case msg, ok := <-c.msgChan: + if !ok { + return + } + _, err := c.conn.Write(msg) + if err != nil { + close(c.msgChan) + return + } + } + } +} + +var runs = struct { + runs map[string]*run + sync.Mutex +}{runs: make(map[string]*run)} + +func runHandler(c echo.Context) error { f, err := os.CreateTemp("", "fireactions-agent-*") if err != nil { - panic(err) + log.Panicln(err) } if _, err = io.Copy(f, c.Request().Body); err != nil { - panic(err) + log.Panicln(err) } if err = f.Close(); err != nil { - panic(err) + log.Panicln(err) } if err = os.Chmod(f.Name(), 0700); err != nil { - panic(err) + log.Panicln(err) } + nanid, err := nanoid.Standard(21) + if err != nil { + log.Panicln(err) + } + + id := nanid() + run := newRun() + func() { + runs.Lock() + defer runs.Unlock() + runs.runs[id] = run + }() + cmd := exec.Command(f.Name()) - var stderr bytes.Buffer - var stdout bytes.Buffer - cmd.Stdout = &stdout - cmd.Stderr = &stderr - err = cmd.Run() - errorr := "" + cmd.Stdout = run + cmd.Stderr = run + err = cmd.Start() + go func() { + defer run.Close() + err := cmd.Wait() + if err != nil { + log.Println(err) + } + }() if err != nil { log.Println(err) - errorr = err.Error() + return c.String(http.StatusInternalServerError, "error running command") } return c.JSON(http.StatusOK, runResp{ - Stdout: string(stdout.Bytes()), - Stderr: string(stderr.Bytes()), - Error: errorr, + RunId: id, }) } +func runLogHandler(c echo.Context) error { + id := c.Param("id") + r, exists := func() (*run, bool) { + runs.Lock() + defer runs.Unlock() + r, ok := runs.runs[id] + return r, ok + }() + if !exists { + return c.String(http.StatusNotFound, "run not found") + } + websocket.Handler(func(ws *websocket.Conn) { + r.newClient(ws) + }).ServeHTTP(c.Response(), c.Request()) + return nil +} + type auth struct { secret string } -func (a auth) middleware(next echo.HandlerFunc) echo.HandlerFunc { +func (a auth) headerMiddleware(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { s := strings.Split(c.Request().Header.Get("Authorization"), " ") if len(s) < 2 { return c.String(http.StatusBadRequest, "fuck no") } sec := s[1] + log.Println(len(sec), len(a.secret)) + if sec == a.secret { + return next(c) + } else { + return c.String(http.StatusUnauthorized, "wrong secret") + } + } +} +func (a auth) queryMiddleware(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + sec := c.Request().URL.Query().Get("token") if sec == a.secret { return next(c) } else { @@ -104,6 +245,9 @@ func parseSecret() (string, error) { opts := strings.Split(string(byt), " ") for _, opt := range opts { s := strings.Split(opt, "=") + if len(s) != 2 { + continue + } key := s[0] val := s[1] if key == "fireactions.secret" {