From d568e048315dc9729c8518d8085cab7dbbfac80f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E1=B4=8A=E1=B4=8F=E1=B4=87=20=E1=B4=84=CA=9C=E1=B4=87?= =?UTF-8?q?=C9=B4?= Date: Thu, 22 Jan 2026 22:30:27 -0500 Subject: [PATCH] two_factor: verify recovery code ownership upon using (#8100) --- internal/db/two_factor.go | 18 ----------- internal/db/two_factors.go | 23 ++++++++++++++ internal/db/two_factors_test.go | 53 +++++++++++++++++++++++++++++++++ internal/route/user/auth.go | 2 +- 4 files changed, 77 insertions(+), 19 deletions(-) diff --git a/internal/db/two_factor.go b/internal/db/two_factor.go index 177f38f38..3cf41ab12 100644 --- a/internal/db/two_factor.go +++ b/internal/db/two_factor.go @@ -109,21 +109,3 @@ func IsTwoFactorRecoveryCodeNotFound(err error) bool { func (err ErrTwoFactorRecoveryCodeNotFound) Error() string { return fmt.Sprintf("two-factor recovery code does not found [code: %s]", err.Code) } - -// UseRecoveryCode validates recovery code of given user and marks it is used if valid. -func UseRecoveryCode(_ int64, code string) error { - recoveryCode := new(TwoFactorRecoveryCode) - has, err := x.Where("code = ?", code).And("is_used = ?", false).Get(recoveryCode) - if err != nil { - return fmt.Errorf("get unused code: %v", err) - } else if !has { - return ErrTwoFactorRecoveryCodeNotFound{Code: code} - } - - recoveryCode.IsUsed = true - if _, err = x.Id(recoveryCode.ID).Cols("is_used").Update(recoveryCode); err != nil { - return fmt.Errorf("mark code as used: %v", err) - } - - return nil -} diff --git a/internal/db/two_factors.go b/internal/db/two_factors.go index 741a2ff79..8648a6095 100644 --- a/internal/db/two_factors.go +++ b/internal/db/two_factors.go @@ -32,6 +32,7 @@ type TwoFactorsStore interface { GetByUserID(ctx context.Context, userID int64) (*TwoFactor, error) // IsEnabled returns true if the user has enabled 2FA. IsEnabled(ctx context.Context, userID int64) bool + UseRecoveryCode(ctx context.Context, userID int64, code string) error } var TwoFactors TwoFactorsStore @@ -121,6 +122,28 @@ func (db *twoFactors) IsEnabled(ctx context.Context, userID int64) bool { return count > 0 } +// UseRecoveryCode validates a recovery code of given user and marks it as used +// if valid. It returns ErrTwoFactorRecoveryCodeNotFound if the code is invalid +// or already used. +func (db *twoFactors) UseRecoveryCode(ctx context.Context, userID int64, code string) error { + return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + var recoveryCode TwoFactorRecoveryCode + err := tx.Where("user_id = ? AND code = ? AND is_used = ?", userID, code, false).First(&recoveryCode).Error + if err != nil { + if err == gorm.ErrRecordNotFound { + return ErrTwoFactorRecoveryCodeNotFound{Code: code} + } + return errors.Wrap(err, "get unused recovery code") + } + + err = tx.Model(&recoveryCode).Update("is_used", true).Error + if err != nil { + return errors.Wrap(err, "mark recovery code as used") + } + return nil + }) +} + // generateRecoveryCodes generates N number of recovery codes for 2FA. func generateRecoveryCodes(userID int64, n int) ([]*TwoFactorRecoveryCode, error) { recoveryCodes := make([]*TwoFactorRecoveryCode, n) diff --git a/internal/db/two_factors_test.go b/internal/db/two_factors_test.go index 64e253bd4..0471c9e6c 100644 --- a/internal/db/two_factors_test.go +++ b/internal/db/two_factors_test.go @@ -79,6 +79,7 @@ func TestTwoFactors(t *testing.T) { {"Create", twoFactorsCreate}, {"GetByUserID", twoFactorsGetByUserID}, {"IsEnabled", twoFactorsIsEnabled}, + {"UseRecoveryCode", twoFactorsUseRecoveryCode}, } { t.Run(tc.name, func(t *testing.T) { t.Cleanup(func() { @@ -139,3 +140,55 @@ func twoFactorsIsEnabled(t *testing.T, db *twoFactors) { assert.True(t, db.IsEnabled(ctx, 1)) assert.False(t, db.IsEnabled(ctx, 2)) } + +func twoFactorsUseRecoveryCode(t *testing.T, ctx context.Context, s *TwoFactorsStore) { + // Create 2FA tokens for two users + err := s.Create(ctx, 1, "secure-key", "secure-secret") + require.NoError(t, err) + err = s.Create(ctx, 2, "secure-key", "secure-secret") + require.NoError(t, err) + + // Get recovery codes for both users + var user1Codes []TwoFactorRecoveryCode + err = s.db.Where("user_id = ?", 1).Find(&user1Codes).Error + require.NoError(t, err) + require.NotEmpty(t, user1Codes) + + var user2Codes []TwoFactorRecoveryCode + err = s.db.Where("user_id = ?", 2).Find(&user2Codes).Error + require.NoError(t, err) + require.NotEmpty(t, user2Codes) + + // User 1 should be able to use their own recovery code + err = s.UseRecoveryCode(ctx, 1, user1Codes[0].Code) + require.NoError(t, err) + + // Verify the code is now marked as used + var usedCode TwoFactorRecoveryCode + err = s.db.Where("id = ?", user1Codes[0].ID).First(&usedCode).Error + require.NoError(t, err) + assert.True(t, usedCode.IsUsed) + + // User 1 should NOT be able to use user 2's recovery code + // This is the key security test - recovery codes must be scoped by user + err = s.UseRecoveryCode(ctx, 1, user2Codes[0].Code) + assert.True(t, IsTwoFactorRecoveryCodeNotFound(err), "expected recovery code not found error when using another user's code") + + // User 2's code should still be unused + var user2Code TwoFactorRecoveryCode + err = s.db.Where("id = ?", user2Codes[0].ID).First(&user2Code).Error + require.NoError(t, err) + assert.False(t, user2Code.IsUsed, "user 2's recovery code should not be marked as used") + + // User 2 should be able to use their own code + err = s.UseRecoveryCode(ctx, 2, user2Codes[0].Code) + require.NoError(t, err) + + // Using an already-used code should fail + err = s.UseRecoveryCode(ctx, 1, user1Codes[0].Code) + assert.True(t, IsTwoFactorRecoveryCodeNotFound(err), "expected error when reusing a recovery code") + + // Using a non-existent code should fail + err = s.UseRecoveryCode(ctx, 1, "invalid-code") + assert.True(t, IsTwoFactorRecoveryCodeNotFound(err), "expected error for invalid recovery code") +} diff --git a/internal/route/user/auth.go b/internal/route/user/auth.go index ff0febb9c..4ba8da4c9 100644 --- a/internal/route/user/auth.go +++ b/internal/route/user/auth.go @@ -267,7 +267,7 @@ func LoginTwoFactorRecoveryCodePost(c *context.Context) { return } - if err := db.UseRecoveryCode(userID, c.Query("recovery_code")); err != nil { + if err := db.TwoFactors.UseRecoveryCode(c.Req.Context(), userID, c.Query("recovery_code")); err != nil { if db.IsTwoFactorRecoveryCodeNotFound(err) { c.Flash.Error(c.Tr("auth.login_two_factor_invalid_recovery_code")) c.RedirectSubpath("/user/login/two_factor_recovery_code")