Add migration from Machine.EnabledRoutes to the new Route table

This commit is contained in:
Juan Font 2022-11-23 18:52:09 +00:00
parent c78403691c
commit 28ef3145c5

87
db.go
View file

@ -18,8 +18,10 @@ import (
) )
const ( const (
dbVersion = "1" dbVersion = "1"
errValueNotFound = Error("not found")
errValueNotFound = Error("not found")
ErrCannotParsePrefix = Error("cannot parse prefix")
) )
// KV is a key-value store in a psql table. For future use... // KV is a key-value store in a psql table. For future use...
@ -79,6 +81,65 @@ func (h *Headscale) initDB() error {
} }
} }
err = db.AutoMigrate(&Route{})
if err != nil {
return err
}
if db.Migrator().HasColumn(&Machine{}, "enabled_routes") {
log.Info().Msgf("Database has legacy enabled_routes column in machine, migrating...")
type MachineAux struct {
ID uint64
EnabledRoutes IPPrefixes
}
machinesAux := []MachineAux{}
err := db.Table("machines").Select("id, enabled_routes").Scan(&machinesAux).Error
if err != nil {
log.Fatal().Err(err).Msg("Error accessing db")
}
for _, machine := range machinesAux {
for _, prefix := range machine.EnabledRoutes {
if err != nil {
log.Error().
Err(err).
Str("enabled_route", prefix.String()).
Msg("Error parsing enabled_route")
continue
}
err = db.Preload("Machine").Where("machine_id = ? AND prefix = ?", machine.ID, IPPrefix(prefix)).First(&Route{}).Error
if err == nil {
log.Info().
Str("enabled_route", prefix.String()).
Msg("Route already migrated to new table, skipping")
continue
}
route := Route{
MachineID: machine.ID,
Advertised: true,
Enabled: true,
Prefix: IPPrefix(prefix),
}
if err := h.db.Create(&route).Error; err != nil {
log.Error().Err(err).Msg("Error creating route")
} else {
log.Info().
Uint64("machine_id", route.MachineID).
Str("prefix", prefix.String()).
Msg("Route migrated")
}
}
}
err = db.Migrator().DropColumn(&Machine{}, "enabled_routes")
if err != nil {
log.Error().Err(err).Msg("Error dropping enabled_routes column")
}
}
err = db.AutoMigrate(&Machine{}) err = db.AutoMigrate(&Machine{})
if err != nil { if err != nil {
return err return err
@ -264,6 +325,28 @@ func (hi HostInfo) Value() (driver.Value, error) {
return string(bytes), err return string(bytes), err
} }
type IPPrefix netip.Prefix
func (i *IPPrefix) Scan(destination interface{}) error {
switch value := destination.(type) {
case string:
prefix, err := netip.ParsePrefix(value)
if err != nil {
return err
}
*i = IPPrefix(prefix)
return nil
default:
return fmt.Errorf("%w: unexpected data type %T", ErrCannotParsePrefix, destination)
}
}
// Value return json value, implement driver.Valuer interface.
func (i IPPrefix) Value() (driver.Value, error) {
prefixStr := netip.Prefix(i).String()
return prefixStr, nil
}
type IPPrefixes []netip.Prefix type IPPrefixes []netip.Prefix
func (i *IPPrefixes) Scan(destination interface{}) error { func (i *IPPrefixes) Scan(destination interface{}) error {