Pollers + config

This commit is contained in:
Cat /dev/Nulo 2023-01-11 20:38:19 -03:00
parent c28859eb84
commit 59ff8f2a23
4 changed files with 156 additions and 19 deletions

72
config.go Normal file
View file

@ -0,0 +1,72 @@
package main
import (
"encoding/json"
"errors"
"net/http"
"os"
"nulo.in/ddnser/nameservers"
)
type config struct {
Ip string `json:"ip,omitempty"`
Every int `json:"every,omitempty"`
Domains []domain `json:"domains"`
}
type domain struct {
Type string `json:"type"`
Name string `json:"name"`
// TODO: lograr que esto sea un coso de propiedades arbitrario
Key string `json:"key"`
}
func parseConfig(path string) (config config, err error) {
bytes, err := os.ReadFile(path)
if err != nil {
return
}
err = json.Unmarshal(bytes, &config)
return
}
type State struct {
HTTPClient http.Client
Ip string
// Every defines how often (in seconds) poll for DDNS.
// -1 means never poll.
Every int
Domains []Domain
}
type Domain struct {
Name string
NameServer nameservers.NameServer
}
func LoadConfig(path string) (state State, err error) {
parsed, err := parseConfig(path)
if err != nil {
return
}
state.Ip = parsed.Ip
state.Every = parsed.Every
// if not defined or 0, set to default
if state.Every == 0 {
state.Every = 300
}
for _, d := range parsed.Domains {
switch d.Type {
case "njalla ddns":
state.Domains = append(state.Domains, Domain{
Name: d.Name,
NameServer: &nameservers.Njalla{HTTPClient: &state.HTTPClient, Key: d.Key},
})
default:
err = errors.New("I don't know the service type " + d.Type)
return
}
}
return
}

74
main.go
View file

@ -1,22 +1,80 @@
package main
import (
"context"
"log"
"nulo.in/ddnser/nameservers"
"os"
"os/exec"
"strings"
"time"
)
type refreshChan chan string
func main() {
njalla := nameservers.Njalla{Key: "yourkey"}
record, err := njalla.SetRecord("estoesprueba.nulo.in", "")
if len(os.Args) < 2 {
log.Fatal("Must provide config file path as first argument")
}
config, err := LoadConfig(os.Args[1])
if err != nil {
log.Fatal(err)
}
log.Println(record)
record, err = njalla.SetRecord("estoesprueba.nulo.in", "1.1.1.1")
refr := make(refreshChan)
errch := make(chan error)
go poller(config.Every, refr)
go ipPoller(refr, errch)
for {
select {
case reason := <-refr:
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(10)*time.Second)
defer cancel()
log.Printf("Running because of %s", reason)
for _, d := range config.Domains {
record, err := d.NameServer.SetRecord(ctx, d.Name, config.Ip)
if err != nil {
log.Fatal(err)
log.Println(err)
continue
}
log.Println(record)
log.Printf("[%s] Set to %s", d.Name, record)
}
case err := <-errch:
log.Fatal(err)
}
}
}
func poller(every int, refr refreshChan) {
for {
time.Sleep(time.Duration(every) * time.Minute)
refr <- "poll"
}
}
func ipPoller(refr refreshChan, errch chan error) {
var lastWatched string
for {
cmd := exec.Command("ip", "address")
out, err := cmd.Output()
if err != nil {
errch <- err
return
}
var addrs string
lines := strings.Split(string(out), "\n")
for _, line := range lines {
prefix := " inet "
if strings.Index(line, prefix) == 0 {
last := strings.Index(line[len(prefix):], "/")
ip := line[len(prefix) : len(prefix)+last]
addrs = addrs + ip + "/"
}
}
if addrs != lastWatched {
refr <- "ip changed"
lastWatched = addrs
}
time.Sleep(time.Duration(2) * time.Second)
}
}

View file

@ -1,5 +1,7 @@
package nameservers
import "context"
type NameServer interface {
SetRecord(domain string, overrideIp string) (string, error)
SetRecord(ctx context.Context, domain string, overrideIp string) (string, error)
}

View file

@ -1,6 +1,7 @@
package nameservers
import (
"context"
"encoding/json"
"errors"
"io"
@ -11,7 +12,7 @@ import (
)
type Njalla struct {
httpClient http.Client
HTTPClient *http.Client
Key string
}
@ -25,7 +26,7 @@ type njallaResponse struct {
Value njallaValue `json:"value"`
}
func (n *Njalla) SetRecord(domain string, overrideIp string) (string, error) {
func (n *Njalla) SetRecord(ctx context.Context, domain string, overrideIp string) (string, error) {
u, _ := url.Parse("https://njal.la/update/")
values := url.Values{
"h": {domain},
@ -39,7 +40,11 @@ func (n *Njalla) SetRecord(domain string, overrideIp string) (string, error) {
}
u.RawQuery = values.Encode()
resp, err := n.httpClient.Get(u.String())
req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
if err != nil {
return "", err
}
resp, err := n.HTTPClient.Do(req)
if err != nil {
return "", err
}
@ -52,7 +57,7 @@ func (n *Njalla) SetRecord(domain string, overrideIp string) (string, error) {
if err != nil {
return "", err
}
log.Println(string(body))
log.Printf("[njalla ddns] Response: %s", string(body))
if resp.StatusCode < 200 || resp.StatusCode > 299 {
return "", errors.New("Not nice status code: " + strconv.Itoa(resp.StatusCode) + " with body: " + string(body))