mirror of
https://github.com/gogs/gogs.git
synced 2026-03-04 11:11:03 +01:00
db: use context for backup and restore (#7044)
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
@@ -94,7 +95,7 @@ func runBackup(c *cli.Context) error {
|
||||
|
||||
// Database
|
||||
dbDir := filepath.Join(rootDir, "db")
|
||||
if err = db.DumpDatabase(conn, dbDir, c.Bool("verbose")); err != nil {
|
||||
if err = db.DumpDatabase(context.Background(), conn, dbDir, c.Bool("verbose")); err != nil {
|
||||
log.Fatal("Failed to dump database: %v", err)
|
||||
}
|
||||
if err = z.AddDir(archiveRootDir+"/db", dbDir); err != nil {
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
@@ -114,7 +115,7 @@ func runRestore(c *cli.Context) error {
|
||||
|
||||
// Database
|
||||
dbDir := path.Join(archivePath, "db")
|
||||
if err = db.ImportDatabase(conn, dbDir, c.Bool("verbose")); err != nil {
|
||||
if err = db.ImportDatabase(context.Background(), conn, dbDir, c.Bool("verbose")); err != nil {
|
||||
log.Fatal("Failed to import database: %v", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package db
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
@@ -30,18 +31,24 @@ func getTableType(t interface{}) string {
|
||||
}
|
||||
|
||||
// DumpDatabase dumps all data from database to file system in JSON Lines format.
|
||||
func DumpDatabase(db *gorm.DB, dirPath string, verbose bool) error {
|
||||
func DumpDatabase(ctx context.Context, db *gorm.DB, dirPath string, verbose bool) error {
|
||||
err := os.MkdirAll(dirPath, os.ModePerm)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = dumpLegacyTables(dirPath, verbose)
|
||||
err = dumpLegacyTables(ctx, dirPath, verbose)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "dump legacy tables")
|
||||
}
|
||||
|
||||
for _, table := range Tables {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
tableName := getTableType(table)
|
||||
if verbose {
|
||||
log.Trace("Dumping table %q...", tableName)
|
||||
@@ -55,7 +62,7 @@ func DumpDatabase(db *gorm.DB, dirPath string, verbose bool) error {
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
return dumpTable(db, table, f)
|
||||
return dumpTable(ctx, db, table, f)
|
||||
}()
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "dump table %q", tableName)
|
||||
@@ -65,11 +72,13 @@ func DumpDatabase(db *gorm.DB, dirPath string, verbose bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func dumpTable(db *gorm.DB, table interface{}, w io.Writer) error {
|
||||
query := db.Model(table).Order("id ASC")
|
||||
func dumpTable(ctx context.Context, db *gorm.DB, table interface{}, w io.Writer) error {
|
||||
query := db.WithContext(ctx).Model(table)
|
||||
switch table.(type) {
|
||||
case *LFSObject:
|
||||
query = db.Model(table).Order("repo_id, oid ASC")
|
||||
query = query.Order("repo_id, oid ASC")
|
||||
default:
|
||||
query = query.Order("id ASC")
|
||||
}
|
||||
|
||||
rows, err := query.Rows()
|
||||
@@ -98,10 +107,16 @@ func dumpTable(db *gorm.DB, table interface{}, w io.Writer) error {
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
func dumpLegacyTables(dirPath string, verbose bool) error {
|
||||
func dumpLegacyTables(ctx context.Context, dirPath string, verbose bool) error {
|
||||
// Purposely create a local variable to not modify global variable
|
||||
legacyTables := append(legacyTables, new(Version))
|
||||
for _, table := range legacyTables {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
tableName := getTableType(table)
|
||||
if verbose {
|
||||
log.Trace("Dumping table %q...", tableName)
|
||||
@@ -113,7 +128,7 @@ func dumpLegacyTables(dirPath string, verbose bool) error {
|
||||
return fmt.Errorf("create JSON file: %v", err)
|
||||
}
|
||||
|
||||
if err = x.Asc("id").Iterate(table, func(idx int, bean interface{}) (err error) {
|
||||
if err = x.Context(ctx).Asc("id").Iterate(table, func(idx int, bean interface{}) (err error) {
|
||||
return jsoniter.NewEncoder(f).Encode(bean)
|
||||
}); err != nil {
|
||||
_ = f.Close()
|
||||
@@ -125,13 +140,19 @@ func dumpLegacyTables(dirPath string, verbose bool) error {
|
||||
}
|
||||
|
||||
// ImportDatabase imports data from backup archive in JSON Lines format.
|
||||
func ImportDatabase(db *gorm.DB, dirPath string, verbose bool) error {
|
||||
err := importLegacyTables(dirPath, verbose)
|
||||
func ImportDatabase(ctx context.Context, db *gorm.DB, dirPath string, verbose bool) error {
|
||||
err := importLegacyTables(ctx, dirPath, verbose)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "import legacy tables")
|
||||
}
|
||||
|
||||
for _, table := range Tables {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
tableName := strings.TrimPrefix(fmt.Sprintf("%T", table), "*db.")
|
||||
err := func() error {
|
||||
tableFile := filepath.Join(dirPath, tableName+".json")
|
||||
@@ -150,7 +171,7 @@ func ImportDatabase(db *gorm.DB, dirPath string, verbose bool) error {
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
return importTable(db, table, f)
|
||||
return importTable(ctx, db, table, f)
|
||||
}()
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "import table %q", tableName)
|
||||
@@ -160,13 +181,13 @@ func ImportDatabase(db *gorm.DB, dirPath string, verbose bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func importTable(db *gorm.DB, table interface{}, r io.Reader) error {
|
||||
err := db.Migrator().DropTable(table)
|
||||
func importTable(ctx context.Context, db *gorm.DB, table interface{}, r io.Reader) error {
|
||||
err := db.WithContext(ctx).Migrator().DropTable(table)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "drop table")
|
||||
}
|
||||
|
||||
err = db.Migrator().AutoMigrate(table)
|
||||
err = db.WithContext(ctx).Migrator().AutoMigrate(table)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "auto migrate")
|
||||
}
|
||||
@@ -191,7 +212,7 @@ func importTable(db *gorm.DB, table interface{}, r io.Reader) error {
|
||||
return errors.Wrap(err, "unmarshal JSON to struct")
|
||||
}
|
||||
|
||||
err = db.Create(elem).Error
|
||||
err = db.WithContext(ctx).Create(elem).Error
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "create row")
|
||||
}
|
||||
@@ -200,14 +221,14 @@ func importTable(db *gorm.DB, table interface{}, r io.Reader) error {
|
||||
// PostgreSQL needs manually reset table sequence for auto increment keys
|
||||
if conf.UsePostgreSQL && !skipResetIDSeq[rawTableName] {
|
||||
seqName := rawTableName + "_id_seq"
|
||||
if _, err = x.Exec(fmt.Sprintf(`SELECT setval('%s', COALESCE((SELECT MAX(id)+1 FROM "%s"), 1), false);`, seqName, rawTableName)); err != nil {
|
||||
if _, err = x.Context(ctx).Exec(fmt.Sprintf(`SELECT setval('%s', COALESCE((SELECT MAX(id)+1 FROM "%s"), 1), false);`, seqName, rawTableName)); err != nil {
|
||||
return errors.Wrapf(err, "reset table %q.%q", rawTableName, seqName)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func importLegacyTables(dirPath string, verbose bool) error {
|
||||
func importLegacyTables(ctx context.Context, dirPath string, verbose bool) error {
|
||||
snakeMapper := core.SnakeMapper{}
|
||||
|
||||
skipInsertProcessors := map[string]bool{
|
||||
@@ -218,6 +239,12 @@ func importLegacyTables(dirPath string, verbose bool) error {
|
||||
// Purposely create a local variable to not modify global variable
|
||||
legacyTables := append(legacyTables, new(Version))
|
||||
for _, table := range legacyTables {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
tableName := strings.TrimPrefix(fmt.Sprintf("%T", table), "*db.")
|
||||
tableFile := filepath.Join(dirPath, tableName+".json")
|
||||
if !osutil.IsFile(tableFile) {
|
||||
|
||||
@@ -6,12 +6,14 @@ package db
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"gogs.io/gogs/internal/auth"
|
||||
@@ -22,7 +24,7 @@ import (
|
||||
"gogs.io/gogs/internal/testutil"
|
||||
)
|
||||
|
||||
func Test_dumpAndImport(t *testing.T) {
|
||||
func TestDumpAndImport(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip()
|
||||
}
|
||||
@@ -43,8 +45,6 @@ func Test_dumpAndImport(t *testing.T) {
|
||||
}
|
||||
|
||||
func setupDBToDump(t *testing.T, db *gorm.DB) {
|
||||
t.Helper()
|
||||
|
||||
vals := []interface{}{
|
||||
&Access{
|
||||
ID: 1,
|
||||
@@ -126,31 +126,29 @@ func setupDBToDump(t *testing.T, db *gorm.DB) {
|
||||
}
|
||||
for _, val := range vals {
|
||||
err := db.Create(val).Error
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func dumpTables(t *testing.T, db *gorm.DB) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
for _, table := range Tables {
|
||||
tableName := getTableType(table)
|
||||
|
||||
var buf bytes.Buffer
|
||||
err := dumpTable(db, table, &buf)
|
||||
err := dumpTable(ctx, db, table, &buf)
|
||||
if err != nil {
|
||||
t.Fatalf("%s: %v", tableName, err)
|
||||
}
|
||||
|
||||
golden := filepath.Join("testdata", "backup", tableName+".golden.json")
|
||||
testutil.AssertGolden(t, golden, testutil.Update("Test_dumpAndImport"), buf.String())
|
||||
testutil.AssertGolden(t, golden, testutil.Update("TestDumpAndImport"), buf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func importTables(t *testing.T, db *gorm.DB) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
for _, table := range Tables {
|
||||
tableName := getTableType(table)
|
||||
@@ -163,7 +161,7 @@ func importTables(t *testing.T, db *gorm.DB) {
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
return importTable(db, table, f)
|
||||
return importTable(ctx, db, table, f)
|
||||
}()
|
||||
if err != nil {
|
||||
t.Fatalf("%s: %v", tableName, err)
|
||||
|
||||
Reference in New Issue
Block a user