diff --git a/internal/database/two_factor.go b/internal/database/two_factor.go index 470652d29..90eef9b92 100644 --- a/internal/database/two_factor.go +++ b/internal/database/two_factor.go @@ -106,21 +106,3 @@ func IsTwoFactorRecoveryCodeNotFound(err error) bool { func (err ErrTwoFactorRecoveryCodeNotFound) Error() string { return fmt.Sprintf("two-factor recovery code does not found [code: %s]", err.Code) } - -// UseRecoveryCode validates recovery code of given user and marks it is used if valid. -func UseRecoveryCode(_ int64, code string) error { - recoveryCode := new(TwoFactorRecoveryCode) - has, err := x.Where("code = ?", code).And("is_used = ?", false).Get(recoveryCode) - if err != nil { - return errors.Newf("get unused code: %v", err) - } else if !has { - return ErrTwoFactorRecoveryCodeNotFound{Code: code} - } - - recoveryCode.IsUsed = true - if _, err = x.Id(recoveryCode.ID).Cols("is_used").Update(recoveryCode); err != nil { - return errors.Newf("mark code as used: %v", err) - } - - return nil -} diff --git a/internal/database/two_factors.go b/internal/database/two_factors.go index f4a140bfe..9469dc4d0 100644 --- a/internal/database/two_factors.go +++ b/internal/database/two_factors.go @@ -110,6 +110,28 @@ func (s *TwoFactorsStore) IsEnabled(ctx context.Context, userID int64) bool { return count > 0 } +// UseRecoveryCode validates a recovery code of given user and marks it as used +// if valid. It returns ErrTwoFactorRecoveryCodeNotFound if the code is invalid +// or already used. +func (s *TwoFactorsStore) UseRecoveryCode(ctx context.Context, userID int64, code string) error { + return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + var recoveryCode TwoFactorRecoveryCode + err := tx.Where("user_id = ? AND code = ? AND is_used = ?", userID, code, false).First(&recoveryCode).Error + if err != nil { + if err == gorm.ErrRecordNotFound { + return ErrTwoFactorRecoveryCodeNotFound{Code: code} + } + return errors.Wrap(err, "get unused recovery code") + } + + err = tx.Model(&recoveryCode).Update("is_used", true).Error + if err != nil { + return errors.Wrap(err, "mark recovery code as used") + } + return nil + }) +} + // generateRecoveryCodes generates N number of recovery codes for 2FA. func generateRecoveryCodes(userID int64, n int) ([]*TwoFactorRecoveryCode, error) { recoveryCodes := make([]*TwoFactorRecoveryCode, n) diff --git a/internal/database/two_factors_test.go b/internal/database/two_factors_test.go index 925e1d0e8..88db77e14 100644 --- a/internal/database/two_factors_test.go +++ b/internal/database/two_factors_test.go @@ -74,6 +74,7 @@ func TestTwoFactors(t *testing.T) { {"Create", twoFactorsCreate}, {"GetByUserID", twoFactorsGetByUserID}, {"IsEnabled", twoFactorsIsEnabled}, + {"UseRecoveryCode", twoFactorsUseRecoveryCode}, } { t.Run(tc.name, func(t *testing.T) { t.Cleanup(func() { @@ -128,3 +129,55 @@ func twoFactorsIsEnabled(t *testing.T, ctx context.Context, s *TwoFactorsStore) assert.True(t, s.IsEnabled(ctx, 1)) assert.False(t, s.IsEnabled(ctx, 2)) } + +func twoFactorsUseRecoveryCode(t *testing.T, ctx context.Context, s *TwoFactorsStore) { + // Create 2FA tokens for two users + err := s.Create(ctx, 1, "secure-key", "secure-secret") + require.NoError(t, err) + err = s.Create(ctx, 2, "secure-key", "secure-secret") + require.NoError(t, err) + + // Get recovery codes for both users + var user1Codes []TwoFactorRecoveryCode + err = s.db.Where("user_id = ?", 1).Find(&user1Codes).Error + require.NoError(t, err) + require.NotEmpty(t, user1Codes) + + var user2Codes []TwoFactorRecoveryCode + err = s.db.Where("user_id = ?", 2).Find(&user2Codes).Error + require.NoError(t, err) + require.NotEmpty(t, user2Codes) + + // User 1 should be able to use their own recovery code + err = s.UseRecoveryCode(ctx, 1, user1Codes[0].Code) + require.NoError(t, err) + + // Verify the code is now marked as used + var usedCode TwoFactorRecoveryCode + err = s.db.Where("id = ?", user1Codes[0].ID).First(&usedCode).Error + require.NoError(t, err) + assert.True(t, usedCode.IsUsed) + + // User 1 should NOT be able to use user 2's recovery code + // This is the key security test - recovery codes must be scoped by user + err = s.UseRecoveryCode(ctx, 1, user2Codes[0].Code) + assert.True(t, IsTwoFactorRecoveryCodeNotFound(err), "expected recovery code not found error when using another user's code") + + // User 2's code should still be unused + var user2Code TwoFactorRecoveryCode + err = s.db.Where("id = ?", user2Codes[0].ID).First(&user2Code).Error + require.NoError(t, err) + assert.False(t, user2Code.IsUsed, "user 2's recovery code should not be marked as used") + + // User 2 should be able to use their own code + err = s.UseRecoveryCode(ctx, 2, user2Codes[0].Code) + require.NoError(t, err) + + // Using an already-used code should fail + err = s.UseRecoveryCode(ctx, 1, user1Codes[0].Code) + assert.True(t, IsTwoFactorRecoveryCodeNotFound(err), "expected error when reusing a recovery code") + + // Using a non-existent code should fail + err = s.UseRecoveryCode(ctx, 1, "invalid-code") + assert.True(t, IsTwoFactorRecoveryCodeNotFound(err), "expected error for invalid recovery code") +} diff --git a/internal/route/user/auth.go b/internal/route/user/auth.go index 117ab23c0..b1f0b24e6 100644 --- a/internal/route/user/auth.go +++ b/internal/route/user/auth.go @@ -263,7 +263,7 @@ func LoginTwoFactorRecoveryCodePost(c *context.Context) { return } - if err := database.UseRecoveryCode(userID, c.Query("recovery_code")); err != nil { + if err := database.Handle.TwoFactors().UseRecoveryCode(c.Req.Context(), userID, c.Query("recovery_code")); err != nil { if database.IsTwoFactorRecoveryCodeNotFound(err) { c.Flash.Error(c.Tr("auth.login_two_factor_invalid_recovery_code")) c.RedirectSubpath("/user/login/two_factor_recovery_code")