mirror of
https://github.com/juanfont/headscale.git
synced 2024-12-01 19:23:05 +00:00
make parse destination string into a func
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
717abe89c1
commit
2675ff4b94
2 changed files with 97 additions and 29 deletions
|
@ -375,9 +375,39 @@ func (pol *ACLPolicy) getNetPortRangeFromDestination(
|
|||
machines types.Machines,
|
||||
needsWildcard bool,
|
||||
) ([]tailcfg.NetPortRange, error) {
|
||||
var tokens []string
|
||||
alias, port, err := parseDestination(dest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Trace().Str("destination", dest).Msg("generating policy destination")
|
||||
expanded, err := pol.ExpandAlias(
|
||||
machines,
|
||||
alias,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ports, err := expandPorts(port, needsWildcard)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dests := []tailcfg.NetPortRange{}
|
||||
for _, dest := range expanded.Prefixes() {
|
||||
for _, port := range *ports {
|
||||
pr := tailcfg.NetPortRange{
|
||||
IP: dest.String(),
|
||||
Ports: port,
|
||||
}
|
||||
dests = append(dests, pr)
|
||||
}
|
||||
}
|
||||
|
||||
return dests, nil
|
||||
}
|
||||
|
||||
func parseDestination(dest string) (string, string, error) {
|
||||
var tokens []string
|
||||
|
||||
// Check if there is a IPv4/6:Port combination, IPv6 has more than
|
||||
// three ":".
|
||||
|
@ -397,7 +427,7 @@ func (pol *ACLPolicy) getNetPortRangeFromDestination(
|
|||
if maybeIPv6, err := netip.ParseAddr(filteredMaybeIPv6Str); err != nil && !maybeIPv6.Is6() {
|
||||
log.Trace().Err(err).Msg("trying to parse as IPv6")
|
||||
|
||||
return nil, fmt.Errorf(
|
||||
return "", "", fmt.Errorf(
|
||||
"failed to parse destination, tokens %v: %w",
|
||||
tokens,
|
||||
ErrInvalidPortFormat,
|
||||
|
@ -407,8 +437,6 @@ func (pol *ACLPolicy) getNetPortRangeFromDestination(
|
|||
}
|
||||
}
|
||||
|
||||
log.Trace().Strs("tokens", tokens).Msg("generating policy destination")
|
||||
|
||||
var alias string
|
||||
// We can have here stuff like:
|
||||
// git-server:*
|
||||
|
@ -424,30 +452,7 @@ func (pol *ACLPolicy) getNetPortRangeFromDestination(
|
|||
alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1])
|
||||
}
|
||||
|
||||
expanded, err := pol.ExpandAlias(
|
||||
machines,
|
||||
alias,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ports, err := expandPorts(tokens[len(tokens)-1], needsWildcard)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dests := []tailcfg.NetPortRange{}
|
||||
for _, dest := range expanded.Prefixes() {
|
||||
for _, port := range *ports {
|
||||
pr := tailcfg.NetPortRange{
|
||||
IP: dest.String(),
|
||||
Ports: port,
|
||||
}
|
||||
dests = append(dests, pr)
|
||||
}
|
||||
}
|
||||
|
||||
return dests, nil
|
||||
return alias, tokens[len(tokens)-1], nil
|
||||
}
|
||||
|
||||
// parseProtocol reads the proto field of the ACL and generates a list of
|
||||
|
|
|
@ -2557,3 +2557,66 @@ func TestSSHRules(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDestination(t *testing.T) {
|
||||
tests := []struct {
|
||||
dest string
|
||||
wantAlias string
|
||||
wantPort string
|
||||
}{
|
||||
{
|
||||
dest: "git-server:*",
|
||||
wantAlias: "git-server",
|
||||
wantPort: "*",
|
||||
},
|
||||
{
|
||||
dest: "192.168.1.0/24:22",
|
||||
wantAlias: "192.168.1.0/24",
|
||||
wantPort: "22",
|
||||
},
|
||||
{
|
||||
dest: "192.168.1.1:22",
|
||||
wantAlias: "192.168.1.1",
|
||||
wantPort: "22",
|
||||
},
|
||||
{
|
||||
dest: "fd7a:115c:a1e0::2:22",
|
||||
wantAlias: "fd7a:115c:a1e0::2",
|
||||
wantPort: "22",
|
||||
},
|
||||
{
|
||||
dest: "fd7a:115c:a1e0::2/128:22",
|
||||
wantAlias: "fd7a:115c:a1e0::2/128",
|
||||
wantPort: "22",
|
||||
},
|
||||
{
|
||||
dest: "tag:montreal-webserver:80,443",
|
||||
wantAlias: "tag:montreal-webserver",
|
||||
wantPort: "80,443",
|
||||
},
|
||||
{
|
||||
dest: "tag:api-server:443",
|
||||
wantAlias: "tag:api-server",
|
||||
wantPort: "443",
|
||||
},
|
||||
{
|
||||
dest: "example-host-1:*",
|
||||
wantAlias: "example-host-1",
|
||||
wantPort: "*",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.dest, func(t *testing.T) {
|
||||
alias, port, _ := parseDestination(tt.dest)
|
||||
|
||||
if alias != tt.wantAlias {
|
||||
t.Errorf("unexpected alias: want(%s) != got(%s)", tt.wantAlias, alias)
|
||||
}
|
||||
|
||||
if port != tt.wantPort {
|
||||
t.Errorf("unexpected port: want(%s) != got(%s)", tt.wantPort, port)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue