diff --git a/db.go b/db.go index fc539cc3..72386ca8 100644 --- a/db.go +++ b/db.go @@ -18,8 +18,10 @@ import ( ) const ( - dbVersion = "1" - errValueNotFound = Error("not found") + dbVersion = "1" + + errValueNotFound = Error("not found") + ErrCannotParsePrefix = Error("cannot parse prefix") ) // KV is a key-value store in a psql table. For future use... @@ -79,6 +81,65 @@ func (h *Headscale) initDB() error { } } + err = db.AutoMigrate(&Route{}) + if err != nil { + return err + } + + if db.Migrator().HasColumn(&Machine{}, "enabled_routes") { + log.Info().Msgf("Database has legacy enabled_routes column in machine, migrating...") + + type MachineAux struct { + ID uint64 + EnabledRoutes IPPrefixes + } + + machinesAux := []MachineAux{} + err := db.Table("machines").Select("id, enabled_routes").Scan(&machinesAux).Error + if err != nil { + log.Fatal().Err(err).Msg("Error accessing db") + } + for _, machine := range machinesAux { + for _, prefix := range machine.EnabledRoutes { + if err != nil { + log.Error(). + Err(err). + Str("enabled_route", prefix.String()). + Msg("Error parsing enabled_route") + continue + } + + err = db.Preload("Machine").Where("machine_id = ? AND prefix = ?", machine.ID, IPPrefix(prefix)).First(&Route{}).Error + if err == nil { + log.Info(). + Str("enabled_route", prefix.String()). + Msg("Route already migrated to new table, skipping") + continue + } + + route := Route{ + MachineID: machine.ID, + Advertised: true, + Enabled: true, + Prefix: IPPrefix(prefix), + } + if err := h.db.Create(&route).Error; err != nil { + log.Error().Err(err).Msg("Error creating route") + } else { + log.Info(). + Uint64("machine_id", route.MachineID). + Str("prefix", prefix.String()). + Msg("Route migrated") + } + } + } + + err = db.Migrator().DropColumn(&Machine{}, "enabled_routes") + if err != nil { + log.Error().Err(err).Msg("Error dropping enabled_routes column") + } + } + err = db.AutoMigrate(&Machine{}) if err != nil { return err @@ -264,6 +325,28 @@ func (hi HostInfo) Value() (driver.Value, error) { return string(bytes), err } +type IPPrefix netip.Prefix + +func (i *IPPrefix) Scan(destination interface{}) error { + switch value := destination.(type) { + case string: + prefix, err := netip.ParsePrefix(value) + if err != nil { + return err + } + *i = IPPrefix(prefix) + return nil + default: + return fmt.Errorf("%w: unexpected data type %T", ErrCannotParsePrefix, destination) + } +} + +// Value return json value, implement driver.Valuer interface. +func (i IPPrefix) Value() (driver.Value, error) { + prefixStr := netip.Prefix(i).String() + return prefixStr, nil +} + type IPPrefixes []netip.Prefix func (i *IPPrefixes) Scan(destination interface{}) error {