提交 675d4c4e 编写于 作者: C colynn

fix: casbin beego orm adapter

上级 3adf3e3c
......@@ -93,7 +93,6 @@ func (a *AuthController) Authenticate() {
http.Error(a.Ctx.ResponseWriter, err.Error(), http.StatusInternalServerError)
return
}
e, err := mycasbin.NewCasbin()
if err != nil {
log.Log.Error("add user role, new casbin instance error: %s", err.Error())
......
......@@ -39,7 +39,7 @@ func NewProvider() auth.Provider {
func (p *Provider) Authenticate(loginUser, password string) (*auth.ExternalAccount, error) {
userModel, err := dao.GetUser(loginUser)
if err != nil {
log.Log.Error("get user error: %v")
log.Log.Error("get user error: %v", err.Error())
return nil, fmt.Errorf("用户不存在或密码错误")
}
......
......@@ -18,6 +18,7 @@ package dao
import (
"fmt"
mycasbin "github.com/go-atomci/atomci/internal/middleware/casbin"
"github.com/go-atomci/atomci/internal/middleware/log"
"github.com/go-atomci/atomci/internal/models"
......
......@@ -23,7 +23,7 @@ import (
"github.com/go-atomci/atomci/utils/errors"
)
func init() {
func Init() {
// 注册/更新资源
initResource()
......
......@@ -27,7 +27,7 @@ import (
func Authorization(c *context.Context, username string) (bool, error) {
e, err := mycasbin.NewCasbin()
if err != nil {
log.Log.Error("casbin new occur error: %v", err)
log.Log.Error("casbin new occur error: %v", err.Error())
return false, err
}
urlPath := c.Request.URL.Path
......
......@@ -18,24 +18,32 @@ package mycasbin
import (
glog "log"
"sync"
"github.com/astaxie/beego"
"github.com/go-atomci/atomci/internal/middleware/log"
beegoormadapter "github.com/casbin/beego-orm-adapter/v2"
"github.com/casbin/casbin/v2"
"github.com/casbin/casbin/v2/model"
beegoormadapter "github.com/go-atomci/atomci/pkg/beego-orm-adapter"
_ "github.com/go-sql-driver/mysql"
)
var casbinObj *casbin.Enforcer
var casbinadapter *beegoormadapter.Adapter
var casbinadapterOnce sync.Once
var casbinErr error
// GetOrmer :set ormer singleton
func GetAdapter() (*beegoormadapter.Adapter, error) {
casbinadapterOnce.Do(func() {
casbinadapter, casbinErr = initAdapter()
})
return casbinadapter, casbinErr
}
// NewCasbin ..
func NewCasbin() (*casbin.Enforcer, error) {
if casbinObj == nil {
rbacModel, err := model.NewModelFromString(`
rbacModel, err := model.NewModelFromString(`
[request_definition]
r = sub, obj, act
......@@ -53,25 +61,31 @@ e = some(where (p.eft == allow))
# m = g(r.sub, p.sub) && r.obj == p.obj && (r.act == p.act || p.act == "*") || r.sub == "admin"
m = g(r.sub, p.sub) && keyMatch2(r.obj,p.obj) && (r.act == p.act || p.act == "*") || r.sub == "admin"
`)
if err != nil {
glog.Fatalf("error: model: %s", err)
}
dsn := beego.AppConfig.String("DB::url")
rbacPolicy, _ := beegoormadapter.NewAdapter("casbin", "mysql", dsn)
e, err := casbin.NewEnforcer(rbacModel, rbacPolicy)
if err != nil {
log.Log.Error("casbin new enforcer error: %s", err.Error())
return nil, err
}
if err := e.LoadPolicy(); err == nil {
casbinObj = e
return e, err
}
log.Log.Error("casbin rbac_model or policy init error, message: %v", err)
if err != nil {
glog.Fatalf("error: model: %s", err)
}
a, err := GetAdapter()
if err != nil {
return nil, err
}
e, err := casbin.NewEnforcer(rbacModel, a)
if err != nil {
log.Log.Error("casbin new enforcer error: %s", err.Error())
return nil, err
}
if err := e.LoadPolicy(); err != nil {
log.Log.Error("casbin rbac_model or policy init error, message: %v", err.Error())
return e, err
}
return e, nil
}
return casbinObj, nil
func initAdapter() (*beegoormadapter.Adapter, error) {
dsn := beego.AppConfig.String("DB::url")
a, err := beegoormadapter.NewAdapter("casbin", "mysql", dsn)
if err != nil {
log.Log.Error("beego orm adapter error: %s", err.Error())
return nil, err
}
return a, nil
}
......@@ -102,7 +102,7 @@ func sureCreateTable(ormer orm.Ormer) {
ormer.Raw(ddl).Exec()
}
func init() {
func Migrate() {
if len(os.Args) > 1 && os.Args[1][:5] == "-test" {
return
}
......
......@@ -150,7 +150,7 @@ func initOrm() {
}
// Init ...
func init() {
func InitDB() {
if len(os.Args) > 1 && os.Args[1][:5] == "-test" {
return
}
......
// Copyright 2017 The casbin Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package beegoormadapter
import (
"fmt"
"runtime"
"github.com/astaxie/beego/orm"
"github.com/casbin/casbin/v2/model"
"github.com/casbin/casbin/v2/persist"
)
type CasbinRule struct {
Id int
Ptype string
V0 string
V1 string
V2 string
V3 string
V4 string
V5 string
}
func init() {
orm.RegisterModel(new(CasbinRule))
}
const (
defaultTableName = "casbin_rule"
)
// Adapter represents the Xorm adapter for policy storage.
type Adapter struct {
driverName string
dataSourceName string
dataSourceAlias string
tableName string
dbSpecified bool
o orm.Ormer
}
// finalizer is the destructor for Adapter.
func finalizer(a *Adapter) {
}
// NewAdapter is the constructor for Adapter.
// dataSourceAlias: Database alias. ORM will use it to switch database.
// driverName: database driverName.
// dataSourceName: connection string
func NewAdapter(dataSourceAlias, driverName, dataSourceName string) (*Adapter, error) {
a := &Adapter{}
a.driverName = driverName
a.dataSourceName = dataSourceName
a.dataSourceAlias = dataSourceAlias
a.tableName = defaultTableName
err := a.open()
if err != nil {
return nil, err
}
// Call the destructor when the object is released.
runtime.SetFinalizer(a, finalizer)
return a, nil
}
func (a *Adapter) registerDataBase(aliasName, driverName, dataSource string, params ...int) error {
err := orm.RegisterDataBase(aliasName, driverName, dataSource, params...)
return err
}
func (a *Adapter) open() error {
var err error
err = a.registerDataBase(a.dataSourceAlias, a.driverName, a.dataSourceName)
if err != nil {
return err
}
a.o = orm.NewOrm()
err = a.o.Using(a.dataSourceAlias)
if err != nil {
return err
}
err = a.createTable()
if err != nil {
return err
}
return nil
}
func (a *Adapter) close() {
a.o = nil
}
func (a *Adapter) createTable() error {
return orm.RunSyncdb(a.dataSourceAlias, false, true)
}
func (a *Adapter) dropTable() error {
_, err := a.o.Raw(fmt.Sprintf("DROP TABLE IF EXISTS %v;", a.tableName)).Exec()
return err
}
func loadPolicyLine(line CasbinRule, model model.Model) {
lineText := line.Ptype
if line.V0 != "" {
lineText += ", " + line.V0
}
if line.V1 != "" {
lineText += ", " + line.V1
}
if line.V2 != "" {
lineText += ", " + line.V2
}
if line.V3 != "" {
lineText += ", " + line.V3
}
if line.V4 != "" {
lineText += ", " + line.V4
}
if line.V5 != "" {
lineText += ", " + line.V5
}
persist.LoadPolicyLine(lineText, model)
}
// LoadPolicy loads policy from database.
func (a *Adapter) LoadPolicy(model model.Model) error {
var lines []CasbinRule
_, err := a.o.QueryTable("casbin_rule").All(&lines)
if err != nil {
return err
}
for _, line := range lines {
loadPolicyLine(line, model)
}
return nil
}
func savePolicyLine(ptype string, rule []string) CasbinRule {
line := CasbinRule{}
line.Ptype = ptype
if len(rule) > 0 {
line.V0 = rule[0]
}
if len(rule) > 1 {
line.V1 = rule[1]
}
if len(rule) > 2 {
line.V2 = rule[2]
}
if len(rule) > 3 {
line.V3 = rule[3]
}
if len(rule) > 4 {
line.V4 = rule[4]
}
if len(rule) > 5 {
line.V5 = rule[5]
}
return line
}
// SavePolicy saves policy to database.
func (a *Adapter) SavePolicy(model model.Model) error {
err := a.dropTable()
if err != nil {
return err
}
err = a.createTable()
if err != nil {
return err
}
var lines []CasbinRule
for ptype, ast := range model["p"] {
for _, rule := range ast.Policy {
line := savePolicyLine(ptype, rule)
lines = append(lines, line)
}
}
for ptype, ast := range model["g"] {
for _, rule := range ast.Policy {
line := savePolicyLine(ptype, rule)
lines = append(lines, line)
}
}
_, err = a.o.InsertMulti(len(lines), lines)
return err
}
// AddPolicy adds a policy rule to the storage.
func (a *Adapter) AddPolicy(sec string, ptype string, rule []string) error {
line := savePolicyLine(ptype, rule)
_, err := a.o.Insert(&line)
return err
}
// RemovePolicy removes a policy rule from the storage.
func (a *Adapter) RemovePolicy(sec string, ptype string, rule []string) error {
line := savePolicyLine(ptype, rule)
_, err := a.o.Delete(&line, "ptype", "v0", "v1", "v2", "v3", "v4", "v5")
return err
}
// RemoveFilteredPolicy removes policy rules that match the filter from the storage.
func (a *Adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) error {
line := CasbinRule{}
line.Ptype = ptype
filter := []string{}
filter = append(filter, "ptype")
if fieldIndex <= 0 && 0 < fieldIndex+len(fieldValues) {
line.V0 = fieldValues[0-fieldIndex]
filter = append(filter, "v0")
}
if fieldIndex <= 1 && 1 < fieldIndex+len(fieldValues) {
line.V1 = fieldValues[1-fieldIndex]
filter = append(filter, "v1")
}
if fieldIndex <= 2 && 2 < fieldIndex+len(fieldValues) {
line.V2 = fieldValues[2-fieldIndex]
filter = append(filter, "v2")
}
if fieldIndex <= 3 && 3 < fieldIndex+len(fieldValues) {
line.V3 = fieldValues[3-fieldIndex]
filter = append(filter, "v3")
}
if fieldIndex <= 4 && 4 < fieldIndex+len(fieldValues) {
line.V4 = fieldValues[4-fieldIndex]
filter = append(filter, "v4")
}
if fieldIndex <= 5 && 5 < fieldIndex+len(fieldValues) {
line.V5 = fieldValues[5-fieldIndex]
filter = append(filter, "v5")
}
_, err := a.o.Delete(&line, filter...)
return err
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册