sql.go 3.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
T
tanggen 已提交
14 15 16
//
//
// Modifications copyright (C) 2020 Finogeeks Co., Ltd
17 18 19 20

package common

import (
T
tanggen 已提交
21
	"context"
22
	"database/sql"
T
tanggen 已提交
23 24 25 26 27
	"fmt"
	"time"

	log "github.com/finogeeks/ligase/skunkworks/log"
	"github.com/lib/pq"
28 29
)

30 31 32 33 34 35 36 37 38 39 40
// A Transaction is something that can be committed or rolledback.
type Transaction interface {
	// Commit the transaction
	Commit() error
	// Rollback the transaction.
	Rollback() error
}

// EndTransaction ends a transaction.
// If the transaction succeeded then it is committed, otherwise it is rolledback.
func EndTransaction(txn Transaction, succeeded *bool) {
T
tanggen 已提交
41
	//last := time.Now()
42
	if *succeeded {
E
Erik Johnston 已提交
43
		txn.Commit() // nolint: errcheck
44
	} else {
E
Erik Johnston 已提交
45
		txn.Rollback() // nolint: errcheck
46
	}
T
tanggen 已提交
47
	//fmt.Printf("------------------------EndTransaction use %v", time.Now().Sub(last))
48 49
}

50 51 52 53 54 55 56 57
// WithTransaction runs a block of code passing in an SQL transaction
// If the code returns an error or panics then the transactions is rolledback
// Otherwise the transaction is committed.
func WithTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) {
	txn, err := db.Begin()
	if err != nil {
		return
	}
58 59 60
	succeeded := false
	defer EndTransaction(txn, &succeeded)

61
	err = fn(txn)
62 63 64 65 66
	if err != nil {
		return
	}

	succeeded = true
67 68
	return
}
69 70 71 72 73 74 75 76 77 78 79

// TxStmt wraps an SQL stmt inside an optional transaction.
// If the transaction is nil then it returns the original statement that will
// run outside of a transaction.
// Otherwise returns a copy of the statement that will run inside the transaction.
func TxStmt(transaction *sql.Tx, statement *sql.Stmt) *sql.Stmt {
	if transaction != nil {
		statement = transaction.Stmt(statement)
	}
	return statement
}
T
tanggen 已提交
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125

// IsUniqueConstraintViolationErr returns true if the error is a postgresql unique_violation error
func IsUniqueConstraintViolationErr(err error) bool {
	pqErr, ok := err.(*pq.Error)
	return ok && pqErr.Code == "23505"
}

// Hooks satisfies the sqlhook.Hooks interface
type Hooks struct{}

// Before hook will print the query with it's args and return the context with the timestamp
func (h *Hooks) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
	log.Infof("> %s %q", query, args)
	return context.WithValue(ctx, "begin", time.Now()), nil
}

// After hook will get the timestamp registered on the Before hook and print the elapsed time
func (h *Hooks) Error(ctx context.Context, errstr error, query string, args ...interface{}) {
	log.Fatalf("> %s %v %q", query, errstr, args)
}

// After hook will get the timestamp registered on the Before hook and print the elapsed time
func (h *Hooks) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
	begin := ctx.Value("begin").(time.Time)
	log.Infof(". took: %v", time.Since(begin))
	return ctx, nil
}

func CreateDatabase(driver, addr, name string) error {
	db, err := sql.Open(driver, addr)
	if err != nil {
		return err
	}

	dbName := ""
	sqlStr := fmt.Sprintf("SELECT datname FROM pg_catalog.pg_database WHERE datname = '%s'", name)
	db.QueryRow(sqlStr).Scan(&dbName)
	if err == sql.ErrNoRows || dbName == "" {
		sqlStr := fmt.Sprintf("CREATE DATABASE %s", name)
		db.Exec(sqlStr)
		time.Sleep(time.Second * 10) //wait master&slave sync
	}
	db.Close()

	return nil
}