Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/stackit_auth_login.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ stackit auth login [flags]
### Options

```
-h, --help Help for "stackit auth login"
-h, --help Help for "stackit auth login"
--port int The port on which the callback server will listen to. By default, it tries to bind a port between 8000 and 8020.
When a value is specified, it will only try to use the specified port. Valid values are within the range of 8000 to 8020.
```

### Options inherited from parent commands
Expand Down
48 changes: 44 additions & 4 deletions internal/cmd/auth/login/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,24 @@ package login
import (
"fmt"

"github.com/stackitcloud/stackit-cli/internal/pkg/types"

"github.com/stackitcloud/stackit-cli/internal/pkg/args"
"github.com/stackitcloud/stackit-cli/internal/pkg/auth"
"github.com/stackitcloud/stackit-cli/internal/pkg/examples"
"github.com/stackitcloud/stackit-cli/internal/pkg/flags"
"github.com/stackitcloud/stackit-cli/internal/pkg/print"
"github.com/stackitcloud/stackit-cli/internal/pkg/types"

"github.com/spf13/cobra"
)

const (
portFlag = "port"
)

type inputModel struct {
Port *int
}

func NewCmd(params *types.CmdParams) *cobra.Command {
cmd := &cobra.Command{
Use: "login",
Expand All @@ -25,8 +34,16 @@ func NewCmd(params *types.CmdParams) *cobra.Command {
`Login to the STACKIT CLI. This command will open a browser window where you can login to your STACKIT account`,
"$ stackit auth login"),
),
RunE: func(_ *cobra.Command, _ []string) error {
err := auth.AuthorizeUser(params.Printer, false)
RunE: func(cmd *cobra.Command, args []string) error {
model, err := parseInput(params.Printer, cmd, args)
if err != nil {
return err
}

err = auth.AuthorizeUser(params.Printer, auth.UserAuthConfig{
IsReauthentication: false,
Port: model.Port,
})
if err != nil {
return fmt.Errorf("authorization failed: %w", err)
}
Expand All @@ -36,5 +53,28 @@ func NewCmd(params *types.CmdParams) *cobra.Command {
return nil
},
}
configureFlags(cmd)
return cmd
}

func configureFlags(cmd *cobra.Command) {
cmd.Flags().Int(portFlag, 0,
"The port on which the callback server will listen to. By default, it tries to bind a port between 8000 and 8020.\n"+
"When a value is specified, it will only try to use the specified port. Valid values are within the range of 8000 to 8020.",
)
}

func parseInput(p *print.Printer, cmd *cobra.Command, _ []string) (*inputModel, error) {
port := flags.FlagToIntPointer(p, cmd, portFlag)
// For the CLI client only callback URLs with localhost:[8000-8020] are valid. Additional callbacks must be enabled in the backend.
if port != nil && (*port < 8000 || *port > 8020) {
return nil, fmt.Errorf("port must be between 8000 and 8020")
}

model := inputModel{
Port: port,
}

p.DebugInputModel(model)
return &model, nil
}
93 changes: 93 additions & 0 deletions internal/cmd/auth/login/login_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package login

import (
"testing"

"github.com/stackitcloud/stackit-cli/internal/pkg/testutils"
"github.com/stackitcloud/stackit-cli/internal/pkg/utils"
)

func fixtureFlagValues(mods ...func(flagValues map[string]string)) map[string]string {
flagValues := map[string]string{
portFlag: "8010",
}
for _, mod := range mods {
mod(flagValues)
}
return flagValues
}

func fixtureInputModel(mods ...func(model *inputModel)) *inputModel {
model := &inputModel{
Port: utils.Ptr(8010),
}
for _, mod := range mods {
mod(model)
}
return model
}

func TestParseInput(t *testing.T) {
tests := []struct {
description string
flagValues map[string]string
argValues []string
isValid bool
expectedModel *inputModel
}{
{
description: "base",
flagValues: fixtureFlagValues(),
isValid: true,
expectedModel: fixtureInputModel(),
},
{
description: "no values",
flagValues: map[string]string{},
isValid: true,
expectedModel: &inputModel{
Port: nil,
},
},
{
description: "lower limit",
flagValues: map[string]string{
portFlag: "8000",
},
isValid: true,
expectedModel: &inputModel{
Port: utils.Ptr(8000),
},
},
{
description: "below lower limit is not valid ",
flagValues: map[string]string{
portFlag: "7999",
},
isValid: false,
},
{
description: "upper limit",
flagValues: map[string]string{
portFlag: "8020",
},
isValid: true,
expectedModel: &inputModel{
Port: utils.Ptr(8020),
},
},
{
description: "above upper limit is not valid ",
flagValues: map[string]string{
portFlag: "8021",
},
isValid: false,
},
}

for _, tt := range tests {
t.Run(tt.description, func(t *testing.T) {
testutils.TestParseInput(t, NewCmd, parseInput, tt.expectedModel, tt.argValues, tt.flagValues, tt.isValid)
})
}
}
7 changes: 5 additions & 2 deletions internal/pkg/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type tokenClaims struct {
//
// If the user was logged in and the user session expired, reauthorizeUserRoutine is called to reauthenticate the user again.
// If the environment variable STACKIT_ACCESS_TOKEN is set this token is used instead.
func AuthenticationConfig(p *print.Printer, reauthorizeUserRoutine func(p *print.Printer, _ bool) error) (authCfgOption sdkConfig.ConfigurationOption, err error) {
func AuthenticationConfig(p *print.Printer, reauthorizeUserRoutine func(p *print.Printer, _ UserAuthConfig) error) (authCfgOption sdkConfig.ConfigurationOption, err error) {
// Get access token from env and use this if present
accessToken := os.Getenv(envAccessTokenName)
if accessToken != "" {
Expand Down Expand Up @@ -70,7 +70,10 @@ func AuthenticationConfig(p *print.Printer, reauthorizeUserRoutine func(p *print
case AUTH_FLOW_USER_TOKEN:
p.Debug(print.DebugLevel, "authenticating using user token")
if userSessionExpired {
err = reauthorizeUserRoutine(p, true)
err = reauthorizeUserRoutine(p, UserAuthConfig{
IsReauthentication: true,
Port: nil,
})
if err != nil {
return nil, fmt.Errorf("user login: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/pkg/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ func TestAuthenticationConfig(t *testing.T) {
}

reauthorizeUserCalled := false
reauthenticateUser := func(_ *print.Printer, _ bool) error {
reauthenticateUser := func(_ *print.Printer, _ UserAuthConfig) error {
if reauthorizeUserCalled {
t.Errorf("user reauthorized more than once")
}
Expand Down
23 changes: 18 additions & 5 deletions internal/pkg/auth/user_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,19 @@ type InputValues struct {
Logo string
}

type UserAuthConfig struct {
// IsReauthentication defines if an expired user session should be renewed
IsReauthentication bool
// Port defines which port should be used for the UserAuthFlow callback
Port *int
}

type apiClient interface {
Do(req *http.Request) (*http.Response, error)
}

// AuthorizeUser implements the PKCE OAuth2 flow.
func AuthorizeUser(p *print.Printer, isReauthentication bool) error {
func AuthorizeUser(p *print.Printer, authConfig UserAuthConfig) error {
idpWellKnownConfig, err := retrieveIDPWellKnownConfig(p)
if err != nil {
return err
Expand All @@ -68,7 +75,7 @@ func AuthorizeUser(p *print.Printer, isReauthentication bool) error {
}
}

if isReauthentication {
if authConfig.IsReauthentication {
err := p.PromptForEnter("Your session has expired, press Enter to login again...")
if err != nil {
return err
Expand All @@ -79,8 +86,14 @@ func AuthorizeUser(p *print.Printer, isReauthentication bool) error {
var listener net.Listener
var listenerErr error
var port int
for i := range configuredPortRange {
port = defaultPort + i
startingPort := defaultPort
portRange := configuredPortRange
if authConfig.Port != nil {
startingPort = *authConfig.Port
portRange = 1
}
for i := range portRange {
port = startingPort + i
portString := fmt.Sprintf(":%s", strconv.Itoa(port))
p.Debug(print.DebugLevel, "trying to bind port %d for login redirect", port)
listener, listenerErr = net.Listen("tcp", portString)
Expand All @@ -92,7 +105,7 @@ func AuthorizeUser(p *print.Printer, isReauthentication bool) error {
p.Debug(print.DebugLevel, "unable to bind port %d for login redirect: %s", port, listenerErr)
}
if listenerErr != nil {
return fmt.Errorf("unable to bind port for login redirect, tried from port %d to %d: %w", defaultPort, port, err)
return fmt.Errorf("unable to bind port for login redirect, tried from port %d to %d: %w", defaultPort, port, listenerErr)
}

conf := &oauth2.Config{
Expand Down
9 changes: 7 additions & 2 deletions internal/pkg/auth/user_token_flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (

type userTokenFlow struct {
printer *print.Printer
reauthorizeUserRoutine func(p *print.Printer, isReauthentication bool) error // Called if the user needs to login again
reauthorizeUserRoutine func(p *print.Printer, isReauthentication UserAuthConfig) error // Called if the user needs to login again
client *http.Client
authFlow AuthFlow
accessToken string
Expand Down Expand Up @@ -95,7 +95,12 @@ func loadVarsFromStorage(utf *userTokenFlow) error {
}

func reauthenticateUser(utf *userTokenFlow) error {
err := utf.reauthorizeUserRoutine(utf.printer, true)
err := utf.reauthorizeUserRoutine(
utf.printer,
UserAuthConfig{
IsReauthentication: true,
},
)
if err != nil {
return fmt.Errorf("authenticate user: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/pkg/auth/user_token_flow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ func TestRoundTrip(t *testing.T) {
authorizeUserCalled: &authorizeUserCalled,
tokensRefreshed: &tokensRefreshed,
}
authorizeUserRoutine := func(_ *print.Printer, _ bool) error {
authorizeUserRoutine := func(_ *print.Printer, _ UserAuthConfig) error {
return reauthorizeUser(authorizeUserContext)
}

Expand Down
14 changes: 14 additions & 0 deletions internal/pkg/flags/flag_to_value.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,20 @@ func FlagToStringToStringPointer(p *print.Printer, cmd *cobra.Command, flag stri
return nil
}

// Returns a pointer to the flag's value.
// Returns nil if the flag is not set, if its value can not be converted to int, or if the flag does not exist.
func FlagToIntPointer(p *print.Printer, cmd *cobra.Command, flag string) *int {
value, err := cmd.Flags().GetInt(flag)
if err != nil {
p.Debug(print.ErrorLevel, "convert flag to Uint64 pointer: %v", err)
return nil
}
if cmd.Flag(flag).Changed {
return &value
}
return nil
}

// Returns a pointer to the flag's value.
// Returns nil if the flag is not set, if its value can not be converted to int64, or if the flag does not exist.
func FlagToInt64Pointer(p *print.Printer, cmd *cobra.Command, flag string) *int64 {
Expand Down
Loading