WIP: agent: websocket

This commit is contained in:
Cat /dev/Nulo 2023-06-03 12:16:34 -03:00
parent 87e1a509a8
commit e3995c1926

View file

@ -1,7 +1,6 @@
package main
import (
"bytes"
"errors"
"io"
"io/ioutil"
@ -10,9 +9,13 @@ import (
"os"
"os/exec"
"strings"
"sync"
"github.com/jaevor/go-nanoid"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
gommon "github.com/labstack/gommon/log"
"golang.org/x/net/websocket"
)
func main() {
@ -24,11 +27,15 @@ func main() {
e := echo.New()
e.Use(middleware.Logger())
e.Use(middleware.Recover())
e.Use(auth{secret: sec}.middleware)
e.Use(middleware.RecoverWithConfig(middleware.RecoverConfig{
Skipper: middleware.DefaultSkipper,
LogLevel: gommon.ERROR,
}))
au := auth{secret: sec}
e.GET("/hello", hello)
e.POST("/run", run)
e.POST("/run", runHandler, au.headerMiddleware)
e.GET("/run/:id", runLogHandler, au.queryMiddleware)
e.Logger.Fatal(e.Start(":8080"))
}
@ -38,56 +45,190 @@ func hello(c echo.Context) error {
}
type runResp struct {
Stdout string
Stderr string
Error string
RunId string
}
func run(c echo.Context) error {
type run struct {
mutex sync.Mutex
clients []client
prev []byte
msgChan chan []byte
}
func newRun() *run {
r := &run{
msgChan: make(chan []byte),
}
go r.pump()
return r
}
func (r *run) newClient(ws *websocket.Conn) {
cl := client{conn: ws, msgChan: make(chan []byte)}
prev := func() []byte {
r.mutex.Lock()
defer r.mutex.Unlock()
r.clients = append(r.clients, cl)
return r.prev
}()
cl.pump(prev)
}
func (r *run) pump() {
for {
select {
case msg, ok := <-r.msgChan:
func() {
r.mutex.Lock()
defer r.mutex.Unlock()
if !ok {
for _, c := range r.clients {
close(c.msgChan)
}
return
}
r.prev = append(r.prev, msg...)
for _, c := range r.clients {
c.msgChan <- msg
}
}()
if !ok {
return
}
}
}
}
// Write implements io.Writer
func (r *run) Write(p []byte) (n int, err error) {
r.msgChan <- p
return len(p), nil
}
// Close implements io.Closer
func (r *run) Close() error {
close(r.msgChan)
return nil
}
type client struct {
conn *websocket.Conn
msgChan chan []byte
}
func (c *client) pump(prev []byte) {
_, err := c.conn.Write(prev)
if err != nil {
close(c.msgChan)
return
}
for {
select {
case msg, ok := <-c.msgChan:
if !ok {
return
}
_, err := c.conn.Write(msg)
if err != nil {
close(c.msgChan)
return
}
}
}
}
var runs = struct {
runs map[string]*run
sync.Mutex
}{runs: make(map[string]*run)}
func runHandler(c echo.Context) error {
f, err := os.CreateTemp("", "fireactions-agent-*")
if err != nil {
panic(err)
log.Panicln(err)
}
if _, err = io.Copy(f, c.Request().Body); err != nil {
panic(err)
log.Panicln(err)
}
if err = f.Close(); err != nil {
panic(err)
log.Panicln(err)
}
if err = os.Chmod(f.Name(), 0700); err != nil {
panic(err)
log.Panicln(err)
}
nanid, err := nanoid.Standard(21)
if err != nil {
log.Panicln(err)
}
id := nanid()
run := newRun()
func() {
runs.Lock()
defer runs.Unlock()
runs.runs[id] = run
}()
cmd := exec.Command(f.Name())
var stderr bytes.Buffer
var stdout bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
err = cmd.Run()
errorr := ""
cmd.Stdout = run
cmd.Stderr = run
err = cmd.Start()
go func() {
defer run.Close()
err := cmd.Wait()
if err != nil {
log.Println(err)
}
}()
if err != nil {
log.Println(err)
errorr = err.Error()
return c.String(http.StatusInternalServerError, "error running command")
}
return c.JSON(http.StatusOK, runResp{
Stdout: string(stdout.Bytes()),
Stderr: string(stderr.Bytes()),
Error: errorr,
RunId: id,
})
}
func runLogHandler(c echo.Context) error {
id := c.Param("id")
r, exists := func() (*run, bool) {
runs.Lock()
defer runs.Unlock()
r, ok := runs.runs[id]
return r, ok
}()
if !exists {
return c.String(http.StatusNotFound, "run not found")
}
websocket.Handler(func(ws *websocket.Conn) {
r.newClient(ws)
}).ServeHTTP(c.Response(), c.Request())
return nil
}
type auth struct {
secret string
}
func (a auth) middleware(next echo.HandlerFunc) echo.HandlerFunc {
func (a auth) headerMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
s := strings.Split(c.Request().Header.Get("Authorization"), " ")
if len(s) < 2 {
return c.String(http.StatusBadRequest, "fuck no")
}
sec := s[1]
log.Println(len(sec), len(a.secret))
if sec == a.secret {
return next(c)
} else {
return c.String(http.StatusUnauthorized, "wrong secret")
}
}
}
func (a auth) queryMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
sec := c.Request().URL.Query().Get("token")
if sec == a.secret {
return next(c)
} else {
@ -104,6 +245,9 @@ func parseSecret() (string, error) {
opts := strings.Split(string(byt), " ")
for _, opt := range opts {
s := strings.Split(opt, "=")
if len(s) != 2 {
continue
}
key := s[0]
val := s[1]
if key == "fireactions.secret" {