clean up use of log.Error where errors could be wrapped

Replace a lot of occurences of log.Error with fmt.Errorf,
bubbling the error up the chain instead.
This commit is contained in:
Kristoffer Dalby 2024-04-10 14:49:34 +02:00 committed by Juan Font
parent 58c94d2bd3
commit bf4fd078fc

View file

@ -2,6 +2,7 @@ package db
import ( import (
"errors" "errors"
"fmt"
"net/netip" "net/netip"
"github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/policy"
@ -252,20 +253,20 @@ func DeleteRoute(
func deleteNodeRoutes(tx *gorm.DB, node *types.Node, isConnected types.NodeConnectedMap) ([]types.NodeID, error) { func deleteNodeRoutes(tx *gorm.DB, node *types.Node, isConnected types.NodeConnectedMap) ([]types.NodeID, error) {
routes, err := GetNodeRoutes(tx, node) routes, err := GetNodeRoutes(tx, node)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("getting node routes: %w", err)
} }
var changed []types.NodeID var changed []types.NodeID
for i := range routes { for i := range routes {
if err := tx.Unscoped().Delete(&routes[i]).Error; err != nil { if err := tx.Unscoped().Delete(&routes[i]).Error; err != nil {
return nil, err return nil, fmt.Errorf("deleting route(%d): %w", &routes[i].ID, err)
} }
// TODO(kradalby): This is a bit too aggressive, we could probably // TODO(kradalby): This is a bit too aggressive, we could probably
// figure out which routes needs to be failed over rather than all. // figure out which routes needs to be failed over rather than all.
chn, err := failoverRouteTx(tx, isConnected, &routes[i]) chn, err := failoverRouteTx(tx, isConnected, &routes[i])
if err != nil { if err != nil {
return changed, err return changed, fmt.Errorf("failing over route after delete: %w", err)
} }
if chn != nil { if chn != nil {
@ -410,10 +411,8 @@ func FailoverRouteIfAvailable(
isConnected types.NodeConnectedMap, isConnected types.NodeConnectedMap,
node *types.Node, node *types.Node,
) (*types.StateUpdate, error) { ) (*types.StateUpdate, error) {
log.Debug().Caller().Uint64("node.id", node.ID.Uint64()).Msgf("ROUTE DEBUG ENTERED FAILOVER")
nodeRoutes, err := GetNodeRoutes(tx, node) nodeRoutes, err := GetNodeRoutes(tx, node)
if err != nil { if err != nil {
log.Debug().Caller().Uint64("node.id", node.ID.Uint64()).Interface("nodeRoutes", nodeRoutes).Msgf("ROUTE DEBUG NO ROUTES")
return nil, nil return nil, nil
} }
@ -421,34 +420,31 @@ func FailoverRouteIfAvailable(
for _, nodeRoute := range nodeRoutes { for _, nodeRoute := range nodeRoutes {
routes, err := getRoutesByPrefix(tx, netip.Prefix(nodeRoute.Prefix)) routes, err := getRoutesByPrefix(tx, netip.Prefix(nodeRoute.Prefix))
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("getting routes by prefix: %w", err)
} }
for _, route := range routes { for _, route := range routes {
if route.IsPrimary { if route.IsPrimary {
// if we have a primary route, and the node is connected // if we have a primary route, and the node is connected
// nothing needs to be done. // nothing needs to be done.
log.Debug().Caller().Uint64("node.id", node.ID.Uint64()).Uint64("route.node.id", route.Node.ID.Uint64()).Msgf("ROUTE DEBUG CHECKING IF ONLINE")
if isConnected[route.Node.ID] { if isConnected[route.Node.ID] {
log.Debug().Caller().Uint64("node.id", node.ID.Uint64()).Uint64("route.node.id", route.Node.ID.Uint64()).Msgf("ROUTE DEBUG IS ONLINE")
return nil, nil return nil, nil
} }
log.Debug().Caller().Uint64("node.id", node.ID.Uint64()).Uint64("route.node.id", route.Node.ID.Uint64()).Msgf("ROUTE DEBUG NOT ONLINE, FAILING OVER")
// if not, we need to failover the route // if not, we need to failover the route
changedIDs, err := failoverRouteTx(tx, isConnected, &route) failover := failoverRoute(isConnected, &route, routes)
if failover != nil {
failover.save(tx)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("saving failover routes: %w", err)
} }
if changedIDs != nil { changedNodes = append(changedNodes, failover.old.Node.ID, failover.new.Node.ID)
changedNodes = append(changedNodes, changedIDs...)
} }
} }
} }
} }
log.Debug().Caller().Uint64("node.id", node.ID.Uint64()).Interface("changedNodes", changedNodes).Msgf("ROUTE DEBUG")
if len(changedNodes) != 0 { if len(changedNodes) != 0 {
return &types.StateUpdate{ return &types.StateUpdate{
Type: types.StatePeerChanged, Type: types.StatePeerChanged,
@ -490,7 +486,7 @@ func failoverRouteTx(
routes, err := getRoutesByPrefix(tx, netip.Prefix(r.Prefix)) routes, err := getRoutesByPrefix(tx, netip.Prefix(r.Prefix))
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("getting routes by prefix: %w", err)
} }
fo := failoverRoute(isConnected, r, routes) fo := failoverRoute(isConnected, r, routes)
@ -498,18 +494,9 @@ func failoverRouteTx(
return nil, nil return nil, nil
} }
err = tx.Save(fo.old).Error err = fo.save(tx)
if err != nil { if err != nil {
log.Error().Err(err).Msg("disabling old primary route") return nil, fmt.Errorf("saving failover route: %w", err)
return nil, err
}
err = tx.Save(fo.new).Error
if err != nil {
log.Error().Err(err).Msg("saving new primary route")
return nil, err
} }
log.Trace(). log.Trace().
@ -525,6 +512,20 @@ type failover struct {
new *types.Route new *types.Route
} }
func (f *failover) save(tx *gorm.DB) error {
err := tx.Save(f.old).Error
if err != nil {
return fmt.Errorf("saving old primary: %w", err)
}
err = tx.Save(f.new).Error
if err != nil {
return fmt.Errorf("saving new primary: %w", err)
}
return nil
}
func failoverRoute( func failoverRoute(
isConnected types.NodeConnectedMap, isConnected types.NodeConnectedMap,
routeToReplace *types.Route, routeToReplace *types.Route,
@ -603,13 +604,7 @@ func EnableAutoApprovedRoutes(
routes, err := GetNodeAdvertisedRoutes(tx, node) routes, err := GetNodeAdvertisedRoutes(tx, node)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
log.Error(). return fmt.Errorf("getting advertised routes for node(%s %d): %w", node.Hostname, node.ID, err)
Caller().
Err(err).
Str("node", node.Hostname).
Msg("Could not get advertised routes for node")
return err
} }
log.Trace().Interface("routes", routes).Msg("routes for autoapproving") log.Trace().Interface("routes", routes).Msg("routes for autoapproving")
@ -625,12 +620,7 @@ func EnableAutoApprovedRoutes(
netip.Prefix(advertisedRoute.Prefix), netip.Prefix(advertisedRoute.Prefix),
) )
if err != nil { if err != nil {
log.Err(err). return fmt.Errorf("failed to resolve autoApprovers for route(%d) for node(%s %d): %w", advertisedRoute.ID, node.Hostname, node.ID, err)
Str("advertisedRoute", advertisedRoute.String()).
Uint64("nodeId", node.ID.Uint64()).
Msg("Failed to resolve autoApprovers for advertised route")
return err
} }
log.Trace(). log.Trace().
@ -647,11 +637,7 @@ func EnableAutoApprovedRoutes(
// TODO(kradalby): figure out how to get this to depend on less stuff // TODO(kradalby): figure out how to get this to depend on less stuff
approvedIps, err := aclPolicy.ExpandAlias(types.Nodes{node}, approvedAlias) approvedIps, err := aclPolicy.ExpandAlias(types.Nodes{node}, approvedAlias)
if err != nil { if err != nil {
log.Err(err). return fmt.Errorf("expanding alias %q for autoApprovers: %w", approvedAlias, err)
Str("alias", approvedAlias).
Msg("Failed to expand alias when processing autoApprovers policy")
return err
} }
// approvedIPs should contain all of node's IPs if it matches the rule, so check for first // approvedIPs should contain all of node's IPs if it matches the rule, so check for first
@ -665,12 +651,7 @@ func EnableAutoApprovedRoutes(
for _, approvedRoute := range approvedRoutes { for _, approvedRoute := range approvedRoutes {
_, err := EnableRoute(tx, uint64(approvedRoute.ID)) _, err := EnableRoute(tx, uint64(approvedRoute.ID))
if err != nil { if err != nil {
log.Err(err). return fmt.Errorf("enabling approved route(%d): %w", approvedRoute.ID, err)
Str("approvedRoute", approvedRoute.String()).
Uint64("nodeId", node.ID.Uint64()).
Msg("Failed to enable approved route")
return err
} }
} }