Rework getAvailableIp

This commit reworks getAvailableIp with a "simpler" version that will
look for the first available IP address in our IP Prefix.

There is a couple of ideas behind this:

* Make the host IPs reasonably predictable and in within similar
  subnets, which should simplify ACLs for subnets
* The code is not random, but deterministic so we can have tests
* The code is a bit more understandable (no bit shift magic)
This commit is contained in:
Kristoffer Dalby 2021-08-02 21:57:45 +01:00
parent 309f868a21
commit b5841c8a8b
4 changed files with 170 additions and 45 deletions

107
utils.go
View file

@ -7,18 +7,11 @@ package headscale
import (
"crypto/rand"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"time"
mathrand "math/rand"
"golang.org/x/crypto/nacl/box"
"gorm.io/gorm"
"inet.af/netaddr"
"tailscale.com/types/wgkey"
)
@ -78,47 +71,73 @@ func encodeMsg(b []byte, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, err
return msg, nil
}
func (h *Headscale) getAvailableIP() (*net.IP, error) {
i := 0
func (h *Headscale) getAvailableIP() (*netaddr.IP, error) {
ipPrefix := h.cfg.IPPrefix
usedIps, err := h.getUsedIPs()
if err != nil {
return nil, err
}
// for _, ip := range usedIps {
// nextIP := ip.Next()
// if !containsIPs(usedIps, nextIP) && ipPrefix.Contains(nextIP) {
// return &nextIP, nil
// }
// }
// // If there are no IPs in use, we are starting fresh and
// // can issue IPs from the beginning of the prefix.
// ip := ipPrefix.IP()
// return &ip, nil
// return nil, fmt.Errorf("failed to find any available IP in %s", ipPrefix)
// Get the first IP in our prefix
ip := ipPrefix.IP()
for {
ip, err := getRandomIP(h.cfg.IPPrefix)
if !ipPrefix.Contains(ip) {
return nil, fmt.Errorf("could not find any suitable IP in %s", ipPrefix)
}
if ip.IsZero() &&
ip.IsLoopback() {
continue
}
if !containsIPs(usedIps, ip) {
return &ip, nil
}
ip = ip.Next()
}
}
func (h *Headscale) getUsedIPs() ([]netaddr.IP, error) {
var addresses []string
h.db.Model(&Machine{}).Pluck("ip_address", &addresses)
ips := make([]netaddr.IP, len(addresses))
for index, addr := range addresses {
ip, err := netaddr.ParseIP(addr)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to parse ip from database, %w", err)
}
m := Machine{}
if result := h.db.First(&m, "ip_address = ?", ip.String()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
return ip, nil
}
i++
if i == 100 { // really random number
break
}
}
return nil, errors.New(fmt.Sprintf("Could not find an available IP address in %s", h.cfg.IPPrefix.String()))
}
func getRandomIP(ipPrefix netaddr.IPPrefix) (*net.IP, error) {
mathrand.Seed(time.Now().Unix())
ipo, ipnet, err := net.ParseCIDR(ipPrefix.String())
if err == nil {
ip := ipo.To4()
// fmt.Println("In Randomize IPAddr: IP ", ip, " IPNET: ", ipnet)
// fmt.Println("Final address is ", ip)
// fmt.Println("Broadcast address is ", ipb)
// fmt.Println("Network address is ", ipn)
r := mathrand.Uint32()
ipRaw := make([]byte, 4)
binary.LittleEndian.PutUint32(ipRaw, r)
// ipRaw[3] = 254
// fmt.Println("ipRaw is ", ipRaw)
for i, v := range ipRaw {
// fmt.Println("IP Before: ", ip[i], " v is ", v, " Mask is: ", ipnet.Mask[i])
ip[i] = ip[i] + (v &^ ipnet.Mask[i])
// fmt.Println("IP After: ", ip[i])
}
// fmt.Println("FINAL IP: ", ip.String())
return &ip, nil
ips[index] = ip
}
return nil, err
return ips, nil
}
func containsIPs(ips []netaddr.IP, ip netaddr.IP) bool {
for _, v := range ips {
if v == ip {
return true
}
}
return false
}