提交 1798bd6e 编写于 作者: H hanxinke

atune: add input validation for analysis/define/train/update/upgrade command...

atune: add input validation for analysis/define/train/update/upgrade command and restrict collection/train command to be executed remotely
上级 c49ddcef
...@@ -31,11 +31,15 @@ class Script(Configurator): ...@@ -31,11 +31,15 @@ class Script(Configurator):
_module = "SCRIPT" _module = "SCRIPT"
_submod = "SCRIPT" _submod = "SCRIPT"
cmd_delimiter = "|" cmd_delimiter = "|"
scripts_path = "/usr/libexec/atuned/scripts"
def __init__(self, user=None): def __init__(self, user=None):
Configurator.__init__(self, user) Configurator.__init__(self, user)
def _set(self, key, value): def _set(self, key, value):
script_dir = os.path.dirname(key)
if self.scripts_path != os.path.realpath(script_dir):
raise SetConfigError("key:{} is invalid".format(key))
name = os.path.basename(key) name = os.path.basename(key)
script = "{}/set.sh".format(key) script = "{}/set.sh".format(key)
if not os.path.exists(script): if not os.path.exists(script):
...@@ -48,6 +52,9 @@ class Script(Configurator): ...@@ -48,6 +52,9 @@ class Script(Configurator):
return 0 return 0
def _get(self, key, value): def _get(self, key, value):
script_dir = os.path.dirname(key)
if self.scripts_path != os.path.realpath(script_dir):
raise GetConfigError("key:{} is invalid".format(key))
name = os.path.basename(key) name = os.path.basename(key)
script = "{}/get.sh".format(key) script = "{}/get.sh".format(key)
if not os.path.exists(script): if not os.path.exists(script):
......
...@@ -72,6 +72,9 @@ func profileAnalysis(ctx *cli.Context) error { ...@@ -72,6 +72,9 @@ func profileAnalysis(ctx *cli.Context) error {
} }
if ctx.NArg() == 1 { if ctx.NArg() == 1 {
appname = ctx.Args().Get(0) appname = ctx.Args().Get(0)
if !utils.IsInputStringValid(appname) {
return fmt.Errorf("input:%s is invalid", appname)
}
} }
c, err := client.NewClientFromContext(ctx) c, err := client.NewClientFromContext(ctx)
...@@ -81,6 +84,10 @@ func profileAnalysis(ctx *cli.Context) error { ...@@ -81,6 +84,10 @@ func profileAnalysis(ctx *cli.Context) error {
defer c.Close() defer c.Close()
modelFile := ctx.String("model") modelFile := ctx.String("model")
if modelFile != "" && !utils.IsInputStringValid(modelFile) {
return fmt.Errorf("input:%s is invalid", modelFile)
}
svc := PB.NewProfileMgrClient(c.Connection()) svc := PB.NewProfileMgrClient(c.Connection())
stream, _ := svc.Analysis(CTX.Background(), &PB.AnalysisMessage{Name: appname, Model: modelFile}) stream, _ := svc.Analysis(CTX.Background(), &PB.AnalysisMessage{Name: appname, Model: modelFile})
......
...@@ -64,6 +64,10 @@ func profileDefineCheck(ctx *cli.Context) error { ...@@ -64,6 +64,10 @@ func profileDefineCheck(ctx *cli.Context) error {
} }
file := ctx.Args().Get(2) file := ctx.Args().Get(2)
if !utils.IsInputStringValid(file) {
return fmt.Errorf("input:%s is invalid", file)
}
exist, err := utils.PathExist(file) exist, err := utils.PathExist(file)
if err != nil { if err != nil {
return err return err
...@@ -85,7 +89,14 @@ func profileDefined(ctx *cli.Context) error { ...@@ -85,7 +89,14 @@ func profileDefined(ctx *cli.Context) error {
return err return err
} }
workloadType := ctx.Args().Get(0) workloadType := ctx.Args().Get(0)
if !utils.IsInputStringValid(workloadType) {
return fmt.Errorf("input:%s is invalid", workloadType)
}
profileName := ctx.Args().Get(1) profileName := ctx.Args().Get(1)
if !utils.IsInputStringValid(profileName) {
return fmt.Errorf("input:%s is invalid", profileName)
}
data, err := ioutil.ReadFile(ctx.Args().Get(2)) data, err := ioutil.ReadFile(ctx.Args().Get(2))
if err != nil { if err != nil {
......
...@@ -77,12 +77,18 @@ func checkTrainCtx(ctx *cli.Context) error { ...@@ -77,12 +77,18 @@ func checkTrainCtx(ctx *cli.Context) error {
_ = cli.ShowCommandHelp(ctx, "train") _ = cli.ShowCommandHelp(ctx, "train")
return fmt.Errorf("error: data_path must be specified") return fmt.Errorf("error: data_path must be specified")
} }
if !utils.IsInputStringValid(dataPath) {
return fmt.Errorf("input:%s is invalid", dataPath)
}
outputPath := ctx.String("output_file") outputPath := ctx.String("output_file")
if outputPath == "" { if outputPath == "" {
_ = cli.ShowCommandHelp(ctx, "train") _ = cli.ShowCommandHelp(ctx, "train")
return fmt.Errorf("error: output_file must be specified") return fmt.Errorf("error: output_file must be specified")
} }
if !utils.IsInputStringValid(outputPath) {
return fmt.Errorf("input:%s is invalid", outputPath)
}
return nil return nil
} }
......
...@@ -61,6 +61,10 @@ func profileUpdateCheck(ctx *cli.Context) error { ...@@ -61,6 +61,10 @@ func profileUpdateCheck(ctx *cli.Context) error {
} }
file := ctx.Args().Get(2) file := ctx.Args().Get(2)
if !utils.IsInputStringValid(file) {
return fmt.Errorf("input:%s is invalid", file)
}
exist, err := utils.PathExist(file) exist, err := utils.PathExist(file)
if err != nil { if err != nil {
return err return err
...@@ -77,7 +81,14 @@ func profileUpdate(ctx *cli.Context) error { ...@@ -77,7 +81,14 @@ func profileUpdate(ctx *cli.Context) error {
return err return err
} }
workloadType := ctx.Args().Get(0) workloadType := ctx.Args().Get(0)
if !utils.IsInputStringValid(workloadType) {
return fmt.Errorf("input:%s is invalid", workloadType)
}
profileName := ctx.Args().Get(1) profileName := ctx.Args().Get(1)
if !utils.IsInputStringValid(profileName) {
return fmt.Errorf("input:%s is invalid", profileName)
}
data, err := ioutil.ReadFile(ctx.Args().Get(2)) data, err := ioutil.ReadFile(ctx.Args().Get(2))
if err != nil { if err != nil {
......
...@@ -60,6 +60,10 @@ func profileUpgrade(ctx *cli.Context) error { ...@@ -60,6 +60,10 @@ func profileUpgrade(ctx *cli.Context) error {
} }
dbPath := ctx.Args().Get(0) dbPath := ctx.Args().Get(0)
if !utils.IsInputStringValid(dbPath) {
return fmt.Errorf("input:%s is invalid", dbPath)
}
exist, err := utils.PathExist(dbPath) exist, err := utils.PathExist(dbPath)
if err != nil { if err != nil {
return err return err
......
...@@ -677,6 +677,14 @@ func (s *ProfileServer) ProfileRollback(profileInfo *PB.ProfileInfo, stream PB.P ...@@ -677,6 +677,14 @@ func (s *ProfileServer) ProfileRollback(profileInfo *PB.ProfileInfo, stream PB.P
Collection method call collection script to collect system data. Collection method call collection script to collect system data.
*/ */
func (s *ProfileServer) Collection(message *PB.CollectFlag, stream PB.ProfileMgr_CollectionServer) error { func (s *ProfileServer) Collection(message *PB.CollectFlag, stream PB.ProfileMgr_CollectionServer) error {
isLocalAddr, err := SVC.CheckRpcIsLocalAddr(stream.Context())
if err != nil {
return err
}
if !isLocalAddr {
return fmt.Errorf("the collection command can not be remotely operated")
}
if valid := utils.IsInputStringValid(message.GetWorkload()); !valid { if valid := utils.IsInputStringValid(message.GetWorkload()); !valid {
return fmt.Errorf("input:%s is invalid", message.GetWorkload()) return fmt.Errorf("input:%s is invalid", message.GetWorkload())
} }
...@@ -698,7 +706,7 @@ func (s *ProfileServer) Collection(message *PB.CollectFlag, stream PB.ProfileMgr ...@@ -698,7 +706,7 @@ func (s *ProfileServer) Collection(message *PB.CollectFlag, stream PB.ProfileMgr
} }
classApps := &sqlstore.GetClassApp{Class: message.GetType()} classApps := &sqlstore.GetClassApp{Class: message.GetType()}
err := sqlstore.GetClassApps(classApps) err = sqlstore.GetClassApps(classApps)
if err != nil { if err != nil {
return err return err
} }
...@@ -785,6 +793,14 @@ func (s *ProfileServer) Collection(message *PB.CollectFlag, stream PB.ProfileMgr ...@@ -785,6 +793,14 @@ func (s *ProfileServer) Collection(message *PB.CollectFlag, stream PB.ProfileMgr
Training method train the collected data to generate the model Training method train the collected data to generate the model
*/ */
func (s *ProfileServer) Training(message *PB.TrainMessage, stream PB.ProfileMgr_TrainingServer) error { func (s *ProfileServer) Training(message *PB.TrainMessage, stream PB.ProfileMgr_TrainingServer) error {
isLocalAddr, err := SVC.CheckRpcIsLocalAddr(stream.Context())
if err != nil {
return err
}
if !isLocalAddr {
return fmt.Errorf("the train command can not be remotely operated")
}
DataPath := message.GetDataPath() DataPath := message.GetDataPath()
OutputPath := message.GetOutputPath() OutputPath := message.GetOutputPath()
......
...@@ -20,7 +20,7 @@ from analysis.plugin.configurator.script.script import Script ...@@ -20,7 +20,7 @@ from analysis.plugin.configurator.script.script import Script
class TestScript: class TestScript:
""" test script""" """ test script"""
user = "UT" user = "UT"
path = "scripts/hugepage" path = "/usr/libexec/atuned/scripts/hugepage"
def test_get_script_with_hugepage(self): def test_get_script_with_hugepage(self):
"""test get script result with hugepage""" """test get script result with hugepage"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册