提交 f1ef21e1 编写于 作者: H HFO4

Feat: add local policy

上级 c1d2b933
package middleware
import (
"github.com/HFO4/cloudreve/bootstrap/constant"
"github.com/HFO4/cloudreve/pkg/cache"
"github.com/HFO4/cloudreve/pkg/hashid"
"github.com/gin-gonic/gin"
......@@ -14,6 +15,7 @@ func TestHashID(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
TestFunc := HashID(hashid.FolderID)
constant.HashIDTable = []int{0, 1, 2, 3, 4, 5, 6}
// 未给定ID对象,跳过
{
......
......@@ -48,7 +48,11 @@ func (task *Download) AfterFind() (err error) {
// BeforeSave Save下载任务前的钩子
func (task *Download) BeforeSave() (err error) {
return task.AfterFind()
// 解析状态
if task.Attrs != "" {
err = json.Unmarshal([]byte(task.Attrs), &task.StatusInfo)
}
return err
}
// Create 创建离线下载记录
......
package model
import (
"github.com/HFO4/cloudreve/pkg/cache"
"github.com/HFO4/cloudreve/pkg/conf"
"github.com/HFO4/cloudreve/pkg/util"
"github.com/jinzhu/gorm"
......@@ -24,6 +25,11 @@ func migration() {
util.Log().Info("开始进行数据库初始化...")
// 清除所有缓存
if instance, ok := cache.Store.(*cache.RedisStore); ok {
instance.DeleteAll()
}
// 自动迁移模式
if conf.DatabaseConfig.Type == "mysql" {
DB = DB.Set("gorm:table_options", "ENGINE=InnoDB")
......@@ -54,9 +60,7 @@ func addDefaultPolicy() {
defaultPolicy := Policy{
Name: "默认存储策略",
Type: "local",
Server: "/api/v3/file/upload",
BaseURL: "http://cloudreve.org/public/uploads/",
MaxSize: 10 * 1024 * 1024 * 1024,
MaxSize: 0,
AutoRename: true,
DirNameRule: "uploads/{uid}/{path}",
FileNameRule: "{uid}_{randomkey8}_{originname}",
......@@ -137,7 +141,7 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti
{Name: "shopid", Value: ``, Type: "payment"},
{Name: "hot_share_num", Value: `10`, Type: "share"},
{Name: "group_sell_data", Value: `[]`, Type: "group_sell"},
{Name: "gravatar_server", Value: `https://gravatar.loli.net/`, Type: "avatar"},
{Name: "gravatar_server", Value: `https://www.gravatar.com/`, Type: "avatar"},
{Name: "defaultTheme", Value: `#3f51b5`, Type: "basic"},
{Name: "themes", Value: `{"#3f51b5":{"palette":{"primary":{"main":"#3f51b5"},"secondary":{"main":"#f50057"}}},"#2196f3":{"palette":{"primary":{"main":"#2196f3"},"secondary":{"main":"#FFC107"}}},"#673AB7":{"palette":{"primary":{"main":"#673AB7"},"secondary":{"main":"#2196F3"}}},"#E91E63":{"palette":{"primary":{"main":"#E91E63"},"secondary":{"main":"#42A5F5","contrastText":"#fff"}}},"#FF5722":{"palette":{"primary":{"main":"#FF5722"},"secondary":{"main":"#3F51B5"}}},"#FFC107":{"palette":{"primary":{"main":"#FFC107"},"secondary":{"main":"#26C6DA"}}},"#8BC34A":{"palette":{"primary":{"main":"#8BC34A","contrastText":"#fff"},"secondary":{"main":"#FF8A65","contrastText":"#fff"}}},"#009688":{"palette":{"primary":{"main":"#009688"},"secondary":{"main":"#4DD0E1","contrastText":"#fff"}}},"#607D8B":{"palette":{"primary":{"main":"#607D8B"},"secondary":{"main":"#F06292"}}},"#795548":{"palette":{"primary":{"main":"#795548"},"secondary":{"main":"#4CAF50","contrastText":"#fff"}}}}`, Type: "basic"},
{Name: "aria2_token", Value: ``, Type: "aria2"},
......@@ -169,9 +173,9 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti
{Name: "captcha_ComplexOfNoiseText", Value: "0", Type: "captcha"},
{Name: "captcha_ComplexOfNoiseDot", Value: "0", Type: "captcha"},
{Name: "captcha_IsShowHollowLine", Value: "0", Type: "captcha"},
{Name: "captcha_IsShowNoiseDot", Value: "0", Type: "captcha"},
{Name: "captcha_IsShowNoiseDot", Value: "1", Type: "captcha"},
{Name: "captcha_IsShowNoiseText", Value: "0", Type: "captcha"},
{Name: "captcha_IsShowSlimeLine", Value: "0", Type: "captcha"},
{Name: "captcha_IsShowSlimeLine", Value: "1", Type: "captcha"},
{Name: "captcha_IsShowSineLine", Value: "0", Type: "captcha"},
{Name: "captcha_CaptchaLen", Value: "6", Type: "captcha"},
{Name: "thumb_width", Value: "400", Type: "thumb"},
......
......@@ -102,13 +102,20 @@ func (policy *Policy) SerializeOptions() (err error) {
func (policy *Policy) GeneratePath(uid uint, origin string) string {
dirRule := policy.DirNameRule
replaceTable := map[string]string{
"{randomkey16}": util.RandStringRunes(16),
"{randomkey8}": util.RandStringRunes(8),
"{timestamp}": strconv.FormatInt(time.Now().Unix(), 10),
"{uid}": strconv.Itoa(int(uid)),
"{datetime}": time.Now().Format("20060102150405"),
"{date}": time.Now().Format("20060102"),
"{path}": origin + "/",
"{randomkey16}": util.RandStringRunes(16),
"{randomkey8}": util.RandStringRunes(8),
"{timestamp}": strconv.FormatInt(time.Now().Unix(), 10),
"{timestamp_nano}": strconv.FormatInt(time.Now().UnixNano(), 10),
"{uid}": strconv.Itoa(int(uid)),
"{datetime}": time.Now().Format("20060102150405"),
"{date}": time.Now().Format("20060102"),
"{year}": time.Now().Format("2006"),
"{month}": time.Now().Format("01"),
"{day}": time.Now().Format("02"),
"{hour}": time.Now().Format("15"),
"{minute}": time.Now().Format("04"),
"{second}": time.Now().Format("05"),
"{path}": origin + "/",
}
dirRule = util.Replace(replaceTable, dirRule)
return path.Clean(dirRule)
......@@ -124,12 +131,19 @@ func (policy *Policy) GenerateFileName(uid uint, origin string) string {
fileRule := policy.FileNameRule
replaceTable := map[string]string{
"{randomkey16}": util.RandStringRunes(16),
"{randomkey8}": util.RandStringRunes(8),
"{timestamp}": strconv.FormatInt(time.Now().Unix(), 10),
"{uid}": strconv.Itoa(int(uid)),
"{datetime}": time.Now().Format("20060102150405"),
"{date}": time.Now().Format("20060102"),
"{randomkey16}": util.RandStringRunes(16),
"{randomkey8}": util.RandStringRunes(8),
"{timestamp}": strconv.FormatInt(time.Now().Unix(), 10),
"{timestamp_nano}": strconv.FormatInt(time.Now().UnixNano(), 10),
"{uid}": strconv.Itoa(int(uid)),
"{datetime}": time.Now().Format("20060102150405"),
"{date}": time.Now().Format("20060102"),
"{year}": time.Now().Format("2006"),
"{month}": time.Now().Format("01"),
"{day}": time.Now().Format("02"),
"{hour}": time.Now().Format("15"),
"{minute}": time.Now().Format("04"),
"{second}": time.Now().Format("05"),
}
replaceTable["{originname}"] = policy.getOriginNameRule(origin)
......
......@@ -93,8 +93,10 @@ func Init(isReload bool) {
// 关闭上个初始连接
if previousClient, ok := Instance.(*RPCService); ok {
util.Log().Debug("关闭上个 aria2 连接")
previousClient.caller.Close()
if previousClient.Caller != nil {
util.Log().Debug("关闭上个 aria2 连接")
previousClient.Caller.Close()
}
}
options := model.GetSettingByNames("aria2_rpcurl", "aria2_token", "aria2_options")
......
......@@ -14,7 +14,7 @@ import (
// RPCService 通过RPC服务的Aria2任务管理器
type RPCService struct {
options *clientOptions
caller rpc.Client
Caller rpc.Client
}
type clientOptions struct {
......@@ -24,8 +24,8 @@ type clientOptions struct {
// Init 初始化
func (client *RPCService) Init(server, secret string, timeout int, options map[string]interface{}) error {
// 客户端已存在,则关闭先前连接
if client.caller != nil {
client.caller.Close()
if client.Caller != nil {
client.Caller.Close()
}
client.options = &clientOptions{
......@@ -33,18 +33,18 @@ func (client *RPCService) Init(server, secret string, timeout int, options map[s
}
caller, err := rpc.New(context.Background(), server, secret, time.Duration(timeout)*time.Second,
EventNotifier)
client.caller = caller
client.Caller = caller
return err
}
// Status 查询下载状态
func (client *RPCService) Status(task *model.Download) (rpc.StatusInfo, error) {
res, err := client.caller.TellStatus(task.GID)
res, err := client.Caller.TellStatus(task.GID)
if err != nil {
// 失败后重试
util.Log().Debug("无法获取离线下载状态,%s,10秒钟后重试", err)
time.Sleep(time.Duration(10) * time.Second)
res, err = client.caller.TellStatus(task.GID)
res, err = client.Caller.TellStatus(task.GID)
}
return res, err
......@@ -53,7 +53,7 @@ func (client *RPCService) Status(task *model.Download) (rpc.StatusInfo, error) {
// Cancel 取消下载
func (client *RPCService) Cancel(task *model.Download) error {
// 取消下载任务
_, err := client.caller.Remove(task.GID)
_, err := client.Caller.Remove(task.GID)
if err != nil {
util.Log().Warning("无法取消离线下载任务[%s], %s", task.GID, err)
}
......@@ -79,7 +79,7 @@ func (client *RPCService) Select(task *model.Download, files []int) error {
for i := 0; i < len(files); i++ {
selected[i] = strconv.Itoa(files[i])
}
_, err := client.caller.ChangeOption(task.GID, map[string]interface{}{"select-file": strings.Join(selected, ",")})
_, err := client.Caller.ChangeOption(task.GID, map[string]interface{}{"select-file": strings.Join(selected, ",")})
return err
}
......@@ -103,7 +103,7 @@ func (client *RPCService) CreateTask(task *model.Download, groupOptions map[stri
options[k] = v
}
gid, err := client.caller.AddURI(task.Source, options)
gid, err := client.Caller.AddURI(task.Source, options)
if err != nil || gid == "" {
return err
}
......
......@@ -20,7 +20,7 @@ type InstanceMock struct {
testMock.Mock
}
func (m InstanceMock) CreateTask(task *model.Download, options []interface{}) error {
func (m InstanceMock) CreateTask(task *model.Download, options map[string]interface{}) error {
args := m.Called(task, options)
return args.Error(0)
}
......@@ -307,13 +307,16 @@ func TestMonitor_Complete(t *testing.T) {
}
cache.Set("setting_max_worker_num", "1", 0)
mock.ExpectQuery("SELECT(.+)tasks").WillReturnRows(sqlmock.NewRows([]string{"id"}))
task.Init()
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
mock.ExpectQuery("SELECT(.+)users").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
mock.ExpectQuery("SELECT(.+)policies").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectExec("INSERT(.+)tasks").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectExec("UPDATE(.+)downloads").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
asserts.True(monitor.Complete(rpc.StatusInfo{}))
asserts.NoError(mock.ExpectationsWereMet())
......
......@@ -201,3 +201,16 @@ func (store *RedisStore) Delete(keys []string, prefix string) error {
}
return nil
}
// DeleteAll 批量所有键
func (store *RedisStore) DeleteAll() error {
rc := store.pool.Get()
defer rc.Close()
if rc.Err() != nil {
return rc.Err()
}
_, err := rc.Do("FLUSHDB")
return err
}
......@@ -124,6 +124,15 @@ func (handler Driver) Source(
return "", errors.New("无法获取文件记录上下文")
}
// 是否启用了CDN
if handler.Policy.BaseURL != "" {
cdnURL, err := url.Parse(handler.Policy.BaseURL)
if err != nil {
return "", err
}
baseURL = *cdnURL
}
var (
signedURI *url.URL
err error
......
......@@ -117,3 +117,47 @@ func AdminDeleteRedeem(c *gin.Context) {
c.JSON(200, ErrorResponse(err))
}
}
// AdminTestAria2 测试aria2连接
func AdminTestAria2(c *gin.Context) {
var service admin.Aria2TestService
if err := c.ShouldBindJSON(&service); err == nil {
res := service.Test()
c.JSON(200, res)
} else {
c.JSON(200, ErrorResponse(err))
}
}
// AdminListPolicy 列出存储策略
func AdminListPolicy(c *gin.Context) {
var service admin.AdminListService
if err := c.ShouldBindJSON(&service); err == nil {
res := service.Policies()
c.JSON(200, res)
} else {
c.JSON(200, ErrorResponse(err))
}
}
// AdminTestPath 测试本地路径可用性
func AdminTestPath(c *gin.Context) {
var service admin.PathTestService
if err := c.ShouldBindJSON(&service); err == nil {
res := service.Test()
c.JSON(200, res)
} else {
c.JSON(200, ErrorResponse(err))
}
}
// AdminAddPolicy 新建存储策略
func AdminAddPolicy(c *gin.Context) {
var service admin.AddPolicyService
if err := c.ShouldBindJSON(&service); err == nil {
res := service.Add()
c.JSON(200, res)
} else {
c.JSON(200, ErrorResponse(err))
}
}
......@@ -315,6 +315,23 @@ func InitMasterRouter() *gin.Engine {
redeem.DELETE(":id", controllers.AdminDeleteRedeem)
}
// 离线下载相关
aria2 := admin.Group("aria2")
{
// 测试连接配置
aria2.POST("test", controllers.AdminTestAria2)
}
policy := admin.Group("policy")
{
// 列出存储策略
policy.POST("list", controllers.AdminListPolicy)
// 测试本地路径可用性
policy.POST("test/path", controllers.AdminTestPath)
// 创建存储策略
policy.POST("", controllers.AdminAddPolicy)
}
}
// 用户
......
package admin
import (
"github.com/HFO4/cloudreve/pkg/aria2"
"github.com/HFO4/cloudreve/pkg/serializer"
"net/url"
)
// Aria2TestService aria2连接测试服务
type Aria2TestService struct {
Server string `json:"server" binding:"required"`
Token string `json:"token"`
}
// Test 测试aria2连接
func (service *Aria2TestService) Test() serializer.Response {
testRPC := aria2.RPCService{}
// 解析RPC服务地址
server, err := url.Parse(service.Server)
if err != nil {
return serializer.ParamErr("无法解析 aria2 RPC 服务地址, "+err.Error(), nil)
}
server.Path = "/jsonrpc"
if err := testRPC.Init(server.String(), service.Token, 5, map[string]interface{}{}); err != nil {
return serializer.ParamErr("无法初始化连接, "+err.Error(), nil)
}
defer testRPC.Caller.Close()
info, err := testRPC.Caller.GetVersion()
if err != nil {
return serializer.ParamErr("无法请求 RPC 服务, "+err.Error(), nil)
}
if info.Version == "" {
return serializer.ParamErr("RPC 服务返回非预期响应", nil)
}
return serializer.Response{Data: info.Version}
}
package admin
import (
"fmt"
model "github.com/HFO4/cloudreve/models"
"github.com/HFO4/cloudreve/pkg/serializer"
"github.com/HFO4/cloudreve/pkg/util"
"os"
"path/filepath"
)
// PathTestService 本地路径测试服务
type PathTestService struct {
Path string `json:"path" binding:"required"`
}
// AddPolicyService 存储策略添加服务
type AddPolicyService struct {
Policy model.Policy `json:"policy" binding:"required"`
}
// Add 添加存储策略
func (service *AddPolicyService) Add() serializer.Response {
if err := model.DB.Create(&service.Policy).Error; err != nil {
return serializer.ParamErr("存储策略添加失败", err)
}
return serializer.Response{}
}
// Test 测试本地路径
func (service *PathTestService) Test() serializer.Response {
policy := model.Policy{DirNameRule: service.Path}
path := policy.GeneratePath(1, "/My File")
path = filepath.Join(path, "test.txt")
file, err := util.CreatNestedFile(path)
if err != nil {
return serializer.ParamErr(fmt.Sprintf("无法创建路径 %s , %s", path, err.Error()), nil)
}
file.Close()
os.Remove(path)
return serializer.Response{}
}
// Policies 列出存储策略
func (service *AdminListService) Policies() serializer.Response {
var res []model.Policy
total := 0
tx := model.DB.Model(&model.Policy{})
if service.OrderBy != "" {
tx = tx.Order(service.OrderBy)
}
for k, v := range service.Conditions {
tx = tx.Where("? = ?", k, v)
}
// 计算总数用于分页
tx.Count(&total)
// 查询记录
tx.Limit(service.PageSize).Offset((service.Page - 1) * service.PageSize).Find(&res)
// 统计每个策略的文件使用
statics := make(map[uint][2]int, len(res))
for i := 0; i < len(res); i++ {
total := [2]int{}
row := model.DB.Model(&model.File{}).Where("policy_id = ?", res[i].ID).
Select("count(id),sum(size)").Row()
row.Scan(&total[0], &total[1])
statics[res[i].ID] = total
}
return serializer.Response{Data: map[string]interface{}{
"total": total,
"items": res,
"statics": statics,
}}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册