privilege.go 3.0 KB
Newer Older
martianzhang's avatar
martianzhang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
/*
 * Copyright 2018 Xiaomi, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package database

import (
	"errors"
	"strings"

	"github.com/XiaoMi/soar/common"
)

// CurrentUser get current user with current_user() function
func (db *Connector) CurrentUser() (string, string, error) {
	res, err := db.Query("select current_user()")
	if err != nil {
		return "", "", err
	}
	if len(res.Rows) > 0 {
		cols := strings.Split(res.Rows[0].Str(0), "@")
		if len(cols) == 2 {
			user := strings.Trim(cols[0], "'")
			host := strings.Trim(cols[1], "'")
			if strings.Contains(user, "'") || strings.Contains(host, "'") {
				return "", "", errors.New("user or host contains irregular character")
			}
			return user, host, nil
		}
		return "", "", errors.New("user or host contains irregular character")
	}
	return "", "", errors.New("no privilege info")
}

// HasSelectPrivilege if user has select privilege
func (db *Connector) HasSelectPrivilege() bool {
	user, host, err := db.CurrentUser()
	if err != nil {
		common.Log.Error("User: %s, HasSelectPrivilege: %s", db.User, err.Error())
		return false
	}
	res, err := db.Query("select Select_priv from mysql.user where user='%s' and host='%s'", user, host)
	if err != nil {
		common.Log.Error("HasSelectPrivilege, DSN: %s, Error: %s", db.Addr, err.Error())
		return false
	}
	// Select_priv
	if len(res.Rows) > 0 {
		if res.Rows[0].Str(0) == "Y" {
			return true
		}
	}
	return false
}

// HasAllPrivilege if user has all privileges
func (db *Connector) HasAllPrivilege() bool {
	user, host, err := db.CurrentUser()
	if err != nil {
		common.Log.Error("User: %s, HasAllPrivilege: %s", db.User, err.Error())
		return false
	}

	// concat privilege columns
77
	res, err := db.Query("SELECT GROUP_CONCAT(COLUMN_NAME) from information_schema.COLUMNS where TABLE_SCHEMA='mysql' and TABLE_NAME='user' and COLUMN_NAME like '%%_priv'")
martianzhang's avatar
martianzhang 已提交
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
	if err != nil {
		common.Log.Error("HasAllPrivilege, DSN: %s, Error: %s", db.Addr, err.Error())
		return false
	}
	var priv string
	if len(res.Rows) > 0 {
		priv = res.Rows[0].Str(0)
	} else {
		common.Log.Error("HasAllPrivilege, DSN: %s, get privilege string error", db.Addr)
		return false
	}

	// get all privilege status
	res, err = db.Query("select concat("+priv+") from mysql.user where user='%s' and host='%s'", user, host)
	if err != nil {
		common.Log.Error("HasAllPrivilege, DSN: %s, Error: %s", db.Addr, err.Error())
		return false
	}

	// %_priv
	if len(res.Rows) > 0 {
		if strings.Replace(res.Rows[0].Str(0), "Y", "", -1) == "" {
			return true
		}
	}
	return false
}